# Notebook 2: Model Training ‚Äî Cutaneous Leishmaniasis Ulcer Classification

## Purpose
Train and validate a binary classification model for CL ulcer images using **MobileNetV2** transfer learning.

- **Class 0 (Sensitive):** CL ulcers showing healing / good treatment response
- **Class 1 (Poor):** CL ulcers showing poor treatment response

## Prerequisites
- Run `preprocessing.ipynb` first to generate `processed_data.zip`.
- Upload `processed_data.zip` to this notebook when prompted.

## Model Architecture
- **Base:** MobileNetV2 (pretrained on ImageNet, frozen)
- **Head:** GlobalAveragePooling ‚Üí Dropout(0.5) ‚Üí Dense(1, sigmoid)
- **Loss:** Binary Crossentropy
- **Optimizer:** Adam (lr=1e-4)

## Data Split
- **70% Training** / **15% Validation** (remaining 15% unused ‚Äî test data is separate)

## Output
- Trained model saved as `model.h5`
- Training / validation accuracy and loss plots

---
**Target test accuracy: ~0.6** (clinically realistic for small CL datasets)

## 1. Import Libraries

In [None]:
import os
import shutil
import zipfile
import numpy as np
import matplotlib.pyplot as plt
import cv2

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, callbacks
from tensorflow.keras.applications import MobileNetV2
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

# Google Colab file upload utility
try:
    from google.colab import files
    IN_COLAB = True
except ImportError:
    IN_COLAB = False
    print("Not running in Google Colab. Manual upload will be skipped.")

# Reproducibility
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")
print("All libraries imported successfully.")

## 2. Upload Preprocessed Data (Manual Upload)

Upload the `processed_data.zip` file generated by **Notebook 1** (`preprocessing.ipynb`).  
This ZIP should contain:
```
processed_data/
  ‚îú‚îÄ‚îÄ sensitive/   ‚Üê Preprocessed healing CL ulcer images
  ‚îî‚îÄ‚îÄ poor/        ‚Üê Preprocessed poor-response CL ulcer images
```

In [None]:
# ============================================================
# UPLOAD PREPROCESSED DATA
# ============================================================

DATA_DIR = 'processed_data'
CLASSES = ['sensitive', 'poor']

if IN_COLAB:
    print("="*50)
    print("  STEP: Upload processed_data.zip")
    print("="*50)
    print("Select the processed_data.zip from Notebook 1.\n")

    uploaded = files.upload()

    for filename in uploaded.keys():
        if filename.endswith('.zip'):
            print(f"\nExtracting '{filename}'...")
            with zipfile.ZipFile(filename, 'r') as zip_ref:
                zip_ref.extractall('.')
            print(f"Extraction complete.")
        else:
            print(f"‚ö†Ô∏è  '{filename}' is not a ZIP file.")
else:
    print("Not in Colab. Ensure 'processed_data/' folder exists.")

# --------------------------------------------------
# AUTO-DETECT DIRECTORY (handles different ZIP layouts)
# --------------------------------------------------
def find_data_dir(expected_name, required_subdirs):
    """Find directory containing required subdirectories after extraction."""
    # Case 1: Expected directory exists
    if os.path.isdir(expected_name):
        if all(os.path.isdir(os.path.join(expected_name, s)) for s in required_subdirs):
            return expected_name

    # Case 2: Subdirs exist at root
    if all(os.path.isdir(s) for s in required_subdirs):
        os.makedirs(expected_name, exist_ok=True)
        for s in required_subdirs:
            dest = os.path.join(expected_name, s)
            if not os.path.exists(dest):
                shutil.move(s, dest)
        return expected_name

    # Case 3: Search in extracted content
    for root, dirs, _ in os.walk('.'):
        dirs[:] = [d for d in dirs if not d.startswith('.') and d != '__MACOSX']
        if all(s in dirs for s in required_subdirs) and root != '.':
            return root

    raise FileNotFoundError(
        f"Could not find directory with subdirectories {required_subdirs}.\n"
        f"Check your ZIP structure."
    )

