# üå± Plant Disease Classification - Training on Google Colab

This notebook trains a ResNet-50 model on your high-accuracy AgriDetect dataset (99.7%).

**Estimated Time**: 30-60 minutes with free GPU

**Steps**:
1. Upload your dataset zip file
2. Install dependencies
3. Train the model
4. Download the trained model

## Step 1: Check GPU Availability

In [None]:
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è No GPU detected! Go to Runtime > Change runtime type > Select GPU")

## Step 2: Install Dependencies

In [None]:
!pip install -q transformers datasets accelerate scikit-learn pillow

## Step 3: Upload Dataset

Upload `AgriDetect.v1i.folder-2.zip` using the file upload button below:

In [None]:
from google.colab import files
import zipfile
import os

print("üì§ Upload your AgriDetect.v1i.folder-2.zip file:")
uploaded = files.upload()

# Extract the zip file
zip_filename = list(uploaded.keys())[0]
print(f"\nüì¶ Extracting {zip_filename}...")

with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall('AgriDetect_new_model')

print("‚úÖ Dataset extracted!")
print(f"\nDataset structure:")
!ls -la AgriDetect_new_model/

## Step 4: Training Script

In [None]:
from datasets import load_dataset
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer
)
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

print("=" * 60)
print("Plant Disease Classification - Colab Training")
print("=" * 60)

# Configuration
DATASET_DIR = "AgriDetect_new_model"
MODEL_NAME = "microsoft/resnet-50"
OUTPUT_DIR = "./plant-disease-model-v2"
NUM_EPOCHS = 10
BATCH_SIZE = 32  # Larger batch size for GPU
LEARNING_RATE = 2e-5

print(f"\nüìä Configuration:")
print(f"   Dataset: {DATASET_DIR} (99.7% accuracy)")
print(f"   Base Model: {MODEL_NAME}")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Learning Rate: {LEARNING_RATE}")

# Load dataset
print(f"\nüì• Loading dataset...")
dataset = load_dataset(
    "imagefolder",
    data_dir=DATASET_DIR,
    drop_labels=False
)

print(f"‚úÖ Dataset loaded!")
print(f"   Train: {len(dataset['train'])} images")
print(f"   Validation: {len(dataset['validation'])} images")
print(f"   Test: {len(dataset['test'])} images")

# Get labels
labels = dataset["train"].features["label"].names
num_labels = len(labels)
print(f"\nüè∑Ô∏è  Classes ({num_labels}):")
for i, label in enumerate(labels):
    print(f"   {i}: {label}")

# Load image processor
print(f"\nüîß Loading image processor...")
image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
print(f"‚úÖ Image processor loaded!")

# Preprocessing function
def preprocess_images(examples):
    """Preprocess images for the model"""
    images = [img.convert("RGB") for img in examples["image"]]
    inputs = image_processor(images, return_tensors="pt")
    inputs = {k: v.squeeze() if v.ndim > 1 and v.shape[0] == 1 else v for k, v in inputs.items()}
    inputs["labels"] = examples["label"]
    return inputs

# Apply preprocessing
print(f"\nüîÑ Preprocessing images...")
dataset = dataset.map(
    preprocess_images,
    batched=True,
    batch_size=32,
    remove_columns=dataset["train"].column_names
)
print(f"‚úÖ Preprocessing complete!")

# Load model
print(f"\nü§ñ Loading model...")
model = AutoModelForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    id2label={i: label for i, label in enumerate(labels)},
    label2id={label: i for i, label in enumerate(labels)},
    ignore_mismatched_sizes=True
)
print(f"‚úÖ Model loaded!")

# Metrics function
def compute_metrics(eval_pred):
    """Compute accuracy, precision, recall, F1"""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='weighted'
    )
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

# Training arguments
print(f"\n‚öôÔ∏è  Setting up training...")
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
    report_to="none",
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

print(f"‚úÖ Trainer ready!")

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nüíª Device: {device}")
if device == "cuda":
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
else:
    print("   ‚ö†Ô∏è  No GPU detected!")

print("\n" + "=" * 60)
print("üöÄ STARTING TRAINING")
print("=" * 60)
print("\nThis will take 30-60 minutes with GPU...\n")

## Step 5: Train the Model

In [None]:
# Train the model
train_results = trainer.train()

print("\n" + "=" * 60)
print("‚úÖ TRAINING COMPLETE!")
print("=" * 60)

## Step 6: Evaluate the Model

