In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install required libraries
!pip install torch torchvision transformers pillow pandas numpy matplotlib scikit-learn tqdm

# Import necessary libraries
import os
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import DistilBertTokenizer, DistilBertModel


# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Define paths
BASE_PATH = '/content/drive/MyDrive/coursework'
METADATA_PATH = '/content/drive/MyDrive/coursework/nutrition5k_metadata/nutrition5k_dataset/metadata'
IMAGES_PATH = '/content/drive/MyDrive/coursework/nutrition5k_dataset/imagery/realsense_overhead'
CHECKPOINT_PATH = os.path.join(BASE_PATH, 'model_checkpoints')

# Create checkpoint directory
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

print("Paths configured:")
print(f"Metadata: {METADATA_PATH}")
print(f"Images: {IMAGES_PATH}")
print(f"Checkpoints: {CHECKPOINT_PATH}")

In [None]:
# Load all metadata files
cafe1_df = pd.read_csv(os.path.join(METADATA_PATH, 'dish_metadata_cafe1.csv'),
                       on_bad_lines='skip', engine='python')
cafe2_df = pd.read_csv(os.path.join(METADATA_PATH, 'dish_metadata_cafe2.csv'),
                       on_bad_lines='skip', engine='python')
ingredients_df = pd.read_csv(os.path.join(METADATA_PATH, 'ingredients_metadata.csv'),
                             on_bad_lines='skip', engine='python')

# Combine cafe metadata
metadata_df = pd.concat([cafe1_df, cafe2_df], ignore_index=True)

print(f"Cafe1 shape: {cafe1_df.shape}")
print(f"Cafe2 shape: {cafe2_df.shape}")
print(f"Combined metadata shape: {metadata_df.shape}")
print(f"Ingredients shape: {ingredients_df.shape}")

In [None]:
# The metadata files don't have proper headers, reload them
cafe1_df = pd.read_csv(os.path.join(METADATA_PATH, 'dish_metadata_cafe1.csv'),
                       header=None, on_bad_lines='skip', engine='python')
cafe2_df = pd.read_csv(os.path.join(METADATA_PATH, 'dish_metadata_cafe2.csv'),
                       header=None, on_bad_lines='skip', engine='python')

# Combine cafe metadata
metadata_df = pd.concat([cafe1_df, cafe2_df], ignore_index=True)

print(f"Metadata shape: {metadata_df.shape}")
print(f"\nFirst 5 rows:")
print(metadata_df.head())

# The first column should be dish_id, second should be total_calories
print(f"\n=== First few dish IDs and calories ===")
print(metadata_df[[0, 1]].head(10))

# Check data types
print(f"\nColumn 0 (dish_id) sample: {metadata_df.iloc[0, 0]}")
print(f"Column 1 (calories) sample: {metadata_df.iloc[0, 1]}")

print(f"\n=== Ingredients (already loaded correctly) ===")
print(f"Shape: {ingredients_df.shape}")
print(ingredients_df.head())

In [None]:
# Rename columns for easier access
metadata_df.columns = ['dish_id', 'total_calories', 'total_mass', 'total_fat', 'total_carb', 'total_protein'] + [f'col_{i}' for i in range(6, metadata_df.shape[1])]

print("=== METADATA STATISTICS ===")
print(f"Total dishes: {len(metadata_df)}")
print(f"\nCalories statistics:")
print(metadata_df['total_calories'].describe())

print(f"\nMass statistics (grams):")
print(metadata_df['total_mass'].describe())

# Visualize calorie distribution
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.hist(metadata_df['total_calories'], bins=50, edgecolor='black')
plt.xlabel('Total Calories')
plt.ylabel('Frequency')
plt.title('Distribution of Calories')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.scatter(metadata_df['total_mass'], metadata_df['total_calories'], alpha=0.5)
plt.xlabel('Total Mass (g)')
plt.ylabel('Total Calories')
plt.title('Calories vs Mass')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
plt.boxplot([metadata_df['total_fat'], metadata_df['total_carb'], metadata_df['total_protein']])
plt.xticks([1, 2, 3], ['Fat', 'Carbs', 'Protein'])
plt.ylabel('Grams')
plt.title('Macronutrient Distribution')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\n=== INGREDIENTS ===")
print(f"Total unique ingredients: {len(ingredients_df)}")
print(f"\nTop 10 ingredients by calorie density:")
print(ingredients_df.nlargest(10, 'cal/g')[['ingr', 'cal/g', 'fat(g)', 'carb(g)', 'protein(g)']])

