# Plant Disease Detection - Complete Model Training & Evaluation

This notebook trains all models (CNN, ViT, and Hybrid) and provides comprehensive results comparison.

## 1. Setup Environment

In [None]:
# Check GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️  WARNING: No GPU detected. Training will be very slow.")

In [None]:
# Clone repository if not already done
import os
if not os.path.exists('VIT-PLANT-VILLAGE'):
    !git clone https://github.com/Lexo7/VIT-PLANT-VILLAGE.git
    %cd VIT-PLANT-VILLAGE
else:
    %cd VIT-PLANT-VILLAGE

In [None]:
# Install dependencies
!pip install -r requirements.txt
!pip install -e .

## 2. Download and Prepare Dataset

In [None]:
# Download PlantVillage dataset
import os
from google.colab import files

# Method 1: Use Kaggle API (upload kaggle.json first)
print("📁 Please upload your kaggle.json file when prompted...")
uploaded = files.upload()

# Setup Kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download dataset
!kaggle datasets download -d vipoooool/new-plant-diseases-dataset
!unzip -q new-plant-diseases-dataset.zip -d data/plant_village/raw/

In [None]:
# Alternative method: Manual upload if Kaggle doesn't work
# Uncomment if you want to upload dataset manually
# print("📁 Upload your dataset zip file...")
# uploaded = files.upload()
# !unzip -q *.zip -d data/plant_village/raw/

In [None]:
# Preprocess data
!python data/data_preprocessing.py

## 3. Training Configuration

In [None]:
# Training parameters
EPOCHS = 50  # Adjust based on your needs
BATCH_SIZE = 16  # Adjust based on GPU memory
IMAGE_SIZE = 224
LEARNING_RATE = 0.001

# Create results directory
!mkdir -p results/model_weights
!mkdir -p results/logs
!mkdir -p results/evaluation

print(f"Training configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Image size: {IMAGE_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")

## 4. Train CNN Models

In [None]:
# Train ResNet50
print("🚀 Training ResNet50...")
!python training/train_cnn.py \
    --model resnet50 \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --lr {LEARNING_RATE} \
    --image_size {IMAGE_SIZE}

In [None]:
# Train EfficientNet-B0
print("🚀 Training EfficientNet-B0...")
!python training/train_cnn.py \
    --model efficientnet_b0 \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --lr {LEARNING_RATE} \
    --image_size {IMAGE_SIZE}

## 5. Train Vision Transformer Models

In [None]:
# Train ViT Base
print("🚀 Training Vision Transformer Base...")
!python training/train_vit.py \
    --model vit_base_patch16_224 \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --lr 3e-4 \
    --image_size {IMAGE_SIZE}

In [None]:
# Train ViT Small
print("🚀 Training Vision Transformer Small...")
!python training/train_vit.py \
    --model vit_small_patch16_224 \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --lr 3e-4 \
    --image_size {IMAGE_SIZE}

In [None]:
# Train Custom ViT
print("🚀 Training Custom Vision Transformer...")
!python training/train_vit.py \
    --model custom_vit \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --lr 3e-4 \
    --image_size {IMAGE_SIZE}

## 6. Train Hybrid Models

In [None]:
# Train Hybrid CNN-ViT
print("🚀 Training Hybrid CNN-ViT...")
!python training/train_hybrid.py \
    --model hybrid_cnn_vit \
    --cnn_backbone resnet50 \
    --vit_model vit_base_patch16_224 \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --lr 1e-4 \
    --image_size {IMAGE_SIZE}

In [None]:
# Train Parallel CNN-ViT
print("🚀 Training Parallel CNN-ViT...")
!python training/train_hybrid.py \
    --model parallel_cnn_vit \
    --cnn_backbone resnet50 \
    --vit_model vit_base_patch16_224 \
    --fusion_method concat \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --lr 1e-4 \
    --image_size {IMAGE_SIZE}

In [None]:
# Train Attention-Fused CNN-ViT
print("🚀 Training Attention-Fused CNN-ViT...")
!python training/train_hybrid.py \
    --model attention_fused_cnn_vit \
    --cnn_backbone resnet50 \
    --vit_model vit_base_patch16_224 \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --lr 1e-4 \
    --image_size {IMAGE_SIZE}

## 7. Evaluate All Models

In [None]:
# Evaluate all trained models
!python evaluation/evaluate_models.py

