In [None]:
# ==============================
# CCROP Cannabis Leaf Disease AI - Enhanced Professional Version
# ==============================

# 1. Install dependencies
!apt-get install -y unzip
!pip install torch torchvision torchaudio matplotlib pandas scikit-learn opencv-python kaggle seaborn pillow --quiet

# ------------------------------
# 2. Import libraries
# ------------------------------
import os
import zipfile
import json
import time
from datetime import datetime
from pathlib import Path

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
import numpy as np
import cv2
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, precision_recall_fscore_support
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from PIL import Image

# ------------------------------
# 3. Configuration Class
# ------------------------------
class CCROPConfig:
    """Centralized configuration for CCROP pipeline"""

    # Kaggle credentials (optional - will use uploaded kaggle.json if available)
    KAGGLE_USERNAME = None  # Leave as None to use kaggle.json
    KAGGLE_KEY = None  # Leave as None to use kaggle.json

    # Dataset settings
    DATASET_MAIN = "engineeringubu/leaf-manifestation-diseases-of-cannabis"
    DATASET_FALLBACK = "vipoooool/new-plant-diseases-dataset"
    ROOT_DIR = "./dataset"

    # Model settings
    MODEL_ARCH = "resnet18"  # Options: resnet18, resnet34, resnet50, efficientnet_b0
    INPUT_SIZE = 224
    NUM_WORKERS = 2

    # Training hyperparameters
    BATCH_SIZE = 16
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    EPOCHS = 15
    EARLY_STOPPING_PATIENCE = 5

    # Data split
    TRAIN_SPLIT = 0.8
    VAL_SPLIT = 0.1
    TEST_SPLIT = 0.1

    # Paths
    CHECKPOINT_DIR = "./checkpoints"
    LOGS_DIR = "./logs"
    RESULTS_DIR = "./results"

    @classmethod
    def setup_directories(cls):
        """Create necessary directories"""
        for dir_path in [cls.ROOT_DIR, cls.CHECKPOINT_DIR, cls.LOGS_DIR, cls.RESULTS_DIR]:
            os.makedirs(dir_path, exist_ok=True)

# Initialize configuration
config = CCROPConfig()
config.setup_directories()

# ------------------------------
# 4. Kaggle Authentication with kaggle.json Support
# ------------------------------
def setup_kaggle_credentials(config):
    """Setup Kaggle credentials from kaggle.json or config"""

    # Check for uploaded kaggle.json file first
    kaggle_json_path = Path("kaggle.json")
    kaggle_dir = Path.home() / ".kaggle"
    kaggle_config_path = kaggle_dir / "kaggle.json"

    # Priority 1: kaggle.json in current directory
    if kaggle_json_path.exists():
        print("✓ Found kaggle.json in current directory")

        # Create .kaggle directory if it doesn't exist
        kaggle_dir.mkdir(parents=True, exist_ok=True)

        # Copy to ~/.kaggle/kaggle.json
        import shutil
        shutil.copy(kaggle_json_path, kaggle_config_path)

        # Set proper permissions (required by Kaggle API)
        kaggle_config_path.chmod(0o600)

        print(f"✓ Kaggle credentials installed to {kaggle_config_path}")
        return True

    # Priority 2: kaggle.json already in ~/.kaggle/
    elif kaggle_config_path.exists():
        print(f"✓ Using existing kaggle.json from {kaggle_config_path}")
        return True

    # Priority 3: Use credentials from config
    elif config.KAGGLE_USERNAME and config.KAGGLE_KEY:
        print("ℹ️  Using credentials from CCROPConfig")
        os.environ['KAGGLE_USERNAME'] = config.KAGGLE_USERNAME
        os.environ['KAGGLE_KEY'] = config.KAGGLE_KEY
        return True

    # No credentials found
    else:
        print("\n" + "="*70)
        print("⚠️  KAGGLE AUTHENTICATION REQUIRED")
        print("="*70)
        print("\nPlease provide Kaggle credentials using ONE of these methods:\n")
        print("METHOD 1 (Recommended): Upload kaggle.json file")
        print("  1. Download kaggle.json from https://www.kaggle.com/settings")
        print("  2. Upload it to this Colab notebook using the file browser")
        print("  3. Re-run this cell\n")
        print("METHOD 2: Set credentials in CCROPConfig class")
        print("  1. Edit the CCROPConfig class above")
        print("  2. Set KAGGLE_USERNAME and KAGGLE_KEY")
        print("  3. Re-run this cell\n")
        print("="*70)
        return False