In [None]:
# Get all dish folders
dish_folders = [f for f in os.listdir(IMAGES_PATH) if f.startswith('dish_')]
print(f"Total dish folders: {len(dish_folders)}")

# Check a sample dish folder
sample_dish = dish_folders[0]
sample_path = os.path.join(IMAGES_PATH, sample_dish)
print(f"\nSample dish folder: {sample_dish}")
print(f"Files in sample folder:")
files_in_sample = os.listdir(sample_path)
for f in files_in_sample:
    print(f"  - {f}")

# Load and display sample images
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i in range(6):
    if i < len(dish_folders):
        dish_folder = dish_folders[i]
        dish_path = os.path.join(IMAGES_PATH, dish_folder)

        # Find RGB image
        rgb_files = [f for f in os.listdir(dish_path) if 'rgb' in f.lower() and f.endswith('.png')]

        if rgb_files:
            img_path = os.path.join(dish_path, rgb_files[0])
            img = Image.open(img_path)
            axes[i].imshow(img)
            axes[i].set_title(f"{dish_folder[:15]}...")
            axes[i].axis('off')

plt.tight_layout()
plt.show()

# Check image dimensions
sample_img_path = os.path.join(IMAGES_PATH, dish_folders[0], 'rgb.png')
if os.path.exists(sample_img_path):
    sample_img = Image.open(sample_img_path)
    print(f"\nSample image dimensions: {sample_img.size}")
    print(f"Image mode: {sample_img.mode}")

In [None]:
# Match images with metadata and ingredients
data_list = []

for dish_folder in tqdm(dish_folders, desc="Creating dataset"):
    # Extract dish_id from folder name (e.g., dish_1561662216 -> dish_1561662216)
    dish_id = dish_folder

    # Get image path
    dish_path = os.path.join(IMAGES_PATH, dish_folder)
    rgb_files = [f for f in os.listdir(dish_path) if 'rgb' in f.lower() and f.endswith('.png')]

    if rgb_files:
        image_path = os.path.join(dish_path, rgb_files[0])

        # Find matching metadata row
        matching_meta = metadata_df[metadata_df['dish_id'] == dish_id]

        if len(matching_meta) > 0:
            calories = matching_meta.iloc[0]['total_calories']
            mass = matching_meta.iloc[0]['total_mass']
            fat = matching_meta.iloc[0]['total_fat']
            carb = matching_meta.iloc[0]['total_carb']
            protein = matching_meta.iloc[0]['total_protein']

            # Extract ingredient information from the row (ingredients are in columns starting from col_6)
            # Ingredients alternate: id, name, amount, ...
            ingredients_list = []
            row = matching_meta.iloc[0]

            # Parse ingredients from the metadata row
            for j in range(6, len(row), 6):  # Ingredients appear in sets
                if j < len(row) and pd.notna(row.iloc[j]):
                    ingr_name = row.iloc[j+1] if j+1 < len(row) and pd.notna(row.iloc[j+1]) else ""
                    if ingr_name and ingr_name != "":
                        ingredients_list.append(str(ingr_name))

            # Create ingredient text
            ingredient_text = ", ".join(ingredients_list) if ingredients_list else "unknown ingredients"

            if pd.notna(calories) and float(calories) > 0:
                data_list.append({
                    'dish_id': dish_id,
                    'image_path': image_path,
                    'text': ingredient_text,
                    'calories': float(calories),
                    'mass': float(mass) if pd.notna(mass) else 0,
                    'fat': float(fat) if pd.notna(fat) else 0,
                    'carb': float(carb) if pd.notna(carb) else 0,
                    'protein': float(protein) if pd.notna(protein) else 0
                })

# Create DataFrame
df = pd.DataFrame(data_list)

print(f"=== DATASET CREATED ===")
print(f"Total samples: {len(df)}")
print(f"\nCalories statistics:")
print(df['calories'].describe())
print(f"\nSample data:")
print(df[['dish_id', 'text', 'calories']].head(10))

# Save dataset
df.to_csv(os.path.join(BASE_PATH, 'prepared_dataset.csv'), index=False)
print(f"\nDataset saved!")

In [None]:
from sklearn.model_selection import train_test_split