## 8. Visualize Results

In [None]:
# Load TensorBoard for training visualization
%load_ext tensorboard
%tensorboard --logdir results/logs

In [None]:
# Compare model results
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import glob

# Load all model results
results = []
result_files = glob.glob('results/evaluation/*_summary.json')

for file in result_files:
    model_name = file.split('/')[-1].replace('_summary.json', '')
    with open(file, 'r') as f:
        data = json.load(f)
        data['model'] = model_name
        results.append(data)

# Create results DataFrame
df_results = pd.DataFrame(results)
df_results = df_results.sort_values('accuracy', ascending=False)

print("📊 Model Performance Comparison:")
print(df_results[['model', 'accuracy', 'macro_avg_f1', 'weighted_avg_f1']].to_string(index=False))

In [None]:
# Plot model comparison
plt.figure(figsize=(12, 8))

# Accuracy comparison
plt.subplot(2, 2, 1)
sns.barplot(data=df_results, x='accuracy', y='model', palette='viridis')
plt.title('Model Accuracy Comparison')
plt.xlabel('Accuracy')

# F1 Score comparison
plt.subplot(2, 2, 2)
sns.barplot(data=df_results, x='macro_avg_f1', y='model', palette='plasma')
plt.title('Macro F1 Score Comparison')
plt.xlabel('Macro F1 Score')

# Weighted F1 Score comparison
plt.subplot(2, 2, 3)
sns.barplot(data=df_results, x='weighted_avg_f1', y='model', palette='coolwarm')
plt.title('Weighted F1 Score Comparison')
plt.xlabel('Weighted F1 Score')

# Model type categorization
plt.subplot(2, 2, 4)
model_types = []
for model in df_results['model']:
    if 'resnet' in model or 'efficientnet' in model:
        model_types.append('CNN')
    elif 'vit' in model and 'hybrid' not in model and 'parallel' not in model:
        model_types.append('ViT')
    else:
        model_types.append('Hybrid')

df_results['model_type'] = model_types
type_performance = df_results.groupby('model_type')['accuracy'].mean()
sns.barplot(x=type_performance.index, y=type_performance.values, palette='Set2')
plt.title('Average Accuracy by Model Type')
plt.ylabel('Average Accuracy')

plt.tight_layout()
plt.savefig('results/model_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 9. Generate Final Report

In [None]:
# Generate comprehensive report
report = f"""
# Plant Disease Detection - Final Results Report

## Training Configuration
- Epochs: {EPOCHS}
- Batch Size: {BATCH_SIZE}
- Image Size: {IMAGE_SIZE}x{IMAGE_SIZE}
- Learning Rate: {LEARNING_RATE}

## Model Performance Summary

### Top 3 Performing Models:
"""

# Add top 3 models to report
for i, (_, row) in enumerate(df_results.head(3).iterrows()):
    report += f"""
{i+1}. **{row['model']}**
   - Accuracy: {row['accuracy']:.4f}
   - Macro F1: {row['macro_avg_f1']:.4f}
   - Weighted F1: {row['weighted_avg_f1']:.4f}
"""

report += f"""

### Model Type Analysis:
"""

# Add model type analysis
for model_type, avg_acc in type_performance.items():
    report += f"\n- {model_type}: {avg_acc:.4f} average accuracy"

report += f"""

## Conclusions

Best performing model: **{df_results.iloc[0]['model']}** with {df_results.iloc[0]['accuracy']:.4f} accuracy

The results show that {'Vision Transformers' if type_performance.idxmax() == 'ViT' else 'CNNs' if type_performance.idxmax() == 'CNN' else 'Hybrid models'} 
perform best on this plant disease detection task.

## Files Generated
- Model weights: `results/model_weights/`
- Training logs: `results/logs/`
- Evaluation results: `results/evaluation/`
- Confusion matrices and classification reports for each model
"""

# Save report
with open('results/final_report.md', 'w') as f:
    f.write(report)

print(report)
print("\n✅ Training and evaluation complete!")
print("📁 Check the 'results' folder for all outputs.")

In [None]:
# Download results (optional)
from google.colab import files
import shutil

# Create zip file with all results
shutil.make_archive('plant_disease_results', 'zip', 'results')

print("📦 Results packaged! Click to download:")
files.download('plant_disease_results.zip')