In [None]:
#MAIN

import os
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as data
from collections import deque
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.io import decode_image
from tqdm.notebook import tqdm
import torchvision.models as models
from torchvision.models import ResNet50_Weights
from sklearn.metrics import mean_absolute_error
import torch.optim as optim
import timm
import matplotlib.pyplot as plt
import numpy as np
import sys
import pickle
import h5py

from google.colab import auth, drive

auth.authenticate_user()
drive.mount('/content/drive', force_remount= True)

img_folder = "/content/drive/MyDrive/kevin/nutrition5k_cleaned/images/"
csv_file = "/content/drive/MyDrive/kevin/nutrition5k_cleaned/labels.csv"

print('Pytorch version: ', torch.__version__)
print('Pandas Version: ' , pd.__version__)
print('Numpy Version: ' , np.__version__)
print('Timm Version: ' , timm.__version__)

df = pd.read_csv(csv_file)
print(df.head())
print(df.columns)

os.listdir('/content')

Mounted at /content/drive
Pytorch version:  2.8.0+cu126
Pandas Version:  2.2.2
Numpy Version:  2.0.2
Timm Version:  1.0.20
   id  original_dish_id   calories
0 NaN        1556572657  41.399998
1 NaN        1556573514   6.440000
2 NaN        1556575014  71.299995
3 NaN        1556575083  27.520000
4 NaN        1556575124   4.480000
Index(['id', 'original_dish_id', 'calories'], dtype='object')


['.config', 'drive', 'sample_data']

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Running this cell trains model and saves it into a .pt file for the front end logic to run


import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import mean_absolute_error
from tqdm.notebook import tqdm
import os
import pandas as pd
from torchvision.io import decode_image
from torchvision.transforms import Compose, Resize, ConvertImageDtype, Normalize, Grayscale
import torch.utils.data as data

# === CNN Model for Multi-Input Regression ===
class MultiInputSnapCalCNN(nn.Module):
    def __init__(self):
        super(MultiInputSnapCalCNN, self).__init__()

        # Feature extraction for RGB (3 channels)
        self.features_rgb = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1), # 448->448
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),                             # 448->224

            nn.Conv2d(32, 64, kernel_size=3, padding=1), # 224->224
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),                             # 224->112
        )

        # Feature extraction for Heat and Depth (1 channel each)
        # They share initial layers to potentially learn similar low-level features
        self.features_mono = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1), # 448->448
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),                            # 448->224

            nn.Conv2d(16, 32, kernel_size=3, padding=1),# 224->224
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),                            # 224->112
        )

        # Further feature extraction after initial layers, specific to each modality
        self.features_rgb_cont = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),# 112->112
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),                             # 112->56

            nn.Conv2d(128, 256, kernel_size=3, padding=1),# 56->56
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),                              # 56->28
        )

        self.features_heat_cont = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1), # 112->112
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),                             # 112->56

            nn.Conv2d(64, 128, kernel_size=3, padding=1), # 56->56
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),                              # 56->28
        )

        self.features_depth_cont = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1), # 112->112
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),                             # 112->56

            nn.Conv2d(64, 128, kernel_size=3, padding=1), # 56->56
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),                              # 56->28
        )


        # Classifier
        # Calculate the total number of input features after flattening
        # Each modality outputs features of shape [Batch, Channels, 28, 28]
        # After GAP and Flatten: RGB=256, Heat=128, Depth=128
        total_flattened_features = 256 + 128 + 128 # Sum of final channels from each path

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(total_flattened_features, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)               # Output: calories
        )

        # Define AdaptiveAvgPool2d as a separate layer to apply before flattening
        self.gap = nn.AdaptiveAvgPool2d((1, 1))


    def forward(self, x_rgb, x_heat, x_depth):
        # Process each modality
        features_rgb = self.features_rgb(x_rgb)
        features_rgb = self.features_rgb_cont(features_rgb) # Continue processing RGB

        features_heat = self.features_mono(x_heat) # Initial processing for Heat
        features_heat = self.features_heat_cont(features_heat) # Continue processing Heat

        features_depth = self.features_mono(x_depth) # Initial processing for Depth
        features_depth = self.features_depth_cont(features_depth) # Continue processing Depth


        # Combine features before the classifier
        # Apply GAP before concatenation and flattening
        features_rgb = self.gap(features_rgb)
        features_heat = self.gap(features_heat)
        features_depth = self.gap(features_depth)

        # Flatten features
        features_rgb = features_rgb.view(features_rgb.size(0), -1)
        features_heat = features_heat.view(features_heat.size(0), -1)
        features_depth = features_depth.view(features_depth.size(0), -1)


        # Concatenate features along the channel dimension (dim 1)
        combined_features = torch.cat((features_rgb, features_heat, features_depth), dim=1)

        # Pass through the classifier
        output = self.classifier(combined_features)

        return output


# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define file paths (assuming these are correct from previous steps)
img_folder = "/content/drive/MyDrive/kevin/nutrition5k_cleaned/images/"
csv_file = "/content/drive/MyDrive/kevin/nutrition5k_cleaned/labels.csv"

# Define the transform for RGB images (ImageNet stats)
transform_rgb = Compose([
    Resize((448, 448)),  # match timm backbones
    ConvertImageDtype(torch.float32),
    Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]) # ImageNet stats
])

# Define a simple resize transform for 1-channel images
mono_transform = Compose([
    Resize((448, 448)),
    ConvertImageDtype(torch.float32),
    # No normalization here, handle in model or add separate mono normalization
])


class NutritionDataset(data.Dataset):
    def __init__(self, csv_file, img_dir, transform_rgb=None, mono_transform=None):
        self.labels = pd.read_csv(csv_file)

        # Drop rows with NaN in 'original_dish_id' or 'calories'
        self.labels = self.labels.dropna(subset=['original_dish_id', 'calories'])

        # Filter for positive calories
        self.labels = self.labels[self.labels['calories'] > 0]

        self.img_dir = img_dir
        self.transform_rgb = transform_rgb
        self.mono_transform = mono_transform

        # Modify filter to check for _rgbd.png and _gray.png
        initial_count = len(self.labels)
        self.labels = self.labels[self.labels['original_dish_id'].apply(
            lambda x: os.path.exists(os.path.join(self.img_dir, f"{int(x)}_rgbd.png")) and
                      os.path.exists(os.path.join(self.img_dir, f"{int(x)}_gray.png"))
        )
        ].reset_index(drop=True)

        filtered_count = len(self.labels)
        print(f"Labels shape after filtering (NaN, positive calories, and _rgbd.png + _gray.png existence): {self.labels.shape}")
        print(f"Filtered out {initial_count - filtered_count} samples due to missing files during initial check.")


    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        row = self.labels.iloc[idx]
        dish_id = int(row['original_dish_id'])

        # Construct image names based on available files
        img_name_rgbd = f"{dish_id}_rgbd.png"
        img_name_gray = f"{dish_id}_gray.png"
        img_name_rgb = f"{dish_id}_rgb.png" # Check if separate rgb exists

        # Construct image paths
        img_path_rgbd = os.path.join(self.img_dir, img_name_rgbd)
        img_path_gray = os.path.join(self.img_dir, img_name_gray)
        img_path_rgb = os.path.join(self.img_dir, img_name_rgb)

        # Return dummy data if there's an error loading any image
        dummy_rgb = torch.zeros(3, 448, 448, dtype=torch.float32)
        dummy_mono = torch.zeros(1, 448, 448, dtype=torch.float32)
        dummy_label = torch.tensor([0.0], dtype=torch.float32)
        dummy_dish_id = torch.tensor(0)


        try:
            # Load images and normalize to [0, 1]
            # Prioritize separate _rgb.png if it exists, otherwise extract from _rgbd.png
            try:
                if os.path.exists(img_path_rgb):
                     image_rgb = decode_image(img_path_rgb).float() / 255.0
                else:
                     # Assuming first 3 channels of _rgbd.png are RGB
                     image_rgbd_full = decode_image(img_path_rgbd).float() / 255.0
                     image_rgb = image_rgbd_full[:3, :, :] # Take only RGB channels
            except Exception as e:
                print(f"Warning: Error loading or processing RGB image for Dish ID {dish_id} from {img_path_rgb} or {img_path_rgbd}: {e}")
                return dummy_rgb, dummy_mono, dummy_mono, dummy_label, dummy_dish_id


            try:
                image_heat = decode_image(img_path_gray).float() / 255.0 # Load gray as heat
            except Exception as e:
                 print(f"Warning: Error loading or processing Heat image for Dish ID {dish_id} from {img_path_gray}: {e}")
                 return dummy_rgb, dummy_mono, dummy_mono, dummy_label, dummy_dish_id


            try:
                image_depth = decode_image(img_path_rgbd).float() / 255.0 # Load rgbd for depth
            except Exception as e:
                 print(f"Warning: Error loading or processing Depth image for Dish ID {dish_id} from {img_path_rgbd}: {e}")
                 return dummy_rgb, dummy_mono, dummy_mono, dummy_label, dummy_dish_id


            # Convert heat and depth to single channel if they are loaded as 3 channels
            if image_heat.shape[0] == 3:
                # Use Grayscale transform to convert to 1 channel
                image_heat = Grayscale(num_output_channels=1)(image_heat)
            if image_depth.shape[0] == 3:
                 # Assuming depth is encoded in intensity for 3-channel gray, take one channel
                 image_depth = Grayscale(num_output_channels=1)(image_depth)
            elif image_depth.shape[0] > 3:
                 # Assuming depth is the 4th channel in rgbd
                 image_depth = image_depth[3:4, :, :] # Take only the 4th channel
            elif image_depth.shape[0] == 1:
                 # It's already 1 channel, do nothing
                 pass
            else:
                 print(f"Warning: Unexpected number of channels ({image_depth.shape[0]}) for depth image for Dish ID {dish_id}.")
                 return dummy_rgb, dummy_mono, dummy_mono, dummy_label, dummy_dish_id


            # Apply transform
            if self.transform_rgb:
                image_rgb = self.transform_rgb(image_rgb)

            if self.mono_transform:
                image_heat = self.mono_transform(image_heat)
                image_depth = self.mono_transform(image_depth)


            # Label (calories)
            label = torch.tensor([row['calories']], dtype=torch.float32)

            # Return RGB, Heat, Depth, label, and dish_id
            return image_rgb, image_heat, image_depth, label, dish_id

        except Exception as e:
            print(f"Warning: An unexpected error occurred for Dish ID {dish_id}: {e}")
            return dummy_rgb, dummy_mono, dummy_mono, dummy_label, dummy_dish_id