# Split dataset: 70% train, 15% validation, 15% test
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f"=== DATASET SPLIT ===")
print(f"Total samples: {len(df)}")
print(f"Train samples: {len(train_df)} ({len(train_df)/len(df)*100:.1f}%)")
print(f"Validation samples: {len(val_df)} ({len(val_df)/len(df)*100:.1f}%)")
print(f"Test samples: {len(test_df)} ({len(test_df)/len(df)*100:.1f}%)")

# Check calorie distribution in each split
print(f"\n=== Calorie Distribution ===")
print(f"Train - Mean: {train_df['calories'].mean():.2f}, Std: {train_df['calories'].std():.2f}")
print(f"Val   - Mean: {val_df['calories'].mean():.2f}, Std: {val_df['calories'].std():.2f}")
print(f"Test  - Mean: {test_df['calories'].mean():.2f}, Std: {test_df['calories'].std():.2f}")

# Visualize splits
plt.figure(figsize=(15, 4))

plt.subplot(1, 3, 1)
plt.hist(train_df['calories'], bins=30, alpha=0.7, label='Train', edgecolor='black')
plt.xlabel('Calories')
plt.ylabel('Frequency')
plt.title('Train Set Distribution')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.hist(val_df['calories'], bins=30, alpha=0.7, label='Val', color='orange', edgecolor='black')
plt.xlabel('Calories')
plt.ylabel('Frequency')
plt.title('Validation Set Distribution')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
plt.hist(test_df['calories'], bins=30, alpha=0.7, label='Test', color='green', edgecolor='black')
plt.xlabel('Calories')
plt.ylabel('Frequency')
plt.title('Test Set Distribution')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Save splits
train_df.to_csv(os.path.join(BASE_PATH, 'train_data.csv'), index=False)
val_df.to_csv(os.path.join(BASE_PATH, 'val_data.csv'), index=False)
test_df.to_csv(os.path.join(BASE_PATH, 'test_data.csv'), index=False)

print(f"\nSplits saved to drive!")

In [None]:
# Custom Dataset class for images + text
class FoodCalorieDataset(Dataset):
    def __init__(self, dataframe, transform=None, max_text_length=100):
        self.data = dataframe.reset_index(drop=True)
        self.transform = transform
        self.max_text_length = max_text_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Load image
        img_path = self.data.loc[idx, 'image_path']
        image = Image.open(img_path).convert('RGB')

        # Apply transforms
        if self.transform:
            image = self.transform(image)

        # Get text (ingredients)
        text = str(self.data.loc[idx, 'text'])

        # Simple text encoding: convert to lowercase and limit length
        text = text.lower()[:self.max_text_length]

        # Get target (calories)
        calories = torch.tensor(self.data.loc[idx, 'calories'], dtype=torch.float32)

        return {
            'image': image,
            'text': text,
            'calories': calories,
            'dish_id': self.data.loc[idx, 'dish_id']
        }

# Define image transformations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = FoodCalorieDataset(train_df, transform=train_transform)
val_dataset = FoodCalorieDataset(val_df, transform=val_transform)
test_dataset = FoodCalorieDataset(test_df, transform=val_transform)

print(f"=== DATASETS CREATED ===")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Test loading one sample
sample = train_dataset[0]
print(f"\n=== SAMPLE DATA ===")
print(f"Image shape: {sample['image'].shape}")
print(f"Text: {sample['text'][:100]}...")
print(f"Calories: {sample['calories'].item():.2f}")
print(f"Dish ID: {sample['dish_id']}")

In [None]:
# Create data loaders
BATCH_SIZE = 16

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"=== DATA LOADERS CREATED ===")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# Test loading one batch
print(f"\n=== SAMPLE BATCH ===")
for batch in train_loader:
    print(f"Batch image shape: {batch['image'].shape}")
    print(f"Batch calories shape: {batch['calories'].shape}")
    print(f"Number of texts: {len(batch['text'])}")
    print(f"Sample text: {batch['text'][0][:80]}...")
    print(f"Sample calories: {batch['calories'][0].item():.2f}")
    break

In [None]:
# Initialize BERT tokenizer (pre-trained)
bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

print(f"‚úÖ BERT tokenizer loaded!")
print(f"Vocabulary size: {bert_tokenizer.vocab_size}")

