# 🌱 Cross-Attention Plant Disease Classification

This notebook implements a Cross-Attention Fusion method for multimodal plant disease classification using images and text descriptions.

## 📋 Features
- **Multimodal Architecture**: Combines image and text features using cross-attention
- **Configurable Class Subsets**: Run experiments on specific plant diseases
- **Hyperparameter Tuning**: Systematic optimization of model parameters
- **Multiple Fusion Strategies**: Compare different multimodal fusion approaches

## 🚀 Quick Start
1. Run the setup cell to install dependencies
2. Upload or download your dataset
3. Configure your experiment
4. Train the model
5. Analyze results

## 📦 Setup and Installation

In [None]:
# Install required packages
!pip install -q transformers torch torchvision torchaudio
!pip install -q gdown pyyaml matplotlib seaborn plotly
!pip install -q scikit-learn opencv-python albumentations
!pip install -q dataclasses-json tqdm

print("✅ All packages installed successfully!")

In [None]:
# Upload your project files to Colab
from google.colab import files
import os

print("📁 Upload your project files:")
print("Required files:")
print("  - cross_attention_model.py")
print("  - config_manager.py")
print("  - dataset_loader.py")
print("  - trainer.py")
print("  - colab_dataset_setup.py")
print("  - plantwild_prompts.json")
print("\n⚠️ Make sure to upload ALL files before proceeding!")

# Uncomment the next line to upload files interactively
# uploaded = files.upload()

In [None]:
# Alternative: Clone from GitHub repository
# Replace 'your-username/your-repo' with your actual repository

# !git clone https://github.com/your-username/your-repo.git
# %cd your-repo

# Or download specific files from GitHub raw URLs
import urllib.request

files_to_download = [
    # Replace with your actual raw GitHub URLs
    # ('cross_attention_model.py', 'https://raw.githubusercontent.com/user/repo/main/cross_attention_model.py'),
    # ('config_manager.py', 'https://raw.githubusercontent.com/user/repo/main/config_manager.py'),
    # Add other files...
]

# for filename, url in files_to_download:
#     print(f"Downloading {filename}...")
#     urllib.request.urlretrieve(url, filename)

print("✅ Project files ready!")

## 📊 Dataset Setup

Choose one of the following methods to load your dataset:

In [None]:
# Import dataset setup utilities
from colab_dataset_setup import (
    setup_colab_environment,
    load_from_google_drive,
    load_from_url,
    create_sample_for_testing,
    ColabDatasetManager
)

# Set up the environment
dataset_manager = setup_colab_environment()

In [None]:
# Method 1: Load from Google Drive
# 1. Upload your dataset to Google Drive
# 2. Get the file ID from the sharing link
# 3. Replace 'YOUR_FILE_ID' with the actual ID

# GOOGLE_DRIVE_FILE_ID = "YOUR_FILE_ID"  # Replace with your file ID
# dataset_manager = load_from_google_drive(GOOGLE_DRIVE_FILE_ID)

print("🔗 To use Google Drive:")
print("1. Upload dataset.zip to Google Drive")
print("2. Right-click → Get shareable link")
print("3. Copy the file ID from the URL")
print("4. Replace YOUR_FILE_ID above and uncomment the lines")

In [None]:
# Method 2: Load from direct URL
# If you have a direct download link to your dataset

# DATASET_URL = "https://example.com/dataset.zip"  # Replace with your URL
# dataset_manager = load_from_url(DATASET_URL)

print("🌐 To use direct URL:")
print("1. Host your dataset.zip on a service with direct download")
print("2. Get the direct download URL")
print("3. Replace the URL above and uncomment the lines")

In [None]:
# Method 3: Manual upload (for smaller datasets)
from google.colab import files
import zipfile

