# Salt Crystal Purity Classification - MobileNetV2

This notebook trains a **MobileNetV2** model to classify salt crystals as **pure** or **impure** using transfer learning.

## Model Architecture: MobileNetV2

MobileNetV2 is a lightweight convolutional neural network designed for mobile and embedded vision applications.

### Key Features:
- **Inverted Residual Blocks**: Expand-depthwise-project structure
- **Linear Bottlenecks**: Prevents information loss in narrow layers
- **Depthwise Separable Convolutions**: Reduces computation by 8-9x
- **Parameters**: ~3.4 million (vs 25M for ResNet50)
- **Input Size**: 224x224x3

### Comparison Context
This model is trained for **image classification** (one label per image) to compare against **YOLOv8 object detection** which provides:
- Bounding box localization
- Multiple object detection per image
- Per-crystal confidence scores

## Before Starting
1. Go to **Runtime > Change runtime type**
2. Select **T4 GPU** (or any available GPU)
3. Click **Save**

---
## Step 1: Check GPU & Install Dependencies

In [None]:
# Check GPU availability
!nvidia-smi

import tensorflow as tf
print(f"\nTensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

In [None]:
# Install/upgrade required packages
!pip install -q pillow scikit-learn matplotlib seaborn

import os
import json
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix

print("Dependencies installed successfully!")

---
## Step 2: Mount Google Drive & Load Dataset

The dataset is stored in Google Drive at `MyDrive/salt-crystal/data.zip` (same as YOLOv8 training).

In [None]:
from google.colab import drive
import zipfile

# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')

# Path to your dataset in Google Drive
zip_path = '/content/drive/MyDrive/salt-crystal/data.zip'

# Verify the file exists
if os.path.exists(zip_path):
    print(f"\nDataset found: {zip_path}")
else:
    print(f"\nERROR: Dataset not found at {zip_path}")
    print("Please check the path and try again.")

# Extract the dataset
print("\nExtracting dataset...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall('/content/dataset_yolo')

print("Dataset extracted successfully!")
print("\nExtracted contents:")
!ls -la /content/dataset_yolo

---
## Step 3: Convert YOLO Detection Format to Classification Format

### Problem:
- YOLO format: Images with bounding box annotations (multiple crystals per image)
- Classification format: Individual crystal images in class folders

### Solution:
Crop each annotated crystal from the source images and organize into `pure/` and `impure/` folders.

In [None]:
import os
import shutil
from PIL import Image
import random

# Source paths (YOLO format)
SOURCE_IMAGES = '/content/dataset_yolo/images'
SOURCE_LABELS = '/content/dataset_yolo/labels'
CLASSES_FILE = '/content/dataset_yolo/classes.txt'

# Target paths (Classification format)
TARGET_DIR = '/content/dataset_classification'

# Read class names
with open(CLASSES_FILE, 'r') as f:
    classes = [line.strip() for line in f.readlines() if line.strip()]

print(f"Classes: {classes}")
print(f"Class 0: {classes[0]}")
print(f"Class 1: {classes[1]}")

# Create target directories
for split in ['train', 'valid']:
    for cls in classes:
        os.makedirs(f'{TARGET_DIR}/{split}/{cls}', exist_ok=True)

print(f"\nTarget directory structure created at {TARGET_DIR}")

In [None]:
def crop_crystals_from_yolo(image_dir, label_dir, output_dir, classes, target_size=224):
    """
    Crop individual crystals from images using YOLO annotations.
    
    YOLO format: class_id x_center y_center width height (normalized 0-1)
    """
    stats = {cls: 0 for cls in classes}
    
    image_files = [f for f in os.listdir(image_dir) 
                   if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
    
    for img_file in image_files:
        # Load image
        img_path = os.path.join(image_dir, img_file)
        img = Image.open(img_path).convert('RGB')
        img_width, img_height = img.size
        
        # Find corresponding label file
        label_file = os.path.splitext(img_file)[0] + '.txt'
        label_path = os.path.join(label_dir, label_file)
        
        if not os.path.exists(label_path):
            continue
        
        # Read annotations
        with open(label_path, 'r') as f:
            lines = f.readlines()
        
        for idx, line in enumerate(lines):
            parts = line.strip().split()
            if len(parts) < 5:
                continue
            
            class_id = int(parts[0])
            x_center = float(parts[1]) * img_width
            y_center = float(parts[2]) * img_height
            width = float(parts[3]) * img_width
            height = float(parts[4]) * img_height
            
            # Calculate bounding box coordinates
            x1 = max(0, int(x_center - width / 2))
            y1 = max(0, int(y_center - height / 2))
            x2 = min(img_width, int(x_center + width / 2))
            y2 = min(img_height, int(y_center + height / 2))
            
            # Skip very small crops
            if (x2 - x1) < 10 or (y2 - y1) < 10:
                continue
            
            # Crop and resize
            crop = img.crop((x1, y1, x2, y2))
            crop = crop.resize((target_size, target_size), Image.Resampling.LANCZOS)
            
            # Save cropped crystal
            class_name = classes[class_id]
            output_filename = f"{os.path.splitext(img_file)[0]}_crop{idx}.jpg"
            output_path = os.path.join(output_dir, class_name, output_filename)
            crop.save(output_path, 'JPEG', quality=95)
            
            stats[class_name] += 1
    
    return stats

print("Crystal cropping function defined.")

In [None]:
# First, let's do a train/valid split of the source images
image_files = [f for f in os.listdir(SOURCE_IMAGES) 
               if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]

random.seed(42)  # Same seed as YOLOv8 for consistency
random.shuffle(image_files)

split_idx = int(len(image_files) * 0.9)
train_files = set(image_files[:split_idx])
valid_files = set(image_files[split_idx:])

print(f"Total images: {len(image_files)}")
print(f"Train images: {len(train_files)}")
print(f"Valid images: {len(valid_files)}")

# Create temporary directories for split
os.makedirs('/content/temp_train/images', exist_ok=True)
os.makedirs('/content/temp_train/labels', exist_ok=True)
os.makedirs('/content/temp_valid/images', exist_ok=True)
os.makedirs('/content/temp_valid/labels', exist_ok=True)

# Copy files to temp directories
for img_file in train_files:
    shutil.copy(os.path.join(SOURCE_IMAGES, img_file), '/content/temp_train/images/')
    label_file = os.path.splitext(img_file)[0] + '.txt'
    label_path = os.path.join(SOURCE_LABELS, label_file)
    if os.path.exists(label_path):
        shutil.copy(label_path, '/content/temp_train/labels/')

for img_file in valid_files:
    shutil.copy(os.path.join(SOURCE_IMAGES, img_file), '/content/temp_valid/images/')
    label_file = os.path.splitext(img_file)[0] + '.txt'
    label_path = os.path.join(SOURCE_LABELS, label_file)
    if os.path.exists(label_path):
        shutil.copy(label_path, '/content/temp_valid/labels/')

print("\nFiles split into train/valid directories.")

In [None]:
# Crop crystals for training set
print("Cropping training crystals...")
train_stats = crop_crystals_from_yolo(
    '/content/temp_train/images',
    '/content/temp_train/labels',
    f'{TARGET_DIR}/train',
    classes
)
print(f"Training set: {train_stats}")

# Crop crystals for validation set
print("\nCropping validation crystals...")
valid_stats = crop_crystals_from_yolo(
    '/content/temp_valid/images',
    '/content/temp_valid/labels',
    f'{TARGET_DIR}/valid',
    classes
)
print(f"Validation set: {valid_stats}")

# Clean up temp directories
shutil.rmtree('/content/temp_train')
shutil.rmtree('/content/temp_valid')

print("\n" + "="*50)
print("DATASET PREPARATION COMPLETE")
print("="*50)
print(f"\nTraining samples:")
for cls in classes:
    count = len(os.listdir(f'{TARGET_DIR}/train/{cls}'))
    print(f"  {cls}: {count} images")

print(f"\nValidation samples:")
for cls in classes:
    count = len(os.listdir(f'{TARGET_DIR}/valid/{cls}'))
    print(f"  {cls}: {count} images")

In [None]:
# Visualize some cropped samples
import matplotlib.pyplot as plt
from PIL import Image
import os

fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for row, cls in enumerate(classes):
    cls_dir = f'{TARGET_DIR}/train/{cls}'
    samples = os.listdir(cls_dir)[:5]
    
    for col, sample in enumerate(samples):
        img = Image.open(os.path.join(cls_dir, sample))
        axes[row, col].imshow(img)
        axes[row, col].set_title(f'{cls}')
        axes[row, col].axis('off')

plt.suptitle('Sample Cropped Crystals (224x224)', fontsize=14)
plt.tight_layout()
plt.show()

---
## Step 4: Create Data Generators with Augmentation

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Training data augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest'
)

# Validation data (no augmentation, only rescaling)
valid_datagen = ImageDataGenerator(rescale=1./255)

# Create generators
IMG_SIZE = 224
BATCH_SIZE = 32

train_generator = train_datagen.flow_from_directory(
    f'{TARGET_DIR}/train',
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='binary',
    shuffle=True,
    seed=42
)

valid_generator = valid_datagen.flow_from_directory(
    f'{TARGET_DIR}/valid',
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='binary',
    shuffle=False
)

print(f"\nClass indices: {train_generator.class_indices}")
print(f"Training samples: {train_generator.samples}")
print(f"Validation samples: {valid_generator.samples}")

---
## Step 5: Build MobileNetV2 Model with Transfer Learning

In [None]:
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.optimizers import Adam

# Load MobileNetV2 with pretrained ImageNet weights (without top layer)
base_model = MobileNetV2(
    weights='imagenet',
    include_top=False,
    input_shape=(IMG_SIZE, IMG_SIZE, 3)
)

# Freeze base model layers
base_model.trainable = False

# Build the complete model
model = Sequential([
    base_model,
    GlobalAveragePooling2D(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(1, activation='sigmoid')  # Binary classification
])

# Compile the model
model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

print("\nMobileNetV2 Model Architecture:")
print("="*50)
model.summary()

# Count parameters
total_params = model.count_params()
trainable_params = sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])
non_trainable_params = total_params - trainable_params

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {non_trainable_params:,}")

---
## Step 6: Training Phase 1 - Frozen Base (Feature Extraction)

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

# Callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    ModelCheckpoint(
        '/content/mobilenet_best.keras',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=5,
        min_lr=1e-7,
        verbose=1
    )
]