# Multimodal Model with BERT: Image + Text -> Calories
class CaloriePredictionModelBERT(nn.Module):
    def __init__(self):
        super(CaloriePredictionModelBERT, self).__init__()

        # Image encoder: pretrained ResNet
        resnet = models.resnet50(pretrained=True)
        self.image_encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.image_fc = nn.Linear(2048, 512)

        # Text encoder: DistilBERT
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

        # Freeze most BERT layers (train only last 2 transformer layers)
        for param in self.bert.parameters():
            param.requires_grad = False
        for param in self.bert.transformer.layer[-2:].parameters():
            param.requires_grad = True

        # BERT output projection
        self.text_fc = nn.Linear(768, 256)

        # Fusion and regression head
        self.fusion = nn.Sequential(
            nn.Linear(512 + 256, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, images, input_ids, attention_mask):
        # Image features
        img_features = self.image_encoder(images)
        img_features = img_features.view(img_features.size(0), -1)
        img_features = self.image_fc(img_features)
        img_features = torch.relu(img_features)

        # Text features with BERT
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_features = bert_output.last_hidden_state[:, 0, :]
        text_features = self.text_fc(text_features)
        text_features = torch.relu(text_features)

        # Fuse features
        combined = torch.cat([img_features, text_features], dim=1)
        output = self.fusion(combined)

        return output.squeeze()

# Initialize model
model = CaloriePredictionModelBERT()
model = model.to(device)

print(f"=== MODEL WITH BERT CREATED ===")
print(f"Device: {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# New text processing function for BERT
def process_text_batch_bert(text_list, max_length=128):
    """Process text using BERT tokenizer"""
    encoding = bert_tokenizer(
        text_list,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors='pt'
    )

    return {
        'input_ids': encoding['input_ids'].to(device),
        'attention_mask': encoding['attention_mask'].to(device)
    }

In [None]:
# Define loss function and optimizer
criterion = nn.MSELoss()  # Mean Squared Error for regression
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

# Training configuration
NUM_EPOCHS = 50
BEST_VAL_LOSS = float('inf')
PATIENCE = 10  # Early stopping patience
PATIENCE_COUNTER = 0

# Checkpoint saving function
def save_checkpoint(epoch, model, optimizer, scheduler, train_loss, val_loss,
                   best_val_loss, filepath):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'best_val_loss': best_val_loss,
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath}")

# Checkpoint loading function
def load_checkpoint(filepath, model, optimizer, scheduler):
    if os.path.exists(filepath):
        checkpoint = torch.load(filepath, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        epoch = checkpoint['epoch']
        best_val_loss = checkpoint['best_val_loss']
        print(f"Checkpoint loaded from epoch {epoch}")
        return epoch, best_val_loss
    else:
        print("No checkpoint found, starting from scratch")
        return 0, float('inf')

print("=== TRAINING SETUP COMPLETE ===")
print(f"Loss function: MSE")
print(f"Optimizer: Adam (lr=0.001)")
print(f"Scheduler: ReduceLROnPlateau")
print(f"Max epochs: {NUM_EPOCHS}")
print(f"Early stopping patience: {PATIENCE}")

In [None]:
# Training function
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0

    for batch in tqdm(train_loader, desc="Training"):
        images = batch['image'].to(device)
        text_encoding = process_text_batch_bert(batch['text'])  # CHANGED
        calories = batch['calories'].to(device)

        optimizer.zero_grad()

        # CHANGED: Pass input_ids and attention_mask
        outputs = model(images, text_encoding['input_ids'], text_encoding['attention_mask'])
        loss = criterion(outputs, calories)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    return avg_loss

# Validation function
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            images = batch['image'].to(device)
            text_encoding = process_text_batch_bert(batch['text'])  # CHANGED
            calories = batch['calories'].to(device)

            # CHANGED: Pass input_ids and attention_mask
            outputs = model(images, text_encoding['input_ids'], text_encoding['attention_mask'])
            loss = criterion(outputs, calories)

            running_loss += loss.item()

            all_predictions.extend(outputs.cpu().numpy())
            all_targets.extend(calories.cpu().numpy())

    avg_loss = running_loss / len(val_loader)

    all_predictions = np.array(all_predictions)
    all_targets = np.array(all_targets)
    mae = np.mean(np.abs(all_predictions - all_targets))

    return avg_loss, mae, all_predictions, all_targets

print("=== TRAINING & VALIDATION FUNCTIONS DEFINED ===")
print("Ready to start training!")

In [None]:
# Smart checkpoint handler - automatically loads if exists
print("=" * 70)
print("CHECKING FOR EXISTING CHECKPOINTS")
print("=" * 70)

# Check if best BERT model exists
BERT_CHECKPOINT_FILE = os.path.join(CHECKPOINT_PATH, 'bert_latest_checkpoint.pth')
BERT_BEST_MODEL_FILE = os.path.join(CHECKPOINT_PATH, 'bert_best_model.pth')
if os.path.exists(BERT_BEST_MODEL_FILE):
    print(f"\n‚úÖ Found existing BERT model!")

    # Load it
    checkpoint = torch.load(BERT_BEST_MODEL_FILE, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])

    best_mae = checkpoint.get('val_loss', 'unknown')  # Get saved MAE
    print(f"   Loaded model from epoch {checkpoint['epoch']}")
    print(f"   Best validation loss: {best_mae}")

    # Ask user what to do
    print(f"\nü§î What would you like to do?")
    print(f"   1. Continue fine-tuning (recommended)")
    print(f"   2. Retrain from scratch (will overwrite)")

    choice = input("\nEnter choice (1 or 2): ")

    if choice == '1':
        print("\n‚úÖ Will use existing model for fine-tuning")
        SKIP_INITIAL_TRAINING = True
    else:
        print("\n‚ö†Ô∏è Will retrain from scratch (this will take time!)")
        SKIP_INITIAL_TRAINING = False
else:
    print(f"\n‚ùå No existing BERT model found")
    print(f"   Will train from scratch")
    SKIP_INITIAL_TRAINING = False

print("=" * 70)

In [None]:
# Check if there's a checkpoint to resume from
# BERT checkpoints (new names to keep old LSTM checkpoints safe)
if SKIP_INITIAL_TRAINING:
    print("‚è≠Ô∏è Skipping initial training - using loaded model")
else:

# Start fresh with BERT
    start_epoch = 0
    BEST_VAL_LOSS = float('inf')
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_mae': []
    }



    for epoch in range(start_epoch, NUM_EPOCHS):
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")

        # Train
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)

        # Validate
        val_loss, val_mae, val_preds, val_targets = validate(model, val_loader, criterion, device)

        # Update learning rate scheduler
        scheduler.step(val_loss)

        # Store history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_mae'].append(val_mae)

        # Print epoch results
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val MAE: {val_mae:.2f} cal")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")

        # Save latest checkpoint
        save_checkpoint(
            epoch=epoch+1,
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            train_loss=train_loss,
            val_loss=val_loss,
            best_val_loss=BEST_VAL_LOSS,
            filepath=BERT_CHECKPOINT_FILE  # NEW - uses BERT path
        )

        # Save best model (BERT)
        if val_loss < BEST_VAL_LOSS:
            BEST_VAL_LOSS = val_loss
            PATIENCE_COUNTER = 0
            save_checkpoint(
                epoch=epoch+1,
                model=model,
                optimizer=optimizer,
                scheduler=scheduler,
                train_loss=train_loss,
                val_loss=val_loss,
                best_val_loss=BEST_VAL_LOSS,
                filepath=BERT_BEST_MODEL_FILE  # NEW - uses BERT path
            )
            print(f"‚úì New best model saved! Val Loss: {BEST_VAL_LOSS:.4f}")

    print("\n=== TRAINING COMPLETED ===")
    print(f"Best Validation Loss: {BEST_VAL_LOSS:.4f}")

