In [None]:

import os
import random
import warnings
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler, autocast
from torchvision import datasets, transforms
import timm

try:
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    import cv2
    USE_ALBUMENTATIONS = True
except ImportError:
    print("  Albumentations not found. Using basic transforms.")
    USE_ALBUMENTATIONS = False

# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'

print(" Plant Disease Detection Starting...")
print(f" Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
print(f" PyTorch version: {torch.__version__}")

# ========================== CONFIGURATION ==========================
class Config:
    # Dataset paths - CHANGE THESE TO MATCH YOUR DATA
    DATA_PATHS = [
        '/home/siham/Bureau/xavier/PlantVillage',
        '/home/siham/Bureau/xavier/PlantVillage/color',
        './plantvillage',
        './plantvillage/color', 
        './PlantVillage',
        '/content/plantvillage',
        '/content/plantvillage/color',
        './data',
        './dataset'
    ]
    
    # Training settings
    BATCH_SIZE = 8  # Small batch for compatibility
    IMG_SIZE = 224  # Standard size
    NUM_EPOCHS = 5  # Quick demo
    BASE_LR = 1e-4
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    NUM_WORKERS = 2
    SEED = 42

config = Config()

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_everything(config.SEED)

# ========================== FIND DATASET ==========================
def find_dataset():
    """Find the dataset in common locations"""
    print("\n🔍 Looking for dataset...")
    
    for path in config.DATA_PATHS:
        if os.path.exists(path):
            # Check if it has class subdirectories
            try:
                subdirs = [d for d in os.listdir(path) 
                          if os.path.isdir(os.path.join(path, d)) and not d.startswith('.')]
                
                if len(subdirs) >= 2:  # At least 2 classes
                    # Count images in subdirs
                    total_images = 0
                    for subdir in subdirs[:3]:  # Check first 3 subdirs
                        subdir_path = os.path.join(path, subdir)
                        images = [f for f in os.listdir(subdir_path) 
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                        total_images += len(images)
                    
                    if total_images > 10:  # Found a valid dataset
                        print(f"✅ Found dataset at: {path}")
                        print(f"   Classes found: {len(subdirs)}")
                        print(f"   Sample classes: {subdirs[:5]}")
                        return path
            except:
                continue
    
    print("❌ No dataset found! Please:")
    print("1. Download PlantVillage dataset")
    print("2. Extract it to one of these locations:")
    for path in config.DATA_PATHS:
        print(f"   - {path}")
    print("3. Make sure it has subfolders for each disease class")
    return None

# ========================== TRANSFORMS ==========================
def get_transforms():
    """Get data transforms - Albumentations if available, else torchvision"""
    if USE_ALBUMENTATIONS:
        train_transform = A.Compose([
            A.Resize(config.IMG_SIZE, config.IMG_SIZE),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.3),
            A.RandomRotate90(p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        
        val_transform = A.Compose([
            A.Resize(config.IMG_SIZE, config.IMG_SIZE),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        return train_transform, val_transform
    else:
        train_transform = transforms.Compose([
            transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomVerticalFlip(0.3),
            transforms.RandomRotation(30),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        val_transform = transforms.Compose([
            transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        return train_transform, val_transform

# ========================== CUSTOM DATASET ==========================
class PlantDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, use_albumentations=False):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.use_albumentations = use_albumentations
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        if self.use_albumentations and USE_ALBUMENTATIONS:
            image = cv2.imread(image_path)
            if image is None:
                # Fallback - create dummy image
                image = np.zeros((224, 224, 3), dtype=np.uint8)
            else:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            if self.transform:
                transformed = self.transform(image=image)
                image = transformed['image']
        else:
            from PIL import Image
            try:
                image = Image.open(image_path).convert('RGB')
            except:
                image = Image.new('RGB', (224, 224), color='black')
            
            if self.transform:
                image = self.transform(image)
        
        return image, label

# ========================== MODEL ==========================
def create_model(num_classes):
    """Create EfficientNet model"""
    print(f"🏗️  Creating EfficientNet-B0 model for {num_classes} classes...")
    model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=num_classes)
    return model

# ========================== TRAINING FUNCTIONS ==========================
def train_epoch(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    print(f"\n Training Epoch {epoch}...")
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        if batch_idx % 10 == 0:
            print(f'  Batch {batch_idx:3d}/{len(train_loader)} | '
                  f'Loss: {loss.item():.4f} | '
                  f'Acc: {100.*correct/total:.1f}%')
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct / total
    print(f"📊 Train Results: Loss={epoch_loss:.4f}, Accuracy={epoch_acc:.4f}")
    return epoch_loss, epoch_acc

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    print(" Validating...")
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = correct / total
    print(f"📊 Val Results: Loss={epoch_loss:.4f}, Accuracy={epoch_acc:.4f}")
    return epoch_loss, epoch_acc, all_preds, all_labels

# ========================== PLOTTING FUNCTIONS ==========================
def plot_training_curves(history):
    """Plot training and validation curves"""
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss', color='blue')
    plt.plot(history['val_loss'], label='Val Loss', color='red')
    plt.title('Training & Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Accuracy', color='blue')
    plt.plot(history['val_acc'], label='Val Accuracy', color='red')
    plt.title('Training & Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("📈 Training curves saved as 'training_curves.png'")

def plot_confusion_matrix(y_true, y_pred, class_names):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    
    # Add text annotations
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], 'd'),
                horizontalalignment="center",
                color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("📊 Confusion matrix saved as 'confusion_matrix.png'")

# ========================== MAIN FUNCTION ==========================
def main():
    """Main training pipeline"""
    print("=" * 60)
    print("🌱 PLANT DISEASE DETECTION - DEEP LEARNING PIPELINE")
    print("=" * 60)
    
    # Find dataset
    dataset_path = find_dataset()
    if dataset_path is None:
        print("\n Cannot proceed without dataset. Please check the instructions above.")
        return
    
    # Load dataset
    print(f"\n Loading dataset from: {dataset_path}")
    try:
        full_dataset = datasets.ImageFolder(dataset_path)
        class_names = full_dataset.classes
        num_classes = len(class_names)
        
        print(f" Dataset loaded successfully!")
        print(f"    Total images: {len(full_dataset)}")
        print(f"     Number of classes: {num_classes}")
        print(f"    Classes: {class_names}")
        
        # Show class distribution
        targets = [sample[1] for sample in full_dataset.samples]
        class_counts = np.bincount(targets)
        print(f"\n📈 Class Distribution:")
        for i, (name, count) in enumerate(zip(class_names, class_counts)):
            print(f"   {name}: {count} images")
        
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return
    
    # Create train/validation split
    print(f"\n✂️  Creating train/validation split (80/20)...")
    all_image_paths = [sample[0] for sample in full_dataset.samples]
    all_labels = [sample[1] for sample in full_dataset.samples]
    
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        all_image_paths, all_labels, test_size=0.2, 
        stratify=all_labels, random_state=config.SEED
    )
    
    print(f"    Training images: {len(train_paths)}")
    print(f"    Validation images: {len(val_paths)}")
    
    # Get transforms
    train_transform, val_transform = get_transforms()
    print(f"    Using {'Albumentations' if USE_ALBUMENTATIONS else 'Torchvision'} transforms")
    
    # Create datasets
    train_dataset = PlantDataset(train_paths, train_labels, train_transform, USE_ALBUMENTATIONS)
    val_dataset = PlantDataset(val_paths, val_labels, val_transform, USE_ALBUMENTATIONS)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=config.BATCH_SIZE, 
        shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config.BATCH_SIZE,
        shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True
    )
    
    print(f"    Batch size: {config.BATCH_SIZE}")
    print(f"    Training batches: {len(train_loader)}")
    print(f"   Validation batches: {len(val_loader)}")
    
    # Create model
    model = create_model(num_classes).to(config.DEVICE)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.BASE_LR, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.NUM_EPOCHS)
    
    print(f"    Loss function: CrossEntropyLoss")
    print(f"    Optimizer: AdamW (lr={config.BASE_LR})")
    
    # Training loop
    print(f"\n Starting training for {config.NUM_EPOCHS} epochs...")
    print("=" * 60)
    
    history = defaultdict(list)
    best_val_acc = 0.0
    
    for epoch in range(1, config.NUM_EPOCHS + 1):
        print(f"\n🔄 EPOCH {epoch}/{config.NUM_EPOCHS}")
        print("-" * 40)
        
        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, config.DEVICE, epoch
        )
        
        # Validate
        val_loss, val_acc, val_preds, val_labels = validate_epoch(
            model, val_loader, criterion, config.DEVICE
        )
        
        # Update learning rate
        scheduler.step()
        
        # Save metrics
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'class_names': class_names,
                'num_classes': num_classes,
                'val_acc': val_acc,
                'epoch': epoch
            }, 'best_model.pth')
            print(f" New best model saved! Accuracy: {val_acc:.4f}")
        
        print(f" Epoch Summary: Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Best: {best_val_acc:.4f}")
    
    # Final results
    print("\n" + "=" * 60)
    print(" TRAINING COMPLETED!")
    print("=" * 60)
    print(f" Best Validation Accuracy: {best_val_acc:.4f} ({best_val_acc*100:.1f}%)")
    print(f" Best model saved as: best_model.pth")
    
    # Generate visualizations
    print(f"\n Generating visualizations...")
    
    # Plot training curves
    plot_training_curves(history)
    
    # Plot confusion matrix
    plot_confusion_matrix(val_labels, val_preds, class_names)
    
    # Classification report
    print(f"\n Classification Report:")
    print(classification_report(val_labels, val_preds, target_names=class_names))
    
    print(f"\n All done! Check the generated plots and saved model.")
    print(f" Files generated:")
    print(f"   - best_model.pth (trained model)")
    print(f"   - training_curves.png (loss & accuracy plots)")
    print(f"   - confusion_matrix.png (confusion matrix)")

# ========================== RUN THE PROGRAM ==========================

main()

 Plant Disease Detection Starting...
 Device: CPU
 PyTorch version: 2.8.0+cpu
🌱 PLANT DISEASE DETECTION - DEEP LEARNING PIPELINE

🔍 Looking for dataset...
✅ Found dataset at: /home/siham/Bureau/xavier/PlantVillage
   Classes found: 15
   Sample classes: ['Tomato_Leaf_Mold', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato_Early_blight', 'Pepper__bell___Bacterial_spot', 'Tomato__Target_Spot']

 Loading dataset from: /home/siham/Bureau/xavier/PlantVillage
 Dataset loaded successfully!
    Total images: 20638
     Number of classes: 15
    Classes: ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']

📈 Class Distribution:
   Pepper__bell__