In [None]:
# ============================================
# CELL 1: Check GPU and Datasets
# ============================================
import os
import sys
from pathlib import Path

print("="*70)
print("üîç CHECKING KAGGLE ENVIRONMENT")
print("="*70)

# Check GPU
import torch
print(f"\nüéÆ GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è  NO GPU! Go to Settings ‚Üí Accelerator ‚Üí GPU")

# Check datasets
print("\nüì¶ Checking Datasets:")
kaggle_input = '/kaggle/input'
if os.path.exists(kaggle_input):
    datasets = os.listdir(kaggle_input)
    for ds in datasets:
        path = os.path.join(kaggle_input, ds)
        count = sum([len(f) for _, _, f in os.walk(path)])
        print(f"  ‚úÖ {ds}: {count:,} files")
    print(f"\nüìä Total datasets: {len(datasets)}")
    if len(datasets) < 3:
        print("\n‚ö†Ô∏è  Add more datasets! Click '+ Add Data' on the right")
else:
    print("‚ùå Not on Kaggle!")

print("\n" + "="*70)

In [None]:
# ============================================
# CELL 2: Install Packages & Setup
# ============================================
!pip install timm -q

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import json
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import timm

# ‚ö° Speed optimizations
torch.backends.cudnn.benchmark = True
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"‚úÖ Setup complete! Using: {device}")

In [None]:
# ============================================
# CELL 3: Find and Load Datasets
# ============================================
def find_image_classes(base_path, max_depth=5):
    """Find all image class directories"""
    image_extensions = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'}
    class_data = {}
    
    for root, dirs, files in os.walk(base_path):
        depth = root[len(base_path):].count(os.sep)
        if depth > max_depth:
            continue
            
        image_files = [f for f in files if Path(f).suffix in image_extensions]
        if image_files and len(image_files) > 50:
            class_name = Path(root).name
            if class_name not in class_data:
                class_data[class_name] = []
            class_data[class_name].extend([os.path.join(root, f) for f in image_files])
    
    return class_data

# Auto-detect datasets
all_classes = {}
print("="*70)
print("üìä ANALYZING DATASETS")
print("="*70)

for dataset_name in os.listdir('/kaggle/input'):
    path = f'/kaggle/input/{dataset_name}'
    print(f"\nüîç Analyzing {dataset_name}...")
    classes = find_image_classes(path)
    if classes:
        # Determine crop type from dataset name
        crop = 'other'
        for c in ['rice', 'cotton', 'wheat', 'mango', 'plant', 'village']:
            if c in dataset_name.lower():
                crop = c if c not in ['plant', 'village'] else 'plantvillage'
                break
        all_classes[crop] = classes
        print(f"   Found: {len(classes)} classes, {sum(len(v) for v in classes.values()):,} images")

total_images = sum(len(imgs) for crop in all_classes.values() for imgs in crop.values())
print(f"\nüìà TOTAL: {sum(len(c) for c in all_classes.values())} classes, {total_images:,} images")

In [None]:
# ============================================
# CELL 4: Create Dataset Class
# ============================================
class PlantDiseaseDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, self.labels[idx]

# Build unified dataset
all_image_paths = []
all_labels = []
class_names = []

current_idx = 0
for crop, classes in all_classes.items():
    for class_name, paths in classes.items():
        if len(paths) < 100:  # Skip tiny classes
            continue
        class_names.append(f"{crop}___{class_name}")
        all_image_paths.extend(paths)
        all_labels.extend([current_idx] * len(paths))
        current_idx += 1

num_classes = len(class_names)
print(f"‚úÖ Created dataset: {num_classes} classes, {len(all_image_paths):,} images")

In [None]:
# ============================================
# CELL 5: Data Loaders (Optimized)
# ============================================
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Create and split
full_dataset = PlantDiseaseDataset(all_image_paths, all_labels, train_transform)
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_ds, val_ds, test_ds = random_split(full_dataset, [train_size, val_size, test_size],
                                          generator=torch.Generator().manual_seed(42))

