In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import time
import argparse
import pickle
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error

def parse_args():
    parser = argparse.ArgumentParser(description="Train nutrition estimation models.")
    parser.add_argument('--model_name', type=str, required=True, 
                        choices=['SimpleConvNet', 'DeepConvNet', 'MobileNetLike', 'ResNetFromScratch', 'ResNetPretrained'],
                        help="Name of the model to train.")
    parser.add_argument('--base_dir', type=str, default="/users/eleves-b/2023/georgii.kuznetsov/CNN_nutrition/nutrition5k",
                        help="Base directory for the dataset.")
    parser.add_argument('--output_dir', type=str, default=".",
                        help="Directory to save model checkpoints and history.")
    parser.add_argument('--epochs', type=int, default=100, help="Number of training epochs.")
    parser.add_argument('--batch_size', type=int, default=32, help="Batch size for training.")
    parser.add_argument('--lr', type=float, default=1e-3, help="Learning rate.")
    parser.add_argument('--num_workers', type=int, default=0, help="Number of workers for DataLoader.")
    parser.add_argument('--no_plots', action='store_true', help="Disable plotting (for headless execution).")
    parser.add_argument('--seed', type=int, default=42, help="Random seed for reproducibility.")
    
    # Add a specific argument for choosing pre-trained or not for ResNet models
    # This could be inferred from model_name but making it explicit can be clearer
    # For now, ResNetPretrained implies pretrained=True, ResNetFromScratch implies pretrained=False

    return parser.parse_args()

# Global constants based on typical usage, some will be overridden by args
RGB_IMAGE_FILENAME = "rgb.png" 
TARGET_COLUMNS = ['calories_per_100g', 'fat_per_100g', 'carbs_per_100g', 'protein_per_100g']

In [None]:
# %%
# This cell would typically be run after parsing args in a script context
# For notebook execution, you might define args manually for testing:
# class Args:
#     model_name = 'DeepConvNet'
#     base_dir = "/users/eleves-b/2023/georgii.kuznetsov/CNN_nutrition/nutrition5k"
#     output_dir = "."
#     epochs = 10 # For quick test
#     batch_size = 32
#     lr = 1e-3
#     num_workers = 0
#     no_plots = False
#     seed = 42
# args = Args() # Uncomment for notebook testing

# If running as a script, args will be populated by parse_args()
# args = parse_args() # This line will be in the __main__ block

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("mps" if torch.mps.is_available() else "cpu") # If you prefer mps

# To be initialized in main script flow after args are parsed
# LOCAL_BASE_DIR = args.base_dir 
# IMAGERY_DIR = os.path.join(LOCAL_BASE_DIR, "imagery/realsense_overhead")
# METADATA_FILE_CAFE1 = os.path.join(LOCAL_BASE_DIR, "metadata/dish_metadata_cafe1.csv")
# METADATA_FILE_CAFE2 = os.path.join(LOCAL_BASE_DIR, "metadata/dish_metadata_cafe2.csv")

# BATCH_SIZE = args.batch_size
# LEARNING_RATE = args.lr
# NUM_EPOCHS = args.epochs

# Ensure output directory exists
# if not os.path.exists(args.output_dir):
#    os.makedirs(args.output_dir)

## Simple Convolutional

In [None]:
class SimpleConvNet(nn.Module):
    """Simple CNN from scratch"""
    def __init__(self, num_outputs=4):
        super(SimpleConvNet, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        
        # Calculate size after convolutions
        # 224 -> 112 -> 56 -> 28 -> 14 (after 4 pooling layers)
        self.fc1 = nn.Linear(256 * 14 * 14, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_outputs)
        
    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))
        
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.fc3(x)
        
        return x

## Deep convolutional network

