<a href="https://colab.research.google.com/github/armelyara/drgreen/blob/claude/drgreen-v2-01TfLAqRxjEF2BkLLt72vJrL/drgreen_v7_stratified.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dr Green V7 - STRATIFIED Split (Critical Fix)

**V6 Problem: All validation samples were from ONE class (kinkeliba)**

This was caused by `image_dataset_from_directory` not doing stratified splits.

**V7 Fix:**
- Uses sklearn's `train_test_split` with `stratify` parameter
- Guarantees each class is represented in both train and validation
- Same architecture as V6 (MobileNetV2 + Focal Loss)

### Target: Proper evaluation across all 4 classes

## 1. Setup & Imports

In [None]:
# Install gdown for dataset download
!pip install -q gdown

# Core imports
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path
import json
from datetime import datetime
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
import zipfile
import os
import gdown
import glob

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

# Check GPU
if len(tf.config.list_physical_devices('GPU')) > 0:
    print("GPU detected - training will be fast!")
else:
    print("No GPU - training will be slow. Enable GPU in Runtime > Change runtime type")

In [None]:
# Download dataset from Google Drive
file_id = '1zI5KfTtuV0BlBQnNDNq4tBJuEkxLZZBD'
url = f'https://drive.google.com/uc?id={file_id}'
output = '/content/drgreen.zip'

print("Downloading dataset from Google Drive...")
try:
    gdown.download(url, output, quiet=False)
    print("Dataset downloaded!")

    # Extract
    print("\nExtracting...")
    with zipfile.ZipFile(output, 'r') as zip_ref:
        zip_ref.extractall('/content')
    print("Dataset extracted!")
except:
    print("\nAuto-download failed. Please upload manually:")
    print("from google.colab import files")
    print("uploaded = files.upload()")

## 2. Configuration

In [None]:
# V7 Configuration - Same as V6 but with stratified split
CONFIG = {
    # Paths
    'data_dir': 'rename',
    'model_save_dir': 'models',

    # Image parameters
    'img_height': 224,
    'img_width': 224,
    'batch_size': 16,

    # Training parameters
    'epochs': 100,
    'initial_lr': 0.0005,

    'validation_split': 0.2,
    'seed': 42,

    # Model parameters
    'base_model': 'MobileNetV2',
    'dropout_rate': 0.6,  # Slightly lower than V6
    'num_classes': 4,
    'dense_units': 64,

    # Regularization
    'l2_reg': 0.02,
    'label_smoothing': 0.15,

    # Focal Loss parameters
    'focal_gamma': 2.0,
    'focal_alpha': 0.25,

    # Callbacks
    'early_stopping_patience': 15,
    'reduce_lr_patience': 5,
    'reduce_lr_factor': 0.5,
    'min_lr': 1e-7,
}

PLANT_CLASSES = ['artemisia', 'carica', 'goyavier', 'kinkeliba']
Path(CONFIG['model_save_dir']).mkdir(exist_ok=True)

print("\n" + "="*60)
print("DR GREEN V7 - STRATIFIED SPLIT")
print("="*60)
print(f"\nBase Model: {CONFIG['base_model']} (FROZEN)")
print(f"Image Size: {CONFIG['img_height']}x{CONFIG['img_width']}")
print(f"Batch Size: {CONFIG['batch_size']}")
print(f"\nCRITICAL FIX: Using sklearn stratified split")
print("This ensures all classes are in both train and validation!")

## 3. Focal Loss Implementation

In [None]:
class FocalLoss(tf.keras.losses.Loss):
    def __init__(self, gamma=2.0, alpha=0.25, label_smoothing=0.0, **kwargs):
        super().__init__(**kwargs)
        self.gamma = gamma
        self.alpha = alpha
        self.label_smoothing = label_smoothing

    def call(self, y_true, y_pred):
        num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
        y_true = y_true * (1.0 - self.label_smoothing) + (self.label_smoothing / num_classes)
        y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())
        cross_entropy = -y_true * tf.math.log(y_pred)
        p_t = tf.reduce_sum(y_true * y_pred, axis=-1)
        focal_weight = tf.pow(1 - p_t, self.gamma)
        focal_loss = self.alpha * focal_weight * tf.reduce_sum(cross_entropy, axis=-1)
        return tf.reduce_mean(focal_loss)

    def get_config(self):
        config = super().get_config()
        config.update({
            'gamma': self.gamma,
            'alpha': self.alpha,
            'label_smoothing': self.label_smoothing
        })
        return config

print("Focal Loss implemented")