# Setup Kaggle authentication
if not setup_kaggle_credentials(config):
    raise RuntimeError("Kaggle authentication failed. Please provide credentials.")

# Verify authentication
print("\nVerifying Kaggle API access...")
auth_status = os.system("kaggle datasets list -s cannabis > /dev/null 2>&1")
if auth_status != 0:
    print("❌ Kaggle authentication failed")
    print("\nTroubleshooting:")
    print("  • Ensure kaggle.json is uploaded to the notebook directory")
    print("  • Check that your Kaggle API token is valid")
    print("  • Verify credentials at https://www.kaggle.com/settings")
    raise RuntimeError("Cannot proceed without valid Kaggle credentials")
else:
    print("✓ Kaggle API authenticated successfully")

# ------------------------------
# 5. Enhanced Dataset Handler
# ------------------------------
class DatasetManager:
    """Manages dataset download, extraction, and organization"""

    def __init__(self, config):
        self.config = config

    def download_dataset(self):
        """Download dataset with fallback support"""
        print(f"📥 Downloading dataset: {self.config.DATASET_MAIN}")
        status = os.system(f"kaggle datasets download -d {self.config.DATASET_MAIN} -p {self.config.ROOT_DIR} 2>/dev/null")

        if status != 0:
            print(f"⚠️  Primary dataset failed. Trying fallback: {self.config.DATASET_FALLBACK}")
            status = os.system(f"kaggle datasets download -d {self.config.DATASET_FALLBACK} -p {self.config.ROOT_DIR} 2>/dev/null")

        if status != 0:
            raise RuntimeError("Failed to download dataset. Check Kaggle credentials and dataset availability.")

        print("✓ Dataset downloaded successfully")

    def extract_dataset(self):
        """Extract ZIP file"""
        zip_files = list(Path(self.config.ROOT_DIR).glob("*.zip"))

        if not zip_files:
            raise FileNotFoundError("No ZIP file found in dataset directory")

        zip_file = zip_files[0]
        print(f"📦 Extracting: {zip_file.name}")

        with zipfile.ZipFile(zip_file, 'r') as zip_ref:
            zip_ref.extractall(self.config.ROOT_DIR)

        print("✓ Dataset extracted successfully")

    def find_dataset_path(self):
        """Auto-detect dataset folder with improved logic"""
        def has_multiple_class_folders(path):
            try:
                subdirs = [d for d in os.listdir(path)
                          if os.path.isdir(os.path.join(path, d)) and not d.startswith('.')]
                return len(subdirs) > 1
            except:
                return False

        # Search for valid dataset folder
        for root, dirs, _ in os.walk(self.config.ROOT_DIR):
            if has_multiple_class_folders(root):
                # Check for 'color' subfolder (Plant Diseases dataset)
                color_dir = os.path.join(root, "color")
                if os.path.exists(color_dir) and has_multiple_class_folders(color_dir):
                    return color_dir
                return root

        raise FileNotFoundError("Unable to locate valid dataset folder with multiple class directories")

    def get_class_info(self, dataset_path):
        """Get class names and create stress mapping"""
        classes = sorted([d for d in os.listdir(dataset_path)
                         if os.path.isdir(os.path.join(dataset_path, d)) and not d.startswith('.')])

        if len(classes) == 0:
            raise ValueError("No class folders found in dataset")

        # Create stress mapping (0-100 scale)
        if len(classes) == 1:
            stress_mapping = {classes[0]: 50.0}
        else:
            stress_mapping = {cls: idx * 100.0 / (len(classes) - 1)
                            for idx, cls in enumerate(classes)}

        return classes, stress_mapping

# Initialize dataset manager
dm = DatasetManager(config)

# Download and extract if needed
if not any(Path(config.ROOT_DIR).glob("*.zip")) and not list(Path(config.ROOT_DIR).glob("*/*")):
    dm.download_dataset()
    dm.extract_dataset()
else:
    print("ℹ️  Dataset already exists, skipping download")

# Find dataset path and get classes
DATASET_PATH = dm.find_dataset_path()
CLASSES, STRESS_MAPPING = dm.get_class_info(DATASET_PATH)

print(f"✓ Dataset path: {DATASET_PATH}")
print(f"✓ Classes detected: {len(CLASSES)}")
print(f"  Sample classes: {CLASSES[:5]}")