In [None]:
# Unfreeze BERT layers for fine-tuning
print("=== UNFREEZING BERT LAYERS ===\n")

# Show current state
trainable_before = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Before unfreezing:")
print(f"  Trainable parameters: {trainable_before:,}")
print(f"  Total parameters: {total_params:,}")
print(f"  Frozen: {total_params - trainable_before:,}")

# Unfreeze all BERT layers
for param in model.bert.parameters():
    param.requires_grad = True

# Show new state
trainable_after = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nAfter unfreezing:")
print(f"  Trainable parameters: {trainable_after:,}")
print(f"  Total parameters: {total_params:,}")
print(f"  Newly unfrozen: {trainable_after - trainable_before:,}")

print("\n‚úÖ BERT layers unfrozen and ready for fine-tuning!")

In [None]:
# Create new optimizer for fine-tuning BERT
print("=== SETTING UP FINE-TUNING OPTIMIZER ===\n")

# Use MUCH lower learning rate for BERT fine-tuning
FINETUNE_LR = 0.00001  # 10x smaller than original

# Create new optimizer
optimizer_finetune = optim.Adam(model.parameters(), lr=FINETUNE_LR)

# Create new scheduler
scheduler_finetune = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_finetune, mode='min', factor=0.5, patience=3
)

print(f"Fine-tuning learning rate: {FINETUNE_LR}")
print(f"Optimizer: Adam")
print(f"Scheduler: ReduceLROnPlateau")

