# Improved Food Recognition & Weight Estimation Model

This notebook implements a multi-task learning model for food recognition and weight estimation with advanced training utilities:
- Early stopping
- Top-N model saving
- Resume from checkpoint capability
- Adaptive learning rate scheduling
- Training visualization

The model uses EfficientNet-B0 as the backbone with separate heads for classification and regression.

## 1. Import Required Libraries

In [None]:
# Standard imports
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import accuracy_score, mean_absolute_error

# Torch imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from torch.optim import Adam
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

try:
    from one_cycle_lr import OneCycleLR as CustomOneCycleLR
except ImportError:
    CustomOneCycleLR = None

# Check device availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Model Architecture

This model architecture combines a shared EfficientNet-B0 backbone with separate heads for food classification and weight estimation.

In [None]:
class MultiTaskNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # Load pretrained EfficientNet-B0
        self.backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)

        # Get the number of features before the classification layer
        num_features = self.backbone.classifier[1].in_features

        # Remove original classifier head
        self.backbone.classifier = nn.Identity()

        # Add multitask heads
        self.classifier = nn.Linear(num_features, num_classes)
        self.regressor = nn.Linear(num_features, 1)

    def forward(self, x):
        features = self.backbone(x)
        class_logits = self.classifier(features)
        weight_pred = self.regressor(features).squeeze(1)
        return class_logits, weight_pred

## 3. Dataset Implementation

The `FoodDataset` class handles loading and preprocessing food images along with their classification labels and weight information.

In [None]:
class FoodDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ])

        # Cache for found paths to speed up loading
        self.path_cache = {}
    
    def __len__(self):
        return len(self.df)
    
    def _try_load_image(self, path, row):
        """Helper to attempt loading an image from a path"""
        try:
            image = Image.open(path).convert('RGB')
            image = self.transform(image)
            return image, torch.tensor(row['label_idx'], dtype=torch.long), torch.tensor(row['weight'], dtype=torch.float32)
        except Exception:
            return None
    
    def _create_placeholder(self, row):
        """Create a placeholder black image for missing files"""
        image = Image.new('RGB', (224, 224), color='black')
        image = self.transform(image)
        return image, torch.tensor(row['label_idx'], dtype=torch.long), torch.tensor(row['weight'], dtype=torch.float32)
    
    def _find_image_path(self, img_name):
        """Find the correct path for an image, handling different cases and extensions"""
        # Get base name without extension
        name_without_ext, _ = os.path.splitext(img_name)
        
        # 1. Try original path first
        original_path = os.path.join(self.image_dir, img_name)
        if os.path.exists(original_path):
            return original_path
            
        # 2. Try with common extensions
        for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
            test_path = os.path.join(self.image_dir, name_without_ext + ext)
            if os.path.exists(test_path):
                return test_path
        
        # 3. Case-insensitive search
        try:
            name_lower = name_without_ext.lower()
            for file in os.listdir(self.image_dir):
                file_name, _ = os.path.splitext(file)
                if file_name.lower() == name_lower:
                    return os.path.join(self.image_dir, file)
        except Exception as e:
            print(f"Error during file search: {e}")
            
        # Not found
        return None
    
    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            img_name = row['image_name']
            
            # Check cache first
            if img_name in self.path_cache:
                cached_path = self.path_cache[img_name]
                if cached_path == "PLACEHOLDER":
                    return self._create_placeholder(row)
                    
                result = self._try_load_image(cached_path, row)
                if result is not None:
                    return result
                # Path no longer valid, clear from cache
                del self.path_cache[img_name]
            
            # Try to find the image path
            img_path = self._find_image_path(img_name)
            
            if img_path:
                # Found a path, try to load it
                result = self._try_load_image(img_path, row)
                if result is not None:
                    self.path_cache[img_name] = img_path
                    return result
            
            # If we get here, image wasn't found or couldn't be loaded
            print(f"Warning: Image {img_name} not found or corrupted, using placeholder")
            self.path_cache[img_name] = "PLACEHOLDER"
            return self._create_placeholder(row)
            
        except Exception as e:
            print(f"Unexpected error for index {idx}: {e}")
            return self._create_placeholder(row)

## 4. Data Preparation Functions

These functions prepare the data for training, including loading the CSV file, creating label mappings, and setting up train/validation splits.