# ------------------------------
# 6. Enhanced Data Augmentation
# ------------------------------
train_transform = transforms.Compose([
    transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.RandomResizedCrop(config.INPUT_SIZE, scale=(0.7, 1.0)),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# ------------------------------
# 7. Dataset Loading with Class Balancing
# ------------------------------
full_dataset = datasets.ImageFolder(root=DATASET_PATH, transform=train_transform)

# Calculate class weights for balanced training
class_counts = np.bincount([label for _, label in full_dataset])
class_weights = 1.0 / torch.Tensor(class_counts)
sample_weights = [class_weights[label] for _, label in full_dataset]

# Split dataset
train_size = int(config.TRAIN_SPLIT * len(full_dataset))
val_size = int(config.VAL_SPLIT * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# Apply transforms
val_dataset.dataset.transform = val_test_transform
test_dataset.dataset.transform = val_test_transform

# Create weighted sampler for balanced training
train_indices = train_dataset.indices
train_sample_weights = [sample_weights[i] for i in train_indices]
sampler = WeightedRandomSampler(train_sample_weights, len(train_sample_weights))

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE,
                         sampler=sampler, num_workers=config.NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE,
                       shuffle=False, num_workers=config.NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE,
                        shuffle=False, num_workers=config.NUM_WORKERS)

print(f"✓ Dataset split: Train={train_size}, Val={val_size}, Test={test_size}")

# ------------------------------
# 8. Enhanced Model Architecture
# ------------------------------
class CCROPModel:
    """Wrapper for CCROP model with flexible architecture"""

    @staticmethod
    def create_model(arch, num_classes, pretrained=True):
        """Create model based on architecture choice"""
        if arch == "resnet18":
            model = models.resnet18(pretrained=pretrained)
            model.fc = nn.Linear(model.fc.in_features, num_classes)
        elif arch == "resnet34":
            model = models.resnet34(pretrained=pretrained)
            model.fc = nn.Linear(model.fc.in_features, num_classes)
        elif arch == "resnet50":
            model = models.resnet50(pretrained=pretrained)
            model.fc = nn.Linear(model.fc.in_features, num_classes)
        elif arch == "efficientnet_b0":
            model = models.efficientnet_b0(pretrained=pretrained)
            model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        else:
            raise ValueError(f"Unsupported architecture: {arch}")

        return model

model = CCROPModel.create_model(config.MODEL_ARCH, len(CLASSES))

# Unfreeze all layers for fine-tuning
for param in model.parameters():
    param.requires_grad = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"✓ Model: {config.MODEL_ARCH} on {device}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# ------------------------------
# 9. Training Components
# ------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE,
                       weight_decay=config.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                 factor=0.5, patience=3)

# ------------------------------
# 10. Enhanced Training Loop with Metrics Tracking
# ------------------------------
class MetricsTracker:
    """Track and visualize training metrics"""

    def __init__(self):
        self.history = {
            'train_loss': [], 'val_loss': [],
            'train_acc': [], 'val_acc': [],
            'learning_rates': [], 'epochs': []
        }
        self.best_val_loss = float('inf')
        self.best_epoch = 0
        self.patience_counter = 0

    def update(self, epoch, train_loss, val_loss, train_acc, val_acc, lr):
        self.history['epochs'].append(epoch)
        self.history['train_loss'].append(train_loss)
        self.history['val_loss'].append(val_loss)
        self.history['train_acc'].append(train_acc)
        self.history['val_acc'].append(val_acc)
        self.history['learning_rates'].append(lr)

        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            self.best_epoch = epoch
            self.patience_counter = 0
            return True
        else:
            self.patience_counter += 1
            return False

    def should_stop(self, patience):
        return self.patience_counter >= patience

    def plot_history(self, save_path=None):
        """Plot training history"""
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))

        # Loss
        axes[0].plot(self.history['epochs'], self.history['train_loss'], label='Train Loss')
        axes[0].plot(self.history['epochs'], self.history['val_loss'], label='Val Loss')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training and Validation Loss')
        axes[0].legend()
        axes[0].grid(True)

        # Accuracy
        axes[1].plot(self.history['epochs'], self.history['train_acc'], label='Train Acc')
        axes[1].plot(self.history['epochs'], self.history['val_acc'], label='Val Acc')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Accuracy (%)')
        axes[1].set_title('Training and Validation Accuracy')
        axes[1].legend()
        axes[1].grid(True)

        # Learning Rate
        axes[2].plot(self.history['epochs'], self.history['learning_rates'])
        axes[2].set_xlabel('Epoch')
        axes[2].set_ylabel('Learning Rate')
        axes[2].set_title('Learning Rate Schedule')
        axes[2].set_yscale('log')
        axes[2].grid(True)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"✓ Training history saved to {save_path}")

        plt.show()

    def save_history(self, path):
        """Save history to JSON"""
        with open(path, 'w') as f:
            json.dump(self.history, f, indent=2)