def manual_dataset_upload():
    print("📤 Upload your dataset.zip file:")
    uploaded = files.upload()
    
    for filename in uploaded.keys():
        if filename.endswith('.zip'):
            print(f"📦 Extracting {filename}...")
            with zipfile.ZipFile(filename, 'r') as zip_ref:
                zip_ref.extractall("/content/plant_disease_data/")
            
            # Remove the zip file after extraction
            os.remove(filename)
            print("✅ Dataset extracted successfully!")
            break
    
    # Verify the dataset
    dataset_manager.print_dataset_info()
    return dataset_manager

# Uncomment the next line to upload manually
# dataset_manager = manual_dataset_upload()

In [None]:
# Method 4: Create sample dataset for testing
# This creates a small dataset for testing your code

dataset_manager = create_sample_for_testing()

print("\n⚠️ NOTE: This is a sample dataset for testing only!")
print("Replace with your actual dataset for real experiments.")

In [None]:
# Verify dataset is properly loaded
is_valid, info = dataset_manager.verify_dataset()
dataset_manager.print_dataset_info()

if is_valid:
    print("\n✅ Dataset is ready for training!")
else:
    print("\n❌ Dataset setup incomplete. Please check the steps above.")

## ⚙️ Configuration

In [None]:
# Import configuration modules
from config_manager import ModelConfig, ConfigManager, ClassSubsetManager
import yaml

# Create configuration for your experiment
config = ModelConfig(
    # Dataset configuration
    dataset_path="/content/plant_disease_data",  # Colab path
    selected_classes=[],  # Empty = use all classes, or specify subset
    
    # Model architecture
    image_backbone="resnet50",
    text_encoder="bert-base-uncased",
    feature_dim=768,
    fusion_type="adaptive",  # concat, adaptive, bilinear, attention
    num_cross_attention_layers=4,
    num_attention_heads=12,
    
    # Training parameters
    batch_size=32,
    learning_rate=1e-4,
    num_epochs=30,  # Reduce for faster testing
    
    # Colab-specific paths
    output_dir="/content/outputs",
    checkpoint_dir="/content/checkpoints",
    
    # Enable mixed precision for faster training on Colab
    mixed_precision=True
)

# Save configuration
config_manager = ConfigManager()
config_manager.save_config(config, "colab_config.yaml")

print("✅ Configuration created and saved to colab_config.yaml")
print(f"📊 Dataset path: {config.dataset_path}")
print(f"🏗️ Model: {config.image_backbone} + {config.text_encoder}")
print(f"🔄 Fusion: {config.fusion_type}")
print(f"⏱️ Epochs: {config.num_epochs}")

In [None]:
# Configure class subsets (optional)
class_manager = ClassSubsetManager()

# Option 1: Use predefined subsets
available_subsets = class_manager.get_predefined_subsets()
print("📋 Available predefined subsets:")
for name, classes in available_subsets.items():
    print(f"  {name}: {len(classes)} classes")

# Option 2: Create custom subset
# Example: Apple diseases only
apple_diseases = class_manager.create_subset_by_category(['apple'])
print(f"\n🍎 Apple diseases: {len(apple_diseases)} classes")
print(f"Classes: {apple_diseases[:3]}...")  # Show first 3

# Option 3: Random subset for quick testing
quick_test_classes = class_manager.create_random_subset(5)
print(f"\n🎲 Quick test subset: {len(quick_test_classes)} classes")
print(f"Classes: {quick_test_classes}")

# Update config with selected subset (uncomment to use)
# config.selected_classes = quick_test_classes
# config.num_classes = len(quick_test_classes)
# print(f"\n✅ Updated config to use {len(quick_test_classes)} classes")

## 🚀 Training

In [None]:
# Quick test with minimal configuration
from trainer import Trainer

# Create a quick test configuration
quick_config = ModelConfig(
    dataset_path="/content/plant_disease_data",
    selected_classes=quick_test_classes[:3],  # Use only 3 classes
    num_classes=3,
    batch_size=16,  # Smaller batch for Colab memory
    learning_rate=1e-4,
    num_epochs=5,   # Just 5 epochs for quick test
    eval_frequency=2,
    output_dir="/content/quick_test_outputs",
    checkpoint_dir="/content/quick_test_checkpoints",
    mixed_precision=True
)