In [None]:
def prepare_data(csv_path, images_dir, batch_size=16, num_workers=0):
    """Prepare data for training and validation"""
    # Load and process CSV
    df = pd.read_csv(csv_path, sep=';', quotechar='"')
    print(f"Successfully loaded {len(df)} records from {csv_path}")

    # Create label-to-index mapping
    label_to_idx = {label: idx for idx, label in enumerate(df['labels'].unique())}
    df['label_idx'] = df['labels'].map(label_to_idx)

    print(f"Number of classes: {len(label_to_idx)}")
    print(df.head())

    # Define train and val transforms separately
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Create base dataset with dummy transform (we'll override it below)
    base_dataset = FoodDataset(df, images_dir, transform=None)

    # Split into train/val
    train_size = int(0.8 * len(base_dataset))
    val_size = len(base_dataset) - train_size
    train_dataset, val_dataset = random_split(base_dataset, [train_size, val_size])

    # Assign correct transforms
    train_dataset.dataset.transform = train_transform
    val_dataset.dataset.transform = val_transform

    # DataLoaders
    train_dataloader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )

    val_dataloader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    print(f"Training on {train_size} samples, validating on {val_size} samples")

    return train_dataloader, val_dataloader, label_to_idx

## 5. Learning Rate Scheduler Setup

In [None]:
def get_scheduler(optimizer, lr_strategy, num_epochs, steps_per_epoch, lr, min_lr=1e-6):
    if lr_strategy == 'one_cycle':
        if CustomOneCycleLR is not None:
            return CustomOneCycleLR(
                optimizer,
                max_lr=lr,
                total_epochs=num_epochs,
                steps_per_epoch=steps_per_epoch
            )
        else:
            return torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=lr,
                steps_per_epoch=steps_per_epoch,
                epochs=num_epochs
            )
    elif lr_strategy == 'cosine':
        return CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr)
    elif lr_strategy == 'step':
        return StepLR(optimizer, step_size=5, gamma=0.75)
    else:
        return None

## 6. Training and Validation Functions

In [None]:
def _train_one_epoch(model, dataloader, optimizer, scheduler, device, criterion_class, criterion_weight, lr_strategy):
    model.train()
    running_loss = 0.0
    for images, labels, weights in dataloader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        weights = weights.to(device, non_blocking=True)
        optimizer.zero_grad()
        outputs_class, outputs_weight = model(images)
        loss_class = criterion_class(outputs_class, labels)
        loss_weight = criterion_weight(outputs_weight, weights)
        total_loss = 0.7 * loss_class + 0.3 * loss_weight
        total_loss.backward()
        optimizer.step()
        if scheduler is not None and lr_strategy == 'one_cycle':
            scheduler.step()
        running_loss += total_loss.item()
    return running_loss / len(dataloader)

def _validate_one_epoch(model, dataloader, device, criterion_class, criterion_weight):
    model.eval()
    running_loss = 0.0
    all_preds, all_labels = [], []
    all_weight_preds, all_weight_true = [], []
    with torch.no_grad():
        for images, labels, weights in dataloader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            weights = weights.to(device, non_blocking=True)
            outputs_class, outputs_weight = model(images)
            loss_class = criterion_class(outputs_class, labels)
            loss_weight = criterion_weight(outputs_weight, weights)
            total_loss = loss_class + loss_weight
            running_loss += total_loss.item()
            _, predicted = torch.max(outputs_class, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_weight_preds.extend(outputs_weight.cpu().numpy())
            all_weight_true.extend(weights.cpu().numpy())
    avg_loss = running_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    mae = mean_absolute_error(all_weight_true, all_weight_preds)
    return avg_loss, accuracy, mae

## 7. Model Saving Functions

In [None]:
def save_model(model, optimizer, avg_val_loss, val_accuracy, val_mae, label_to_idx, model_path, epoch, score):
    """Helper function to save a model checkpoint"""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': avg_val_loss,
        'val_accuracy': val_accuracy,
        'val_mae': val_mae,
        'composite_score': score,
        'label_to_idx': label_to_idx
    }, model_path)
    return model_path