# Instantiate dataset
dataset = NutritionDataset(csv_file, img_folder, transform_rgb=transform_rgb, mono_transform=mono_transform)

# Check if the dataset is empty
print(f"Dataset size after filtering: {len(dataset)}")

if len(dataset) > 0:
    # Split dataset
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = data.random_split(dataset, [train_size, val_size])

    # Create DataLoaders (reduced num_workers for debugging)
    train_loader = data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0) # Reduced batch size and num_workers
    val_loader = data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0) # Reduced batch size and num_workers

    # Verify output shapes (optional, as done in a previous cell)
    # images_rgb, images_heat, images_depth, labels, dish_ids = next(iter(train_loader))
    # print("Image RGB batch shape:", images_rgb.shape)
    # print("Image Heat batch shape:", images_heat.shape)
    # print("Image Depth batch shape:", images_depth.shape)
    # print("Label batch shape:", labels.shape)
    # print("Dish IDs shape:", dish_ids.shape)

# Assume the following are already defined and available from previous cells:
# criterion (nn.MSELoss)
# checkpoint_dir, best_model_path, best_val_loss, start_epoch (for checkpointing)

# Directory to save checkpoints
checkpoint_dir = "/content/drive/MyDrive/nutrition_model_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

best_val_loss = float('inf')  # initialize best validation loss
best_model_path = os.path.join(checkpoint_dir, "best_multi_input_model.pt") # Use a distinct name
start_epoch = 0 # initialize start epoch

# Instantiate the multi-input model
model = MultiInputSnapCalCNN()
model.to(device)

# Optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)


# === Optionally Resume from Best Checkpoint ===
# Resume from the best checkpoint if it exists and resume_from_best is True
resume_from_best = True # Set to False if you want to start fresh
if resume_from_best and os.path.exists(best_model_path):
    print(f"Attempting to load checkpoint from {best_model_path}")
    try:
        # Load the checkpoint
        checkpoint = torch.load(best_model_path)

        # Load model state dictionary
        model.load_state_dict(checkpoint['model_state_dict'])

        # Load optimizer state dictionary (if saved)
        if 'optimizer_state_dict' in checkpoint:
             optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Load scheduler state dictionary (if saved)
        if 'scheduler_state_dict' in checkpoint:
             scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        # Load best validation loss and starting epoch
        if 'best_val_loss' in checkpoint:
             best_val_loss = checkpoint['best_val_loss']
        if 'epoch' in checkpoint:
             start_epoch = checkpoint['epoch'] + 1 # Start from the next epoch


        print(f"✅ Loaded checkpoint from {best_model_path}")
        print(f"Resuming from epoch {start_epoch} with best validation loss {best_val_loss:.3f}")


    except Exception as e:
        print(f"❌ Error loading checkpoint: {e}")
        print("Starting training from scratch.")
        # Reset if loading failed
        best_val_loss = float('inf')
        start_epoch = 0


