# BeanScan CNN Training on Kaggle

This notebook is ready to run on Kaggle. Attach your datasets via the right sidebar:
- Your code as a Kaggle Dataset containing the `backend/ml` folder
- Your images/annotations dataset
- (Optional) A weights dataset to resume training

Then set the parameters below and run all cells (GPU recommended).


In [None]:
# === Parameters (edit these) ===
DATASET_IMAGES = "/kaggle/input/your-images-dataset"  # folder with images and *_annotations.json
DATASET_CODE = "/kaggle/input/beanscan-ml-code"       # dataset containing backend/ml and backend/utils
DATASET_WEIGHTS = None  # e.g., "/kaggle/input/beanscan-weights" or None
WEIGHTS_FILE = None     # e.g., "/kaggle/input/cnn-cnn-v1/cnn_best.pth" when attaching a Kaggle Model

# Training options
USE_GPU = True
NUM_EPOCHS = 20
BATCH_SIZE = 32
SAVE_INTERVAL = 5
MODELS_DIR = "/kaggle/working/models"



In [None]:
# === Environment setup ===
# Install torch/torchvision compatible with Kaggle CUDA (usually 11.8)
%pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Optional: if you uploaded your repo requirements as a file, you can point to it
# Here we install typical dependencies used by your training code
%pip install -q numpy pandas matplotlib tqdm pillow opencv-python scikit-learn albumentations


In [None]:
# === Inline backend/ml sources so no code dataset is needed ===
import os
base_dir = "/kaggle/working/backend/ml"
os.makedirs(base_dir, exist_ok=True)

custom_models_src = r'''import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v3_small, mobilenet_v3_large
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import torchvision.transforms as transforms
from typing import Dict, List, Tuple, Optional
import numpy as np

class MobileNetV3Backbone(nn.Module):
    """Custom MobileNetV3 backbone for feature extraction"""
    
    def __init__(self, pretrained: bool = True, width_mult: float = 1.0):
        super().__init__()
        # Load pretrained MobileNetV3
        if pretrained:
            self.backbone = mobilenet_v3_small(pretrained=True)
        else:
            self.backbone = mobilenet_v3_small(pretrained=False)
        
        # Extract features from different layers
        self.features = self.backbone.features
        
        # Feature dimensions for different scales
        self.feature_channels = [16, 24, 40, 48, 96, 576]
        
    def forward(self, x):
        features = []
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in [2, 4, 6, 8, 10, 12]:  # Key feature layers
                features.append(x)
        return features

class BeanClassifierCNN(nn.Module):
    """CNN for bean type classification using MobileNetV3 backbone"""
    
    def __init__(self, num_classes: int = 4, pretrained: bool = True):
        super().__init__()
        self.backbone = MobileNetV3Backbone(pretrained=pretrained)
        
        # Classification head (increased dropout ~0.3 to mitigate overfitting)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(576, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
        
        # Bean type names
        self.class_names = ["Arabica", "Robusta", "Liberica", "Excelsa"]
        
    def forward(self, x):
        features = self.backbone(x)
        # Use the last feature map for classification
        x = features[-1]
        x = self.classifier(x)
        return x
    
    def predict(self, x, threshold: float = 0.5):
        """Predict bean type with confidence"""
        self.eval()
        with torch.no_grad():
            logits = self.forward(x)
            probabilities = F.softmax(logits, dim=1)
            confidence, predicted = torch.max(probabilities, 1)
            
            # Filter by confidence threshold
            mask = confidence >= threshold
            predictions = []
            
            for i in range(len(predicted)):
                if mask[i]:
                    predictions.append({
                        'class': self.class_names[predicted[i].item()],
                        'confidence': confidence[i].item(),
                        'probabilities': probabilities[i].tolist()
                    })
                else:
                    predictions.append({
                        'class': 'Unknown',
                        'confidence': confidence[i].item(),
                        'probabilities': probabilities[i].tolist()
                    })
            
            return predictions

class DefectDetectorMaskRCNN(nn.Module):
    """Mask R-CNN for defect detection using MobileNetV3 backbone"""
    
    def __init__(self, num_classes: int = 4, pretrained: bool = True):
        super().__init__()
        # Create custom backbone with MobileNetV3
        self.backbone = MobileNetV3Backbone(pretrained=pretrained)
        
        # Create FPN from backbone features
        self.fpn = BackboneWithFPN(
            self.backbone,
            return_layers={'0': '0', '1': '1', '2': '2', '3': '3', '4': '4', '5': '5'},
            in_channels_list=[16, 24, 40, 48, 96, 576],
            out_channels=256
        )
        
        # Create Mask R-CNN with custom backbone
        self.mask_rcnn = maskrcnn_resnet50_fpn(
            pretrained=False,
            num_classes=num_classes + 1  # +1 for background
        )
        
        # Replace backbone
        self.mask_rcnn.backbone = self.fpn
        
        # Customize box and mask predictors
        in_features = self.mask_rcnn.roi_heads.box_predictor.cls_score.in_features
        self.mask_rcnn.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes + 1)
        
        in_features_mask = self.mask_rcnn.roi_heads.mask_predictor.conv5_mask.in_channels
        hidden_layer = 256
        self.mask_rcnn.roi_heads.mask_predictor = MaskRCNNPredictor(
            in_features_mask, hidden_layer, num_classes + 1
        )
        
        # Defect types
        self.defect_types = ["Mold", "Insect_Damage", "Discoloration", "Physical_Damage"]
        
    def forward(self, images, targets=None):
        return self.mask_rcnn(images, targets)

class DefectDetectorFasterRCNN(nn.Module):
    """Faster R-CNN detector (bounding boxes only) for bean defects"""
    
    def __init__(self, num_classes: int = 7, pretrained: bool = True,
                 class_names: Optional[List[str]] = None):
        super().__init__()
        # num_classes should include background (>=2)
        self.num_classes = max(2, num_classes)
        self.model = fasterrcnn_mobilenet_v3_large_fpn(pretrained=pretrained)
        in_features = self.model.roi_heads.box_predictor.cls_score.in_features
        self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, self.num_classes)
        
        default_classes = [
            "insect_damage",
            "nugget",
            "quaker",
            "roasted-beans",
            "shell",
            "under_roast"
        ]
        self.class_names = ["__background__"] + (class_names or default_classes)
    
    def forward(self, images, targets=None):
        return self.model(images, targets)

class ShelfLifeLSTM(nn.Module):
    """LSTM for shelf life prediction based on defect progression"""
    
    def __init__(self, input_size: int = 64, hidden_size: int = 128, num_layers: int = 2, dropout: float = 0.2):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True,
            bidirectional=True
        )
        
        # Attention mechanism
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_size * 2,  # *2 for bidirectional
            num_heads=8,
            dropout=dropout
        )
        
        # Prediction head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, 1)  # Predict days until expiration
        )
        
        # Shelf life categories
        self.shelf_life_categories = ["Expired", "Critical", "Warning", "Good", "Excellent"]
        
    def forward(self, x, hidden=None):
        # x shape: (batch_size, seq_len, input_size)
        batch_size = x.size(0)
        
        # Initialize hidden state if not provided
        if hidden is None:
            h0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(x.device)
            c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(x.device)
            hidden = (h0, c0)
        
        # LSTM forward pass
        lstm_out, hidden = self.lstm(x, hidden)
        
        # Apply attention
        lstm_out = lstm_out.transpose(0, 1)  # (seq_len, batch_size, hidden_size*2)
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        attn_out = attn_out.transpose(0, 1)  # (batch_size, seq_len, hidden_size*2)
        
        # Global average pooling
        pooled = torch.mean(attn_out, dim=1)  # (batch_size, hidden_size*2)
        
        # Predict shelf life
        shelf_life = self.classifier(pooled)
        
        return shelf_life, hidden

class BeanScanEnsemble(nn.Module):
    """Ensemble model combining CNN, Mask R-CNN, and LSTM"""
    
    def __init__(self, cnn_model: BeanClassifierCNN, 
                 defect_model: DefectDetectorMaskRCNN,
                 lstm_model: ShelfLifeLSTM):
        super().__init__()
        self.cnn_model = cnn_model
        self.defect_model = defect_model
        self.lstm_model = lstm_model
        
    def forward(self, image, defect_sequence=None):
        """Complete bean analysis pipeline"""
        results = {}
        
        # 1. Bean type classification
        bean_type = self.cnn_model.predict(image)
        results['bean_classification'] = bean_type
        
        # 2. Defect detection
        defects = self.defect_model.detect_defects(image)
        results['defect_detection'] = defects
        
        # 3. Shelf life prediction (if sequence provided)
        if defect_sequence is not None:
            shelf_life = self.lstm_model.predict_shelf_life(defect_sequence)
            results['shelf_life_prediction'] = shelf_life
        
        # 4. Calculate overall health score
        health_score = self._calculate_health_score(bean_type, defects)
        results['health_score'] = health_score
        
        return results
    
    def _calculate_health_score(self, bean_type, defects):
        """Calculate overall bean health score"""
        # Base score from bean type confidence
        base_score = bean_type[0]['confidence'] if bean_type else 0.5
        
        # Penalty for defects
        defect_penalty = 0
        if defects:
            for defect in defects:
                # Higher penalty for more severe defects
                if defect['defect_type'] == 'Mold':
                    defect_penalty += 0.3
                elif defect['defect_type'] == 'Insect_Damage':
                    defect_penalty += 0.25
                elif defect['defect_type'] == 'Discoloration':
                    defect_penalty += 0.15
                elif defect['defect_type'] == 'Physical_Damage':
                    defect_penalty += 0.1
                
                # Additional penalty based on defect area
                defect_penalty += min(0.2, defect['area'] / 10000)  # Normalize area
        
        # Calculate final health score
        health_score = max(0.0, min(1.0, base_score - defect_penalty))
        
        return {
            'score': health_score,
            'percentage': health_score * 100,
            'grade': self._get_health_grade(health_score),
            'defect_count': len(defects) if defects else 0
        }
    
    def _get_health_grade(self, score):
        """Convert health score to letter grade"""
        if score >= 0.9:
            return 'A+'
        elif score >= 0.8:
            return 'A'
        elif score >= 0.7:
            return 'B+'
        elif score >= 0.6:
            return 'B'
        elif score >= 0.5:
            return 'C+'
        elif score >= 0.4:
            return 'C'
        elif score >= 0.3:
            return 'D'
        else:
            return 'F'

# Utility functions

def create_models(device: str = 'cpu'):
    """Create and initialize all models"""
    device = torch.device(device)
    
    # Initialize models
    cnn = BeanClassifierCNN(num_classes=4, pretrained=True)
    defect_detector = DefectDetectorMaskRCNN(num_classes=4, pretrained=True)
    lstm = ShelfLifeLSTM(input_size=64, hidden_size=128, num_layers=2)
    
    # Move to device
    cnn.to(device)
    defect_detector.to(device)
    lstm.to(device)
    
    # Create ensemble
    ensemble = BeanScanEnsemble(cnn, defect_detector, lstm)
    ensemble.to(device)
    
    return {
        'cnn': cnn,
        'defect_detector': defect_detector,
        'lstm': lstm,
        'ensemble': ensemble
    }


def save_models(models: Dict, save_dir: str = './models'):
    """Save all models"""
    import os
    os.makedirs(save_dir, exist_ok=True)
    
    for name, model in models.items():
        torch.save(model.state_dict(), os.path.join(save_dir, f'{name}.pth'))
        print(f"✅ Saved {name} model")


def load_models(device: str = 'cpu', model_dir: str = './models'):
    """Load all models"""
    device = torch.device(device)
    
    # Create models
    models = create_models(device)
    
    # Load saved weights if available
    for name, model in models.items():
        model_path = os.path.join(model_dir, f'{name}.pth')
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
            print(f"✅ Loaded {name} model from {model_path}")
        else:
            print(f"⚠️  No saved weights found for {name}, using initialized weights")
    
    return models
'''