In [None]:
class DeepConvNet(nn.Module):
    """Deeper CNN with residual connections"""
    def __init__(self, num_outputs=4):
        super(DeepConvNet, self).__init__()
        
        # Initial conv
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(3, stride=2, padding=1)
        
        # Residual blocks
        self.res_block1 = self._make_residual_block(64, 128)
        self.res_block2 = self._make_residual_block(128, 256)
        self.res_block3 = self._make_residual_block(256, 512)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_outputs)
        
    def _make_residual_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        
    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        
        x = self.res_block1(x)
        x = self.res_block2(x)
        x = self.res_block3(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x

## Mobile Network

In [None]:
class MobileNetLike(nn.Module):
    """Lightweight model inspired by MobileNet (depthwise separable convolutions)"""
    def __init__(self, num_outputs=4):
        super(MobileNetLike, self).__init__()
        
        def depthwise_separable_conv(in_channels, out_channels, stride=1):
            return nn.Sequential(
                # Depthwise
                nn.Conv2d(in_channels, in_channels, 3, stride=stride, 
                         padding=1, groups=in_channels),
                nn.BatchNorm2d(in_channels),
                nn.ReLU(inplace=True),
                # Pointwise
                nn.Conv2d(in_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        
        self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        
        self.dw_conv2 = depthwise_separable_conv(32, 64, stride=2)
        self.dw_conv3 = depthwise_separable_conv(64, 128, stride=2)
        self.dw_conv4 = depthwise_separable_conv(128, 256, stride=2)
        self.dw_conv5 = depthwise_separable_conv(256, 512, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_outputs)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.dw_conv2(x)
        x = self.dw_conv3(x)
        x = self.dw_conv4(x)
        x = self.dw_conv5(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x

## Pre trained

In [None]:
class ResNetFromScratch(nn.Module):
    """ResNet-like architmecture without pre-training"""
    def __init__(self, num_outputs=4, use_pretrained=False):
        super(ResNetFromScratch, self).__init__()
        # Use ResNet34 architecture but without pre-trained weights
        self.backbone = models.resnet34(pretrained=use_pretrained)
        num_features = self.backbone.fc.in_features
        
        # Replace the final layer
        self.backbone.fc = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_outputs)
        )
        
    def forward(self, x):
        return self.backbone(x)

## Loss

In [None]:
class MultiTaskLoss(nn.Module):
    def __init__(self, task_weights=None):
        super(MultiTaskLoss, self).__init__()
        if task_weights is None:
            # Number of target columns
            self.task_weights = torch.ones(len(TARGET_COLUMNS)) 
        else:
            self.task_weights = torch.tensor(task_weights, dtype=torch.float32)
    
    def forward(self, predictions, targets):
        # Calculate MAE for each task
        losses = torch.abs(predictions - targets) # Using L1Loss (MAE) as base
        
        # Ensure task weights are on the same device
        if self.task_weights.device != predictions.device:
            self.task_weights = self.task_weights.to(predictions.device)
        
        # Weight the losses
        weighted_losses = losses * self.task_weights
        
        # Return mean loss
        return weighted_losses.mean()

## Parsing Data

In [None]:
# %%
def parse_nutrition_csv(file_path):
    dishes = []
    # ingredients_list = [] # Original code had this, but it wasn't used.
    
    with open(file_path, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            if not parts[0].startswith('dish_'):
                continue
                
            try:
                dish_id = parts[0]
                dish_calories = float(parts[1])
                dish_weight = float(parts[2])
                dish_fat = float(parts[3])
                dish_carbs = float(parts[4])
                dish_protein = float(parts[5])
            except (IndexError, ValueError) as e:
                # print(f"Skipping malformed line: {line.strip()} - Error: {e}")
                continue

            if dish_weight == 0:
                continue
            
            dishes.append({
                'dish_id': dish_id,
                'calories': dish_calories,
                'weight': dish_weight,
                'fat': dish_fat,
                'carbs': dish_carbs,
                'protein': dish_protein,
                'calories_per_100g': (dish_calories / dish_weight) * 100,
                'fat_per_100g': (dish_fat / dish_weight) * 100,
                'carbs_per_100g': (dish_carbs / dish_weight) * 100,
                'protein_per_100g': (dish_protein / dish_weight) * 100
            })
    
    dish_df = pd.DataFrame(dishes)
    # ingredient_df = pd.DataFrame(ingredients_list) if ingredients_list else pd.DataFrame() # Original
    
    return dish_df #, ingredient_df

class NutritionDataset(Dataset):
    def __init__(self, dish_ids, labels, imagery_dir, transform=None):
        self.dish_ids = dish_ids
        self.labels = labels
        self.imagery_dir = imagery_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.dish_ids)
    
    def __getitem__(self, idx):
        dish_id = self.dish_ids[idx]
        
        img_path = os.path.join(self.imagery_dir, dish_id, RGB_IMAGE_FILENAME)
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        
        return image, label

# Define transforms (kept global as they are standard)
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    batch_losses = []
    
    pbar = tqdm(loader, desc='Training', leave=False)
    
    for batch_idx, (images, labels) in enumerate(pbar):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        batch_loss = loss.item()
        total_loss += batch_loss
        batch_losses.append(batch_loss)
        
        pbar.set_postfix({
            'loss': f'{batch_loss:.4f}',
            'avg_loss': f'{np.mean(batch_losses):.4f}'
        })
    
    pbar.close()
    return total_loss / len(loader)

def validate(model, loader, criterion, device, target_columns_list):
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    pbar = tqdm(loader, desc='Validating', leave=False)
    
    with torch.no_grad():
        for images, labels_batch in pbar: # Renamed labels to labels_batch
            images, labels_batch = images.to(device), labels_batch.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels_batch)
            
            batch_loss = loss.item()
            total_loss += batch_loss
            
            all_predictions.append(outputs.cpu().numpy())
            all_labels.append(labels_batch.cpu().numpy())
            
            pbar.set_postfix({'loss': f'{batch_loss:.4f}'})
    
    pbar.close()
    
    predictions_np = np.concatenate(all_predictions) # Renamed
    labels_np = np.concatenate(all_labels) # Renamed
    
    percentage_errors = {}
    for i, col in enumerate(target_columns_list): # Use passed target_columns_list
        mae = mean_absolute_error(labels_np[:, i], predictions_np[:, i])
        # Handle cases where mean_val might be zero or very small, or labels_np is empty
        if labels_np.shape[0] > 0 and labels_np[:, i].mean() != 0:
            mean_val = labels_np[:, i].mean()
            percentage_error = (mae / mean_val) * 100
        else:
            percentage_error = float('nan') # Or some other indicator for problematic calculation
        percentage_errors[col] = percentage_error
    
    return total_loss / len(loader), percentage_errors, predictions_np, labels_np

In [None]:
# %%
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    batch_losses = []
    
    pbar = tqdm(loader, desc='Training', leave=False)
    
    for batch_idx, (images, labels) in enumerate(pbar):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        batch_loss = loss.item()
        total_loss += batch_loss
        batch_losses.append(batch_loss)
        
        pbar.set_postfix({
            'loss': f'{batch_loss:.4f}',
            'avg_loss': f'{np.mean(batch_losses):.4f}'
        })
    
    pbar.close()
    return total_loss / len(loader)

def validate(model, loader, criterion, device, target_columns_list):
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    pbar = tqdm(loader, desc='Validating', leave=False)
    
    with torch.no_grad():
        for images, labels_batch in pbar: # Renamed labels to labels_batch
            images, labels_batch = images.to(device), labels_batch.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels_batch)
            
            batch_loss = loss.item()
            total_loss += batch_loss
            
            all_predictions.append(outputs.cpu().numpy())
            all_labels.append(labels_batch.cpu().numpy())
            
            pbar.set_postfix({'loss': f'{batch_loss:.4f}'})
    
    pbar.close()
    
    predictions_np = np.concatenate(all_predictions) # Renamed
    labels_np = np.concatenate(all_labels) # Renamed
    
    percentage_errors = {}
    for i, col in enumerate(target_columns_list): # Use passed target_columns_list
        mae = mean_absolute_error(labels_np[:, i], predictions_np[:, i])
        # Handle cases where mean_val might be zero or very small, or labels_np is empty
        if labels_np.shape[0] > 0 and labels_np[:, i].mean() != 0:
            mean_val = labels_np[:, i].mean()
            percentage_error = (mae / mean_val) * 100
        else:
            percentage_error = float('nan') # Or some other indicator for problematic calculation
        percentage_errors[col] = percentage_error
    
    return total_loss / len(loader), percentage_errors, predictions_np, labels_np