print("🧪 Running quick test with 3 classes and 5 epochs...")
print(f"Classes: {quick_config.selected_classes}")

# Create and run trainer
trainer = Trainer(quick_config)
test_results = trainer.train()

print("\n✅ Quick test completed!")
print(f"Final test accuracy: {test_results['accuracy']:.4f}")
print(f"Final test F1-score: {test_results['macro_f1']:.4f}")

In [None]:
# Full training with your configuration
from trainer import Trainer

print("🚀 Starting full training...")
print(f"Configuration: {config.num_classes} classes, {config.num_epochs} epochs")
print(f"Model: {config.image_backbone} + {config.fusion_type} fusion")

# Create trainer
trainer = Trainer(config)

# Start training
try:
    test_results = trainer.train()
    
    print("\n🎉 Training completed successfully!")
    print("Final Results:")
    print(f"  Accuracy: {test_results['accuracy']:.4f}")
    print(f"  Precision: {test_results['macro_precision']:.4f}")
    print(f"  Recall: {test_results['macro_recall']:.4f}")
    print(f"  F1-Score: {test_results['macro_f1']:.4f}")
    
except Exception as e:
    print(f"❌ Training failed: {e}")
    print("Try reducing batch_size or num_epochs if running out of memory")

## 🧪 Experiments

In [None]:
# Compare different fusion methods
fusion_methods = ['concat', 'adaptive', 'bilinear', 'attention']
fusion_results = []

# Use a small subset for quick comparison
test_classes = quick_test_classes[:3]

print("🔄 Comparing fusion methods...")
for fusion_type in fusion_methods:
    print(f"\n--- Testing {fusion_type} fusion ---")
    
    # Create configuration for this fusion method
    fusion_config = ModelConfig(
        dataset_path="/content/plant_disease_data",
        selected_classes=test_classes,
        num_classes=len(test_classes),
        fusion_type=fusion_type,
        batch_size=16,
        num_epochs=10,
        output_dir=f"/content/fusion_{fusion_type}",
        checkpoint_dir=f"/content/fusion_{fusion_type}_checkpoints",
        mixed_precision=True
    )
    
    try:
        trainer = Trainer(fusion_config)
        results = trainer.train()
        
        fusion_results.append({
            'fusion_type': fusion_type,
            'accuracy': results['accuracy'],
            'f1_score': results['macro_f1']
        })
        
        print(f"✅ {fusion_type}: Acc={results['accuracy']:.4f}, F1={results['macro_f1']:.4f}")
        
    except Exception as e:
        print(f"❌ {fusion_type} failed: {e}")
        fusion_results.append({
            'fusion_type': fusion_type,
            'accuracy': 0.0,
            'f1_score': 0.0
        })

# Display comparison results
print("\n📊 Fusion Method Comparison:")
print("=" * 50)
for result in fusion_results:
    print(f"{result['fusion_type']:10s}: Acc={result['accuracy']:.4f}, F1={result['f1_score']:.4f}")

# Find best fusion method
best_fusion = max(fusion_results, key=lambda x: x['accuracy'])
print(f"\n🏆 Best fusion method: {best_fusion['fusion_type']} (Acc={best_fusion['accuracy']:.4f})")

In [None]:
# Simple hyperparameter search
import random

# Define hyperparameter ranges
hp_ranges = {
    'learning_rate': [1e-5, 5e-5, 1e-4, 5e-4],
    'batch_size': [16, 32],
    'feature_dim': [512, 768],
    'num_attention_heads': [8, 12]
}

num_hp_experiments = 3  # Run 3 hyperparameter configurations
hp_results = []

print(f"🔍 Running {num_hp_experiments} hyperparameter experiments...")