train_models_src = r'''import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import os
from PIL import Image
import json
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
from tqdm import tqdm

from custom_models import (
    BeanClassifierCNN, 
    DefectDetectorMaskRCNN, 
    ShelfLifeLSTM,
    BeanScanEnsemble,
    create_models
)

class BeanImageDataset(Dataset):
    """Dataset for bean images with labels"""
    
    def __init__(self, data_dir: str, transform=None, split: str = 'train'):
        self.data_dir = data_dir
        self.transform = transform
        self.split = split
        
        # Load annotations
        self.annotations = self._load_annotations()
        
        # Image transformations
        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
    
    def _load_annotations(self):
        """Load dataset annotations"""
        annotations_file = os.path.join(self.data_dir, f'{self.split}_annotations.json')
        if os.path.exists(annotations_file):
            with open(annotations_file, 'r') as f:
                return json.load(f)
        else:
            # Create dummy annotations for testing
            return self._create_dummy_annotations()
    
    def _create_dummy_annotations(self):
        """Create dummy annotations for testing"""
        annotations = []
        bean_types = ["Arabica", "Robusta", "Liberica", "Excelsa"]
        
        # Create dummy data
        for i in range(100):
            annotation = {
                'image_id': f'bean_{i:04d}.jpg',
                'bean_type': bean_types[i % len(bean_types)],
                'defects': [
                    {
                        'type': 'Mold' if i % 10 == 0 else 'None',
                        'bbox': [10, 10, 100, 100] if i % 10 == 0 else [0, 0, 0, 0],
                        'mask': np.zeros((224, 224)).tolist() if i % 10 == 0 else []
                    }
                ],
                'health_score': max(0.1, 1.0 - (i % 10) * 0.1)
            }
            annotations.append(annotation)
        
        return annotations
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        
        # Load image (create dummy if not exists)
        image_path = os.path.join(self.data_dir, annotation['image_id'])
        if os.path.exists(image_path):
            image = Image.open(image_path).convert('RGB')
        else:
            # Create dummy image
            image = Image.new('RGB', (224, 224), color=(139, 69, 19))  # Brown color
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        # Prepare labels
        bean_type_label = self._get_bean_type_label(annotation['bean_type'])
        defect_labels = self._get_defect_labels(annotation['defects'])
        
        return {
            'image': image,
            'bean_type_label': bean_type_label,
            'defect_labels': defect_labels,
            'health_score': annotation['health_score'],
            'image_id': annotation['image_id']
        }
    
    def _get_bean_type_label(self, bean_type: str):
        """Convert bean type to label index"""
        bean_types = ["Arabica", "Robusta", "Liberica", "Excelsa"]
        return bean_types.index(bean_type) if bean_type in bean_types else 0
    
    def _get_defect_labels(self, defects: List):
        """Convert defects to label format"""
        defect_types = ["Mold", "Insect_Damage", "Discoloration", "Physical_Damage"]
        labels = []
        
        for defect in defects:
            if defect['type'] in defect_types:
                label = {
                    'boxes': torch.tensor([defect['bbox']], dtype=torch.float32),
                    'labels': torch.tensor([defect_types.index(defect['type']) + 1], dtype=torch.long),
                    'masks': torch.tensor([defect['mask']], dtype=torch.uint8)
                }
                labels.append(label)
        
        return labels

class ModelTrainer:
    """Trainer class for all models"""
    
    def __init__(self, device: str = 'cpu', models_dir: str = './models'):
        self.device = torch.device(device)
        self.models_dir = models_dir
        os.makedirs(models_dir, exist_ok=True)
        
        # Initialize models
        self.models = create_models(device)
        
        # Training parameters
        self.learning_rate = 0.001
        self.batch_size = 16
        self.num_epochs = 50
        self.save_interval = 10
        
        # Loss functions
        self.cnn_criterion = nn.CrossEntropyLoss()
        self.defect_criterion = self._get_defect_loss()
        
        # Optimizers
        self.optimizers = self._create_optimizers()
        
        # Training history
        self.training_history = {
            'cnn_loss': [], 'cnn_acc': [],
            'defect_loss': [], 'defect_map': [],
            'lstm_loss': [], 'lstm_mae': []
        }
    
    def _get_defect_loss(self):
        """Get loss function for defect detection"""
        return {
            'classification': nn.CrossEntropyLoss(),
            'bbox_regression': nn.SmoothL1Loss(),
            'mask_loss': nn.BCEWithLogitsLoss()
        }
    
    def _create_optimizers(self):
        """Create optimizers for all models"""
        return {
            'cnn': optim.Adam(self.models['cnn'].parameters(), lr=self.learning_rate),
            'defect_detector': optim.Adam(self.models['defect_detector'].parameters(), lr=self.learning_rate),
            'lstm': optim.Adam(self.models['lstm'].parameters(), lr=self.learning_rate)
        }
    
    def train_cnn(self, train_loader: DataLoader, val_loader: DataLoader = None):
        """Train the CNN classifier"""
        print("🚀 Training CNN Classifier...")
        
        model = self.models['cnn']
        optimizer = self.optimizers['cnn']
        model.train()
        
        for epoch in range(self.num_epochs):
            running_loss = 0.0
            correct = 0
            total = 0
            
            # Training loop
            for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
                images = batch['image'].to(self.device)
                labels = batch['bean_type_label'].to(self.device)
                
                # Forward pass
                optimizer.zero_grad()
                outputs = model(images)
                loss = self.cnn_criterion(outputs, labels)
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                # Statistics
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            
            # Calculate epoch metrics
            epoch_loss = running_loss / len(train_loader)
            epoch_acc = 100 * correct / total
            
            self.training_history['cnn_loss'].append(epoch_loss)
            self.training_history['cnn_acc'].append(epoch_acc)
            
            print(f'Epoch {epoch+1}: Loss={epoch_loss:.4f}, Accuracy={epoch_acc:.2f}%')
            
            # Save model periodically
            if (epoch + 1) % self.save_interval == 0:
                self._save_model('cnn', epoch + 1)
        
        print("✅ CNN Training Complete!")
        return self.training_history['cnn_loss'], self.training_history['cnn_acc']
    
    def train_defect_detector(self, train_loader: DataLoader, val_loader: DataLoader = None):
        """Train the Mask R-CNN defect detector"""
        print("🚀 Training Defect Detector...")
        
        model = self.models['defect_detector']
        optimizer = self.optimizers['defect_detector']
        model.train()
        
        for epoch in range(self.num_epochs):
            running_loss = 0.0
            
            # Training loop
            for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
                images = batch['image'].to(self.device)
                targets = batch['defect_labels']
                
                # Prepare targets for Mask R-CNN
                formatted_targets = self._format_targets(targets)
                
                # Forward pass
                optimizer.zero_grad()
                loss_dict = model(images, formatted_targets)
                
                # Calculate total loss
                total_loss = sum(loss_dict.values())
                
                # Backward pass
                total_loss.backward()
                optimizer.step()
                
                # Statistics
                running_loss += total_loss.item()
            
            # Calculate epoch metrics
            epoch_loss = running_loss / len(train_loader)
            self.training_history['defect_loss'].append(epoch_loss)
            
            print(f'Epoch {epoch+1}: Loss={epoch_loss:.4f}')
            
            # Save model periodically
            if (epoch + 1) % self.save_interval == 0:
                self._save_model('defect_detector', epoch + 1)
        
        print("✅ Defect Detector Training Complete!")
        return self.training_history['defect_loss']
    
    def train_lstm(self, train_loader: DataLoader, val_loader: DataLoader = None):
        """Train the LSTM for shelf life prediction"""
        print("🚀 Training LSTM...")
        
        model = self.models['lstm']
        optimizer = self.optimizers['lstm']
        model.train()
        
        for epoch in range(self.num_epochs):
            running_loss = 0.0
            running_mae = 0.0
            
            # Training loop
            for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
                # Create dummy sequence data for training
                seq_length = 10
                batch_size = batch['image'].size(0)
                
                # Generate dummy defect sequences
                sequences = torch.randn(batch_size, seq_length, 64).to(self.device)
                targets = torch.tensor([batch['health_score'] * 30 for _ in range(batch_size)], 
                                     dtype=torch.float32).to(self.device)  # Convert to days
                
                # Forward pass
                optimizer.zero_grad()
                outputs, _ = model(sequences)
                loss = nn.MSELoss()(outputs.squeeze(), targets)
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                # Statistics
                running_loss += loss.item()
                running_mae += torch.mean(torch.abs(outputs.squeeze() - targets)).item()
            
            # Calculate epoch metrics
            epoch_loss = running_loss / len(train_loader)
            epoch_mae = running_mae / len(train_loader)
            
            self.training_history['lstm_loss'].append(epoch_loss)
            self.training_history['lstm_mae'].append(epoch_mae)
            
            print(f'Epoch {epoch+1}: Loss={epoch_loss:.4f}, MAE={epoch_mae:.2f}')
            
            # Save model periodically
            if (epoch + 1) % self.save_interval == 0:
                self._save_model('lstm', epoch + 1)
        
        print("✅ LSTM Training Complete!")
        return self.training_history['lstm_loss'], self.training_history['lstm_mae']
    
    def _format_targets(self, targets: List):
        """Format targets for Mask R-CNN training"""
        formatted = []
        for target in targets:
            if target:  # If defects exist
                formatted.append({
                    'boxes': target[0]['boxes'].to(self.device),
                    'labels': target[0]['labels'].to(self.device),
                    'masks': target[0]['masks'].to(self.device)
                })
            else:  # No defects
                formatted.append({
                    'boxes': torch.empty((0, 4), dtype=torch.float32).to(self.device),
                    'labels': torch.empty((0,), dtype=torch.long).to(self.device),
                    'masks': torch.empty((0, 224, 224), dtype=torch.uint8).to(self.device)
                })
        return formatted
    
    def _save_model(self, model_name: str, epoch: int):
        """Save model checkpoint"""
        save_path = os.path.join(self.models_dir, f'{model_name}_epoch_{epoch}.pth')
        torch.save(self.models[model_name].state_dict(), save_path)
        print(f"💾 Saved {model_name} checkpoint: {save_path}")
    
    def save_final_models(self):
        """Save final trained models"""
        for name, model in self.models.items():
            save_path = os.path.join(self.models_dir, f'{name}_final.pth')
            torch.save(model.state_dict(), save_path)
            print(f"💾 Saved final {name} model: {save_path}")
    
    def plot_training_history(self):
        """Plot training history"""
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # CNN metrics
        axes[0, 0].plot(self.training_history['cnn_loss'])
        axes[0, 0].set_title('CNN Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        
        axes[0, 1].plot(self.training_history['cnn_acc'])
        axes[0, 1].set_title('CNN Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy (%)')
        
        # Defect detector metrics
        axes[0, 2].plot(self.training_history['defect_loss'])
        axes[0, 2].set_title('Defect Detector Loss')
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('Loss')
        
        # LSTM metrics
        axes[1, 0].plot(self.training_history['lstm_loss'])
        axes[1, 0].set_title('LSTM Loss')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Loss')
        
        axes[1, 1].plot(self.training_history['lstm_mae'])
        axes[1, 1].set_title('LSTM MAE')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('MAE')
        
        # Hide empty subplot
        axes[1, 2].set_visible(False)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.models_dir, 'training_history.png'))
        plt.show()


def main():
    """Main training function"""
    print("🎯 BeanScan Deep Learning Model Training")
    print("=" * 50)
    
    # Initialize trainer
    trainer = ModelTrainer(device='cpu')  # Use 'cuda' if GPU available
    
    # Create dummy datasets (replace with real data)
    train_dataset = BeanImageDataset('./data', split='train')
    val_dataset = BeanImageDataset('./data', split='val')
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    
    print(f"📊 Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}")
    
    # Train all models
    try:
        # Train CNN
        cnn_loss, cnn_acc = trainer.train_cnn(train_loader, val_loader)
        
        # Train Defect Detector
        defect_loss = trainer.train_defect_detector(train_loader, val_loader)
        
        # Train LSTM
        lstm_loss, lstm_mae = trainer.train_lstm(train_loader, val_loader)
        
        # Save final models
        trainer.save_final_models()
        
        # Plot training history
        trainer.plot_training_history()
        
        print("\n🎉 All models trained successfully!")
        print("📁 Models saved in:", trainer.models_dir)
        
    except Exception as e:
        print(f"❌ Training failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()
'''