In [None]:
# %%
def calculate_metrics_per_100g(results_df_local): # Renamed to avoid conflict
    metrics_list = []
    for nutrient_col_prefix in ['calories', 'fat', 'carbs', 'protein']: # Iterate through base nutrient names
        true_col_name = f'{nutrient_col_prefix}_per_100g_true'
        pred_col_name = f'{nutrient_col_prefix}_per_100g_pred'

        # Check if columns exist, important if TARGET_COLUMNS was changed
        if true_col_name not in results_df_local.columns or pred_col_name not in results_df_local.columns:
            print(f"Warning: Columns for {nutrient_col_prefix} not found in results_df. Skipping metrics for it.")
            continue

        per_100g_true = results_df_local[true_col_name].values
        per_100g_pred = results_df_local[pred_col_name].values
        
        # Ensure there's data to calculate metrics on
        if len(per_100g_true) == 0 or len(per_100g_pred) == 0:
            print(f"Warning: No data for {nutrient_col_prefix} to calculate metrics. Skipping.")
            continue

        # Drop NaNs for metric calculation if any exist from data issues
        valid_indices = ~ (np.isnan(per_100g_true) | np.isnan(per_100g_pred))
        per_100g_true_valid = per_100g_true[valid_indices]
        per_100g_pred_valid = per_100g_pred[valid_indices]

        if len(per_100g_true_valid) == 0: # Still no data after NaN removal
             print(f"Warning: No valid (non-NaN) data for {nutrient_col_prefix} to calculate metrics. Skipping.")
             continue


        mae = mean_absolute_error(per_100g_true_valid, per_100g_pred_valid)
        rmse = np.sqrt(mean_squared_error(per_100g_true_valid, per_100g_pred_valid))
        r2 = r2_score(per_100g_true_valid, per_100g_pred_valid)
        
        mean_true = np.mean(per_100g_true_valid)
        percentage_error = (mae / mean_true) * 100 if mean_true != 0 else float('nan')
        
        metrics_list.append({
            'nutrient': f'{nutrient_col_prefix}_per_100g', # Use prefix
            'MAE': mae,
            'RMSE': rmse,
            'R²': r2,
            'Percentage Error': percentage_error,
            'Mean True': mean_true,
            'Mean Pred': np.mean(per_100g_pred_valid)
        })
    
    return pd.DataFrame(metrics_list)

def plot_training_history(history, model_name_str, output_dir_str, target_columns_list): # Added arguments
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    ax1.plot(history['train_loss'], label='Train Loss')
    ax1.plot(history['val_loss'], label='Val Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title(f'Training and Validation Loss ({model_name_str})')
    ax1.legend()

    percentage_df = pd.DataFrame(history['percentage_errors'])
    # Ensure columns in percentage_df match target_columns_list for plotting
    for col in target_columns_list: # Iterate over expected target_columns
        if col in percentage_df.columns:
             ax2.plot(percentage_df[col], label=col.replace('_per_100g',''))
        else:
            print(f"Warning: Column {col} not found in history['percentage_errors'] for plotting.")

    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Percentage Error (%)')
    ax2.set_title(f'Validation Percentage Errors ({model_name_str})')
    ax2.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir_str, f'training_history_{model_name_str}.png'))
    plt.close(fig) # Close figure to free memory
    print(f"Training history plot saved to {os.path.join(output_dir_str, f'training_history_{model_name_str}.png')}")