## 4. STRATIFIED Dataset Loading (Critical Fix)

In [None]:
# Get all image paths and labels
data_dir = Path(CONFIG['data_dir'])
class_names = sorted([d.name for d in data_dir.iterdir() if d.is_dir()])
print(f"Classes found: {class_names}")

# Collect all image paths and labels
all_image_paths = []
all_labels = []

for class_idx, class_name in enumerate(class_names):
    class_dir = data_dir / class_name
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
        for img_path in class_dir.glob(ext):
            all_image_paths.append(str(img_path))
            all_labels.append(class_idx)

all_image_paths = np.array(all_image_paths)
all_labels = np.array(all_labels)

print(f"\nTotal images: {len(all_image_paths)}")
for i, name in enumerate(class_names):
    count = (all_labels == i).sum()
    print(f"  {name}: {count} images")

In [None]:
# STRATIFIED SPLIT using sklearn
train_paths, val_paths, train_labels, val_labels = train_test_split(
    all_image_paths,
    all_labels,
    test_size=CONFIG['validation_split'],
    random_state=CONFIG['seed'],
    stratify=all_labels  # CRITICAL: This ensures balanced split!
)

print("\n" + "="*60)
print("STRATIFIED SPLIT RESULTS")
print("="*60)

print(f"\nTraining set: {len(train_paths)} images")
for i, name in enumerate(class_names):
    count = (train_labels == i).sum()
    pct = count / len(train_labels) * 100
    print(f"  {name}: {count} images ({pct:.1f}%)")

print(f"\nValidation set: {len(val_paths)} images")
for i, name in enumerate(class_names):
    count = (val_labels == i).sum()
    pct = count / len(val_labels) * 100
    print(f"  {name}: {count} images ({pct:.1f}%)")

print("\nAll classes are now properly represented in validation!")

In [None]:
# Create TensorFlow datasets from paths
def load_and_preprocess_image(path, label):
    # Load image
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, [CONFIG['img_height'], CONFIG['img_width']])
    return img, label

# Create datasets
train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
train_ds = train_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
val_ds = val_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

# Convert labels to one-hot
train_ds = train_ds.map(lambda x, y: (x, tf.one_hot(y, CONFIG['num_classes'])))
val_ds = val_ds.map(lambda x, y: (x, tf.one_hot(y, CONFIG['num_classes'])))

print("Datasets created from stratified split")

## 5. Data Augmentation & Pipeline

In [None]:
# Data augmentation
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal_and_vertical"),
    tf.keras.layers.RandomRotation(0.3),
    tf.keras.layers.RandomZoom(0.2),
    tf.keras.layers.RandomBrightness(0.2),
    tf.keras.layers.RandomContrast(0.2),
    tf.keras.layers.RandomTranslation(0.15, 0.15),
], name="data_augmentation")

# Preprocessing for MobileNetV2
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

# Apply augmentation and preprocessing
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.map(
    lambda x, y: (data_augmentation(x, training=True), y),
    num_parallel_calls=AUTOTUNE
)
train_ds = train_ds.map(
    lambda x, y: (preprocess_input(x), y),
    num_parallel_calls=AUTOTUNE
)

val_ds = val_ds.map(
    lambda x, y: (preprocess_input(x), y),
    num_parallel_calls=AUTOTUNE
)

# Batch and optimize
train_ds = train_ds.shuffle(1000).batch(CONFIG['batch_size']).prefetch(AUTOTUNE)
val_ds = val_ds.batch(CONFIG['batch_size']).prefetch(AUTOTUNE)

print("Data pipeline configured")

In [None]:
# Calculate class weights
total_train = len(train_labels)
class_weights = {}
for i, class_name in enumerate(class_names):
    count = (train_labels == i).sum()
    base_weight = total_train / (len(class_names) * count)
    class_weights[i] = base_weight ** 1.3  # Moderate power

print("Class weights:")
for i, weight in class_weights.items():
    print(f"  {class_names[i]}: {weight:.3f}")

## 6. Build Model