for i in range(num_hp_experiments):
    print(f"\n--- HP Experiment {i+1}/{num_hp_experiments} ---")
    
    # Random hyperparameter selection
    hp_config = {
        'learning_rate': random.choice(hp_ranges['learning_rate']),
        'batch_size': random.choice(hp_ranges['batch_size']),
        'feature_dim': random.choice(hp_ranges['feature_dim']),
        'num_attention_heads': random.choice(hp_ranges['num_attention_heads'])
    }
    
    print(f"Hyperparameters: {hp_config}")
    
    # Create configuration
    exp_config = ModelConfig(
        dataset_path="/content/plant_disease_data",
        selected_classes=test_classes,
        num_classes=len(test_classes),
        **hp_config,
        num_epochs=8,  # Short epochs for HP search
        fusion_type=best_fusion['fusion_type'],  # Use best fusion from previous experiment
        output_dir=f"/content/hp_exp_{i}",
        checkpoint_dir=f"/content/hp_exp_{i}_checkpoints",
        mixed_precision=True
    )
    
    try:
        trainer = Trainer(exp_config)
        results = trainer.train()
        
        hp_results.append({
            'config': hp_config,
            'accuracy': results['accuracy'],
            'f1_score': results['macro_f1']
        })
        
        print(f"✅ Acc={results['accuracy']:.4f}, F1={results['macro_f1']:.4f}")
        
    except Exception as e:
        print(f"❌ Failed: {e}")
        hp_results.append({
            'config': hp_config,
            'accuracy': 0.0,
            'f1_score': 0.0
        })

# Display HP search results
print("\n🎯 Hyperparameter Search Results:")
print("=" * 70)
for i, result in enumerate(hp_results):
    print(f"Experiment {i+1}:")
    print(f"  Config: {result['config']}")
    print(f"  Results: Acc={result['accuracy']:.4f}, F1={result['f1_score']:.4f}")
    print()

# Find best hyperparameters
best_hp = max(hp_results, key=lambda x: x['accuracy'])
print(f"🏆 Best hyperparameters: {best_hp['config']}")
print(f"   Best accuracy: {best_hp['accuracy']:.4f}")