def plot_predictions_vs_actual(results_df_local, model_name_str, output_dir_str, target_columns_list):
    num_nutrients = len(target_columns_list)
    ncols = 2 
    nrows = (num_nutrients + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(7 * ncols, 6 * nrows))
    axes = axes.flatten()

    for i, nutrient_full_name in enumerate(target_columns_list): # e.g. 'calories_per_100g'
        ax = axes[i]
        true_col = f'{nutrient_full_name}_true'
        pred_col = f'{nutrient_full_name}_pred'

        if true_col not in results_df_local.columns or pred_col not in results_df_local.columns:
            print(f"Warning: Columns '{true_col}' or '{pred_col}' not found. Skipping plot for '{nutrient_full_name}'.")
            ax.set_title(f"{nutrient_full_name.replace('_per_100g', '').capitalize()} - Data Missing")
            ax.axis('off')
            continue
        
        x_data = results_df_local[true_col].values
        y_data = results_df_local[pred_col].values

        valid_indices = ~ (np.isnan(x_data) | np.isnan(y_data))
        x_plot, y_plot = x_data[valid_indices], y_data[valid_indices]

        if len(x_plot) == 0:
            ax.set_title(f"{nutrient_full_name.replace('_per_100g', '').capitalize()} - No Valid Data")
            ax.axis('off')
            continue
            
        ax.scatter(x_plot, y_plot, alpha=0.5, s=30, edgecolors='k', linewidth=0.5)
        min_val = min(x_plot.min(), y_plot.min())
        max_val = max(x_plot.max(), y_plot.max())
        ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect Prediction')

        if len(x_plot) > 1:
            z = np.polyfit(x_plot, y_plot, 1)
            p = np.poly1d(z)
            trend_x = np.array([x_plot.min(), x_plot.max()])
            ax.plot(trend_x, p(trend_x), "b-", alpha=0.8, label=f'Trend: y={z[0]:.2f}x+{z[1]:.2f}')
            r2 = r2_score(x_plot, y_plot)
            r2_text = f"R² = {r2:.3f}"
        else:
            r2_text = "R² = N/A"
            ax.plot([], [], "b-", alpha=0.8, label='Trend: N/A')


        display_nutrient_name = nutrient_full_name.replace('_per_100g', '').replace('_', ' ').capitalize()
        ax.set_xlabel(f'True {display_nutrient_name} (per 100g)')
        ax.set_ylabel(f'Predicted {display_nutrient_name} (per 100g)')
        ax.set_title(f'{display_nutrient_name} (per 100g)\n{r2_text}')
        ax.legend()
        ax.grid(True, alpha=0.3)

    for j in range(num_nutrients, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.suptitle(f'Predictions vs True Values ({model_name_str}, per 100g)', fontsize=16)
    plt.savefig(os.path.join(output_dir_str, f'predictions_vs_actual_{model_name_str}.png'))
    plt.close(fig)
    print(f"Predictions vs actual plot saved to {os.path.join(output_dir_str, f'predictions_vs_actual_{model_name_str}.png')}")


def show_sample_predictions_with_images(results_df_local, imagery_dir_str, model_name_str, output_dir_str, target_columns_list, n_samples=6):
    if len(results_df_local) == 0:
        print("No results to show samples from.")
        return
    
    # Ensure n_samples is not greater than available samples
    n_samples = min(n_samples, len(results_df_local))
    if n_samples == 0:
        print("Not enough samples to display.")
        return

    sample_indices = np.random.choice(len(results_df_local), n_samples, replace=False)
    
    ncols = 3
    nrows = (n_samples + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows))
    axes = axes.flatten()
    
    for i, ax_idx in enumerate(range(n_samples)): # Iterate only up to n_samples
        ax = axes[ax_idx]
        sample_df_idx = sample_indices[i] # Use i for sample_indices
        
        dish_id = results_df_local.iloc[sample_df_idx]['dish_id']
        
        img_path = os.path.join(imagery_dir_str, dish_id, RGB_IMAGE_FILENAME)
        try:
            img = Image.open(img_path)
            ax.imshow(img)
        except FileNotFoundError:
            ax.text(0.5, 0.5, "Image not found", ha='center', va='center')
            print(f"Warning: Image not found for {dish_id} at {img_path}")

        ax.axis('off')
        
        pred_text = "Pred:\n"
        true_text = "Actual (Err%):\n"
        
        for nutrient in target_columns_list: # e.g. 'calories_per_100g'
            pred_val = results_df_local.iloc[sample_df_idx][f'{nutrient}_pred']
            true_val = results_df_local.iloc[sample_df_idx][f'{nutrient}_true']
            
            # Make nutrient name shorter for display
            short_nutrient_name = nutrient.replace('_per_100g', '').capitalize()[:4]

            if true_val != 0:
                error = abs(pred_val - true_val) / true_val * 100
                error_str = f"({error:.0f}%)"
            else:
                error_str = "(N/A)"

            pred_text += f"{short_nutrient_name}: {pred_val:.0f}\n"
            true_text += f"{short_nutrient_name}: {true_val:.0f} {error_str}\n"
        
        ax.text(0.02, 0.02, pred_text, transform=ax.transAxes, 
                verticalalignment='bottom', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7), fontsize=8)
        ax.text(0.98, 0.02, true_text, transform=ax.transAxes, 
                verticalalignment='bottom', horizontalalignment='right',
                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.7), fontsize=8)
        
        ax.set_title(f"Dish: {dish_id[:15]}...", fontsize=9) # Shorten dish_id if too long
    
    # Hide unused subplots
    for j in range(n_samples, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.suptitle(f'Sample Predictions ({model_name_str})', fontsize=14)
    plt.savefig(os.path.join(output_dir_str, f'sample_predictions_{model_name_str}.png'))
    plt.close(fig)
    print(f"Sample predictions plot saved to {os.path.join(output_dir_str, f'sample_predictions_{model_name_str}.png')}")

## Defining the training

In [None]:
# %%
def run_training_and_evaluation(args):
    # Set seed for reproducibility
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    # --- Path and Parameter Setup ---
    LOCAL_BASE_DIR = args.base_dir
    IMAGERY_DIR = os.path.join(LOCAL_BASE_DIR, "imagery/realsense_overhead")
    METADATA_FILE_CAFE1 = os.path.join(LOCAL_BASE_DIR, "metadata/dish_metadata_cafe1.csv")
    METADATA_FILE_CAFE2 = os.path.join(LOCAL_BASE_DIR, "metadata/dish_metadata_cafe2.csv")

    assert(os.path.exists(LOCAL_BASE_DIR)), f"Base directory not found: {LOCAL_BASE_DIR}"
    assert(os.path.exists(IMAGERY_DIR)), f"Imagery directory not found: {IMAGERY_DIR}"
    assert(os.path.exists(METADATA_FILE_CAFE1)), f"Metadata cafe1 not found: {METADATA_FILE_CAFE1}"
    assert(os.path.exists(METADATA_FILE_CAFE2)), f"Metadata cafe2 not found: {METADATA_FILE_CAFE2}"
    
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    # --- Data Loading and Preprocessing ---
    print("Loading and preprocessing data...")
    dish_df_cafe1 = parse_nutrition_csv(METADATA_FILE_CAFE1)
    dish_df_cafe2 = parse_nutrition_csv(METADATA_FILE_CAFE2)
    dish_metadata_df = pd.concat([dish_df_cafe1, dish_df_cafe2], ignore_index=True)

    available_dishes = [d for d in os.listdir(IMAGERY_DIR) 
                       if os.path.isdir(os.path.join(IMAGERY_DIR, d)) and 
                       os.path.exists(os.path.join(IMAGERY_DIR, d, RGB_IMAGE_FILENAME))]
    filtered_metadata = dish_metadata_df[dish_metadata_df['dish_id'].isin(available_dishes)]
    filtered_metadata = filtered_metadata.replace([np.inf, -np.inf], np.nan)
    filtered_metadata = filtered_metadata.dropna(subset=TARGET_COLUMNS + ['weight']) # Ensure weight is also not NaN

    print(f"Found {len(filtered_metadata)} dishes with metadata, images, and valid targets.")
    # print("\nPer-100g statistics:")
    # print(filtered_metadata[TARGET_COLUMNS].describe())

    dish_ids = filtered_metadata['dish_id'].tolist()
    labels = filtered_metadata[TARGET_COLUMNS].values.astype(np.float32)

    train_ids, val_ids, train_labels, val_labels = train_test_split(
        dish_ids, labels, test_size=0.2, random_state=args.seed
    )

    train_dataset = NutritionDataset(train_ids, train_labels, IMAGERY_DIR, train_transform)
    val_dataset = NutritionDataset(val_ids, val_labels, IMAGERY_DIR, val_transform)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

    # --- Model Initialization ---
    num_outputs = len(TARGET_COLUMNS)
    model_configs = {
        'SimpleConvNet': SimpleConvNet(num_outputs=num_outputs),
        'DeepConvNet': DeepConvNet(num_outputs=num_outputs),
        'MobileNetLike': MobileNetLike(num_outputs=num_outputs),
        'ResNetFromScratch': ResNetFromScratch(num_outputs=num_outputs, use_pretrained=False),
        'ResNetPretrained': ResNetFromScratch(num_outputs=num_outputs, use_pretrained=True)
    }
    model = model_configs[args.model_name].to(DEVICE)
    
    print(f"\nSelected model: {args.model_name}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # criterion = MultiTaskLoss() # If using custom weighted loss
    criterion = nn.L1Loss() # MAE loss, as in original training
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, verbose=False)

    # --- Training Loop ---
    MODEL_SAVE_PATH = os.path.join(args.output_dir, f'best_nutrition_model_{args.model_name}.pth')
    history_path = os.path.join(args.output_dir, f'training_history_{args.model_name}.pkl')
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'percentage_errors': [], 'lr': []}

    print("\n" + "=" * 60)
    print(f"STARTING TRAINING: {args.model_name}")
    print(f"Epochs: {args.epochs}, Batch Size: {args.batch_size}, LR: {args.lr}, Device: {DEVICE}")
    print("=" * 60 + "\n")

    try:
        for epoch in range(args.epochs):
            epoch_start_time = time.time()
            current_lr = optimizer.param_groups[0]['lr']
            
            print(f"\nEPOCH {epoch+1}/{args.epochs} | LR: {current_lr:.6f}")
            
            train_loss = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
            val_loss, percentage_errors, _, _ = validate(model, val_loader, criterion, DEVICE, TARGET_COLUMNS)
            
            epoch_time = time.time() - epoch_start_time
            scheduler.step(val_loss)
            
            history['train_loss'].append(train_loss)
            history['val_loss'].append(val_loss)
            history['percentage_errors'].append(percentage_errors)
            history['lr'].append(current_lr)
            
            print(f"Epoch {epoch+1} Summary: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Time: {epoch_time:.2f}s")
            avg_percentage_error = np.nanmean(list(percentage_errors.values())) # Use nanmean
            print(f"  Avg Val %Error: {avg_percentage_error:.2f}%")
            for nutrient, error in percentage_errors.items():
                 print(f"    {nutrient}: {error:.2f}%")


            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_loss': best_val_loss,
                    'model_name': args.model_name,
                    'target_columns': TARGET_COLUMNS,
                }, MODEL_SAVE_PATH) # History is saved at the end
                print(f"  ✓ New best model saved to {MODEL_SAVE_PATH}")
        
    except KeyboardInterrupt:
        print("\n\nTraining interrupted by user.")
    except Exception as e:
        print(f"\n\nError during training: {e}")
        raise
    finally:
        with open(history_path, 'wb') as f:
            pickle.dump(history, f)
        print(f"\nTraining history saved to: {history_path}")
        print(f"Best validation loss for {args.model_name}: {best_val_loss:.4f}")

    # --- Evaluation and Plotting (if not disabled) ---
    print("\n" + "="*60)
    print(f"POST-TRAINING EVALUATION: {args.model_name}")
    print("="*60 + "\n")

    if os.path.exists(MODEL_SAVE_PATH):
        print(f"Loading best model from {MODEL_SAVE_PATH} for evaluation...")
        checkpoint = torch.load(MODEL_SAVE_PATH, map_location=DEVICE, weights_only=False)
        # Re-initialize model to ensure correct architecture before loading state_dict
        model_eval = model_configs[args.model_name].to(DEVICE) # Use the same config
        model_eval.load_state_dict(checkpoint['model_state_dict'])
        model_eval.eval()

        all_predictions_eval = []
        all_labels_eval = []
        all_dish_ids_eval = [] # For matching with weights
        
        # Re-create val_dataset/loader for consistent dish_id retrieval if needed
        # Or ensure val_ids from training split is used correctly
        current_val_idx = 0
        with torch.no_grad():
            for images, labels_batch in tqdm(val_loader, desc="Evaluating best model"):
                images = images.to(DEVICE)
                outputs = model_eval(images)
                all_predictions_eval.append(outputs.cpu().numpy())
                all_labels_eval.append(labels_batch.cpu().numpy())
                
                batch_size_actual = images.size(0)
                all_dish_ids_eval.extend(val_ids[current_val_idx : current_val_idx + batch_size_actual])
                current_val_idx += batch_size_actual

        predictions_np_eval = np.concatenate(all_predictions_eval)
        labels_np_eval = np.concatenate(all_labels_eval)

        # Create DataFrame for results
        eval_results_list = []
        for i in range(len(all_dish_ids_eval)):
            dish_id = all_dish_ids_eval[i]
            row = {'dish_id': dish_id}
            # Get weight from filtered_metadata
            weight = filtered_metadata.loc[filtered_metadata['dish_id'] == dish_id, 'weight'].iloc[0]
            row['weight'] = weight
            for j, col_name in enumerate(TARGET_COLUMNS):
                row[f'{col_name}_pred'] = predictions_np_eval[i, j]
                row[f'{col_name}_true'] = labels_np_eval[i, j]
            eval_results_list.append(row)
        
        results_df_eval = pd.DataFrame(eval_results_list)

        # Calculate absolute values
        for nutrient_tc in TARGET_COLUMNS: # e.g. 'calories_per_100g'
            base_nutrient_name = nutrient_tc.replace('_per_100g', '') # 'calories'
            results_df_eval[f'{base_nutrient_name}_abs_pred'] = results_df_eval[f'{nutrient_tc}_pred'] * results_df_eval['weight'] / 100
            results_df_eval[f'{base_nutrient_name}_abs_true'] = results_df_eval[f'{nutrient_tc}_true'] * results_df_eval['weight'] / 100
        
        results_df_eval.to_csv(os.path.join(args.output_dir, f'evaluation_results_{args.model_name}.csv'), index=False)
        print(f"Evaluation results saved to {os.path.join(args.output_dir, f'evaluation_results_{args.model_name}.csv')}")

        metrics_df_eval = calculate_metrics_per_100g(results_df_eval) # Pass the correct df
        print("\n" + "="*80)
        print(f"MODEL PERFORMANCE METRICS - {args.model_name} (Per 100g on Val Set)")
        print("="*80)
        print(metrics_df_eval.to_string(index=False, float_format='%.3f'))

        if not args.no_plots:
            print("\nGenerating plots...")
            plot_training_history(history, args.model_name, args.output_dir, TARGET_COLUMNS)
            plot_predictions_vs_actual(results_df_eval, args.model_name, args.output_dir, TARGET_COLUMNS)
            show_sample_predictions_with_images(results_df_eval, IMAGERY_DIR, args.model_name, args.output_dir, TARGET_COLUMNS)
        else:
            print("\nPlotting disabled.")
    else:
        print(f"No best model found at {MODEL_SAVE_PATH} to evaluate.")

    print(f"\nFinished run for model: {args.model_name}")