with open(os.path.join(base_dir, 'custom_models.py'), 'w', encoding='utf-8') as f:
    f.write(custom_models_src)
with open(os.path.join(base_dir, 'train_models.py'), 'w', encoding='utf-8') as f:
    f.write(train_models_src)

import sys
if '/kaggle/working/backend' not in sys.path:
    sys.path.append('/kaggle/working/backend')
if '/kaggle/working/backend/ml' not in sys.path:
    sys.path.append('/kaggle/working/backend/ml')

print('Wrote inline modules to', base_dir)


In [None]:
# === Write code files inline so we don't need a code dataset ===
import os
base_dir = "/kaggle/working/backend/ml"
os.makedirs(base_dir, exist_ok=True)

custom_models_src = r'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v3_small, mobilenet_v3_large
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import torchvision.transforms as transforms
from typing import Dict, List, Tuple, Optional
import numpy as np

class MobileNetV3Backbone(nn.Module):
    """Custom MobileNetV3 backbone for feature extraction"""
    
    def __init__(self, pretrained: bool = True, width_mult: float = 1.0):
        super().__init__()
        # Load pretrained MobileNetV3
        if pretrained:
            self.backbone = mobilenet_v3_small(pretrained=True)
        else:
            self.backbone = mobilenet_v3_small(pretrained=False)
        
        # Extract features from different layers
        self.features = self.backbone.features
        
        # Feature dimensions for different scales
        self.feature_channels = [16, 24, 40, 48, 96, 576]
        
    def forward(self, x):
        features = []
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in [2, 4, 6, 8, 10, 12]:  # Key feature layers
                features.append(x)
        return features