def save_best_model(model, optimizer, avg_val_loss, val_accuracy, val_mae, label_to_idx, model_save_dir, epoch):
    """
    Save the best model based on a composite score of accuracy and MAE.
    Returns whether a model was saved and the current score.
    """
    model_path = os.path.join(model_save_dir, "best_model.pth")
    save_model_flag = False
    prev_val_accuracy = 0
    prev_val_loss = float('inf')
    prev_val_mae = float('inf')
    prev_score = -float('inf')
    current_score = val_accuracy - (val_mae / 100)
    
    if os.path.exists(model_path):
        try:
            prev_ckpt = torch.load(model_path, map_location='cpu')
            prev_val_accuracy = prev_ckpt.get('val_accuracy', 0)
            prev_val_loss = prev_ckpt.get('val_loss', float('inf'))
            prev_val_mae = prev_ckpt.get('val_mae', float('inf'))
            if epoch == 0:
                print(f"Found existing model with val_loss={prev_val_loss:.4f}, val_accuracy={prev_val_accuracy:.4f}, val_mae={prev_val_mae:.2f}g")
            
            # Composite score: Higher accuracy is better, lower MAE is better
            # Scale MAE to be roughly in the same range as accuracy (0-1)
            prev_score = prev_val_accuracy - (prev_val_mae / 100)
            
            if current_score > prev_score:
                save_model_flag = True
                print(f"Better combined score: {current_score:.4f} vs {prev_score:.4f}")
            else:
                print(f"Worse combined score: {current_score:.4f} vs {prev_score:.4f} - not saving")
                
            # Fallback to accuracy-only comparison if scores are very close
            if not save_model_flag and val_accuracy > prev_val_accuracy * 1.05:
                save_model_flag = True
                print(f"Significantly better accuracy: {val_accuracy:.4f} vs {prev_val_accuracy:.4f}")
                
        except Exception as e:
            print(f"Warning: Could not load previous checkpoint for comparison: {e}")
            save_model_flag = True
    else:
        save_model_flag = True
        print("No previous model found, saving first model")
    
    if save_model_flag:
        save_model(model, optimizer, avg_val_loss, val_accuracy, val_mae, label_to_idx, model_path, epoch, current_score)
        print(f"Model saved to {model_path}")
    else:
        print(f"Model not saved (composite score {current_score:.4f} <= previous {prev_score:.4f})")
        
    return save_model_flag, current_score

def manage_topn_models(model, optimizer, avg_val_loss, val_accuracy, val_mae, label_to_idx, 
                      model_save_dir, epoch, current_score, top_n_models, save_top_n):
    """
    Manage saving top-N models - returns the updated list of top models
    """
    model_path_base = os.path.join(model_save_dir, f"model_epoch_{epoch+1}")
    model_path = f"{model_path_base}.pth"
    
    # Keep track of top models 
    if len(top_n_models) < save_top_n or current_score > top_n_models[-1][0]:
        # Save this model as one of the top-N
        save_model(model, optimizer, avg_val_loss, val_accuracy, val_mae, label_to_idx, model_path, epoch, current_score)
        
        # Add to list and sort
        top_n_models.append((current_score, epoch, model_path))
        top_n_models.sort(reverse=True)  # Sort by score (highest first)
        
        # If we have more than N models, remove the worst one
        if len(top_n_models) > save_top_n:
            _, _, old_path = top_n_models.pop()  # Remove the lowest score model
            if os.path.exists(old_path):
                try:
                    os.remove(old_path)
                    print(f"Removed model with lower score: {old_path}")
                except Exception as e:
                    print(f"Error removing old model file: {e}")
        
        print(f"Saved as a top-{min(len(top_n_models), save_top_n)} model with score {current_score:.4f}")
    
    return top_n_models

## 8. Early Stopping and Checkpoint Management

In [None]:
def handle_early_stopping(current_score, best_score, best_epoch, no_improve_count, patience, optimizer, lr_strategy, scheduler, epoch):
    """
    Handle early stopping logic and adaptive learning rate adjustments.
    Returns: best_score, best_epoch, no_improve_count, should_stop
    """
    should_stop = False
    
    if current_score > best_score:
        best_score = current_score
        best_epoch = epoch
        no_improve_count = 0
    else:
        no_improve_count += 1
        
    # Adapting scheduler based on stagnation for non-one-cycle strategies
    if no_improve_count > 0 and lr_strategy != 'one_cycle' and scheduler is not None:
        current_lr = optimizer.param_groups[0]['lr']
        
        # Reduce LR more aggressively if we're stagnating
        if no_improve_count >= patience // 2:
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr * 0.5
            print(f"Reducing learning rate to {optimizer.param_groups[0]['lr']} due to stagnation")
            
    print(f"Current best: epoch {best_epoch+1} with score {best_score:.4f}, no improvement for {no_improve_count} epochs")
    
    # Check for early stopping
    if no_improve_count >= patience:
        print(f"Early stopping at epoch {epoch+1} after {patience} epochs without improvement")
        print(f"Best performance was at epoch {best_epoch+1}")
        should_stop = True
    
    return best_score, best_epoch, no_improve_count, should_stop