# Load best BERT model before fine-tuning
best_bert_path = os.path.join(CHECKPOINT_PATH, 'bert_best_model.pth')
checkpoint = torch.load(best_bert_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])

print(f"\n‚úÖ Loaded best BERT model (Val MAE: 65.41 cal)")
print(f"‚úÖ Ready to fine-tune with unfrozen BERT!")

In [None]:
# Fine-tuning training loop
print("=== STARTING BERT FINE-TUNING ===\n")

# Setup for fine-tuning
FINETUNE_EPOCHS = 30
BEST_FINETUNE_VAL_LOSS = float('inf')
PATIENCE_FT = 10
PATIENCE_COUNTER_FT = 0

# New checkpoint paths for fine-tuned model
BERT_FINETUNED_CHECKPOINT = os.path.join(CHECKPOINT_PATH, 'bert_finetuned_latest.pth')
BERT_FINETUNED_BEST = os.path.join(CHECKPOINT_PATH, 'bert_finetuned_best.pth')

# History for fine-tuning
finetune_history = {
    'train_loss': [],
    'val_loss': [],
    'val_mae': []
}

print(f"Fine-tuning for {FINETUNE_EPOCHS} epochs")
print(f"Starting from best BERT model (65.41 cal MAE)")
print(f"Learning rate: {FINETUNE_LR}")
print("\n" + "="*70 + "\n")

for epoch in range(FINETUNE_EPOCHS):
    print(f"Fine-tune Epoch [{epoch+1}/{FINETUNE_EPOCHS}]")

    # Train
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer_finetune, device)

    # Validate
    val_loss, val_mae, val_preds, val_targets = validate(model, val_loader, criterion, device)

    # Update scheduler
    scheduler_finetune.step(val_loss)

    # Store history
    finetune_history['train_loss'].append(train_loss)
    finetune_history['val_loss'].append(val_loss)
    finetune_history['val_mae'].append(val_mae)

    # Print results
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val MAE: {val_mae:.2f} cal")
    print(f"Current LR: {optimizer_finetune.param_groups[0]['lr']:.6f}")

    # Save latest checkpoint
    save_checkpoint(
        epoch=epoch+1,
        model=model,
        optimizer=optimizer_finetune,
        scheduler=scheduler_finetune,
        train_loss=train_loss,
        val_loss=val_loss,
        best_val_loss=BEST_FINETUNE_VAL_LOSS,
        filepath=BERT_FINETUNED_CHECKPOINT
    )

    # Save best model
    if val_mae < 65.41:  # Only save if better than original BERT
        if val_loss < BEST_FINETUNE_VAL_LOSS:
            BEST_FINETUNE_VAL_LOSS = val_loss
            PATIENCE_COUNTER_FT = 0
            save_checkpoint(
                epoch=epoch+1,
                model=model,
                optimizer=optimizer_finetune,
                scheduler=scheduler_finetune,
                train_loss=train_loss,
                val_loss=val_loss,
                best_val_loss=BEST_FINETUNE_VAL_LOSS,
                filepath=BERT_FINETUNED_BEST
            )
            print(f"üéâ NEW BEST! Val MAE: {val_mae:.2f} cal (improved from 65.41)")
    else:
        PATIENCE_COUNTER_FT += 1
        print(f"No improvement. Patience: {PATIENCE_COUNTER_FT}/{PATIENCE_FT}")

    # Early stopping
    if PATIENCE_COUNTER_FT >= PATIENCE_FT:
        print(f"\nEarly stopping at epoch {epoch+1}")
        break

    print("-" * 70)

print("\n" + "="*70)
print("=== FINE-TUNING COMPLETED ===")
print(f"Best Fine-tuned MAE: {min(finetune_history['val_mae']):.2f} cal")
print(f"Original BERT MAE: 65.41 cal")
print(f"Improvement: {65.41 - min(finetune_history['val_mae']):.2f} cal")
print("="*70)

In [None]:
# Load the best model (set weights_only=False for our trusted checkpoint)
best_model_path = os.path.join(CHECKPOINT_PATH, 'bert_best_model.pth')
checkpoint = torch.load(best_model_path, map_location=device, weights_only=False)

model.load_state_dict(checkpoint['model_state_dict'])
print(f"Best model loaded from epoch {checkpoint['epoch']}")
print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")