In [None]:
args = parse_args()
    
# Update global constants based on args for functions that might use them directly
# BATCH_SIZE = args.batch_size # Already handled by passing args
# LEARNING_RATE = args.lr
# NUM_EPOCHS = args.epochs

# Set matplotlib backend for non-interactive environments if plots are enabled
if not args.no_plots:
    # Try to set a non-interactive backend if no display is available
    try:
        plt.gcf() 
    except Exception: #plt.RuntimeError or other backend errors:
        import matplotlib
        matplotlib.use('Agg') # Use 'Agg' for non-interactive plotting to files
        print("Matplotlib backend set to 'Agg' for non-interactive plotting.")
        # Re-import pyplot after setting backend
        import matplotlib.pyplot as plt


run_training_and_evaluation(args)

In [None]:
# Cell 10: Verbose Training loop with model-specific saving
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': [], 'percentage_errors': [], 'lr': []}

# Print training configuration
print("=" * 60)
print("TRAINING CONFIGURATION")
print("=" * 60)
print(f"Model: {MODEL_NAME}")
print(f"Device: {DEVICE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Number of Epochs: {NUM_EPOCHS}")
print(f"Training Samples: {len(train_dataset)}")
print(f"Validation Samples: {len(val_dataset)}")
print(f"Target Columns: {TARGET_COLUMNS}")
print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
print("=" * 60)
print()

# Model-specific save path
MODEL_SAVE_PATH = f'best_nutrition_model_{MODEL_NAME}.pth'