DATA_DIR = find_data_dir(DATA_DIR, CLASSES)
print(f"\n‚úÖ Using data directory: '{DATA_DIR}/'")

## 3. Load and Prepare Data

Load preprocessed CL ulcer images, assign binary labels, and prepare for the model.

In [None]:
# ============================================================
# LOAD PREPROCESSED IMAGES AND ASSIGN LABELS
# Class 0: sensitive (healing CL ulcers)
# Class 1: poor (poor-response CL ulcers)
# ============================================================

IMG_SIZE = 224
VALID_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
IGNORE_FILES = {'.ds_store', 'thumbs.db', 'desktop.ini'}

# Class mapping
CLASS_MAP = {
    'sensitive': 0,  # Healing / good treatment response
    'poor': 1        # Poor treatment response
}

images = []
labels = []

for class_name, label in CLASS_MAP.items():
    class_dir = os.path.join(DATA_DIR, class_name)

    if not os.path.isdir(class_dir):
        raise FileNotFoundError(
            f"Directory '{class_dir}' not found.\n"
            f"Ensure {DATA_DIR}/ has 'sensitive/' and 'poor/' subdirectories."
        )

    image_files = sorted([
        f for f in os.listdir(class_dir)
        if os.path.splitext(f)[1].lower() in VALID_EXTENSIONS
        and not f.startswith('.')
        and f.lower() not in IGNORE_FILES
    ])

    print(f"Loading class '{class_name}' (label={label}): {len(image_files)} images")

    for fname in image_files:
        img_path = os.path.join(class_dir, fname)

        # Load as grayscale (preprocessed L-channel images are single-channel)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)

        if img is None:
            print(f"  ‚ö†Ô∏è  Skipping unreadable file: {fname}")
            continue

        # Ensure correct dimensions
        if img.shape[0] != IMG_SIZE or img.shape[1] != IMG_SIZE:
            img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))

        # Normalize to [0, 1]
        img = img.astype(np.float32) / 255.0

        images.append(img)
        labels.append(label)

# Convert to numpy arrays
if len(images) == 0:
    raise ValueError("No images were loaded! Check your processed_data/ folder.")

X = np.array(images)
y = np.array(labels)

print(f"\nTotal loaded: {len(X)} images")
print(f"  Sensitive (label 0): {np.sum(y == 0)}")
print(f"  Poor (label 1):      {np.sum(y == 1)}")
print(f"  Image shape: {X[0].shape}")

In [None]:
# ============================================================
# PREPARE DATA FOR MOBILENETV2
# MobileNetV2 expects 3-channel (RGB) input.
# We replicate the grayscale L-channel across 3 channels.
# This preserves the CLAHE-enhanced texture information
# while matching the pretrained model's expected input format.
# ============================================================

# Expand: (N, 224, 224) ‚Üí (N, 224, 224, 1) ‚Üí (N, 224, 224, 3)
X = np.expand_dims(X, axis=-1)
X = np.repeat(X, 3, axis=-1)

print(f"Model input shape: {X.shape}")
print(f"Labels shape:      {y.shape}")
print(f"Pixel range:       [{X.min():.4f}, {X.max():.4f}]")

## 4. Train / Validation Split

Split data into **70% training** and **15% validation**.

The remaining 15% is reserved conceptually for testing, which is handled separately in Notebook 3 with its own unseen test data upload. To maximize usage of every available image, we perform a direct 70/15 ratio split on all loaded data:
- Training:   82.4% of loaded data  (= 70 / (70+15))
- Validation: 17.6% of loaded data  (= 15 / (70+15))

In [None]:
# ============================================================
# TRAIN / VALIDATION SPLIT
# 70% train : 15% validation ratio
# Split ratio: val_size = 15/(70+15) ‚âà 0.176
# Stratified to maintain class balance
# ============================================================

VAL_RATIO = 15.0 / (70.0 + 15.0)  # ‚âà 0.176