# Evaluate on test set
print("\n=== EVALUATING ON TEST SET ===")
test_loss, test_mae, test_preds, test_targets = validate(model, test_loader, criterion, device)

print(f"\nTest Loss (MSE): {test_loss:.4f}")
print(f"Test MAE: {test_mae:.2f} calories")
print(f"Test RMSE: {np.sqrt(test_loss):.2f} calories")

# Calculate additional metrics
mape = np.mean(np.abs((test_targets - test_preds) / test_targets)) * 100
r2_score = 1 - (np.sum((test_targets - test_preds)**2) / np.sum((test_targets - np.mean(test_targets))**2))

print(f"Test MAPE: {mape:.2f}%")
print(f"Test R¬≤ Score: {r2_score:.4f}")

# Visualize predictions vs actual
plt.figure(figsize=(15, 5))

# Scatter plot
plt.subplot(1, 3, 1)
plt.scatter(test_targets, test_preds, alpha=0.5)
plt.plot([test_targets.min(), test_targets.max()],
         [test_targets.min(), test_targets.max()],
         'r--', linewidth=2, label='Perfect Prediction')
plt.xlabel('Actual Calories')
plt.ylabel('Predicted Calories')
plt.title('Predictions vs Actual (Test Set)')
plt.legend()
plt.grid(True, alpha=0.3)

# Error distribution
plt.subplot(1, 3, 2)
errors = test_targets - test_preds
plt.hist(errors, bins=50, edgecolor='black')
plt.xlabel('Prediction Error (calories)')
plt.ylabel('Frequency')
plt.title('Distribution of Prediction Errors')
plt.axvline(x=0, color='r', linestyle='--', linewidth=2)
plt.grid(True, alpha=0.3)

# Absolute error distribution
plt.subplot(1, 3, 3)
abs_errors = np.abs(errors)
plt.hist(abs_errors, bins=50, edgecolor='black', color='orange')
plt.xlabel('Absolute Error (calories)')
plt.ylabel('Frequency')
plt.title('Absolute Error Distribution')
plt.axvline(x=test_mae, color='r', linestyle='--', linewidth=2, label=f'MAE: {test_mae:.1f}')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Get some example predictions
model.eval()
num_examples = 9

# Get random samples from test set
sample_indices = np.random.choice(len(test_dataset), num_examples, replace=False)

fig, axes = plt.subplots(3, 3, figsize=(18, 15))
axes = axes.flatten()

with torch.no_grad():
    for idx, sample_idx in enumerate(sample_indices):
        sample = test_dataset[sample_idx]

        # Prepare inputs
        image = sample['image'].unsqueeze(0).to(device)
        text_encoding = process_text_batch_bert([sample['text']])

        # Get prediction (BERT needs input_ids AND attention_mask)
        prediction = model(image, text_encoding['input_ids'], text_encoding['attention_mask']).item()
        actual = sample['calories'].item()

        # Load and display image
        img_path = test_df.iloc[sample_idx]['image_path']
        img = Image.open(img_path)

        axes[idx].imshow(img)
        axes[idx].axis('off')

        # Add prediction info
        error = abs(prediction - actual)
        title = f"Actual: {actual:.0f} cal\nPredicted: {prediction:.0f} cal\nError: {error:.0f} cal"
        color = 'green' if error < 50 else 'orange' if error < 100 else 'red'
        axes[idx].set_title(title, fontsize=10, color=color, weight='bold')

        # Show ingredients (truncated)
        ingredients = sample['text'][:80] + "..." if len(sample['text']) > 80 else sample['text']
        axes[idx].text(0.5, -0.05, ingredients, transform=axes[idx].transAxes,
                      fontsize=8, ha='center', va='top', style='italic')

plt.tight_layout()
plt.show()

print("=== PREDICTION EXAMPLES ===")
print("üü¢ Green: Error < 50 cal (Good)")
print("üü† Orange: Error 50-100 cal (Acceptable)")
print("üî¥ Red: Error > 100 cal (Needs improvement)")

In [None]:
# Analyze performance across different calorie ranges
calorie_ranges = [
    (0, 200, 'Low (0-200)'),
    (200, 400, 'Medium (200-400)'),
    (400, 600, 'High (400-600)'),
    (600, 1000, 'Very High (600+)')
]

print("=== PERFORMANCE BY CALORIE RANGE ===\n")

range_stats = []