def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

def validate_epoch(model, loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

# Training loop
metrics = MetricsTracker()
start_time = time.time()

print("\n" + "="*60)
print("🚀 Starting Training")
print("="*60)

try:
    for epoch in range(config.EPOCHS):
        epoch_start = time.time()

        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)

        # Validate
        val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)

        # Update scheduler
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']

        # Track metrics
        is_best = metrics.update(epoch + 1, train_loss, val_loss, train_acc, val_acc, current_lr)

        # Print progress
        epoch_time = time.time() - epoch_start
        print(f"Epoch [{epoch+1}/{config.EPOCHS}] ({epoch_time:.1f}s)")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        print(f"  LR: {current_lr:.6f} {'✓ BEST' if is_best else ''}")

        # Save checkpoints
        if is_best:
            checkpoint_path = os.path.join(config.CHECKPOINT_DIR, "best_model.pth")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc,
                'classes': CLASSES,
                'stress_mapping': STRESS_MAPPING,
                'config': vars(config)
            }, checkpoint_path)

        # Early stopping
        if metrics.should_stop(config.EARLY_STOPPING_PATIENCE):
            print(f"\n⚠️  Early stopping triggered at epoch {epoch+1}")
            break

    training_time = time.time() - start_time
    print(f"\n✓ Training completed in {training_time/60:.1f} minutes")
    print(f"  Best validation loss: {metrics.best_val_loss:.4f} at epoch {metrics.best_epoch}")

except KeyboardInterrupt:
    print("\n⚠️  Training interrupted by user")
    checkpoint_path = os.path.join(config.CHECKPOINT_DIR, "interrupted_model.pth")
    torch.save(model.state_dict(), checkpoint_path)
    print(f"✓ Model saved to {checkpoint_path}")

# Plot and save training history
history_plot_path = os.path.join(config.RESULTS_DIR, "training_history.png")
history_json_path = os.path.join(config.RESULTS_DIR, "training_history.json")
metrics.plot_history(save_path=history_plot_path)
metrics.save_history(history_json_path)

# ------------------------------
# 11. Comprehensive Evaluation
# ------------------------------
def evaluate_model(model, loader, device, class_names):
    """Comprehensive model evaluation"""
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='weighted'
    )

    print("\n" + "="*60)
    print("📊 Evaluation Results")
    print("="*60)
    print(f"Overall Accuracy: {accuracy*100:.2f}%")
    print(f"Weighted Precision: {precision:.4f}")
    print(f"Weighted Recall: {recall:.4f}")
    print(f"Weighted F1-Score: {f1:.4f}")
    print("\n" + "-"*60)
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()

    cm_path = os.path.join(config.RESULTS_DIR, "confusion_matrix.png")
    plt.savefig(cm_path, dpi=300, bbox_inches='tight')
    print(f"✓ Confusion matrix saved to {cm_path}")
    plt.show()

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'predictions': all_preds,
        'labels': all_labels,
        'probabilities': all_probs
    }

# Evaluate on test set
test_results = evaluate_model(model, test_loader, device, CLASSES)

# ------------------------------
# 12. Stress Prediction System
# ------------------------------
class StressPredictor:
    """Enhanced stress prediction with confidence scores"""

    def __init__(self, model, transform, stress_mapping, class_names, device):
        self.model = model
        self.transform = transform
        self.stress_mapping = stress_mapping
        self.class_names = sorted(stress_mapping.keys())
        self.device = device
        self.model.eval()

    def predict_from_path(self, img_path):
        """Predict stress from image path"""
        image = cv2.imread(img_path)
        if image is None:
            raise ValueError(f"Cannot read image: {img_path}")
        return self.predict_from_array(image)

    def predict_from_array(self, image):
        """Predict stress from numpy array (BGR format)"""
        # Convert BGR to RGB
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(image_rgb)

        # Transform and predict
        img_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            outputs = self.model(img_tensor)
            probs = F.softmax(outputs, dim=1).cpu().numpy()[0]

        # Calculate stress score
        scores = np.array([self.stress_mapping[c] for c in self.class_names])
        stress_score = np.sum(probs * scores)

        # Get top predictions
        top_k = 3
        top_indices = np.argsort(probs)[::-1][:top_k]
        top_classes = [(self.class_names[i], probs[i]*100) for i in top_indices]

        return {
            'stress_score': stress_score,
            'confidence': probs.max() * 100,
            'top_predictions': top_classes,
            'all_probabilities': dict(zip(self.class_names, probs))
        }

    def predict_batch(self, image_paths):
        """Predict stress for multiple images"""
        results = []
        for path in image_paths:
            try:
                result = self.predict_from_path(path)
                result['image_path'] = path
                results.append(result)
            except Exception as e:
                print(f"Error processing {path}: {e}")
        return results