class BeanClassifierCNN(nn.Module):
    """CNN for bean type classification using MobileNetV3 backbone"""
    
    def __init__(self, num_classes: int = 4, pretrained: bool = True):
        super().__init__()
        self.backbone = MobileNetV3Backbone(pretrained=pretrained)
        
        # Classification head (increased dropout ~0.3 to mitigate overfitting)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(576, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
        
        # Bean type names
        self.class_names = ["Arabica", "Robusta", "Liberica", "Excelsa"]
        
    def forward(self, x):
        features = self.backbone(x)
        # Use the last feature map for classification
        x = features[-1]
        x = self.classifier(x)
        return x
    
    def predict(self, x, threshold: float = 0.5):
        """Predict bean type with confidence"""
        self.eval()
        with torch.no_grad():
            logits = self.forward(x)
            probabilities = F.softmax(logits, dim=1)
            confidence, predicted = torch.max(probabilities, 1)
            
            # Filter by confidence threshold
            mask = confidence >= threshold
            predictions = []
            
            for i in range(len(predicted)):
                if mask[i]:
                    predictions.append({
                        'class': self.class_names[predicted[i].item()],
                        'confidence': confidence[i].item(),
                        'probabilities': probabilities[i].tolist()
                    })
                else:
                    predictions.append({
                        'class': 'Unknown',
                        'confidence': confidence[i].item(),
                        'probabilities': probabilities[i].tolist()
                    })
            
            return predictions