In [None]:
# Evaluate on validation set
print("\nüìä Evaluating on validation set...")
val_metrics = trainer.evaluate()
print(f"\nValidation Results:")
print(f"   Accuracy:  {val_metrics['eval_accuracy']:.4f} ({val_metrics['eval_accuracy']*100:.2f}%)")
print(f"   Precision: {val_metrics['eval_precision']:.4f}")
print(f"   Recall:    {val_metrics['eval_recall']:.4f}")
print(f"   F1 Score:  {val_metrics['eval_f1']:.4f}")

# Evaluate on test set
print("\nüìä Evaluating on test set...")
test_metrics = trainer.evaluate(dataset["test"])
print(f"\nTest Results:")
print(f"   Accuracy:  {test_metrics['eval_accuracy']:.4f} ({test_metrics['eval_accuracy']*100:.2f}%)")
print(f"   Precision: {test_metrics['eval_precision']:.4f}")
print(f"   Recall:    {test_metrics['eval_recall']:.4f}")
print(f"   F1 Score:  {test_metrics['eval_f1']:.4f}")

## Step 7: Save the Model

In [None]:
# Save model
print(f"\nüíæ Saving model to {OUTPUT_DIR}...")
trainer.save_model(OUTPUT_DIR)
image_processor.save_pretrained(OUTPUT_DIR)
print(f"‚úÖ Model saved!")

# Save metrics to file
with open(f"{OUTPUT_DIR}/metrics.txt", "w") as f:
    f.write("Plant Disease Classification - Training Results\n")
    f.write("=" * 60 + "\n\n")
    f.write(f"Dataset: {DATASET_DIR}\n")
    f.write(f"Model: {MODEL_NAME}\n")
    f.write(f"Epochs: {NUM_EPOCHS}\n\n")
    f.write("Validation Metrics:\n")
    f.write(f"  Accuracy:  {val_metrics['eval_accuracy']:.4f} ({val_metrics['eval_accuracy']*100:.2f}%)\n")
    f.write(f"  Precision: {val_metrics['eval_precision']:.4f}\n")
    f.write(f"  Recall:    {val_metrics['eval_recall']:.4f}\n")
    f.write(f"  F1 Score:  {val_metrics['eval_f1']:.4f}\n\n")
    f.write("Test Metrics:\n")
    f.write(f"  Accuracy:  {test_metrics['eval_accuracy']:.4f} ({test_metrics['eval_accuracy']*100:.2f}%)\n")
    f.write(f"  Precision: {test_metrics['eval_precision']:.4f}\n")
    f.write(f"  Recall:    {test_metrics['eval_recall']:.4f}\n")
    f.write(f"  F1 Score:  {test_metrics['eval_f1']:.4f}\n")

print("\n" + "=" * 60)
print("üéâ ALL DONE!")
print("=" * 60)
print(f"\nYour trained model is in: {OUTPUT_DIR}")
print(f"Metrics saved to: {OUTPUT_DIR}/metrics.txt")

## Step 8: Download the Trained Model

In [None]:
# Zip the model folder for download
import shutil

print("üì¶ Creating zip file for download...")
shutil.make_archive('plant-disease-model-v2', 'zip', OUTPUT_DIR)
print("‚úÖ Zip file created!")

# Download the model
print("\n‚¨áÔ∏è Downloading model...")
files.download('plant-disease-model-v2.zip')
print("‚úÖ Download started! Check your browser downloads.")

## Step 9: Test a Prediction (Optional)

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

# Load a test image
test_image_path = dataset["test"][0]["image"]  # Get first test image
image = Image.open(test_image_path).convert("RGB")

# Preprocess
inputs = image_processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    probabilities = torch.nn.functional.softmax(logits, dim=-1)

# Get prediction
predicted_class_idx = logits.argmax(-1).item()
confidence = probabilities[0][predicted_class_idx].item()
predicted_label = model.config.id2label[predicted_class_idx]

# Display
plt.figure(figsize=(8, 6))
plt.imshow(image)
plt.axis('off')
plt.title(f"Prediction: {predicted_label}\nConfidence: {confidence*100:.2f}%", fontsize=14)
plt.show()

print(f"\nüéØ Prediction: {predicted_label}")
print(f"üìä Confidence: {confidence*100:.2f}%")

---

## üéâ Training Complete!

Your model has been trained and is ready to use!

**Next Steps:**
1. Download the model zip file (already started above)
2. Extract it on your local machine
3. Run the Streamlit app: `streamlit run streamlit_app_local.py`
4. Test with your own plant images!

**Expected Results:**
- Accuracy: ~99.7%
- Much better than the previous 60-76%!

---