# === Training Loop ===
num_epochs = 15 # Total number of epochs to run
epochs_to_run = num_epochs - start_epoch # Calculate remaining epochs

if epochs_to_run <= 0:
    print("Training already completed for the specified number of epochs.")
else:
    best_val_mae = float("inf")    # best validation MAE (re-initialize if not loaded from checkpoint)


    for epoch in tqdm(range(start_epoch, start_epoch + epochs_to_run), desc="Training Progress"):
        model.train()
        train_losses = []
        train_mae = []

        # Retrieve three images, labels, and dish_ids from the loader
        for images_rgb, images_heat, images_depth, labels, _ in tqdm(train_loader, leave=False, desc=f"Epoch {epoch+1} Training"):
            # Move all tensors to the device
            images_rgb = images_rgb.to(device)
            images_heat = images_heat.to(device)
            images_depth = images_depth.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Pass the three image tensors to the model's forward method
            outputs = model(images_rgb, images_heat, images_depth)

            # Ensure outputs and labels have compatible shapes for loss calculation
            # Assuming the model outputs [batch_size, 1] and labels are [batch_size, 1]
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            train_mae.append(mean_absolute_error(labels.cpu().numpy(), outputs.detach().cpu().numpy()))

        avg_train_loss = np.mean(train_losses)
        avg_train_mae = np.mean(train_mae)

        # --- Validation ---
        model.eval()
        running_loss = 0.0
        val_mae = []
        val_mape = []

        with torch.no_grad():
            # Retrieve three images, labels, and dish_ids from the loader
            for images_rgb, images_heat, images_depth, labels, _ in tqdm(val_loader, leave=False, desc=f"Epoch {epoch+1} Validation"):
                # Move all tensors to the device
                images_rgb = images_rgb.to(device)
                images_heat = images_heat.to(device)
                images_depth = images_depth.to(device)
                labels = labels.to(device)

                # Pass the three image tensors to the model's forward method
                outputs = model(images_rgb, images_heat, images_depth)

                # Ensure outputs and labels have compatible shapes for loss calculation
                # Assuming the model outputs [batch_size, 1] and labels are [batch_size, 1]
                loss = criterion(outputs, labels)
                running_loss += loss.item() * images_rgb.size(0) # Use size(0) of any image tensor


                # Metric calculations
                val_mae.append(mean_absolute_error(labels.cpu().numpy(), outputs.detach().cpu().numpy()))

                # Calculate MAPE, handling potential division by zero if label is 0
                absolute_error = np.abs(labels.cpu().numpy() - outputs.detach().cpu().numpy())
                # Add a small epsilon to avoid division by zero for MAPE calculation
                percentage_error = absolute_error / (labels.cpu().numpy() + 1e-6)
                val_mape.append(np.mean(percentage_error))


        avg_val_loss = running_loss / len(val_loader.dataset)
        avg_val_mae = np.mean(val_mae)
        avg_val_mape = np.mean(val_mape)

        print(
            f"\nEpoch [{epoch+1}/{start_epoch+num_epochs}] " # Corrected epoch display
            f"| Train Loss (MSE): {avg_train_loss:.3f} | Train MAE: {avg_train_mae:.2f} kcal "
            f"|| Val Loss (MSE): {avg_val_loss:.3f} | Val MAE: {avg_val_mae:.2f} kcal | Val MAPE: {avg_val_mape:.2f}%"
        )

        # --- Save best checkpoint ---
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            # Save the checkpoint including model state, optimizer state, best loss, and epoch
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': best_val_loss,
            }
            torch.save(checkpoint, best_model_path)
            print(f"✅ Best model checkpoint saved to {best_model_path} (Val Loss: {best_val_loss:.3f})")


        scheduler.step()