class DefectDetectorMaskRCNN(nn.Module):
    """Mask R-CNN for defect detection using MobileNetV3 backbone"""
    
    def __init__(self, num_classes: int = 4, pretrained: bool = True):
        super().__init__()
        # Create custom backbone with MobileNetV3
        self.backbone = MobileNetV3Backbone(pretrained=pretrained)
        
        # Create FPN from backbone features
        self.fpn = BackboneWithFPN(
            self.backbone,
            return_layers={'0': '0', '1': '1', '2': '2', '3': '3', '4': '4', '5': '5'},
            in_channels_list=[16, 24, 40, 48, 96, 576],
            out_channels=256
        )
        
        # Create Mask R-CNN with custom backbone
        self.mask_rcnn = maskrcnn_resnet50_fpn(
            pretrained=False,
            num_classes=num_classes + 1  # +1 for background
        )
        
        # Replace backbone
        self.mask_rcnn.backbone = self.fpn
        
        # Customize box and mask predictors
        in_features = self.mask_rcnn.roi_heads.box_predictor.cls_score.in_features
        self.mask_rcnn.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes + 1)
        
        in_features_mask = self.mask_rcnn.roi_heads.mask_predictor.conv5_mask.in_channels
        hidden_layer = 256
        self.mask_rcnn.roi_heads.mask_predictor = MaskRCNNPredictor(
            in_features_mask, hidden_layer, num_classes + 1
        )
        
        # Defect types
        self.defect_types = ["Mold", "Insect_Damage", "Discoloration", "Physical_Damage"]
        
    def forward(self, images, targets=None):
        return self.mask_rcnn(images, targets)
    
    def detect_defects(self, image, confidence_threshold: float = 0.5):
        """Detect defects in bean image"""
        self.eval()
        with torch.no_grad():
            # Prepare image
            if len(image.shape) == 3:
                image = image.unsqueeze(0)
            
            # Get predictions
            predictions = self.forward(image)
            
            # Process results
            defects = []
            for pred in predictions:
                boxes = pred['boxes']
                scores = pred['scores']
                masks = pred['masks']
                labels = pred['labels']
                
                for i in range(len(scores)):
                    if scores[i] >= confidence_threshold:
                        defect = {
                            'bbox': boxes[i].tolist(),
                            'confidence': scores[i].item(),
                            'mask': masks[i].squeeze().tolist(),
                            'defect_type': self.defect_types[labels[i].item() - 1],  # -1 for background
                            'area': torch.sum(masks[i]).item(),
                            'coordinates': {
                                'x1': boxes[i][0].item(),
                                'y1': boxes[i][1].item(),
                                'x2': boxes[i][2].item(),
                                'y2': boxes[i][3].item()
                            }
                        }
                        defects.append(defect)
            
            return defects

class DefectDetectorFasterRCNN(nn.Module):
    """Faster R-CNN detector (bounding boxes only) for bean defects"""
    
    def __init__(self, num_classes: int = 7, pretrained: bool = True,
                 class_names: Optional[List[str]] = None):
        super().__init__()
        # num_classes should include background (>=2)
        self.num_classes = max(2, num_classes)
        self.model = fasterrcnn_mobilenet_v3_large_fpn(pretrained=pretrained)
        in_features = self.model.roi_heads.box_predictor.cls_score.in_features
        self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, self.num_classes)
        
        default_classes = [
            "insect_damage",
            "nugget",
            "quaker",
            "roasted-beans",
            "shell",
            "under_roast"
        ]
        self.class_names = ["__background__"] + (class_names or default_classes)
    
    def forward(self, images, targets=None):
        return self.model(images, targets)
    
    def detect(self, image, confidence_threshold: float = 0.5):
        self.eval()
        with torch.no_grad():
            if len(image.shape) == 3:
                image = image.unsqueeze(0)
            outputs = self.forward(image)
            detections = []
            for pred in outputs:
                boxes = pred['boxes']
                scores = pred['scores']
                labels = pred['labels']
                for i in range(len(scores)):
                    if scores[i] >= confidence_threshold:
                        detections.append({
                            'bbox': boxes[i].tolist(),
                            'score': scores[i].item(),
                            'label': self.class_names[labels[i].item()]
                        })
            return detections

class ShelfLifeLSTM(nn.Module):
    """LSTM for shelf life prediction based on defect progression"""
    
    def __init__(self, input_size: int = 64, hidden_size: int = 128, num_layers: int = 2, dropout: float = 0.2):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True,
            bidirectional=True
        )
        
        # Attention mechanism
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_size * 2,  # *2 for bidirectional
            num_heads=8,
            dropout=dropout
        )
        
        # Prediction head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, 1)  # Predict days until expiration
        )
        
        # Shelf life categories
        self.shelf_life_categories = ["Expired", "Critical", "Warning", "Good", "Excellent"]
        
    def forward(self, x, hidden=None):
        # x shape: (batch_size, seq_len, input_size)
        batch_size = x.size(0)
        
        # Initialize hidden state if not provided
        if hidden is None:
            h0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(x.device)
            c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(x.device)
            hidden = (h0, c0)
        
        # LSTM forward pass
        lstm_out, hidden = self.lstm(x, hidden)
        
        # Apply attention
        lstm_out = lstm_out.transpose(0, 1)  # (seq_len, batch_size, hidden_size*2)
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        attn_out = attn_out.transpose(0, 1)  # (batch_size, seq_len, hidden_size*2)
        
        # Global average pooling
        pooled = torch.mean(attn_out, dim=1)  # (batch_size, hidden_size*2)
        
        # Predict shelf life
        shelf_life = self.classifier(pooled)
        
        return shelf_life, hidden
    
    def predict_shelf_life(self, defect_sequence, confidence_threshold: float = 0.7):
        """Predict shelf life based on defect progression sequence"""
        self.eval()
        with torch.no_grad():
            # Prepare input sequence
            if isinstance(defect_sequence, list):
                defect_sequence = torch.tensor(defect_sequence, dtype=torch.float32)
            
            if len(defect_sequence.shape) == 2:
                defect_sequence = defect_sequence.unsqueeze(0)  # Add batch dimension
            
            # Get prediction
            shelf_life_days, _ = self.forward(defect_sequence)
            predicted_days = shelf_life_days.item()
            
            # Categorize shelf life
            if predicted_days <= 0:
                category = "Expired"
                confidence = 1.0
            elif predicted_days <= 3:
                category = "Critical"
                confidence = 0.9
            elif predicted_days <= 7:
                category = "Warning"
                confidence = 0.8
            elif predicted_days <= 14:
                category = "Good"
                confidence = 0.7
            else:
                category = "Excellent"
                confidence = 0.6
            
            # Adjust confidence based on threshold
            if confidence < confidence_threshold:
                category = "Uncertain"
            
            return {
                'predicted_days': max(0, int(predicted_days)),
                'category': category,
                'confidence': confidence,
                'raw_prediction': predicted_days
            }