print("Training Phase 1: Feature Extraction (Frozen Base)")
print("="*50)

# Train with frozen base
history_frozen = model.fit(
    train_generator,
    epochs=20,
    validation_data=valid_generator,
    callbacks=callbacks,
    verbose=1
)

print("\nPhase 1 Training Complete!")

---
## Step 7: Training Phase 2 - Fine-Tuning

In [None]:
# Unfreeze the top layers of the base model for fine-tuning
base_model.trainable = True

# Freeze all layers except the last 30
for layer in base_model.layers[:-30]:
    layer.trainable = False

# Recompile with lower learning rate for fine-tuning
model.compile(
    optimizer=Adam(learning_rate=1e-5),  # Lower LR for fine-tuning
    loss='binary_crossentropy',
    metrics=['accuracy']
)

# Count updated parameters
trainable_params = sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])
print(f"Trainable parameters after unfreezing: {trainable_params:,}")

print("\nTraining Phase 2: Fine-Tuning")
print("="*50)

# Continue training with unfrozen layers
history_finetuned = model.fit(
    train_generator,
    epochs=20,
    validation_data=valid_generator,
    callbacks=callbacks,
    verbose=1
)

print("\nPhase 2 Fine-Tuning Complete!")

---
## Step 8: Plot Training History

In [None]:
import matplotlib.pyplot as plt