Using device: cpu
Labels shape after filtering (NaN, positive calories, and _rgbd.png + _gray.png existence): (3260, 3)
Filtered out 3 samples due to missing files during initial check.
Dataset size after filtering: 3260
Attempting to load checkpoint from /content/drive/MyDrive/nutrition_model_checkpoints/best_multi_input_model.pt
✅ Loaded checkpoint from /content/drive/MyDrive/nutrition_model_checkpoints/best_multi_input_model.pt
Resuming from epoch 4 with best validation loss 16747.749


Training Progress:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 5 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [5/19] | Train Loss (MSE): 22785.022 | Train MAE: 107.45 kcal || Val Loss (MSE): 53381.977 | Val MAE: 132.72 kcal | Val MAPE: 0.63%


Epoch 6 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 6 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [6/19] | Train Loss (MSE): 21004.613 | Train MAE: 103.76 kcal || Val Loss (MSE): 60530.929 | Val MAE: 151.53 kcal | Val MAPE: 0.70%


Epoch 7 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 7 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [7/19] | Train Loss (MSE): 20797.930 | Train MAE: 102.09 kcal || Val Loss (MSE): 74414.758 | Val MAE: 177.91 kcal | Val MAPE: 0.76%


Epoch 8 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 8 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [8/19] | Train Loss (MSE): 20482.376 | Train MAE: 100.10 kcal || Val Loss (MSE): 37778.299 | Val MAE: 93.38 kcal | Val MAPE: 0.94%


Epoch 9 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 9 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [9/19] | Train Loss (MSE): 18659.162 | Train MAE: 97.48 kcal || Val Loss (MSE): 45732.007 | Val MAE: 112.46 kcal | Val MAPE: 0.53%


Epoch 10 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 10 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [10/19] | Train Loss (MSE): 18212.121 | Train MAE: 94.58 kcal || Val Loss (MSE): 40370.802 | Val MAE: 99.72 kcal | Val MAPE: 0.52%


Epoch 11 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 11 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [11/19] | Train Loss (MSE): 17684.164 | Train MAE: 94.37 kcal || Val Loss (MSE): 56787.999 | Val MAE: 141.82 kcal | Val MAPE: 0.65%


Epoch 12 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 12 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [12/19] | Train Loss (MSE): 17303.849 | Train MAE: 92.16 kcal || Val Loss (MSE): 41879.184 | Val MAE: 101.94 kcal | Val MAPE: 0.52%


Epoch 13 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 13 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [13/19] | Train Loss (MSE): 16588.556 | Train MAE: 89.94 kcal || Val Loss (MSE): 47957.079 | Val MAE: 116.01 kcal | Val MAPE: 0.51%


Epoch 14 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 14 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [14/19] | Train Loss (MSE): 16203.672 | Train MAE: 89.86 kcal || Val Loss (MSE): 48206.403 | Val MAE: 119.15 kcal | Val MAPE: 0.54%


Epoch 15 Training:   0%|          | 0/326 [00:00<?, ?it/s]



Epoch 15 Validation:   0%|          | 0/82 [00:00<?, ?it/s]


Epoch [15/19] | Train Loss (MSE): 15932.106 | Train MAE: 88.71 kcal || Val Loss (MSE): 51553.917 | Val MAE: 128.20 kcal | Val MAPE: 0.57%


In [None]:
#FRONT END

# Removed pip install matplotlib here
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import io
import pickle
import numpy as np


st.set_page_config(page_title="SnapCal", layout="centered", initial_sidebar_state="auto")