class BeanScanEnsemble(nn.Module):
    """Ensemble model combining CNN, Mask R-CNN, and LSTM"""
    
    def __init__(self, cnn_model: BeanClassifierCNN, 
                 defect_model: DefectDetectorMaskRCNN,
                 lstm_model: ShelfLifeLSTM):
        super().__init__()
        self.cnn_model = cnn_model
        self.defect_model = defect_model
        self.lstm_model = lstm_model
        
    def forward(self, image, defect_sequence=None):
        """Complete bean analysis pipeline"""
        results = {}
        
        # 1. Bean type classification
        bean_type = self.cnn_model.predict(image)
        results['bean_classification'] = bean_type
        
        # 2. Defect detection
        defects = self.defect_model.detect_defects(image)
        results['defect_detection'] = defects
        
        # 3. Shelf life prediction (if sequence provided)
        if defect_sequence is not None:
            shelf_life = self.lstm_model.predict_shelf_life(defect_sequence)
            results['shelf_life_prediction'] = shelf_life
        
        # 4. Calculate overall health score
        health_score = self._calculate_health_score(bean_type, defects)
        results['health_score'] = health_score
        
        return results
    
    def _calculate_health_score(self, bean_type, defects):
        """Calculate overall bean health score"""
        # Base score from bean type confidence
        base_score = bean_type[0]['confidence'] if bean_type else 0.5
        
        # Penalty for defects
        defect_penalty = 0
        if defects:
            for defect in defects:
                # Higher penalty for more severe defects
                if defect['defect_type'] == 'Mold':
                    defect_penalty += 0.3
                elif defect['defect_type'] == 'Insect_Damage':
                    defect_penalty += 0.25
                elif defect['defect_type'] == 'Discoloration':
                    defect_penalty += 0.15
                elif defect['defect_type'] == 'Physical_Damage':
                    defect_penalty += 0.1
                
                # Additional penalty based on defect area
                defect_penalty += min(0.2, defect['area'] / 10000)  # Normalize area
        
        # Calculate final health score
        health_score = max(0.0, min(1.0, base_score - defect_penalty))
        
        return {
            'score': health_score,
            'percentage': health_score * 100,
            'grade': self._get_health_grade(health_score),
            'defect_count': len(defects) if defects else 0
        }
    
    def _get_health_grade(self, score):
        """Convert health score to letter grade"""
        if score >= 0.9:
            return 'A+'
        elif score >= 0.8:
            return 'A'
        elif score >= 0.7:
            return 'B+'
        elif score >= 0.6:
            return 'B'
        elif score >= 0.5:
            return 'C+'
        elif score >= 0.4:
            return 'C'
        elif score >= 0.3:
            return 'D'
        else:
            return 'F'

# Utility functions

def create_models(device: str = 'cpu'):
    """Create and initialize all models"""
    device = torch.device(device)
    
    # Initialize models
    cnn = BeanClassifierCNN(num_classes=4, pretrained=True)
    defect_detector = DefectDetectorMaskRCNN(num_classes=4, pretrained=True)
    lstm = ShelfLifeLSTM(input_size=64, hidden_size=128, num_layers=2)
    
    # Move to device
    cnn.to(device)
    defect_detector.to(device)
    lstm.to(device)
    
    # Create ensemble
    ensemble = BeanScanEnsemble(cnn, defect_detector, lstm)
    ensemble.to(device)
    
    return {
        'cnn': cnn,
        'defect_detector': defect_detector,
        'lstm': lstm,
        'ensemble': ensemble
    }


def save_models(models: Dict, save_dir: str = './models'):
    """Save all models"""
    import os
    os.makedirs(save_dir, exist_ok=True)
    
    for name, model in models.items():
        torch.save(model.state_dict(), os.path.join(save_dir, f'{name}.pth'))
        print(f"✅ Saved {name} model")


def load_models(device: str = 'cpu', model_dir: str = './models'):
    """Load all models"""
    device = torch.device(device)
    
    # Create models
    models = create_models(device)
    
    # Load saved weights if available
    for name, model in models.items():
        model_path = os.path.join(model_dir, f'{name}.pth')
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
            print(f"✅ Loaded {name} model from {model_path}")
        else:
            print(f"⚠️  No saved weights found for {name}, using initialized weights")
    
    return models
'''

train_models_src = r''' 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import os
from PIL import Image
import json
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
from tqdm import tqdm

from custom_models import (
    BeanClassifierCNN, 
    DefectDetectorMaskRCNN, 
    ShelfLifeLSTM,
    BeanScanEnsemble,
    create_models
)

class BeanImageDataset(Dataset):
    """Dataset for bean images with labels"""
    
    def __init__(self, data_dir: str, transform=None, split: str = 'train'):
        self.data_dir = data_dir
        self.transform = transform
        self.split = split
        
        # Load annotations
        self.annotations = self._load_annotations()
        
        # Image transformations
        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
    
    def _load_annotations(self):
        """Load dataset annotations"""
        annotations_file = os.path.join(self.data_dir, f'{self.split}_annotations.json')
        if os.path.exists(annotations_file):
            with open(annotations_file, 'r') as f:
                return json.load(f)
        else:
            # Create dummy annotations for testing
            return self._create_dummy_annotations()
    
    def _create_dummy_annotations(self):
        """Create dummy annotations for testing"""
        annotations = []
        bean_types = ["Arabica", "Robusta", "Liberica", "Excelsa"]
        
        # Create dummy data
        for i in range(100):
            annotation = {
                'image_id': f'bean_{i:04d}.jpg',
                'bean_type': bean_types[i % len(bean_types)],
                'defects': [
                    {
                        'type': 'Mold' if i % 10 == 0 else 'None',
                        'bbox': [10, 10, 100, 100] if i % 10 == 0 else [0, 0, 0, 0],
                        'mask': np.zeros((224, 224)).tolist() if i % 10 == 0 else []
                    }
                ],
                'health_score': max(0.1, 1.0 - (i % 10) * 0.1)
            }
            annotations.append(annotation)
        
        return annotations
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        
        # Load image (create dummy if not exists)
        image_path = os.path.join(self.data_dir, annotation['image_id'])
        if os.path.exists(image_path):
            image = Image.open(image_path).convert('RGB')
        else:
            # Create dummy image
            image = Image.new('RGB', (224, 224), color=(139, 69, 19))  # Brown color
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        # Prepare labels
        bean_type_label = self._get_bean_type_label(annotation['bean_type'])
        defect_labels = self._get_defect_labels(annotation['defects'])
        
        return {
            'image': image,
            'bean_type_label': bean_type_label,
            'defect_labels': defect_labels,
            'health_score': annotation['health_score'],
            'image_id': annotation['image_id']
        }
    
    def _get_bean_type_label(self, bean_type: str):
        """Convert bean type to label index"""
        bean_types = ["Arabica", "Robusta", "Liberica", "Excelsa"]
        return bean_types.index(bean_type) if bean_type in bean_types else 0
    
    def _get_defect_labels(self, defects: List):
        """Convert defects to label format"""
        defect_types = ["Mold", "Insect_Damage", "Discoloration", "Physical_Damage"]
        labels = []
        
        for defect in defects:
            if defect['type'] in defect_types:
                label = {
                    'boxes': torch.tensor([defect['bbox']], dtype=torch.float32),
                    'labels': torch.tensor([defect_types.index(defect['type']) + 1], dtype=torch.long),
                    'masks': torch.tensor([defect['mask']], dtype=torch.uint8)
                }
                labels.append(label)
        
        return labels

