# BeejX Leaf Disease Model Training
**Professional Training Pipeline | Optimized for Google Colab**

This notebook runs the BeejX professional training script. It handles:
1. Environment Setup
2. Dataset Organization
3. MobileNetV2 Training
4. TFLite Conversion for Android

In [None]:
# @title 0.0 [OPTIONAL] Mount Google Drive (For Large Files)
# Run this ONLY if you uploaded 'project_upload.zip' to Google Drive.
import os
from google.colab import drive

if not os.path.exists('project_upload.zip'):
    print("Mounting Google Drive...")
    drive.mount('/content/drive')
    
    # Copy the file to current workspace with progress bar
    print("Copying project_upload.zip from Drive... (This shows progress now)")
    !rsync -ah --progress "/content/drive/MyDrive/project_upload.zip" .
    print("Copy Complete!")
else:
    print("project_upload.zip already exists. Skipping Drive mount.")

In [None]:
# @title 0.1 Initialize & Unzip
import os
# Check if we need to unzip
if os.path.exists('project_upload.zip'):
    print("Unzipping project... (This takes a few minutes for 6GB)")
    !unzip -q project_upload.zip
    print("Unzip Complete!")
else:
    print("project_upload.zip' not found. Please Upload it or Run Step 0.0 to copy from Drive.")

In [None]:
# @title 1. Setup Environment
# Install dependencies from the professional requirements file
!pip install -r requirements.txt
!pip install pyyaml

In [None]:
# @title 1.1 [AUTO-FIX] Apply Stability Patches (Fix OOM & Bugs)
# Run this to FIX the Memory Crash and AttributeError.
import os

# 1. Fix src/scripts/loader.py (Disable Cache to save RAM)
loader_code = """
import tensorflow as tf
import os
import glob
from typing import Tuple, List, Optional
import yaml

def load_config(config_path: str = "configs/config.yaml") -> dict:
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

class BeejXDataLoader:
    def __init__(self, config: dict):
        self.config = config
        self.img_size = tuple(config['model']['input_shape'][:2])
        self.batch_size = config['training']['batch_size']
        
        self.augment_layers = tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal"),
            tf.keras.layers.RandomRotation(config['augmentation']['rotation_range']),
            tf.keras.layers.RandomZoom(config['augmentation']['zoom_range']),
        ])

    def get_local_dataset(self) -> Optional[tf.data.Dataset]:
        data_dir = self.config['data']['local_data_dir']
        classes = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
        if not classes:
            print("No class folders found in data dir.")
            return None
            
        print(f"Auto-detected {len(classes)} classes: {classes}")
        
        train_ds = tf.keras.utils.image_dataset_from_directory(
            data_dir, validation_split=self.config['training']['validation_split'],
            subset="training", seed=123, image_size=self.img_size,
            batch_size=self.batch_size, labels='inferred', label_mode='int'
        )
        
        val_ds = tf.keras.utils.image_dataset_from_directory(
            data_dir, validation_split=self.config['training']['validation_split'],
            subset="validation", seed=123, image_size=self.img_size,
            batch_size=self.batch_size, labels='inferred', label_mode='int' 
        )

        class_names = train_ds.class_names

        train_ds = train_ds.map(lambda x, y: (self.augment_layers(x, training=True), y), num_parallel_calls=tf.data.AUTOTUNE)
        train_ds = train_ds.map(lambda x, y: (x/255.0, y), num_parallel_calls=tf.data.AUTOTUNE)
        val_ds = val_ds.map(lambda x, y: (x/255.0, y), num_parallel_calls=tf.data.AUTOTUNE)

        # REMOVED .cache() to prevent RAM explosion. Streaming from disk is safer.
        train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
        val_ds = val_ds.prefetch(buffer_size=tf.data.AUTOTUNE)

        return train_ds, val_ds, class_names

    def get_combined_dataset(self):
        return self.get_local_dataset()
"""