app_py_content = r"""
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import io
import pickle
import numpy as np
import os
from torchvision.transforms import Grayscale
import cv2 # Keep cv2 import for image processing
# Removed matplotlib.pyplot import

# Snapchat-inspired Styling
st.set_page_config(page_title="SnapCal", layout="centered", initial_sidebar_state="auto")

st.markdown('''
    <style>
    .stApp > header, .stApp > div > div:first-child {
        background-color: #FFFC00;
    }
    .stApp > div > div {
        background-color: #FBBC04;
        padding: 30px;
        border-radius: 15px;
        box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
    }
    .stButton>button {
        background-color: #00FF00 !important;
        color: white !important;
        font-size: 18px !important;
        padding: 12px 28px !important;
        border-radius: 8px !important;
        border: none !important;
        cursor: pointer !important;
        margin-top: 15px;
        margin-bottom: 15px;
    }
    .stButton>button:hover {
        background-color: #00E600 !important;
    }
    h1 {
        color: #1e3a8a;
        text-align: center;
        margin-bottom: 20px;
        font-size: 900px;
        font-weight: bold;
        text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.2);
        letter-spacing: 2px;
        font-italic: italic;
    }
     h2, h3, h4, h5, h6 {
        color: #1e3a8a;
        margin-top: 20px;
        margin-bottom: 10px;
    }
    .stMarkdown, .stText, .stException {
        color: #000000 !important;
        font-size: 16px;
        line-height: 1.6;
    }
    .stFileUploader label, .stCameraInput label {
        font-size: 18px;
        color: #1e3a8a;
        margin-bottom: 10px;
        display: block;
    }
    .stFileUploader div[data-testid="stFileUploaderDropzone"], .stCameraInput div[data-testid="stCameraInputButton"] {
        border: 2px dashed #1e3a8a;
        padding: 20px;
        border-radius: 8px;
        background-color: #ffffcc;
        text-align: center;
    }
     .stFileUploader div[data-testid="stFileUploaderDropzone"] p, .stCameraInput div[data-testid="stCameraInputButton"] p {
        color: #1e3a8a;
        font-size: 16px;
    }
    .stCaption {
        text-align: center;
        font-style: italic;
        color: #555555;
        margin-top: 5px;
    }
    hr {
        border-top: 2px solid #1e3a8a;
        margin-top: 25px;
        margin-bottom: 25px;
    }
    </style>
    ''', unsafe_allow_html=True)

if 'show_camera' not in st.session_state:
    st.session_state.show_camera = False

class MultiInputSnapCalCNN(nn.Module):
    def __init__(self):
        super(MultiInputSnapCalCNN, self).__init__()

        self.features_rgb = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.features_mono = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.features_rgb_cont = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.features_heat_cont = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.features_depth_cont = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        total_flattened_features = 256 + 128 + 128
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(total_flattened_features, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )
        self.gap = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x_rgb, x_heat, x_depth):
        features_rgb = self.features_rgb(x_rgb)
        features_rgb = self.features_rgb_cont(features_rgb)

        features_heat = self.features_mono(x_heat)
        features_heat = self.features_heat_cont(features_heat)

        features_depth = self.features_mono(x_depth)
        features_depth = self.features_depth_cont(features_depth)

        features_rgb = self.gap(features_rgb)
        features_heat = self.gap(features_heat)
        features_depth = self.gap(features_depth)

        features_rgb = features_rgb.view(features_rgb.size(0), -1)
        features_heat = features_heat.view(features_heat.size(0), -1)
        features_depth = features_depth.view(features_depth.size(0), -1)

        combined_features = torch.cat((features_rgb, features_heat, features_depth), dim=1)
        output = self.classifier(combined_features)
        return output

@st.cache_resource
def load_multi_input_model(filename):
    st.info(f"Attempting to load model state dictionary from {filename}")
    try:
        model = MultiInputSnapCalCNN()
        checkpoint = torch.load(filename, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        st.success("MultiInputSnapCalCNN model state dictionary loaded successfully!")
        return model
    except FileNotFoundError:
        st.error(f"Model file not found at {filename}")
        return None
    except KeyError:
        st.error(f"Checkpoint file {filename} does not contain 'model_state_dict'.")
        return None
    except Exception as e:
        st.error(f"Could not load model state dictionary: {e}")
        return None

model_state_dict_path = "/content/drive/MyDrive/nutrition_model_checkpoints/best_multi_input_model.pt"

model = load_multi_input_model(model_state_dict_path)

if model is None:
     st.stop()

transform_rgb = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

mono_transform = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ToTensor(),
])

grayscale_transform = transforms.Grayscale(num_output_channels=1)

def generate_mono_placeholders(pil_image):
    # Removed matplotlib import from here
    import cv2 # Moved import inside function if needed, but it's already at the top of app_py_content

    img_np = np.array(pil_image.convert('L'))

    # Generate "Heat" approximation using OpenCV colormap
    # Apply a pseudocolor map (e.g., COLORMAP_HOT)
    heatmap_colored_np = cv2.applyColorMap(img_np, cv2.COLORMAP_HOT)
    # Convert the 3-channel colored numpy array back to a 1-channel grayscale PIL image
    heat_pil = Image.fromarray(heatmap_colored_np).convert('L')

    # Generate "Depth" approximation using Gradient Magnitude (already uses cv2)
    grad_x = cv2.Sobel(img_np, cv2.CV_64F, 1, 0, ksize=5)
    grad_y = cv2.Sobel(img_np, cv2.CV_64F, 0, 1, ksize=5)
    gradient_magnitude = cv2.magnitude(grad_x, grad_y)
    gradient_magnitude = cv2.normalize(gradient_magnitude, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
    depth_np = 255 - gradient_magnitude
    depth_pil = Image.fromarray(depth_np).convert('L')

    return heat_pil, depth_pil


st.title('SnapCal')

st.write("Upload your meal!🍴 or take a picture to estimate its calories.")

if st.button("Camera"):
    st.session_state.show_camera = True

uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])

camera_image = None
if st.session_state.show_camera:
    camera_image = st.camera_input("Or take a picture!")


image = None
if uploaded_file is not None:
    image = Image.open(uploaded_file).convert("RGB")
elif camera_image is not None:
    image = Image.open(camera_image).convert("RGB")


if image is not None and st.button('Estimate Calories'):
    try:
        image_rgb_tensor = transform_rgb(image).unsqueeze(0)

        heat_pil, depth_pil = generate_mono_placeholders(image)

        image_heat_tensor = mono_transform(heat_pil).unsqueeze(0)
        image_depth_tensor = mono_transform(depth_pil).unsqueeze(0)

        model.to('cpu')
        model.eval()

        with torch.no_grad():
            prediction = model(image_rgb_tensor, image_heat_tensor, image_depth_tensor)
            estimated_calories = prediction.item()

        st.write("---")
        st.subheader("Processed Inputs to the Model")
        col1, col2, col3 = st.columns(3)
        with col1:
            st.image(image, caption="Original RGB", use_column_width=True)
        with col2:
            st.image(heat_pil, caption="Generated Heat Placeholder", use_column_width=True)
        with col3:
            st.image(depth_pil, caption="Generated Depth Placeholder", use_column_width=True)
        st.write("---")

        st.success(f"Estimated Calories: {estimated_calories:.2f} kcal")
        st.warning("Note: Heat and Depth inputs shown above are generated from the RGB image using simple image processing techniques as placeholders. Accuracy may be limited compared to using actual multi-modal data.")

    except Exception as e:
        st.error(f"Error during prediction: {e}")

"""