class ModelTrainer:
    """Trainer class for all models"""
    
    def __init__(self, device: str = 'cpu', models_dir: str = './models'):
        self.device = torch.device(device)
        self.models_dir = models_dir
        os.makedirs(models_dir, exist_ok=True)
        
        # Initialize models
        self.models = create_models(device)
        
        # Training parameters
        self.learning_rate = 0.001
        self.batch_size = 16
        self.num_epochs = 50
        self.save_interval = 10
        
        # Loss functions
        self.cnn_criterion = nn.CrossEntropyLoss()
        self.defect_criterion = self._get_defect_loss()
        
        # Optimizers
        self.optimizers = self._create_optimizers()
        
        # Training history
        self.training_history = {
            'cnn_loss': [], 'cnn_acc': [],
            'defect_loss': [], 'defect_map': [],
            'lstm_loss': [], 'lstm_mae': []
        }
    
    def _get_defect_loss(self):
        """Get loss function for defect detection"""
        return {
            'classification': nn.CrossEntropyLoss(),
            'bbox_regression': nn.SmoothL1Loss(),
            'mask_loss': nn.BCEWithLogitsLoss()
        }
    
    def _create_optimizers(self):
        """Create optimizers for all models"""
        return {
            'cnn': optim.Adam(self.models['cnn'].parameters(), lr=self.learning_rate),
            'defect_detector': optim.Adam(self.models['defect_detector'].parameters(), lr=self.learning_rate),
            'lstm': optim.Adam(self.models['lstm'].parameters(), lr=self.learning_rate)
        }
    
    def train_cnn(self, train_loader: DataLoader, val_loader: DataLoader = None):
        """Train the CNN classifier"""
        print("🚀 Training CNN Classifier...")
        
        model = self.models['cnn']
        optimizer = self.optimizers['cnn']
        model.train()
        
        for epoch in range(self.num_epochs):
            running_loss = 0.0
            correct = 0
            total = 0
            
            # Training loop
            for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
                images = batch['image'].to(self.device)
                labels = batch['bean_type_label'].to(self.device)
                
                # Forward pass
                optimizer.zero_grad()
                outputs = model(images)
                loss = self.cnn_criterion(outputs, labels)
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                # Statistics
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            
            # Calculate epoch metrics
            epoch_loss = running_loss / len(train_loader)
            epoch_acc = 100 * correct / total
            
            self.training_history['cnn_loss'].append(epoch_loss)
            self.training_history['cnn_acc'].append(epoch_acc)
            
            print(f'Epoch {epoch+1}: Loss={epoch_loss:.4f}, Accuracy={epoch_acc:.2f}%')
            
            # Save model periodically
            if (epoch + 1) % self.save_interval == 0:
                self._save_model('cnn', epoch + 1)
        
        print("✅ CNN Training Complete!")
        return self.training_history['cnn_loss'], self.training_history['cnn_acc']
    
    def train_defect_detector(self, train_loader: DataLoader, val_loader: DataLoader = None):
        """Train the Mask R-CNN defect detector"""
        print("🚀 Training Defect Detector...")
        
        model = self.models['defect_detector']
        optimizer = self.optimizers['defect_detector']
        model.train()
        
        for epoch in range(self.num_epochs):
            running_loss = 0.0
            
            # Training loop
            for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
                images = batch['image'].to(self.device)
                targets = batch['defect_labels']
                
                # Prepare targets for Mask R-CNN
                formatted_targets = self._format_targets(targets)
                
                # Forward pass
                optimizer.zero_grad()
                loss_dict = model(images, formatted_targets)
                
                # Calculate total loss
                total_loss = sum(loss_dict.values())
                
                # Backward pass
                total_loss.backward()
                optimizer.step()
                
                # Statistics
                running_loss += total_loss.item()
            
            # Calculate epoch metrics
            epoch_loss = running_loss / len(train_loader)
            self.training_history['defect_loss'].append(epoch_loss)
            
            print(f'Epoch {epoch+1}: Loss={epoch_loss:.4f}')
            
            # Save model periodically
            if (epoch + 1) % self.save_interval == 0:
                self._save_model('defect_detector', epoch + 1)
        
        print("✅ Defect Detector Training Complete!")
        return self.training_history['defect_loss']
    
    def train_lstm(self, train_loader: DataLoader, val_loader: DataLoader = None):
        """Train the LSTM for shelf life prediction"""
        print("🚀 Training LSTM...")
        
        model = self.models['lstm']
        optimizer = self.optimizers['lstm']
        model.train()
        
        for epoch in range(self.num_epochs):
            running_loss = 0.0
            running_mae = 0.0
            
            # Training loop
            for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
                # Create dummy sequence data for training
                seq_length = 10
                batch_size = batch['image'].size(0)
                
                # Generate dummy defect sequences
                sequences = torch.randn(batch_size, seq_length, 64).to(self.device)
                targets = torch.tensor([batch['health_score'] * 30 for _ in range(batch_size)], 
                                     dtype=torch.float32).to(self.device)  # Convert to days
                
                # Forward pass
                optimizer.zero_grad()
                outputs, _ = model(sequences)
                loss = nn.MSELoss()(outputs.squeeze(), targets)
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                # Statistics
                running_loss += loss.item()
                running_mae += torch.mean(torch.abs(outputs.squeeze() - targets)).item()
            
            # Calculate epoch metrics
            epoch_loss = running_loss / len(train_loader)
            epoch_mae = running_mae / len(train_loader)
            
            self.training_history['lstm_loss'].append(epoch_loss)
            self.training_history['lstm_mae'].append(epoch_mae)
            
            print(f'Epoch {epoch+1}: Loss={epoch_loss:.4f}, MAE={epoch_mae:.2f}')
            
            # Save model periodically
            if (epoch + 1) % self.save_interval == 0:
                self._save_model('lstm', epoch + 1)
        
        print("✅ LSTM Training Complete!")
        return self.training_history['lstm_loss'], self.training_history['lstm_mae']
    
    def _format_targets(self, targets: List):
        """Format targets for Mask R-CNN training"""
        formatted = []
        for target in targets:
            if target:  # If defects exist
                formatted.append({
                    'boxes': target[0]['boxes'].to(self.device),
                    'labels': target[0]['labels'].to(self.device),
                    'masks': target[0]['masks'].to(self.device)
                })
            else:  # No defects
                formatted.append({
                    'boxes': torch.empty((0, 4), dtype=torch.float32).to(self.device),
                    'labels': torch.empty((0,), dtype=torch.long).to(self.device),
                    'masks': torch.empty((0, 224, 224), dtype=torch.uint8).to(self.device)
                })
        return formatted
    
    def _save_model(self, model_name: str, epoch: int):
        """Save model checkpoint"""
        save_path = os.path.join(self.models_dir, f'{model_name}_epoch_{epoch}.pth')
        torch.save(self.models[model_name].state_dict(), save_path)
        print(f"💾 Saved {model_name} checkpoint: {save_path}")
    
    def save_final_models(self):
        """Save final trained models"""
        for name, model in self.models.items():
            save_path = os.path.join(self.models_dir, f'{name}_final.pth')
            torch.save(model.state_dict(), save_path)
            print(f"💾 Saved final {name} model: {save_path}")
    
    def plot_training_history(self):
        """Plot training history"""
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # CNN metrics
        axes[0, 0].plot(self.training_history['cnn_loss'])
        axes[0, 0].set_title('CNN Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        
        axes[0, 1].plot(self.training_history['cnn_acc'])
        axes[0, 1].set_title('CNN Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy (%)')
        
        # Defect detector metrics
        axes[0, 2].plot(self.training_history['defect_loss'])
        axes[0, 2].set_title('Defect Detector Loss')
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('Loss')
        
        # LSTM metrics
        axes[1, 0].plot(self.training_history['lstm_loss'])
        axes[1, 0].set_title('LSTM Loss')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Loss')
        
        axes[1, 1].plot(self.training_history['lstm_mae'])
        axes[1, 1].set_title('LSTM MAE')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('MAE')
        
        # Hide empty subplot
        axes[1, 2].set_visible(False)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.models_dir, 'training_history.png'))
        plt.show()