try:
    for epoch in range(NUM_EPOCHS):
        epoch_start_time = time.time()
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        print(f"\nEPOCH {epoch+1}/{NUM_EPOCHS} | LR: {current_lr:.6f}")
        print("-" * 60)
        
        # Train
        print("Training phase:")
        train_loss = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
        
        # Validate
        print("\nValidation phase:")
        val_loss, percentage_errors, predictions, labels = validate(
            model, val_loader, criterion, DEVICE
        )
        
        # Calculate epoch time
        epoch_time = time.time() - epoch_start_time
        
        # Update scheduler
        old_lr = optimizer.param_groups[0]['lr']
        scheduler.step(val_loss)
        new_lr = optimizer.param_groups[0]['lr']
        
        # Save history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['percentage_errors'].append(percentage_errors)
        history['lr'].append(current_lr)
        
        # Print detailed results
        print(f"\n{'='*60}")
        print(f"EPOCH {epoch+1} RESULTS:")
        print(f"{'='*60}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Epoch Time: {epoch_time:.2f} seconds")
        
        # Calculate improvement
        if epoch > 0:
            train_improvement = (history['train_loss'][-2] - train_loss) / history['train_loss'][-2] * 100
            val_improvement = (history['val_loss'][-2] - val_loss) / history['val_loss'][-2] * 100
            print(f"Train Loss Change: {train_improvement:+.2f}%")
            print(f"Val Loss Change: {val_improvement:+.2f}%")
        
        print("\nPERCENTAGE ERRORS BY NUTRIENT (per 100g):")
        print("-" * 40)
        for nutrient, error in percentage_errors.items():
            # Show trend if we have history
            if len(history['percentage_errors']) > 1:
                prev_error = history['percentage_errors'][-2][nutrient]
                change = error - prev_error
                print(f"  {nutrient:20s}: {error:6.2f}% ({change:+.2f}%)")
            else:
                print(f"  {nutrient:20s}: {error:6.2f}%")
        
        # Calculate average percentage error
        avg_percentage_error = np.mean(list(percentage_errors.values()))
        print(f"  {'Average':20s}: {avg_percentage_error:6.2f}%")
        
        # Save best model
        if val_loss < best_val_loss:
            improvement_pct = (best_val_loss - val_loss) / best_val_loss * 100 if best_val_loss != float('inf') else 100
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
                'model_name': MODEL_NAME,
                'target_columns': TARGET_COLUMNS,
                'history': history
            }, MODEL_SAVE_PATH)
            print(f"\n✓ NEW BEST MODEL SAVED to {MODEL_SAVE_PATH}! (Improvement: {improvement_pct:.2f}%)")
        else:
            epochs_since_best = epoch - history['val_loss'].index(min(history['val_loss']))
            print(f"\n  No improvement for {epochs_since_best} epoch(s)")
        
        # Check if learning rate changed
        if old_lr != new_lr:
            print(f"\n⚡ Learning rate reduced: {old_lr:.6f} → {new_lr:.6f}")
        
        print("=" * 60)
        
except KeyboardInterrupt:
    print("\n\n⚠️  Training interrupted by user!")
    print(f"Completed {epoch}/{NUM_EPOCHS} epochs")
    
except Exception as e:
    print(f"\n\n❌ Error during training: {e}")
    print(f"Failed at epoch: {epoch+1}")
    raise

finally:
    print("\n" + "="*60)
    print(f"TRAINING SUMMARY - {MODEL_NAME}")
    print("="*60)
    if history['train_loss']:
        print(f"Final Train Loss: {history['train_loss'][-1]:.4f}")
        print(f"Final Val Loss: {history['val_loss'][-1]:.4f}")
        print(f"Best Val Loss: {best_val_loss:.4f}")
        print(f"Total Epochs Completed: {len(history['train_loss'])}")
        print(f"Model saved as: {MODEL_SAVE_PATH}")
        
        # Final percentage errors
        if history['percentage_errors']:
            print("\nFinal Percentage Errors (per 100g):")
            final_errors = history['percentage_errors'][-1]
            for nutrient, error in final_errors.items():
                print(f"  {nutrient}: {error:.2f}%")
            
            # Average error
            avg_error = np.mean(list(final_errors.values()))
            print(f"\nAverage Percentage Error: {avg_error:.2f}%")
    
    # Save final history for this model
    history_path = f'training_history_{MODEL_NAME}.pkl'
    import pickle
    with open(history_path, 'wb') as f:
        pickle.dump(history, f)
    print(f"\nTraining history saved to: {history_path}")

In [None]:
# Cell 10-continue
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()

# Percentage error plot
percentage_df = pd.DataFrame(history['percentage_errors'])
for col in percentage_df.columns:
    ax2.plot(percentage_df[col], label=col)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Percentage Error (%)')
ax2.set_title('Percentage Errors by Nutrient')
ax2.legend()

plt.tight_layout()
plt.show()

In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
# Make sure BATCH_SIZE, val_ids, filtered_metadata, DEVICE, MODEL_NAME, val_loader, model are defined as in your environment.
# For example:
# BATCH_SIZE = 32 # Example value
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Example value
# MODEL_NAME = "your_model_name" # Example value
# val_ids = [...] # Should be defined
# filtered_metadata = pd.DataFrame(...) # Should be defined
# val_loader = ... # Should be defined
# model = ... # Your model instance should be defined

# Cell 11: Modified evaluation for per-100g predictions
# Load best model and evaluate

# Path to the saved checkpoint
checkpoint_path = f'best_nutrition_model_{MODEL_NAME}.pth'

# 1. Load the entire checkpoint dictionary.
# Use weights_only=False because the checkpoint contains non-tensor data (like epoch, history, etc.),
# and PyTorch 2.6+ defaults to weights_only=True, which would cause an UnpicklingError.
# This is safe because you trust the source of the .pth file (it was saved by your Cell 10).
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=DEVICE) # Added weights_only=False and map_location

# 2. Load the model's state_dict from the checkpoint dictionary
# Your saving code (Cell 10) stores the model's weights under the key 'model_state_dict'.
model.load_state_dict(checkpoint['model_state_dict'])
model.eval() # Set model to evaluation mode

print(f"Model {MODEL_NAME} loaded successfully and set to evaluation mode.")

# Get predictions on validation set
print(f"Evaluating {MODEL_NAME} on validation set...")
all_predictions = []
all_labels = []
all_dish_ids = []
all_weights = []

with torch.no_grad():
    for i, batch_data in enumerate(tqdm(val_loader, desc='Predicting')):
        # Assuming val_loader yields (images, labels) or a dict. Adjust if different.
        if isinstance(batch_data, (list, tuple)):
            images, labels_batch = batch_data
        elif isinstance(batch_data, dict): # If your DataLoader returns a dictionary
            images = batch_data['image'] # Adjust key if necessary
            labels_batch = batch_data['labels'] # Adjust key if necessary
        else:
            raise TypeError(f"Unexpected batch data type: {type(batch_data)}")

        images = images.to(DEVICE)
        outputs = model(images)

        all_predictions.append(outputs.cpu().numpy())
        all_labels.append(labels_batch.cpu().numpy()) # Assuming labels_batch is a tensor

        # Get dish IDs and weights for this batch
        # Ensure val_ids is correctly populated and corresponds to val_loader's order
        batch_start_index = i * val_loader.batch_size # Use val_loader.batch_size
        batch_end_index = min(batch_start_index + val_loader.batch_size, len(val_ids))
        current_batch_dish_ids = val_ids[batch_start_index:batch_end_index]
        all_dish_ids.extend(current_batch_dish_ids)

        # Get actual weights for conversion back to absolute values
        for dish_id in current_batch_dish_ids:
            # Ensure filtered_metadata is available and has 'dish_id' and 'weight' columns
            weight_values = filtered_metadata[filtered_metadata['dish_id'] == dish_id]['weight'].values
            if len(weight_values) > 0:
                all_weights.append(weight_values[0])
            else:
                print(f"Warning: No weight found for dish_id {dish_id}. Appending NaN or a default.")
                all_weights.append(np.nan) # Or handle as appropriate, e.g., skip or use a default