def resume_from_checkpoint(model, optimizer, model_save_dir, device):
    """
    Attempt to resume training from a saved checkpoint.
    Returns: start_epoch (0 if no valid checkpoint)
    """
    start_epoch = 0
    model_path = os.path.join(model_save_dir, "best_model.pth")
    
    if os.path.exists(model_path):
        try:
            checkpoint = torch.load(model_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint.get('epoch', 0) + 1  # Start from the next epoch
            print(f"Resuming from epoch {start_epoch+1}, checkpoint accuracy: {checkpoint.get('val_accuracy', 0):.4f}")
            return start_epoch
        except Exception as e:
            print(f"Failed to load checkpoint for resuming: {e}")
    
    print("No valid checkpoint found. Starting from epoch 1.")
    return 0

## 9. Training Visualization

In [None]:
def plot_training_history(logs, plots_dir):
    """Plot and save training metrics history"""
    # Plot Loss
    plt.figure(figsize=(10, 6))
    plt.plot(logs["epochs"], logs["train_loss"], label="Train Loss")
    plt.plot(logs["epochs"], logs["val_loss"], label="Validation Loss")
    plt.title("Training and Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, "loss_history.png"))
    plt.close()
    
    # Plot Accuracy and MAE
    fig, ax1 = plt.subplots(figsize=(10, 6))
    
    color = 'tab:blue'
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy', color=color)
    ax1.plot(logs["epochs"], logs["val_accuracy"], color=color, label="Accuracy")
    ax1.tick_params(axis='y', labelcolor=color)
    
    ax2 = ax1.twinx()  # Create a second y-axis
    color = 'tab:red'
    ax2.set_ylabel('MAE (g)', color=color)
    ax2.plot(logs["epochs"], logs["weight_mae"], color=color, label="MAE")
    ax2.tick_params(axis='y', labelcolor=color)
    
    fig.tight_layout()
    plt.title("Validation Accuracy and Weight MAE")
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, "accuracy_mae_history.png"))
    plt.close()
    
    # Plot composite score
    plt.figure(figsize=(10, 6))
    plt.plot(logs["epochs"], logs["scores"], marker='o')
    plt.title("Model Performance Score")
    plt.xlabel("Epoch")
    plt.ylabel("Score (Accuracy - MAE/100)")
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, "score_history.png"))
    plt.close()

## 10. Main Training Function