# Initialize predictor
predictor = StressPredictor(model, val_test_transform, STRESS_MAPPING, CLASSES, device)

print("\n✓ Stress predictor initialized and ready")

# ------------------------------
# 13. Real-time Webcam Stress Detection
# ------------------------------
def webcam_stress_monitor(predictor, display_size=(640, 480)):
    """Real-time webcam monitoring with enhanced visualization"""
    cap = cv2.VideoCapture(0)

    if not cap.isOpened():
        print("❌ Cannot open webcam")
        return

    print("\n🎥 Starting webcam monitoring")
    print("Press 'q' to quit, 's' to save snapshot")

    snapshot_count = 0
    fps_history = []

    while True:
        start_time = time.time()
        ret, frame = cap.read()

        if not ret:
            print("❌ Failed to grab frame")
            break

        # Resize for display
        display_frame = cv2.resize(frame, display_size)

        # Predict stress
        try:
            result = predictor.predict_from_array(frame)
            stress_score = result['stress_score']
            confidence = result['confidence']
            top_class, top_prob = result['top_predictions'][0]

            # Determine color based on stress level
            if stress_score < 33:
                color = (0, 255, 0)  # Green - healthy
                status = "HEALTHY"
            elif stress_score < 66:
                color = (0, 165, 255)  # Orange - moderate
                status = "MODERATE"
            else:
                color = (0, 0, 255)  # Red - severe
                status = "SEVERE"

            # Draw info panel
            panel_height = 120
            panel = np.zeros((panel_height, display_size[0], 3), dtype=np.uint8)

            # Add text
            cv2.putText(panel, f"Status: {status}", (10, 25),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
            cv2.putText(panel, f"Stress Score: {stress_score:.1f}%", (10, 55),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
            cv2.putText(panel, f"Confidence: {confidence:.1f}%", (10, 80),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
            cv2.putText(panel, f"Top Class: {top_class[:30]}", (10, 105),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)

            # Calculate FPS
            fps = 1.0 / (time.time() - start_time)
            fps_history.append(fps)
            if len(fps_history) > 30:
                fps_history.pop(0)
            avg_fps = np.mean(fps_history)

            cv2.putText(panel, f"FPS: {avg_fps:.1f}", (display_size[0]-120, 25),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

            # Draw stress bar
            bar_width = int((stress_score / 100) * (display_size[0] - 20))
            cv2.rectangle(panel, (10, panel_height-20),
                         (10 + bar_width, panel_height-10), color, -1)
            cv2.rectangle(panel, (10, panel_height-20),
                         (display_size[0]-10, panel_height-10), (100, 100, 100), 2)

            # Combine frame and panel
            combined = np.vstack([display_frame, panel])

        except Exception as e:
            combined = display_frame
            cv2.putText(combined, f"Error: {str(e)[:50]}", (10, 30),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)

        cv2.imshow("CCROP - Leaf Stress Monitor", combined)

        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        elif key == ord('s'):
            snapshot_path = os.path.join(config.RESULTS_DIR, f"snapshot_{snapshot_count:03d}.jpg")
            cv2.imwrite(snapshot_path, frame)
            print(f"✓ Snapshot saved: {snapshot_path}")
            snapshot_count += 1

    cap.release()
    cv2.destroyAllWindows()
    print("\n✓ Webcam monitoring stopped")

# ------------------------------
# 14. Batch Inference on Directory
# ------------------------------
def batch_inference(predictor, image_dir, output_csv=None):
    """Run inference on all images in a directory"""
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
    image_paths = []

    for ext in image_extensions:
        image_paths.extend(Path(image_dir).glob(f"*{ext}"))
        image_paths.extend(Path(image_dir).glob(f"*{ext.upper()}"))

    if not image_paths:
        print(f"❌ No images found in {image_dir}")
        return None

    print(f"\n📸 Processing {len(image_paths)} images...")

    results = []
    for i, img_path in enumerate(image_paths):
        try:
            result = predictor.predict_from_path(str(img_path))
            result['filename'] = img_path.name
            result['filepath'] = str(img_path)
            results.append(result)

            if (i + 1) % 10 == 0:
                print(f"  Processed {i+1}/{len(image_paths)} images")

        except Exception as e:
            print(f"  Error processing {img_path.name}: {e}")

    # Create DataFrame
    df_results = pd.DataFrame([{
        'filename': r['filename'],
        'stress_score': r['stress_score'],
        'confidence': r['confidence'],
        'top_class': r['top_predictions'][0][0],
        'top_class_prob': r['top_predictions'][0][1]
    } for r in results])

    # Save to CSV if requested
    if output_csv:
        df_results.to_csv(output_csv, index=False)
        print(f"\n✓ Results saved to {output_csv}")

    # Print summary statistics
    print("\n" + "="*60)
    print("📊 Batch Inference Summary")
    print("="*60)
    print(f"Total images processed: {len(results)}")
    print(f"Average stress score: {df_results['stress_score'].mean():.2f}%")
    print(f"Std deviation: {df_results['stress_score'].std():.2f}%")
    print(f"Min stress: {df_results['stress_score'].min():.2f}%")
    print(f"Max stress: {df_results['stress_score'].max():.2f}%")
    print(f"Average confidence: {df_results['confidence'].mean():.2f}%")

    # Stress distribution
    healthy = len(df_results[df_results['stress_score'] < 33])
    moderate = len(df_results[(df_results['stress_score'] >= 33) & (df_results['stress_score'] < 66)])
    severe = len(df_results[df_results['stress_score'] >= 66])

    print(f"\nStress Distribution:")
    print(f"  Healthy (<33%): {healthy} ({healthy/len(results)*100:.1f}%)")
    print(f"  Moderate (33-66%): {moderate} ({moderate/len(results)*100:.1f}%)")
    print(f"  Severe (>66%): {severe} ({severe/len(results)*100:.1f}%)")

    return df_results

# ------------------------------
# 15. Visualization Tools
# ------------------------------
def visualize_predictions(predictor, image_paths, save_path=None):
    """Visualize predictions for multiple images"""
    n_images = len(image_paths)
    cols = min(4, n_images)
    rows = (n_images + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*4))
    if n_images == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    for idx, img_path in enumerate(image_paths):
        if idx >= len(axes):
            break

        try:
            # Load and predict
            image = cv2.imread(str(img_path))
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            result = predictor.predict_from_array(image)

            # Display
            axes[idx].imshow(image_rgb)
            axes[idx].axis('off')

            # Create title with results
            stress = result['stress_score']
            conf = result['confidence']
            top_class = result['top_predictions'][0][0]

            if stress < 33:
                color = 'green'
                status = 'Healthy'
            elif stress < 66:
                color = 'orange'
                status = 'Moderate'
            else:
                color = 'red'
                status = 'Severe'

            title = f"{status}\nStress: {stress:.1f}% | Conf: {conf:.1f}%\n{top_class[:20]}"
            axes[idx].set_title(title, fontsize=9, color=color, weight='bold')

        except Exception as e:
            axes[idx].text(0.5, 0.5, f"Error:\n{str(e)[:30]}",
                          ha='center', va='center', transform=axes[idx].transAxes)
            axes[idx].axis('off')

    # Hide extra subplots
    for idx in range(n_images, len(axes)):
        axes[idx].axis('off')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✓ Visualization saved to {save_path}")

    plt.show()

def plot_stress_distribution(df_results, save_path=None):
    """Plot stress score distribution"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Histogram
    axes[0].hist(df_results['stress_score'], bins=20, color='skyblue', edgecolor='black')
    axes[0].axvline(df_results['stress_score'].mean(), color='red',
                    linestyle='--', label=f"Mean: {df_results['stress_score'].mean():.1f}%")
    axes[0].set_xlabel('Stress Score (%)')
    axes[0].set_ylabel('Frequency')
    axes[0].set_title('Stress Score Distribution')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Box plot by category
    categories = []
    for score in df_results['stress_score']:
        if score < 33:
            categories.append('Healthy')
        elif score < 66:
            categories.append('Moderate')
        else:
            categories.append('Severe')

    df_results['category'] = categories

    category_order = ['Healthy', 'Moderate', 'Severe']
    colors = ['green', 'orange', 'red']

    box_data = [df_results[df_results['category'] == cat]['stress_score'].values
                for cat in category_order]

    bp = axes[1].boxplot(box_data, labels=category_order, patch_artist=True)
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.6)

    axes[1].set_ylabel('Stress Score (%)')
    axes[1].set_title('Stress Scores by Category')
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✓ Distribution plot saved to {save_path}")

    plt.show()

# ------------------------------
# 16. Model Export and Deployment
# ------------------------------
def export_model_for_deployment(model, save_dir, example_input_size=(1, 3, 224, 224)):
    """Export model in multiple formats for deployment"""
    os.makedirs(save_dir, exist_ok=True)

    model.eval()

    # 1. PyTorch model (.pth)
    torch_path = os.path.join(save_dir, "ccrop_model.pth")
    torch.save({
        'model_state_dict': model.state_dict(),
        'classes': CLASSES,
        'stress_mapping': STRESS_MAPPING,
        'config': vars(config),
        'architecture': config.MODEL_ARCH
    }, torch_path)
    print(f"✓ PyTorch model saved: {torch_path}")

    # 2. TorchScript (for production deployment)
    example_input = torch.randn(example_input_size).to(device)
    traced_model = torch.jit.trace(model, example_input)
    torchscript_path = os.path.join(save_dir, "ccrop_model_scripted.pt")
    traced_model.save(torchscript_path)
    print(f"✓ TorchScript model saved: {torchscript_path}")

    # 3. ONNX (for cross-platform deployment)
    try:
        onnx_path = os.path.join(save_dir, "ccrop_model.onnx")
        torch.onnx.export(
            model,
            example_input,
            onnx_path,
            export_params=True,
            opset_version=11,
            do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
        )
        print(f"✓ ONNX model saved: {onnx_path}")
    except Exception as e:
        print(f"⚠️  ONNX export failed: {e}")

    # 4. Save metadata
    metadata = {
        'model_architecture': config.MODEL_ARCH,
        'num_classes': len(CLASSES),
        'classes': CLASSES,
        'stress_mapping': STRESS_MAPPING,
        'input_size': config.INPUT_SIZE,
        'training_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'best_val_accuracy': test_results['accuracy'] * 100,
        'normalization_mean': [0.485, 0.456, 0.406],
        'normalization_std': [0.229, 0.224, 0.225]
    }

    metadata_path = os.path.join(save_dir, "model_metadata.json")
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    print(f"✓ Metadata saved: {metadata_path}")

    print(f"\n✓ All deployment files saved to: {save_dir}")

# Export model for deployment
deployment_dir = os.path.join(config.RESULTS_DIR, "deployment")
export_model_for_deployment(model, deployment_dir)

# ------------------------------
# 17. Interactive Usage Examples
# ------------------------------
print("\n" + "="*60)
print("🎯 CCROP System Ready - Usage Examples")
print("="*60)

print("""
# Example 1: Predict stress for a single image
result = predictor.predict_from_path('path/to/leaf_image.jpg')
print(f"Stress Score: {result['stress_score']:.1f}%")
print(f"Top Prediction: {result['top_predictions'][0]}")

# Example 2: Start webcam monitoring
webcam_stress_monitor(predictor)

# Example 3: Batch process a directory
df = batch_inference(predictor, 'path/to/image_folder/',
                     output_csv='results.csv')

# Example 4: Visualize predictions
image_list = ['img1.jpg', 'img2.jpg', 'img3.jpg']
visualize_predictions(predictor, image_list,
                     save_path='predictions.png')

# Example 5: Plot stress distribution from batch results
plot_stress_distribution(df, save_path='distribution.png')

# Example 6: Load model for inference only
checkpoint = torch.load('checkpoints/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
predictor = StressPredictor(model, val_test_transform,
                           STRESS_MAPPING, CLASSES, device)
""")

# ------------------------------
# 18. Generate Summary Report
# ------------------------------
def generate_summary_report():
    """Generate comprehensive summary report"""
    report_path = os.path.join(config.RESULTS_DIR, "training_report.txt")

    with open(report_path, 'w') as f:
        f.write("="*70 + "\n")
        f.write("CCROP CANNABIS LEAF DISEASE AI - TRAINING REPORT\n")
        f.write("="*70 + "\n\n")

        f.write(f"Training Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Device: {device}\n\n")

        f.write("MODEL CONFIGURATION\n")
        f.write("-"*70 + "\n")
        f.write(f"Architecture: {config.MODEL_ARCH}\n")
        f.write(f"Input Size: {config.INPUT_SIZE}x{config.INPUT_SIZE}\n")
        f.write(f"Number of Classes: {len(CLASSES)}\n")
        f.write(f"Batch Size: {config.BATCH_SIZE}\n")
        f.write(f"Learning Rate: {config.LEARNING_RATE}\n")
        f.write(f"Epochs Trained: {len(metrics.history['epochs'])}\n\n")

        f.write("DATASET INFORMATION\n")
        f.write("-"*70 + "\n")
        f.write(f"Dataset Path: {DATASET_PATH}\n")
        f.write(f"Total Samples: {len(full_dataset)}\n")
        f.write(f"Training Samples: {train_size}\n")
        f.write(f"Validation Samples: {val_size}\n")
        f.write(f"Test Samples: {test_size}\n\n")

        f.write("CLASSES\n")
        f.write("-"*70 + "\n")
        for i, cls in enumerate(CLASSES[:10]):
            f.write(f"{i+1}. {cls}\n")
        if len(CLASSES) > 10:
            f.write(f"... and {len(CLASSES)-10} more\n")
        f.write("\n")

        f.write("TRAINING RESULTS\n")
        f.write("-"*70 + "\n")
        f.write(f"Best Validation Loss: {metrics.best_val_loss:.4f}\n")
        f.write(f"Best Epoch: {metrics.best_epoch}\n")
        f.write(f"Final Train Loss: {metrics.history['train_loss'][-1]:.4f}\n")
        f.write(f"Final Val Loss: {metrics.history['val_loss'][-1]:.4f}\n")
        f.write(f"Final Train Accuracy: {metrics.history['train_acc'][-1]:.2f}%\n")
        f.write(f"Final Val Accuracy: {metrics.history['val_acc'][-1]:.2f}%\n\n")

        f.write("TEST SET EVALUATION\n")
        f.write("-"*70 + "\n")
        f.write(f"Test Accuracy: {test_results['accuracy']*100:.2f}%\n")
        f.write(f"Test Precision: {test_results['precision']:.4f}\n")
        f.write(f"Test Recall: {test_results['recall']:.4f}\n")
        f.write(f"Test F1-Score: {test_results['f1']:.4f}\n\n")

        f.write("FILES GENERATED\n")
        f.write("-"*70 + "\n")
        f.write(f"- Best Model: checkpoints/best_model.pth\n")
        f.write(f"- Training History: {history_json_path}\n")
        f.write(f"- Training Plot: {history_plot_path}\n")
        f.write(f"- Confusion Matrix: {config.RESULTS_DIR}/confusion_matrix.png\n")
        f.write(f"- Deployment Models: {deployment_dir}/\n")
        f.write(f"- This Report: {report_path}\n\n")

        f.write("="*70 + "\n")
        f.write("Report generated successfully!\n")
        f.write("="*70 + "\n")

    print(f"\n✓ Summary report saved to {report_path}")

    # Also print to console
    with open(report_path, 'r') as f:
        print("\n" + f.read())

generate_summary_report()

print("\n" + "="*60)
print("✅ CCROP Training Pipeline Completed Successfully!")
print("="*60)
print("\nNext Steps:")
print("1. Review training metrics and confusion matrix")
print("2. Test predictions on new images")
print("3. Start webcam monitoring: webcam_stress_monitor(predictor)")
print("4. Process image batches: batch_inference(predictor, 'folder_path')")
print("5. Deploy model using files in:", deployment_dir)
print("\nAll results saved in:", config.RESULTS_DIR)



Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
unzip is already the newest version (6.0-26ubuntu3.2).
0 upgraded, 0 newly installed, 0 to remove and 38 not upgraded.
✓ Found kaggle.json in current directory
✓ Kaggle credentials installed to /root/.kaggle/kaggle.json

Verifying Kaggle API access...
✓ Kaggle API authenticated successfully
📥 Downloading dataset: engineeringubu/leaf-manifestation-diseases-of-cannabis
⚠️  Primary dataset failed. Trying fallback: vipoooool/new-plant-diseases-dataset
✓ Dataset downloaded successfully
📦 Extracting: new-plant-diseases-dataset.zip
✓ Dataset extracted successfully
✓ Dataset path: ./dataset
✓ Classes detected: 3
  Sample classes: ['New Plant Diseases Dataset(Augmented)', 'new plant diseases dataset(augmented)', 'test']
✓ Dataset split: Train=140613, Val=17576, Test=17578




Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 162MB/s]


✓ Model: resnet18 on cpu
  Total parameters: 11,178,051
  Trainable parameters: 11,178,051

🚀 Starting Training