# Minimum samples check: stratified split requires >=2 per class in each set
min_class_count = min(np.sum(y == 0), np.sum(y == 1))
min_needed = max(2, int(np.ceil(1.0 / VAL_RATIO)))  # Need enough for at least 1 val sample

if min_class_count < 2:
    raise ValueError(
        f"Each class must have at least 2 images. "
        f"Found: sensitive={np.sum(y == 0)}, poor={np.sum(y == 1)}"
    )

# Use stratified split if enough samples, otherwise simple split
use_stratify = min_class_count >= 4  # Need >=4 per class for reliable stratification

try:
    X_train, X_val, y_train, y_val = train_test_split(
        X, y,
        test_size=VAL_RATIO,
        random_state=SEED,
        stratify=y if use_stratify else None
    )
except ValueError as e:
    # Fallback: if stratified split fails due to tiny dataset, use non-stratified
    print(f"  ‚ö†Ô∏è  Stratified split failed ({e}). Using non-stratified split.")
    X_train, X_val, y_train, y_val = train_test_split(
        X, y,
        test_size=VAL_RATIO,
        random_state=SEED,
        stratify=None
    )

print(f"Data split summary:")
print(f"  Training:   {len(X_train)} images ({len(X_train)/len(X)*100:.1f}%)")
print(f"    Sensitive: {np.sum(y_train == 0)}, Poor: {np.sum(y_train == 1)}")
print(f"  Validation: {len(X_val)} images ({len(X_val)/len(X)*100:.1f}%)")
print(f"    Sensitive: {np.sum(y_val == 0)}, Poor: {np.sum(y_val == 1)}")

## 5. Build Model ‚Äî MobileNetV2 Transfer Learning

### Architecture
| Layer | Description |
|-------|-------------|
| MobileNetV2 | Pretrained on ImageNet, **frozen** (no fine-tuning) |
| GlobalAveragePooling2D | Reduces spatial dims to feature vector |
| Dropout(0.5) | Regularization for small dataset |
| Dense(1, sigmoid) | Binary output: P(poor-response) |

### Why freeze the base?
With a small medical dataset, fine-tuning all layers would cause **severe overfitting**.
The frozen MobileNetV2 base provides robust feature extraction (edges, textures, shapes)
learned from ImageNet that transfer well to ulcer morphology analysis.

In [None]:
# ============================================================
# BUILD MODEL: MobileNetV2 + Classification Head
# ============================================================

def build_model(input_shape=(224, 224, 3)):
    """
    Binary classifier using MobileNetV2 transfer learning.

    Architecture:
        MobileNetV2 (frozen) ‚Üí GlobalAvgPool ‚Üí Dropout(0.5) ‚Üí Dense(1, sigmoid)

    Returns:
        Compiled Keras model.
    """
    # Load pretrained MobileNetV2 (without ImageNet classification head)
    base_model = MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )

    # Freeze all base layers ‚Äî prevent weight updates
    base_model.trainable = False

    # Build classification head
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dropout(0.5),             # Regularization for small dataset
        layers.Dense(1, activation='sigmoid')  # P(class=poor)
    ])

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    return model


model = build_model()
model.summary()

## 6. Train the Model

- **Epochs:** 15 (with early stopping)
- **Batch size:** 16
- **Class weights:** Computed to handle class imbalance
- **Early stopping:** patience=5 on validation loss, restores best weights

In [None]:
# ============================================================
# TRAINING CONFIGURATION
# ============================================================

EPOCHS = 15
BATCH_SIZE = 16

# --- Class weights ---
# Medical datasets are often imbalanced.
# Class weights penalize misclassification of minority class more heavily.
unique_classes = np.unique(y_train)
if len(unique_classes) >= 2:
    cw_array = compute_class_weight('balanced', classes=unique_classes, y=y_train)
    class_weights = {int(c): w for c, w in zip(unique_classes, cw_array)}
else:
    # Only one class in training set ‚Äî no weighting possible
    print("‚ö†Ô∏è  WARNING: Only one class in training data. Class weighting disabled.")
    class_weights = None

print(f"Class weights: {class_weights}")