In [None]:
def build_model():
    inputs = tf.keras.Input(shape=(CONFIG['img_height'], CONFIG['img_width'], 3))

    base_model = tf.keras.applications.MobileNetV2(
        include_top=False,
        weights='imagenet',
        input_tensor=inputs,
        pooling='avg'
    )
    base_model.trainable = False

    x = base_model.output
    x = tf.keras.layers.Dropout(CONFIG['dropout_rate'])(x)
    x = tf.keras.layers.Dense(
        CONFIG['dense_units'],
        activation='relu',
        kernel_regularizer=tf.keras.regularizers.l2(CONFIG['l2_reg']),
        kernel_initializer='he_normal'
    )(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(CONFIG['dropout_rate'] * 0.5)(x)
    outputs = tf.keras.layers.Dense(
        CONFIG['num_classes'],
        activation='softmax',
        kernel_regularizer=tf.keras.regularizers.l2(CONFIG['l2_reg'])
    )(x)

    model = tf.keras.Model(inputs, outputs, name='DrGreen_V7_Stratified')
    return model, base_model

model, base_model = build_model()

print("\nModel Architecture:")
print(f"  Base: {CONFIG['base_model']} (FROZEN)")
print(f"  Total parameters: {model.count_params():,}")
trainable = sum([tf.size(v).numpy() for v in model.trainable_variables])
print(f"  Trainable parameters: {trainable:,}")

## 7. Compile & Train

In [None]:
# Learning rate schedule
steps_per_epoch = len(train_labels) // CONFIG['batch_size']
total_steps = steps_per_epoch * CONFIG['epochs']

lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=CONFIG['initial_lr'],
    decay_steps=total_steps,
    alpha=0.01
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
    loss=FocalLoss(
        gamma=CONFIG['focal_gamma'],
        alpha=CONFIG['focal_alpha'],
        label_smoothing=CONFIG['label_smoothing']
    ),
    metrics=[
        tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
        tf.keras.metrics.TopKCategoricalAccuracy(k=2, name='top2_accuracy')
    ]
)

print("Model compiled with Focal Loss")

In [None]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=CONFIG['early_stopping_patience'],
        restore_best_weights=True,
        mode='max',
        verbose=1
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_accuracy',
        factor=CONFIG['reduce_lr_factor'],
        patience=CONFIG['reduce_lr_patience'],
        min_lr=CONFIG['min_lr'],
        mode='max',
        verbose=1
    ),
    tf.keras.callbacks.ModelCheckpoint(
        filepath=f"{CONFIG['model_save_dir']}/best_model_v7.keras",
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    ),
    tf.keras.callbacks.CSVLogger(
        f"{CONFIG['model_save_dir']}/training_log_v7.csv"
    )
]

print("Callbacks configured")

In [None]:
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Training samples: {len(train_labels)}")
print(f"Validation samples: {len(val_labels)}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Steps per epoch: {steps_per_epoch}")
print("="*60 + "\n")

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=CONFIG['epochs'],
    callbacks=callbacks,
    class_weight=class_weights,
    verbose=1
)

print("\nTraining completed!")

## 8. Visualize Results

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history.history['accuracy'], label='Train')
axes[0].plot(history.history['val_accuracy'], label='Validation')
axes[0].set_title('Model Accuracy')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history.history['loss'], label='Train')
axes[1].plot(history.history['val_loss'], label='Validation')
axes[1].set_title('Model Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

axes[2].plot(history.history['top2_accuracy'], label='Train')
axes[2].plot(history.history['val_top2_accuracy'], label='Validation')
axes[2].set_title('Top-2 Accuracy')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Top-2 Accuracy')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final metrics
final_train_acc = history.history['accuracy'][-1]
final_val_acc = history.history['val_accuracy'][-1]
best_val_acc = max(history.history['val_accuracy'])
best_epoch = history.history['val_accuracy'].index(best_val_acc) + 1
overfitting_gap = abs(final_train_acc - final_val_acc)

print("\n" + "="*60)
print("FINAL METRICS")
print("="*60)
print(f"Final Train Accuracy: {final_train_acc*100:.2f}%")
print(f"Final Val Accuracy:   {final_val_acc*100:.2f}%")
print(f"Best Val Accuracy:    {best_val_acc*100:.2f}%")
print(f"Best Epoch:           {best_epoch}")
print(f"Overfitting Gap:      {overfitting_gap*100:.2f}%")

## 9. Detailed Evaluation

In [None]:
# Load best model
best_model = tf.keras.models.load_model(
    f"{CONFIG['model_save_dir']}/best_model_v7.keras",
    custom_objects={'FocalLoss': FocalLoss}
)
print("Best model loaded")

# Evaluate
results = best_model.evaluate(val_ds, verbose=1)
print(f"\nValidation Accuracy: {results[1]*100:.2f}%")
print(f"Top-2 Accuracy: {results[2]*100:.2f}%")

In [None]:
# Get predictions
print("Generating predictions...")

# Create evaluation dataset without augmentation
val_ds_eval = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
val_ds_eval = val_ds_eval.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds_eval = val_ds_eval.map(
    lambda x, y: (preprocess_input(x), y),
    num_parallel_calls=AUTOTUNE
)
val_ds_eval = val_ds_eval.batch(CONFIG['batch_size'])

y_true = []
y_pred = []

for images, labels in val_ds_eval:
    predictions = best_model.predict(images, verbose=0)
    y_true.extend(labels.numpy())
    y_pred.extend(np.argmax(predictions, axis=1))

y_true = np.array(y_true)
y_pred = np.array(y_pred)

accuracy = np.mean(y_true == y_pred)
print(f"\nAccuracy: {accuracy*100:.2f}%")

In [None]:
# Prediction distribution
print("\n" + "="*60)
print("PREDICTION DISTRIBUTION")
print("="*60)
pred_counts = {name: 0 for name in class_names}
for p in y_pred:
    pred_counts[class_names[p]] += 1

collapse_detected = False
for class_name, count in pred_counts.items():
    pct = count/len(y_pred)*100
    if pct > 50:
        collapse_detected = True
        print(f"  {class_name}: {count} ({pct:.1f}%) - WARNING")
    else:
        print(f"  {class_name}: {count} ({pct:.1f}%)")

if collapse_detected:
    print("\nClass collapse detected!")
else:
    print("\nPredictions are balanced!")

In [None]:
# Confusion matrix
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names,
    square=True
)
plt.title('Confusion Matrix - Dr Green V7 (Stratified)', fontsize=14)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.tight_layout()
plt.show()