def main():
    """Main training function"""
    print("🎯 BeanScan Deep Learning Model Training")
    print("=" * 50)
    
    # Initialize trainer
    trainer = ModelTrainer(device='cpu')  # Use 'cuda' if GPU available
    
    # Create dummy datasets (replace with real data)
    train_dataset = BeanImageDataset('./data', split='train')
    val_dataset = BeanImageDataset('./data', split='val')
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    
    print(f"📊 Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}")
    
    # Train all models
    try:
        # Train CNN
        cnn_loss, cnn_acc = trainer.train_cnn(train_loader, val_loader)
        
        # Train Defect Detector
        defect_loss = trainer.train_defect_detector(train_loader, val_loader)
        
        # Train LSTM
        lstm_loss, lstm_mae = trainer.train_lstm(train_loader, val_loader)
        
        # Save final models
        trainer.save_final_models()
        
        # Plot training history
        trainer.plot_training_history()
        
        print("\n🎉 All models trained successfully!")
        print("📁 Models saved in:", trainer.models_dir)
        
    except Exception as e:
        print(f"❌ Training failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()
'''

with open(os.path.join(base_dir, 'custom_models.py'), 'w', encoding='utf-8') as f:
    f.write(custom_models_src)
with open(os.path.join(base_dir, 'train_models.py'), 'w', encoding='utf-8') as f:
    f.write(train_models_src)

import sys
if '/kaggle/working/backend' not in sys.path:
    sys.path.append('/kaggle/working/backend')
if '/kaggle/working/backend/ml' not in sys.path:
    sys.path.append('/kaggle/working/backend/ml')

print('Wrote inline modules to', base_dir)


In [None]:
# === Prepare code and paths ===
import os, sys, shutil
os.makedirs(MODELS_DIR, exist_ok=True)

# Copy code from attached dataset into the working dir for clean imports
if os.path.exists(DATASET_CODE):
    if os.path.exists('/kaggle/working/backend'):
        shutil.rmtree('/kaggle/working/backend')
    shutil.copytree(os.path.join(DATASET_CODE, 'backend'), '/kaggle/working/backend')
else:
    print("WARNING: DATASET_CODE not found. Ensure you attached your code dataset.")

# Add ml path to sys.path for imports
sys.path.append('/kaggle/working/backend/ml')
sys.path.append('/kaggle/working/backend')

print("Python version:")
import sys as _sys; print(_sys.version)
print("Torch version:")
import torch as _torch; print(_torch.__version__, 'CUDA:', _torch.cuda.is_available())


In [None]:
# === Patch training script defaults for Kaggle ===
# We import ModelTrainer and BeanImageDataset, then override parameters/paths.
from backend.ml.train_models import ModelTrainer, BeanImageDataset
from torch.utils.data import DataLoader
import torch

# Choose device
device = 'cuda' if (USE_GPU and torch.cuda.is_available()) else 'cpu'
print('Using device:', device)

# Build datasets from Kaggle input
train_dataset = BeanImageDataset(DATASET_IMAGES, split='train')
val_dataset = BeanImageDataset(DATASET_IMAGES, split='val')  # keep same dir structure

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=(device=='cuda'))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=(device=='cuda'))

# Initialize trainer and override epochs/save intervals and models_dir
trainer = ModelTrainer(device=device, models_dir=MODELS_DIR)
trainer.num_epochs = NUM_EPOCHS
trainer.batch_size = BATCH_SIZE
trainer.save_interval = SAVE_INTERVAL



In [None]:
# === (Optional) Resume from pretrained weights ===
# If you attach a weights dataset OR a Kaggle Model, you can load weights here.
# Examples:
# - DATASET_WEIGHTS = "/kaggle/input/beanscan-weights" and it contains cnn_best.pth
# - WEIGHTS_FILE = "/kaggle/input/cnn-cnn-v1/cnn_best.pth" from the Models panel
import os, glob, torch

loaded = False
if WEIGHTS_FILE and os.path.exists(WEIGHTS_FILE):
    print(f"Loading CNN weights file: {WEIGHTS_FILE}")
    trainer.models['cnn'].load_state_dict(torch.load(WEIGHTS_FILE, map_location=device), strict=False)
    loaded = True

if (not loaded) and DATASET_WEIGHTS and os.path.exists(DATASET_WEIGHTS):
    def try_load(model, pattern):
        files = sorted(glob.glob(os.path.join(DATASET_WEIGHTS, pattern)))
        if files:
            path = files[-1]
            print(f"Loading weights: {path}")
            state = torch.load(path, map_location=device)
            model.load_state_dict(state, strict=False)
            return True
        return False

    loaded = try_load(trainer.models['cnn'], 'cnn*.pth') or loaded
    _ = try_load(trainer.models['defect_detector'], 'defect*rcnn*.pth')
    _ = try_load(trainer.models['lstm'], 'lstm*.pth')

if not loaded:
    print("No CNN weights loaded; training CNN from scratch.")


In [None]:
# === Train (CNN only to start) ===
cnn_loss, cnn_acc = trainer.train_cnn(train_loader, val_loader)

# Save final CNN weights and training history plot
trainer.save_final_models()
trainer.plot_training_history()

print("Models saved to:", MODELS_DIR)


## Notes
- Place `train_annotations.json` and `val_annotations.json` in your images dataset root so `BeanImageDataset` can find them.
- Images should be accessible by `image_id` inside that same root.
- Outputs are written to `/kaggle/working/models`; save a notebook version to persist.