In [None]:
def train_model(model, train_dataloader, val_dataloader, device, num_epochs, model_save_dir, 
              lr_strategy='one_cycle', lr=1e-4, patience=5, save_top_n=3, resume=False):
    """
    Train the model with advanced features like early stopping, top-N model saving,
    and checkpoint resuming.
    
    Args:
        model: Model to train
        train_dataloader: DataLoader for training data
        val_dataloader: DataLoader for validation data
        device: Device to use for training
        num_epochs: Number of epochs to train for
        model_save_dir: Directory to save the model
        lr_strategy: Learning rate strategy ('one_cycle', 'cosine', 'step')
        lr: Learning rate
        patience: Number of epochs to wait for improvement before early stopping
        save_top_n: Number of best models to save
        resume: Whether to resume training from a checkpoint
    
    Returns:
        training_logs: Dictionary containing training metrics
    """
    # Loss functions
    criterion_class = nn.CrossEntropyLoss()
    criterion_weight = nn.MSELoss()
    
    # Optimizer
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    
    # Attempt to resume from checkpoint if requested
    start_epoch = 0
    if resume:
        start_epoch = resume_from_checkpoint(model, optimizer, model_save_dir, device)
    
    # Initialize scheduler after potential resume
    steps_per_epoch = len(train_dataloader)
    scheduler = get_scheduler(optimizer, lr_strategy, num_epochs, steps_per_epoch, lr)
    
    print(f"Using device: {device}")
    print(f"Optimizer: Adam, lr={lr}")
    print(f"Scheduler: {lr_strategy}")
    print(f"Early stopping patience: {patience}, save top {save_top_n} models")
    print(f"Starting training from epoch {start_epoch+1} to {num_epochs}...")
    
    # Initialize early stopping variables
    best_score = float('-inf')
    best_epoch = 0
    no_improve_count = 0
    
    # Initialize model saving variables for top-N models
    top_n_models = []  # Will store (score, epoch, path) tuples

    # Metrics tracking
    training_logs = {
        "epochs": [],
        "train_loss": [],
        "val_loss": [],
        "val_accuracy": [],
        "weight_mae": [],
        "scores": []
    }

    # Training loop
    for epoch in range(start_epoch, num_epochs):
        # Training phase
        avg_train_loss = _train_one_epoch(
            model, train_dataloader, optimizer, scheduler, 
            device, criterion_class, criterion_weight, lr_strategy
        )
        
        # Validation phase
        avg_val_loss, val_accuracy, val_mae = _validate_one_epoch(
            model, val_dataloader, device, criterion_class, criterion_weight
        )
        
        # Calculate composite score
        current_score = val_accuracy - (val_mae / 100)
        
        print(f"Epoch {epoch+1}/{num_epochs}: Train Loss = {avg_train_loss:.4f}, "
              f"Val Loss = {avg_val_loss:.4f}, "
              f"Val Acc = {val_accuracy:.4f}, Weight MAE = {val_mae:.2f}g, "
              f"Score = {current_score:.4f}")
        
        # Save epoch metrics to log
        training_logs["epochs"].append(epoch + 1)
        training_logs["train_loss"].append(avg_train_loss)
        training_logs["val_loss"].append(avg_val_loss)
        training_logs["val_accuracy"].append(val_accuracy)
        training_logs["weight_mae"].append(val_mae)
        training_logs["scores"].append(current_score)

        # Safely extract label mapping
        label_df = getattr(train_dataloader.dataset.dataset, 'df', None)
        if label_df is not None:
            label_to_idx = {label: idx for idx, label in enumerate(label_df['labels'].unique())}
        else:
            label_to_idx = {}
        
        # Save best model
        _, _ = save_best_model(
            model, optimizer, avg_val_loss, val_accuracy, val_mae,
            label_to_idx, model_save_dir, epoch
        )
        
        # Manage top-N models
        top_n_models = manage_topn_models(
            model, optimizer, avg_val_loss, val_accuracy, val_mae,
            label_to_idx, model_save_dir, epoch, current_score,
            top_n_models, save_top_n
        )
        
        # Handle early stopping
        best_score, best_epoch, no_improve_count, should_stop = handle_early_stopping(
            current_score, best_score, best_epoch, no_improve_count,
            patience, optimizer, lr_strategy, scheduler, epoch
        )
        
        if should_stop:
            break
        
        # Step the scheduler for epoch-based schedulers
        if scheduler is not None and lr_strategy != 'one_cycle':
            scheduler.step()

    # Print top-N model summary
    print("\nTop models summary:")
    for i, (score, ep, path) in enumerate(top_n_models):
        print(f"{i+1}. Epoch {ep+1}: score={score:.4f}, saved at {path}")
    
    # Save a model summary JSON with info about the top models
    models_summary = {
        "best_model": {
            "epoch": best_epoch + 1,
            "score": float(best_score) 
        },
        "top_models": [
            {"epoch": e+1, "score": float(s), "path": p} for s, e, p in top_n_models
        ],
        "training_completed": True,
        "early_stopped": no_improve_count >= patience,
        "total_epochs_trained": epoch + 1
    }
    
    with open(os.path.join(model_save_dir, "models_summary.json"), "w") as f:
        json.dump(models_summary, f, indent=4)

    # Create a training plots directory
    plots_dir = os.path.join(model_save_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)
    
    # Plot training history
    plot_training_history(training_logs, plots_dir)
    
    return training_logs

## 11. Training Configuration (Edit as Needed)

In [None]:
# Set your training parameters here
# You can adjust these parameters as needed for your specific setup

# Colab/Kaggle specific paths
# For Colab, you might need to adapt these based on where you upload your data
# For Kaggle, these should match your notebook environment