with open('src/scripts/loader.py', 'w') as f:
    f.write(loader_code)
print("Fixed src/scripts/loader.py (Disabled Cache)")

# 2. Fix src/train.py (Memory Warning Fix + Import Fix)
train_code = """
import os
import tensorflow as tf
import numpy as np
import yaml
import glob
from scripts.loader import BeejXDataLoader
from models.mobilenet import build_mobilenet_model

def load_config(path='configs/config.yaml'):
    with open(path, 'r') as f:
        return yaml.safe_load(f)

def main():
    print("="*50)
    print("BeejX Leaf Disease Model - Professional Training Pipeline")
    print("="*50)
    config = load_config()
    loader = BeejXDataLoader(config)
    datasets = loader.get_combined_dataset()

    if datasets is None:
        return
        
    train_ds, val_ds, class_names = datasets
    num_classes = len(class_names)
    print(f"Found {num_classes} classes: {class_names}")

    model = build_mobilenet_model(num_classes, config)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=config['training']['learning_rate']),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    print("\\nCalculating class weights (Optimized for Colab)...")
    data_dir = config['data']['local_data_dir']
    y_train_list = []
    try:
        for i, name in enumerate(class_names):
             safe_name = name if os.path.isdir(os.path.join(data_dir, name)) else name
             count = len(glob.glob(os.path.join(data_dir, safe_name, "*")))
             y_train_list.extend([i] * int(count * 0.8))
             
        from sklearn.utils import class_weight
        y_train = np.array(y_train_list)
        weights = class_weight.compute_class_weight(
            class_weight='balanced',
            classes=np.unique(y_train),
            y=y_train
        )
        class_weights = dict(enumerate(weights))
        print(f"Computed Class Weights (Fast): {class_weights}")
    except Exception as e:
        print(f"Warning during weighting: {e}. Using equal weights.")
        class_weights = None

    print("\\nStarting training...")
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        os.path.join(config['paths']['output_dir'], "best_model.keras"),
        save_best_only=True, monitor='val_accuracy'
    )
    early_stopping_cb = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=3, restore_best_weights=True
    )
    
    model.fit(
        train_ds, epochs=config['training']['epochs'], 
        validation_data=val_ds, class_weight=class_weights,
        callbacks=[checkpoint_cb, early_stopping_cb]
    )

    print("\\nExporting to TFLite...")
    export_dir = config['paths']['output_dir']
    os.makedirs(export_dir, exist_ok=True)
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()
    with open(os.path.join(export_dir, "model.tflite"), "wb") as f:
        f.write(tflite_model)
    with open(os.path.join(export_dir, "labels.txt"), "w") as f:
        for name in class_names:
            f.write(name + "\\n")
    print("SUCCESS! Model saved.")

if __name__ == "__main__":
    main()
"""

with open('src/train.py', 'w') as f:
    f.write(train_code)
print("Fixed src/train.py")


In [None]:
# @title 2. Verify GPU
import tensorflow as tf
print(f"TensorFlow Version: {tf.__version__}")
print(f"GPU Available: {len(tf.config.list_physical_devices('GPU')) > 0}")

if not len(tf.config.list_physical_devices('GPU')) > 0:
    print("WARNING: You are running on CPU. Enable GPU in Runtime > Change runtime type.")

In [None]:
# @title 2.1 Prepare Data
# Just upload your 'data' folder to the Colab file explorer
# This script will organize it automatically.

# Use src/scripts/organize.py because we moved it!
!python src/scripts/organize.py
print("Data organized successfully!")

In [None]:
# @title 3. Start Training Pipeline
# This runs the professional script 'src/train.py'
# It uses the settings in 'configs/config.yaml'

!python src/train.py

In [None]:
# @title 4. Download Model for Android
from google.colab import files
import os

if os.path.exists('exports/model.tflite'):
    files.download('exports/model.tflite')
    files.download('exports/labels.txt')
else:
    print("Model not found. Did training complete successfully?")