for min_cal, max_cal, label in calorie_ranges:
    mask = (test_targets >= min_cal) & (test_targets < max_cal)
    if np.sum(mask) > 0:
        range_targets = test_targets[mask]
        range_preds = test_preds[mask]

        range_mae = np.mean(np.abs(range_targets - range_preds))
        range_mape = np.mean(np.abs((range_targets - range_preds) / range_targets)) * 100
        range_rmse = np.sqrt(np.mean((range_targets - range_preds)**2))

        print(f"{label}:")
        print(f"  Samples: {np.sum(mask)}")
        print(f"  MAE: {range_mae:.2f} cal")
        print(f"  RMSE: {range_rmse:.2f} cal")
        print(f"  MAPE: {range_mape:.2f}%")
        print()

        range_stats.append({
            'range': label,
            'samples': np.sum(mask),
            'mae': range_mae,
            'mape': range_mape
        })

# Visualize performance by range
range_df = pd.DataFrame(range_stats)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.bar(range_df['range'], range_df['mae'], color='steelblue', edgecolor='black')
plt.ylabel('MAE (calories)')
plt.title('Mean Absolute Error by Calorie Range')
plt.xticks(rotation=45, ha='right')
plt.grid(True, alpha=0.3, axis='y')

plt.subplot(1, 2, 2)
plt.bar(range_df['range'], range_df['samples'], color='orange', edgecolor='black')
plt.ylabel('Number of Samples')
plt.title('Sample Distribution by Calorie Range')
plt.xticks(rotation=45, ha='right')
plt.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

In [None]:
# Function to predict calories for a new dish image
def predict_calories_bert(image_path, ingredients_text, model, transform):
    """Predict calories using BERT model"""
    model.eval()

    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    # Process text with BERT
    text_encoding = process_text_batch_bert([ingredients_text])

    # Predict
    with torch.no_grad():
        prediction = model(
            image_tensor,
            text_encoding['input_ids'],
            text_encoding['attention_mask']
        ).item()

    return prediction

# Example usage - Test on a random test image
sample_idx = np.random.randint(0, len(test_df))
sample_row = test_df.iloc[sample_idx]

predicted = predict_calories_bert(
    image_path=sample_row['image_path'],
    ingredients_text=sample_row['text'],
    model=model,
    transform=val_transform
)

# Display result
img = Image.open(sample_row['image_path'])
plt.figure(figsize=(8, 8))
plt.imshow(img)
plt.axis('off')
plt.title(f"Predicted: {predicted:.0f} cal | Actual: {sample_row['calories']:.0f} cal\n"
          f"Ingredients: {sample_row['text'][:100]}...", fontsize=10)
plt.show()

print(f"‚úÖ Inference function ready!")
print(f"\nTo predict on a new image:")
print(f"  calories = predict_calories(image_path, ingredients, model, tokenizer, val_transform)")

In [None]:
# ========================================
# PREDICT ON YOUR OWN IMAGE
# ========================================

from google.colab import files
from IPython.display import display

print("=" * 70)
print("UPLOAD YOUR OWN FOOD IMAGE")
print("=" * 70)

# Step 1: Upload your image
print("\nüì§ Click 'Choose Files' to upload your food image...")
uploaded = files.upload()

# Get the uploaded file name
image_filename = list(uploaded.keys())[0]
print(f"\n‚úÖ Image uploaded: {image_filename}")

# Step 2: Enter ingredients
print("\n" + "=" * 70)
ingredients_input = input("üìù Enter ingredients (comma-separated): ")
# Example: rice, chicken, vegetables, olive oil

# Step 3: Make prediction
print("\nüîÆ Making prediction...")

predicted_calories = predict_calories_bert(
    image_path=image_filename,  # The uploaded file
    ingredients_text=ingredients_input,
    model=model,
    transform=val_transform
)

# Step 4: Display result
img = Image.open(image_filename)
plt.figure(figsize=(10, 8))
plt.imshow(img)
plt.axis('off')
plt.title(f"üçΩÔ∏è Predicted Calories: {predicted_calories:.0f} cal\n"
          f"Ingredients: {ingredients_input}",
          fontsize=14, weight='bold', color='darkgreen')
plt.show()

print("\n" + "=" * 70)
print(f"‚úÖ PREDICTION RESULT")
print("=" * 70)
print(f"Ingredients: {ingredients_input}")
print(f"Predicted Calories: {predicted_calories:.0f} cal")
print("=" * 70)