with open("app.py", "w") as f:
    f.write(app_py_content)



In [None]:
#Initialize Front end with port and link

# Stop any potentially lingering streamlit processes more aggressively
!kill -9 $(lsof -i :8501 -t) >/dev/null 2>&1 || true
!kill -9 $(lsof -i :8502 -t) >/dev/null 2>&1 || true
!pkill -f streamlit
print("Attempted to kill all streamlit processes.")

# Add a small delay to ensure processes are stopped
import time
time.sleep(7) # Increased sleep time

# Install necessary libraries before running the app
!pip install pyngrok streamlit opencv-python

# Add another delay after installation
time.sleep(10)


# Restart the streamlit app on a new port
!streamlit run app.py --server.port 8503 > streamlit.log 2>&1 &

from pyngrok import ngrok
import os
from google.colab import userdata
import time

# Get your authtoken from Colab secrets
NGROK_AUTH_TOKEN = userdata.get('NGROK_AUTH_TOKEN')
if NGROK_AUTH_TOKEN is None:
    print("NGROK_AUTH_TOKEN not found in Colab secrets. Please add it.")
else:
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)

    print("Waiting for Streamlit app to start...")
    # Increased delay to give Streamlit more time to start and ngrok to connect
    time.sleep(15)

    try:
        # Kill existing ngrok tunnels before connecting
        ngrok.kill()
        print("Killed existing ngrok tunnels before starting a new one.")
        # Connect to the new port 8503
        public_url = ngrok.connect(8503)
        print(f"Streamlit app URL: {public_url}")
    except Exception as e:
        print(f"Error connecting ngrok: {e}")
        print("Please check if Streamlit is running and accessible on port 8503.")

Attempted to kill all streamlit processes.








Waiting for Streamlit app to start...
Killed existing ngrok tunnels before starting a new one.
Streamlit app URL: NgrokTunnel: "https://incommunicative-funnily-melda.ngrok-free.dev" -> "http://localhost:8503"