# Ensure lengths match before concatenating, especially if some weights were not found
min_len = min(len(all_predictions), len(all_labels), len(all_dish_ids), len(all_weights))
if len(all_predictions) * val_loader.batch_size < len(val_ids) and val_loader.drop_last == False :
     #This means your original code for batch_dish_ids may have an off-by-one if drop_last=False
     #The val_ids slicing must exactly match the samples processed by the dataloader.
     #The fix above using val_loader.batch_size and batch_start_index/batch_end_index should be more robust.
     pass


predictions_np = np.concatenate(all_predictions)
labels_np = np.concatenate(all_labels)

# Adjust slicing to the actual number of predictions made, which should match labels
num_samples_processed = len(predictions_np)
all_dish_ids = all_dish_ids[:num_samples_processed]
all_weights = all_weights[:num_samples_processed]


# Create DataFrame with per-100g values
results_df = pd.DataFrame({
    'dish_id': all_dish_ids,
    'weight': all_weights,
    'calories_per_100g_pred': predictions_np[:, 0],
    'calories_per_100g_true': labels_np[:, 0],
    'fat_per_100g_pred': predictions_np[:, 1],
    'fat_per_100g_true': labels_np[:, 1],
    'carbs_per_100g_pred': predictions_np[:, 2],
    'carbs_per_100g_true': labels_np[:, 2],
    'protein_per_100g_pred': predictions_np[:, 3],
    'protein_per_100g_true': labels_np[:, 3]
})

# Calculate absolute values for comparison
for nutrient in ['calories', 'fat', 'carbs', 'protein']:
    results_df[f'{nutrient}_abs_pred'] = results_df[f'{nutrient}_per_100g_pred'] * results_df['weight'] / 100
    results_df[f'{nutrient}_abs_true'] = results_df[f'{nutrient}_per_100g_true'] * results_df['weight'] / 100

print(f"\nPredictions completed for {len(results_df)} samples")
if results_df['weight'].isnull().any():
    print("Warning: Some weights were NaN. Absolute nutrient calculations might be affected for those rows.")

In [None]:
# Cell 12: Modified metrics calculation for per-100g
def calculate_metrics_per_100g(results_df):
    metrics_list = []
    
    # Calculate metrics for per-100g predictions
    for nutrient in ['calories', 'fat', 'carbs', 'protein']:
        per_100g_true = results_df[f'{nutrient}_per_100g_true'].values
        per_100g_pred = results_df[f'{nutrient}_per_100g_pred'].values
        
        mae = mean_absolute_error(per_100g_true, per_100g_pred)
        rmse = np.sqrt(mean_squared_error(per_100g_true, per_100g_pred))
        r2 = r2_score(per_100g_true, per_100g_pred)
        
        # Percentage error
        mean_true = np.mean(per_100g_true)
        percentage_error = (mae / mean_true) * 100 if mean_true != 0 else 0
        
        metrics_list.append({
            'nutrient': f'{nutrient}_per_100g',
            'MAE': mae,
            'RMSE': rmse,
            'R²': r2,
            'Percentage Error': percentage_error,
            'Mean True': mean_true,
            'Mean Pred': np.mean(per_100g_pred)
        })
    
    return pd.DataFrame(metrics_list)

# Calculate metrics
metrics_df = calculate_metrics_per_100g(results_df)

print("\n" + "="*80)
print(f"MODEL PERFORMANCE METRICS - {MODEL_NAME} (Per 100g)")
print("="*80)
print(metrics_df.to_string(index=False, float_format='%.3f'))

# Also show how this translates to absolute predictions
print("\n" + "="*80)
print("EFFECTIVE ABSOLUTE PREDICTION ERRORS (using ground truth weight)")
print("="*80)

for nutrient in ['calories', 'fat', 'carbs', 'protein']:
    abs_true = results_df[f'{nutrient}_abs_true'].values
    abs_pred = results_df[f'{nutrient}_abs_pred'].values
    
    mae = mean_absolute_error(abs_true, abs_pred)
    mean_true = np.mean(abs_true)
    percentage_error = (mae / mean_true) * 100
    
    print(f"{nutrient:10s}: {percentage_error:.2f}% error")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd # Assuming results_df is a pandas DataFrame
from sklearn.metrics import r2_score

# Ensure results_df is loaded or available from the previous cell (Cell 11)
# For example:
# results_df = pd.read_csv('path_to_your_results.csv') # Or however it's available

# Cell 13: Visualize predictions vs actual values (per 100g)

# Define the list of nutrient key prefixes to plot.
# These MUST match the prefixes used in results_df for per-100g values
# (e.g., 'calories_per_100g_true', 'calories_per_100g_pred').
nutrients_to_plot_keys = [
    'calories_per_100g',
    'fat_per_100g',
    'carbs_per_100g',    # Matches 'carbs_per_100g' used in results_df creation
    'protein_per_100g'
]

num_nutrients_to_plot = len(nutrients_to_plot_keys)

# Create subplots (2 rows, 3 columns allows up to 6 plots)
# If you have exactly 4 nutrients, you might prefer a 2x2 layout.
# For a dynamic layout:
# ncols = 2
# nrows = (num_nutrients_to_plot + ncols - 1) // ncols
# fig, axes = plt.subplots(nrows, ncols, figsize=(7 * ncols, 6 * nrows))

fig, axes = plt.subplots(2, 3, figsize=(18, 12)) # As per original code
axes = axes.flatten() # Flatten to a 1D array for easy iteration