# ‚ö° Optimized loaders
batch_size = 64
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, 
                          num_workers=4, pin_memory=True, prefetch_factor=2)
val_loader = DataLoader(val_ds, batch_size=128, shuffle=False,
                        num_workers=4, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=128, shuffle=False,
                         num_workers=4, pin_memory=True)

print(f"üìä Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,}")
print(f"‚ö° Batch: {batch_size} | Workers: 4 | Pin Memory: True")

In [None]:
# ============================================
# CELL 6: Create Model
# ============================================
model = timm.create_model('efficientnet_b4', pretrained=True, num_classes=num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-6)

print(f"ü§ñ Model: EfficientNet-B4")
print(f"üìä Classes: {num_classes}")
print(f"‚öôÔ∏è  Optimizer: AdamW (lr=0.0001)")

In [None]:
# ============================================
# CELL 7: Training Loop (‚ö° Mixed Precision)
# ============================================
import time

epochs = 30  # Reduced for faster training
best_val_acc = 0.0
history = {'train_loss': [], 'val_acc': []}

print("="*70)
print("üöÄ STARTING TRAINING (Mixed Precision + cuDNN)")
print("="*70)

for epoch in range(epochs):
    start = time.time()
    
    # Train
    model.train()
    train_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()
    
    # Validate
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            with autocast():
                outputs = model(images)
            _, pred = outputs.max(1)
            correct += pred.eq(labels).sum().item()
            total += labels.size(0)
    
    val_acc = 100. * correct / total
    history['train_loss'].append(train_loss / len(train_loader))
    history['val_acc'].append(val_acc)
    
    # Save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'model_state_dict': model.state_dict(),
            'class_names': class_names,
            'num_classes': num_classes,
            'val_acc': val_acc
        }, 'pakistan_model_best.pth')
    
    scheduler.step()
    elapsed = time.time() - start
    eta = (epochs - epoch - 1) * elapsed / 60
    
    print(f"  Loss: {train_loss/len(train_loader):.4f} | Val Acc: {val_acc:.2f}% | "
          f"Best: {best_val_acc:.2f}% | Time: {elapsed:.0f}s | ETA: {eta:.1f}min")

print(f"\n‚úÖ Training complete! Best accuracy: {best_val_acc:.2f}%")

In [None]:
# ============================================
# CELL 8: Test Evaluation
# ============================================
checkpoint = torch.load('pakistan_model_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

correct, total = 0, 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, pred = outputs.max(1)
        correct += pred.eq(labels).sum().item()
        total += labels.size(0)

test_acc = 100. * correct / total
print(f"\nüéØ TEST ACCURACY: {test_acc:.2f}%")
print(f"‚úÖ {correct:,} / {total:,} correct")

In [None]:
# ============================================
# CELL 9: Save Final Model
# ============================================
# Save metadata
model_info = {
    'class_names': class_names,
    'num_classes': num_classes,
    'test_accuracy': test_acc,
    'best_val_accuracy': best_val_acc,
    'model_architecture': 'efficientnet_b4'
}

with open('class_names.json', 'w') as f:
    json.dump(class_names, f, indent=2)

with open('model_info.json', 'w') as f:
    json.dump(model_info, f, indent=2)

print("üíæ Files saved to /kaggle/working/:")
print("  ‚úì pakistan_model_best.pth")
print("  ‚úì class_names.json")
print("  ‚úì model_info.json")
print("\nüì• Download from 'Output' tab on the right ‚Üí")

In [None]:
# ============================================
# CELL 10: Plot Training History
# ============================================
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(history['train_loss'])
ax1.set_title('Training Loss')
ax1.set_xlabel('Epoch')
ax1.grid(True, alpha=0.3)

ax2.plot(history['val_acc'])
ax2.axhline(y=test_acc, color='r', linestyle='--', label=f'Test: {test_acc:.1f}%')
ax2.set_title('Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150)
plt.show()

print("\nüéâ DONE! Download your model from the Output tab.")