# Classification report
print("\n" + "="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(y_true, y_pred, target_names=class_names, digits=4))

# Per-class accuracy
print("\nPer-class accuracy:")
for i, class_name in enumerate(class_names):
    class_mask = y_true == i
    if class_mask.sum() > 0:
        class_acc = (y_pred[class_mask] == i).mean()
        status = "OK" if class_acc >= 0.60 else "LOW" if class_acc >= 0.40 else "BAD"
        print(f"  [{status}] {class_name}: {class_acc*100:.2f}% ({class_mask.sum()} samples)")

## 10. Save Model

In [None]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_name = f"drgreen_v7_stratified_{timestamp}"

print(f"Saving model: {model_name}")

# Keras format
keras_path = f"{CONFIG['model_save_dir']}/{model_name}.keras"
best_model.save(keras_path)
print(f"  Keras: {keras_path}")

# TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(best_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

tflite_path = f"{CONFIG['model_save_dir']}/{model_name}.tflite"
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)
print(f"  TFLite: {tflite_path}")

# Metadata
metadata = {
    'model_name': model_name,
    'version': '7.0-stratified',
    'created_at': timestamp,
    'class_names': class_names,
    'performance': {
        'val_accuracy': float(results[1]),
        'val_top2_accuracy': float(results[2]),
        'best_val_accuracy': float(best_val_acc),
        'overfitting_gap': float(overfitting_gap)
    }
}

metadata_path = f"{CONFIG['model_save_dir']}/{model_name}_metadata.json"
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"  Metadata: {metadata_path}")

# File sizes
keras_size = os.path.getsize(keras_path) / (1024*1024)
tflite_size = os.path.getsize(tflite_path) / (1024*1024)
print(f"\nFile sizes: Keras={keras_size:.2f}MB, TFLite={tflite_size:.2f}MB")

## 11. Final Summary

In [None]:
print("\n" + "="*60)
print("DR GREEN V7 - TRAINING COMPLETE")
print("="*60)

print(f"\nArchitecture: {CONFIG['base_model']} (frozen)")
print(f"Parameters: {model.count_params():,} ({trainable:,} trainable)")

print(f"\nPerformance:")
print(f"  Validation Accuracy: {results[1]*100:.2f}%")
print(f"  Top-2 Accuracy: {results[2]*100:.2f}%")
print(f"  Overfitting Gap: {overfitting_gap*100:.2f}%")

print(f"\nPrediction Distribution:")
for class_name, count in pred_counts.items():
    pct = count/len(y_pred)*100
    print(f"  {class_name}: {pct:.1f}%")

print(f"\nSaved: {model_name}.keras, {model_name}.tflite")

if not collapse_detected:
    print("\nREADY FOR DEPLOYMENT!")
else:
    print("\nNeed further optimization")

## 12. Download Models

In [None]:
from google.colab import files

print("Downloading models...")
files.download(keras_path)
files.download(tflite_path)
files.download(metadata_path)
print("Downloads initiated!")