# Combine histories
acc = history_frozen.history['accuracy'] + history_finetuned.history['accuracy']
val_acc = history_frozen.history['val_accuracy'] + history_finetuned.history['val_accuracy']
loss = history_frozen.history['loss'] + history_finetuned.history['loss']
val_loss = history_frozen.history['val_loss'] + history_finetuned.history['val_loss']

epochs_range = range(1, len(acc) + 1)
phase1_end = len(history_frozen.history['accuracy'])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy plot
ax1.plot(epochs_range, acc, 'b-', label='Training Accuracy')
ax1.plot(epochs_range, val_acc, 'r-', label='Validation Accuracy')
ax1.axvline(x=phase1_end, color='g', linestyle='--', label='Fine-tuning Start')
ax1.set_title('MobileNetV2 - Training and Validation Accuracy')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Loss plot
ax2.plot(epochs_range, loss, 'b-', label='Training Loss')
ax2.plot(epochs_range, val_loss, 'r-', label='Validation Loss')
ax2.axvline(x=phase1_end, color='g', linestyle='--', label='Fine-tuning Start')
ax2.set_title('MobileNetV2 - Training and Validation Loss')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('/content/mobilenet_training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFinal Training Accuracy: {acc[-1]:.4f}")
print(f"Final Validation Accuracy: {val_acc[-1]:.4f}")

---
## Step 9: Evaluate Model Performance

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

# Load best model
model = tf.keras.models.load_model('/content/mobilenet_best.keras')

# Get predictions
valid_generator.reset()
predictions = model.predict(valid_generator, verbose=1)
predicted_classes = (predictions > 0.5).astype(int).flatten()
true_classes = valid_generator.classes

# Class names
class_names = list(valid_generator.class_indices.keys())

print("\n" + "="*50)
print("MOBILENETV2 CLASSIFICATION REPORT")
print("="*50)
print(classification_report(true_classes, predicted_classes, target_names=class_names))

# Calculate metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(true_classes, predicted_classes)
precision = precision_score(true_classes, predicted_classes)
recall = recall_score(true_classes, predicted_classes)
f1 = f1_score(true_classes, predicted_classes)

print("\nSummary Metrics:")
print(f"  Accuracy:  {accuracy:.4f}")
print(f"  Precision: {precision:.4f}")
print(f"  Recall:    {recall:.4f}")
print(f"  F1-Score:  {f1:.4f}")

In [None]:
# Plot confusion matrix
import seaborn as sns

cm = confusion_matrix(true_classes, predicted_classes)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.title('MobileNetV2 - Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.savefig('/content/mobilenet_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

---
## Step 10: Measure Inference Speed

In [None]:
import time
import numpy as np

# Create a batch of test images
test_images = np.random.rand(100, IMG_SIZE, IMG_SIZE, 3).astype(np.float32)

# Warm-up run
_ = model.predict(test_images[:10], verbose=0)

# Measure inference time
num_runs = 5
times = []

for _ in range(num_runs):
    start_time = time.time()
    _ = model.predict(test_images, verbose=0)
    elapsed = time.time() - start_time
    times.append(elapsed)

avg_time = np.mean(times)
time_per_image = (avg_time / 100) * 1000  # ms per image

print("\n" + "="*50)
print("INFERENCE SPEED (MobileNetV2)")
print("="*50)
print(f"Average batch time (100 images): {avg_time:.4f} seconds")
print(f"Time per image: {time_per_image:.2f} ms")
print(f"Theoretical FPS: {1000/time_per_image:.1f}")

# Get model file size
model.save('/content/mobilenet_final.keras')
model_size_mb = os.path.getsize('/content/mobilenet_final.keras') / (1024 * 1024)
print(f"\nModel file size: {model_size_mb:.2f} MB")

---
## Step 11: Save Results for Comparison

In [None]:
import json

# Compile all results
mobilenet_results = {
    'model_name': 'MobileNetV2',
    'model_type': 'classification',
    'input_size': IMG_SIZE,
    'metrics': {
        'accuracy': float(accuracy),
        'precision': float(precision),
        'recall': float(recall),
        'f1_score': float(f1)
    },
    'performance': {
        'inference_time_ms': float(time_per_image),
        'theoretical_fps': float(1000/time_per_image),
        'model_size_mb': float(model_size_mb)
    },
    'architecture': {
        'total_parameters': int(model.count_params()),
        'trainable_parameters': int(trainable_params),
        'base_model': 'MobileNetV2 (ImageNet pretrained)',
        'custom_layers': ['GlobalAveragePooling2D', 'Dense(128)', 'Dropout(0.5)', 'Dense(1, sigmoid)']
    },
    'training': {
        'epochs_phase1': len(history_frozen.history['accuracy']),
        'epochs_phase2': len(history_finetuned.history['accuracy']),
        'final_train_accuracy': float(acc[-1]),
        'final_val_accuracy': float(val_acc[-1])
    },
    'dataset': {
        'train_samples': train_generator.samples,
        'valid_samples': valid_generator.samples,
        'classes': class_names
    },
    'limitations': [
        'Cannot localize crystals (no bounding boxes)',
        'Cannot count individual crystals per image',
        'Cannot provide per-crystal confidence scores',
        'Cannot enable ROI-based filtering',
        'Cannot calculate whiteness per crystal'
    ]
}

# Save to JSON
with open('/content/mobilenet_results.json', 'w') as f:
    json.dump(mobilenet_results, f, indent=2)

print("Results saved to mobilenet_results.json")
print("\n" + "="*50)
print(json.dumps(mobilenet_results, indent=2))

---
## Step 12: Limitations Analysis - Why Classification is Not Enough

### What MobileNetV2 Classification Provides:
- **Single label per image**: "pure" or "impure"
- **Overall confidence score**: 0-100%

### What MobileNetV2 Classification CANNOT Provide:

| Feature | Classification | Detection (YOLOv8) |
|---------|---------------|--------------------|
| Crystal localization | ❌ No | ✅ Bounding boxes |
| Count crystals per image | ❌ No | ✅ Yes |
| Per-crystal confidence | ❌ No | ✅ Yes |
| Per-crystal whiteness | ❌ No | ✅ Yes |
| ROI filtering | ❌ No | ✅ Yes |
| Purity percentage | ❌ No | ✅ Yes (count-based) |
| Real-time multi-object | ❌ No | ✅ Yes |

### Why This Matters for Salt Crystal Purity Detection:

1. **Counting Requirement**: The system needs to count how many pure vs impure crystals exist to calculate purity percentage.

2. **Localization Requirement**: Bounding boxes enable whiteness calculation for each crystal region.

3. **Batch Processing**: Detection allows tracking statistics over time (batch-level metrics).

4. **Quality Control**: Per-crystal confidence filtering enables higher quality results.

In [None]:
# Visual demonstration of limitation
print("\n" + "="*60)
print("CLASSIFICATION vs DETECTION OUTPUT COMPARISON")
print("="*60)

print("""
┌─────────────────────────────────────────────────────────────┐
│  SAME IMAGE - DIFFERENT MODEL OUTPUTS                       │
├─────────────────────────────┬───────────────────────────────┤
│   YOLOv8 DETECTION          │   MobileNetV2 CLASSIFICATION  │
├─────────────────────────────┼───────────────────────────────┤
│  ┌────┐  ┌────┐  ┌────┐     │                               │
│  │pure│  │imp │  │pure│     │   Prediction: "impure"        │
│  │95% │  │87% │  │92% │     │   Confidence: 67%             │
│  └────┘  └────┘  └────┘     │                               │
│                             │   (No location information)   │
│  3 crystals detected        │   (No individual counts)      │
│  2 pure, 1 impure           │   (No per-crystal metrics)    │
│  Purity: 66.7%              │                               │
│                             │                               │
│  ✅ Localization            │   ❌ No localization          │
│  ✅ Counting                │   ❌ No counting              │
│  ✅ Per-crystal confidence  │   ❌ Only image-level         │
│  ✅ Whiteness calculation   │   ❌ Cannot calculate         │
│  ✅ ROI filtering possible  │   ❌ Not possible             │
└─────────────────────────────┴───────────────────────────────┘
""")

print("\nCONCLUSION:")
print("MobileNetV2 classification achieves good accuracy for individual")
print("crystal classification, but it CANNOT fulfill the requirements")
print("of the salt crystal purity detection system which needs:")
print("  - Crystal counting per image")
print("  - Bounding box coordinates for whiteness calculation")
print("  - Per-crystal confidence scores")
print("  - ROI-based filtering capability")

---
## Step 13: Download Files

In [None]:
from google.colab import files

# Download results JSON
print("Downloading mobilenet_results.json...")
files.download('/content/mobilenet_results.json')

# Download training history plot
print("Downloading training history plot...")
files.download('/content/mobilenet_training_history.png')

# Download confusion matrix
print("Downloading confusion matrix...")
files.download('/content/mobilenet_confusion_matrix.png')

In [None]:
# Optional: Download trained model
from google.colab import files

print("Downloading trained MobileNetV2 model...")
files.download('/content/mobilenet_best.keras')

---
## Summary

### MobileNetV2 Training Complete!

This notebook trained a MobileNetV2 classification model on salt crystal images. The results are saved for comparison with YOLOv8 and ResNet50.

### Key Findings:

1. **Accuracy**: MobileNetV2 can achieve good classification accuracy on individual crystal crops.

2. **Speed**: Efficient inference due to lightweight architecture.

3. **Limitation**: Classification cannot provide the localization and counting capabilities required for the salt crystal purity detection system.

### Next Steps:
1. Run ResNet50 training notebook for comparison
2. Use the Model Comparison notebook to visualize all results
3. Generate final academic justification for YOLOv8 selection