## 📈 Results and Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Visualize fusion method comparison
if fusion_results:
    methods = [r['fusion_type'] for r in fusion_results]
    accuracies = [r['accuracy'] for r in fusion_results]
    f1_scores = [r['f1_score'] for r in fusion_results]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Accuracy comparison
    ax1.bar(methods, accuracies, color='skyblue', alpha=0.7)
    ax1.set_title('Fusion Methods - Accuracy Comparison')
    ax1.set_ylabel('Accuracy')
    ax1.set_ylim(0, 1)
    for i, v in enumerate(accuracies):
        ax1.text(i, v + 0.01, f'{v:.3f}', ha='center')
    
    # F1-score comparison
    ax2.bar(methods, f1_scores, color='lightcoral', alpha=0.7)
    ax2.set_title('Fusion Methods - F1-Score Comparison')
    ax2.set_ylabel('F1-Score')
    ax2.set_ylim(0, 1)
    for i, v in enumerate(f1_scores):
        ax2.text(i, v + 0.01, f'{v:.3f}', ha='center')
    
    plt.tight_layout()
    plt.savefig('/content/fusion_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

# Visualize hyperparameter search results
if hp_results:
    learning_rates = [r['config']['learning_rate'] for r in hp_results]
    accuracies = [r['accuracy'] for r in hp_results]
    
    plt.figure(figsize=(8, 6))
    plt.scatter(learning_rates, accuracies, s=100, alpha=0.7, c='green')
    plt.xscale('log')
    plt.xlabel('Learning Rate')
    plt.ylabel('Accuracy')
    plt.title('Hyperparameter Search: Learning Rate vs Accuracy')
    plt.grid(True, alpha=0.3)
    
    for i, (lr, acc) in enumerate(zip(learning_rates, accuracies)):
        plt.annotate(f'{acc:.3f}', (lr, acc), xytext=(5, 5), textcoords='offset points')
    
    plt.savefig('/content/hyperparameter_search.png', dpi=300, bbox_inches='tight')
    plt.show()

print("📊 Visualizations saved to /content/")

In [None]:
# Download results and model checkpoints
from google.colab import files
import zipfile
import os

def create_results_zip():
    """Create a zip file with all results and models"""
    
    zip_filename = "cross_attention_results.zip"
    
    with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
        # Add configuration files
        if os.path.exists("colab_config.yaml"):
            zipf.write("colab_config.yaml")
        
        # Add result visualizations
        for img_file in ["/content/fusion_comparison.png", "/content/hyperparameter_search.png"]:
            if os.path.exists(img_file):
                zipf.write(img_file, os.path.basename(img_file))
        
        # Add model checkpoints (if they exist)
        checkpoint_dirs = ["/content/checkpoints", "/content/quick_test_checkpoints"]
        for checkpoint_dir in checkpoint_dirs:
            if os.path.exists(checkpoint_dir):
                for root, dirs, files in os.walk(checkpoint_dir):
                    for file in files:
                        file_path = os.path.join(root, file)
                        arc_path = os.path.relpath(file_path, "/content")
                        zipf.write(file_path, arc_path)
        
        # Add output logs
        output_dirs = ["/content/outputs", "/content/quick_test_outputs"]
        for output_dir in output_dirs:
            if os.path.exists(output_dir):
                for root, dirs, files in os.walk(output_dir):
                    for file in files:
                        if file.endswith(('.json', '.txt', '.csv', '.png')):
                            file_path = os.path.join(root, file)
                            arc_path = os.path.relpath(file_path, "/content")
                            zipf.write(file_path, arc_path)
    
    return zip_filename

# Create and download results
print("📦 Creating results package...")
results_zip = create_results_zip()

if os.path.exists(results_zip):
    print(f"✅ Results package created: {results_zip}")
    print(f"📥 Downloading {results_zip}...")
    files.download(results_zip)
else:
    print("❌ No results to download")

# Also create a summary report
summary_report = "experiment_summary.txt"
with open(summary_report, 'w') as f:
    f.write("Cross-Attention Plant Disease Classification - Experiment Summary\n")
    f.write("=" * 70 + "\n\n")
    
    f.write(f"Dataset: {config.dataset_path}\n")
    f.write(f"Classes: {config.num_classes}\n")
    f.write(f"Selected Classes: {config.selected_classes}\n\n")
    
    if fusion_results:
        f.write("Fusion Method Comparison:\n")
        for result in fusion_results:
            f.write(f"  {result['fusion_type']:10s}: Acc={result['accuracy']:.4f}, F1={result['f1_score']:.4f}\n")
        f.write(f"  Best: {best_fusion['fusion_type']} (Acc={best_fusion['accuracy']:.4f})\n\n")
    
    if hp_results:
        f.write("Hyperparameter Search Results:\n")
        for i, result in enumerate(hp_results):
            f.write(f"  Config {i+1}: {result['config']}\n")
            f.write(f"    Results: Acc={result['accuracy']:.4f}, F1={result['f1_score']:.4f}\n")
        f.write(f"  Best: {best_hp['config']} (Acc={best_hp['accuracy']:.4f})\n")

files.download(summary_report)
print(f"📄 Downloaded summary report: {summary_report}")

## 🎯 Next Steps

1. **Scale Up**: Use your full dataset with all classes
2. **Extended Training**: Increase epochs for better convergence
3. **Advanced Hyperparameter Search**: Use more configurations
4. **Model Ensemble**: Combine multiple fusion methods
5. **Attention Visualization**: Enable attention map saving
6. **Comparison with MVPDR**: Implement baseline comparison

## 📝 Tips for Google Colab

- **GPU Runtime**: Make sure you're using GPU runtime (Runtime → Change runtime type → GPU)
- **Memory Management**: Reduce batch_size if you get OOM errors
- **Session Limits**: Save checkpoints frequently; Colab sessions timeout after 12 hours
- **Data Persistence**: Mount Google Drive to save results permanently
- **Mixed Precision**: Always use mixed_precision=True for faster training

## 🔗 Useful Commands

```python
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")

# Monitor GPU memory
!nvidia-smi

# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')
```