for i, nutrient_key_prefix in enumerate(nutrients_to_plot_keys):
    ax = axes[i]

    true_col = f'{nutrient_key_prefix}_true'
    pred_col = f'{nutrient_key_prefix}_pred'

    # Defensive check: ensure columns exist in the DataFrame
    if true_col not in results_df.columns or pred_col not in results_df.columns:
        print(f"Warning: Columns '{true_col}' or '{pred_col}' not found in results_df. Skipping plot for '{nutrient_key_prefix}'.")
        ax.set_title(f"{nutrient_key_prefix.replace('_per_100g', '').capitalize()} - Data Missing")
        ax.axis('off') # Hide axis if data is missing
        continue

    # Get data
    x_data = results_df[true_col].values
    y_data = results_df[pred_col].values

    # Remove NaN values to prevent errors in calculations and plotting
    valid_indices = ~ (np.isnan(x_data) | np.isnan(y_data))
    x_plot = x_data[valid_indices]
    y_plot = y_data[valid_indices]

    if len(x_plot) == 0: # Not enough data to plot
        print(f"Warning: No valid (non-NaN) data points for '{nutrient_key_prefix}'. Skipping plot.")
        ax.set_title(f"{nutrient_key_prefix.replace('_per_100g', '').capitalize()} - No Valid Data")
        ax.axis('off')
        continue

    # Create scatter plot
    ax.scatter(x_plot, y_plot, alpha=0.5, s=50, edgecolors='k', linewidth=0.5)

    # Add perfect prediction line
    # Ensure min/max are calculated on the actual plotted data (x_plot, y_plot)
    min_val = min(x_plot.min(), y_plot.min()) if len(x_plot) > 0 else 0
    max_val = max(x_plot.max(), y_plot.max()) if len(x_plot) > 0 else 1
    if min_val == max_val: # Avoid issues if all points are identical
        min_val -= 0.5
        max_val += 0.5
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect Prediction')

    # Add trend line (polyfit requires at least 2 points for degree 1)
    if len(x_plot) > 1:
        z = np.polyfit(x_plot, y_plot, 1)
        p = np.poly1d(z)
        trend_line_x_points = np.array([x_plot.min(), x_plot.max()]) # Use min/max of actual data
        ax.plot(trend_line_x_points, p(trend_line_x_points), "b-", alpha=0.8, label=f'Trend: y={z[0]:.2f}x+{z[1]:.2f}')
        
        # Calculate R² score
        r2 = r2_score(x_plot, y_plot)
        r2_text = f"R² = {r2:.3f}"
    else:
        # Not enough points for trend line or R²
        ax.plot([], [], "b-", alpha=0.8, label='Trend: N/A (too few points)') # Placeholder for legend
        r2_text = "R² = N/A (too few points)"


    # Labels and title
    # Create a more display-friendly nutrient name
    display_nutrient_name = nutrient_key_prefix.replace('_per_100g', '').replace('_', ' ').capitalize()
    
    ax.set_xlabel(f'True {display_nutrient_name} (per 100g)')
    ax.set_ylabel(f'Predicted {display_nutrient_name} (per 100g)')
    ax.set_title(f'{display_nutrient_name} (per 100g)\n{r2_text}')
    ax.legend()
    ax.grid(True, alpha=0.3)

# Hide any unused subplots if num_nutrients_to_plot < total number of axes
for j in range(num_nutrients_to_plot, len(axes)):
    fig.delaxes(axes[j])

# Adjust layout to prevent overlap and make space for suptitle
plt.tight_layout(rect=[0, 0, 1, 0.97]) # rect=[left, bottom, right, top]
plt.suptitle('Predictions vs True Values (per 100g)', fontsize=16)
plt.show()

In [None]:
# Cell 14: Error distribution analysis
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

for i, nutrient in enumerate(TARGET_COLUMNS):
    ax = axes[i]
    true_col = f'{nutrient}_true'
    pred_col = f'{nutrient}_pred'
    
    # Calculate errors
    errors = results_df[pred_col] - results_df[true_col]
    relative_errors = (errors / results_df[true_col]) * 100
    
    # Remove outliers for better visualization
    q1 = relative_errors.quantile(0.25)
    q3 = relative_errors.quantile(0.75)
    iqr = q3 - q1
    lower_bound = q1 - 1.5 * iqr
    upper_bound = q3 + 1.5 * iqr
    filtered_errors = relative_errors[(relative_errors >= lower_bound) & (relative_errors <= upper_bound)]
    
    # Create histogram
    ax.hist(filtered_errors, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
    ax.axvline(x=0, color='red', linestyle='--', linewidth=2, label='Zero Error')
    ax.axvline(x=filtered_errors.mean(), color='green', linestyle='-', linewidth=2, 
               label=f'Mean: {filtered_errors.mean():.1f}%')
    
    # Labels
    ax.set_xlabel('Relative Error (%)')
    ax.set_ylabel('Frequency')
    ax.set_title(f'{nutrient.capitalize()} Error Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Cell 15: Show sample predictions with images
def show_predictions_with_images(n_samples=6):
    # Get random samples
    sample_indices = np.random.choice(len(results_df), n_samples, replace=False)
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for idx, ax in enumerate(axes):
        if idx >= n_samples:
            ax.axis('off')
            continue
            
        sample_idx = sample_indices[idx]
        dish_id = results_df.iloc[sample_idx]['dish_id']
        
        # Load and display image
        img_path = os.path.join(IMAGERY_DIR, dish_id, "rgb.png")
        img = Image.open(img_path)
        ax.imshow(img)
        ax.axis('off')
        
        # Create prediction text
        pred_text = "Predicted:\n"
        true_text = "Actual:\n"
        
        for nutrient in TARGET_COLUMNS:
            pred_val = results_df.iloc[sample_idx][f'{nutrient}_pred']
            true_val = results_df.iloc[sample_idx][f'{nutrient}_true']
            error = abs(pred_val - true_val) / true_val * 100
            
            if nutrient == 'calories':
                pred_text += f"Cal: {pred_val:.0f}\n"
                true_text += f"Cal: {true_val:.0f} ({error:.1f}%)\n"
            elif nutrient == 'weight':
                pred_text += f"Weight: {pred_val:.0f}g\n"
                true_text += f"Weight: {true_val:.0f}g ({error:.1f}%)\n"
            else:
                pred_text += f"{nutrient.capitalize()}: {pred_val:.1f}g\n"
                true_text += f"{nutrient.capitalize()}: {true_val:.1f}g ({error:.1f}%)\n"
        
        # Add text to image
        ax.text(0.02, 0.98, pred_text, transform=ax.transAxes, 
                verticalalignment='top', bbox=dict(boxstyle='round', 
                facecolor='lightblue', alpha=0.8), fontsize=9)
        ax.text(0.98, 0.98, true_text, transform=ax.transAxes, 
                verticalalignment='top', horizontalalignment='right',
                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8), fontsize=9)
        
        ax.set_title(f"Dish: {dish_id}", fontsize=10)
    
    plt.tight_layout()
    plt.suptitle('Sample Predictions vs Ground Truth', fontsize=14, y=1.02)
    plt.show()

show_predictions_with_images(6)