# --- Early Stopping ---
# Stops training when validation loss stops improving.
# Restores best weights to avoid using an overfit checkpoint.
early_stop = callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

print(f"\nTraining configuration:")
print(f"  Epochs:      {EPOCHS}")
print(f"  Batch size:  {BATCH_SIZE}")
print(f"  Train size:  {len(X_train)}")
print(f"  Val size:    {len(X_val)}")

In [None]:
# ============================================================
# TRAIN THE MODEL
# ============================================================

print("Starting model training...\n")

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    class_weight=class_weights,
    callbacks=[early_stop],
    verbose=1
)

print("\n‚úÖ Training complete!")

## 7. Training Results

Display final metrics and plot training curves.

In [None]:
# ============================================================
# DISPLAY TRAINING METRICS
# ============================================================

final_epoch = len(history.history['accuracy'])
train_acc = history.history['accuracy'][-1]
val_acc = history.history['val_accuracy'][-1]
train_loss = history.history['loss'][-1]
val_loss = history.history['val_loss'][-1]

print(f"Training completed after {final_epoch} epoch(s)")
print(f"{'='*45}")
print(f"  Final Training Accuracy:    {train_acc:.4f}")
print(f"  Final Validation Accuracy:  {val_acc:.4f}")
print(f"  Final Training Loss:        {train_loss:.4f}")
print(f"  Final Validation Loss:      {val_loss:.4f}")
print(f"{'='*45}")

best_val_acc = max(history.history['val_accuracy'])
best_epoch = int(np.argmax(history.history['val_accuracy'])) + 1
print(f"  Best Validation Accuracy:   {best_val_acc:.4f} (Epoch {best_epoch})")

In [None]:
# ============================================================
# PLOT: Accuracy and Loss vs Epoch
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

epochs_range = range(1, final_epoch + 1)

# --- Accuracy Plot ---
axes[0].plot(epochs_range, history.history['accuracy'],
             'b-o', label='Training Accuracy', linewidth=2, markersize=5)
axes[0].plot(epochs_range, history.history['val_accuracy'],
             'r-s', label='Validation Accuracy', linewidth=2, markersize=5)
axes[0].set_title('Accuracy vs Epoch', fontsize=13, fontweight='bold')
axes[0].set_xlabel('Epoch', fontsize=11)
axes[0].set_ylabel('Accuracy', fontsize=11)
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)
axes[0].set_xticks(list(epochs_range))

# --- Loss Plot ---
axes[1].plot(epochs_range, history.history['loss'],
             'b-o', label='Training Loss', linewidth=2, markersize=5)
axes[1].plot(epochs_range, history.history['val_loss'],
             'r-s', label='Validation Loss', linewidth=2, markersize=5)
axes[1].set_title('Loss vs Epoch', fontsize=13, fontweight='bold')
axes[1].set_xlabel('Epoch', fontsize=11)
axes[1].set_ylabel('Loss', fontsize=11)
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)
axes[1].set_xticks(list(epochs_range))

plt.suptitle('CL Ulcer Classification ‚Äî MobileNetV2 Training History',
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## 8. Save Trained Model

Save as `model.h5` for use in **Notebook 3** (`model_testing.ipynb`).

In [None]:
# ============================================================
# SAVE TRAINED MODEL
# ============================================================

MODEL_PATH = 'model.h5'

model.save(MODEL_PATH)
print(f"‚úÖ Model saved as: {MODEL_PATH}")

model_size_mb = os.path.getsize(MODEL_PATH) / (1024 * 1024)
print(f"   File size: {model_size_mb:.2f} MB")

# Download
if IN_COLAB:
    files.download(MODEL_PATH)
    print("\nüì• Download started.")
else:
    print(f"\nModel saved in working directory.")

print("\n" + "="*50)
print("  NOTEBOOK 2 COMPLETE")
print("="*50)
print("\nNext step:")
print("  1. Open model_testing.ipynb")
print("  2. Upload model.h5 when prompted")
print("  3. Upload your test dataset ZIP when prompted")