# When running locally, you might use something like:
#csv_path = "/Users/chalkiasantonios/Desktop/master-thesis/csvfiles/labels.csv"
#images_dir = "/Users/chalkiasantonios/Desktop/master-thesis/image_set_2"
#model_save_dir = "/Users/chalkiasantonios/Desktop/master-thesis/models"

# For Google Colab, you might use:
import os
from google.colab import drive

# Mount Google Drive (uncomment if needed)
# drive.mount('/content/drive')

# Define paths (edit as needed for your environment)
# Kaggle paths
USE_KAGGLE = True  # Set to False if using Colab or local environment

if USE_KAGGLE:
    master_thesis_dir = "/kaggle/working"
    data_path = "/kaggle/input/data-set-labeld-weights"
    csv_path = os.path.join(data_path, "labels.csv")
    images_dir = os.path.join(data_path, "image_set_2")
else:
    # Local or Colab paths
    master_thesis_dir = "/Users/chalkiasantonios/Desktop/master-thesis"
    csv_path = os.path.join(master_thesis_dir, "csvfiles", "labels.csv")
    images_dir = os.path.join(master_thesis_dir, "image_set_2")

model_save_dir = os.path.join(master_thesis_dir, "models")
os.makedirs(model_save_dir, exist_ok=True)

# Training hyperparameters
num_epochs = 20
batch_size = 16
lr = 1e-4
lr_strategy = 'cosine'  # 'one_cycle', 'cosine', or 'step'
patience = 7
save_top_n = 3
resume = False
num_workers = 2  # Adjust based on your CPU cores

## 12. Execute Training

In [None]:
# Print configuration
print(f"CSV path: {csv_path}")
print(f"Images directory: {images_dir}")
print(f"Model save directory: {model_save_dir}")
print(f"Training parameters: epochs={num_epochs}, batch_size={batch_size}, lr={lr}")
print(f"LR strategy: {lr_strategy}, patience: {patience}, save top {save_top_n} models")

# Prepare data
train_dataloader, val_dataloader, label_to_idx = prepare_data(
    csv_path, 
    images_dir, 
    batch_size=batch_size, 
    num_workers=num_workers
)

# Initialize model
num_classes = len(label_to_idx)
model = MultiTaskNet(num_classes)
model.to(device)

# Train the model
training_logs = train_model(
    model, 
    train_dataloader, 
    val_dataloader, 
    device, 
    num_epochs, 
    model_save_dir,
    lr_strategy=lr_strategy,
    lr=lr,
    patience=patience,
    save_top_n=save_top_n,
    resume=resume
)

print("Training completed!")

## 13. Evaluation and Visualization of Results

In [None]:
# Visualize training results if available
plots_dir = os.path.join(model_save_dir, 'plots')
if os.path.exists(plots_dir):
    from IPython.display import Image, display
    
    print("Training Loss History:")
    display(Image(os.path.join(plots_dir, "loss_history.png")))
    
    print("\nAccuracy and MAE:")
    display(Image(os.path.join(plots_dir, "accuracy_mae_history.png")))
    
    print("\nComposite Score:")
    display(Image(os.path.join(plots_dir, "score_history.png")))
else:
    print("No training plots found. Run the training cell first.")

## 14. Load and Test the Best Model

In [None]:
def load_best_model(model_save_dir, device, num_classes):
    """Load the best model from the saved checkpoint"""
    model_path = os.path.join(model_save_dir, "best_model.pth")
    if not os.path.exists(model_path):
        print(f"No model found at {model_path}")
        return None, None
    
    # Create a new model instance
    model = MultiTaskNet(num_classes)
    
    try:
        # Load the checkpoint
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        
        print(f"Loaded model from epoch {checkpoint.get('epoch', 0) + 1}")
        print(f"Validation metrics:")
        print(f"- Accuracy: {checkpoint.get('val_accuracy', 0):.4f}")
        print(f"- MAE: {checkpoint.get('val_mae', 0):.2f}g")
        print(f"- Score: {checkpoint.get('composite_score', 0):.4f}")
        
        return model, checkpoint.get('label_to_idx', {})
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None

# Load the best model for testing
best_model, label_map = load_best_model(model_save_dir, device, num_classes)

if best_model is not None:
    print("\nBest model loaded successfully!")
else:
    print("\nFailed to load model. Make sure training has completed.")