# AI Nutritional Agent

This notebook will try to implement the AI nutritional model proposed in the paper *AI nutrition recommendation using a deep generative model and ChatGPT* by *Ilias Papastratis , Dimitrios Konstantinidis , Petros Daras & Kosmas Dimitropoulos*.

# Imports

In [1]:
import numpy as np
import pandas as pd

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary

In [2]:
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using mps device


#### Data Loader

In [3]:
class MealPlanningDataset(Dataset):
    def __init__(self, csv_file="../datasets/synthetic_nutrition_data.csv"):
        if not csv_file : 
            csv_file = "../datasets/synthetic_nutrition_data.csv"
        data : np.ndarray = pd.read_csv(csv_file)

        self.X : np.ndarray = data[["weight","height","BMI","BMR","PAL","has_CVD","has_T2D","has_iron_def"]].values.astype('float32')

        self.Y_meals: np.ndarray = data[['meal_1', 'meal_2', 'meal_3', 'meal_4', 'meal_5', 'meal_6']].values.astype('long')

        self.target_EI: np.ndarray = data[['target_EI']].values.astype('float32')

        self.min_macros: np.ndarray = data[['min_prot', 'min_carb', 'min_fat', 'min_sfa']].values.astype('float32')
        self.max_macros: np.ndarray = data[['max_prot', 'max_carb', 'max_fat', 'max_sfa']].values.astype('float32')


    def __len__(self):
        return len(self.X)


    def __getitem__(self, idx):
        return (
            torch.tensor(self.X[idx]),           # X_features
            torch.tensor(self.Y_meals[idx]),     # Y_meals
            torch.tensor(self.target_EI[idx]),   # target_EIs
            torch.tensor(self.min_macros[idx]),  # min_macros
            torch.tensor(self.max_macros[idx])   # max_macros
        )

In [4]:
dataset = MealPlanningDataset()
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [5]:
# Check shapes
for X_features, Y_meals, target_EIs, min_macros, max_macros in dataloader:
    print("User information shape : ", X_features.size())
    print("User ground meal plan shape : ", Y_meals.size())
    print("Energy target shape : ", target_EIs.size())
    print("Minimum macronutriments shape : ", min_macros.size())
    print("Maximum macronutriments shape : ", max_macros.size())
    break

User information shape :  torch.Size([64, 8])
User ground meal plan shape :  torch.Size([64, 6])
Energy target shape :  torch.Size([64, 1])
Minimum macronutriments shape :  torch.Size([64, 4])
Maximum macronutriments shape :  torch.Size([64, 4])


## Encoder

#### Code

In [6]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        # fc1: projects input features to a hidden representation.
        #   Input: [batch_size, input_dim] → Output: [batch_size, hidden_dim]
        self.fc1 = nn.Linear(in_features=input_dim, out_features=hidden_dim)
        
        # fc2: further processes the hidden representation.
        #   Input: [batch_size, hidden_dim] → Output: [batch_size, hidden_dim]
        self.fc2 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)

        # fc_mu: computes the mean of the latent distribution.
        #   Input: [batch_size, hidden_dim] → Output: [batch_size, latent_dim]
        self.fc_mu = nn.Linear(in_features=hidden_dim, out_features=latent_dim)
        
        # fc_logvar: computes the log-variance of the latent distribution.
        #   Input: [batch_size, hidden_dim] → Output: [batch_size, latent_dim]
        self.fc_logvar = nn.Linear(in_features=hidden_dim, out_features=latent_dim)
    
    def forward(self, x):
        # Apply fc1 with ReLU activation.
        #   x: [batch_size, input_dim] → [batch_size, hidden_dim]
        x = F.relu(self.fc1(x))
        
        # Apply fc2 with ReLU activation.
        #   x: [batch_size, hidden_dim] → h: [batch_size, hidden_dim]
        h = F.relu(self.fc2(x))
        
        # Compute the latent mean.
        #   h: [batch_size, hidden_dim] → mu: [batch_size, latent_dim]
        mu = self.fc_mu(h)
        
        # Compute the latent log-variance.
        #   h: [batch_size, hidden_dim] → logvar: [batch_size, latent_dim]
        logvar = self.fc_logvar(h)
        
        return mu, logvar

#### Summary

In [7]:
input_dim = 8
hidden_units = 62
latent_dim = 16

encoder = Encoder(input_dim, hidden_units, latent_dim)

In [8]:
summary(encoder, (input_dim,), col_names=["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
Encoder                                  [8]                       [16]                      --
├─Linear: 1-1                            [8]                       [62]                      558
├─Linear: 1-2                            [62]                      [62]                      3,906
├─Linear: 1-3                            [62]                      [16]                      1,008
├─Linear: 1-4                            [62]                      [16]                      1,008
Total params: 6,480
Trainable params: 6,480
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.31
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.03
Estimated Total Size (MB): 0.03

## Decoder

#### Code

In [9]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_units, num_classes, macro_dim):
        super(Decoder, self).__init__()
        self.hidden_units = hidden_units
        self.macro_dim = macro_dim

        # Projects latent vector (shape: [batch_size, latent_dim]) to hidden space (shape: [batch_size, hidden_dim])
        self.latent_to_hidden = nn.Linear(in_features=latent_dim, out_features=hidden_units)

        # GRUCell that takes an input of shape [batch_size, hidden_dim] and outputs a hidden state of the same shape
        self.gru1 = nn.GRUCell(input_size=hidden_units, hidden_size=hidden_units)
        self.gru2 = nn.GRUCell(input_size=hidden_units, hidden_size=hidden_units)

        # Classifier head: maps hidden state [batch_size, hidden_dim] to class logits [batch_size, num_classes]
        self.classifier = nn.Linear(in_features=hidden_units, out_features=num_classes)
        # Energy head: maps hidden state [batch_size, hidden_dim] to a scalar energy [batch_size, 1]
        self.energy_head = nn.Linear(in_features=hidden_units, out_features=1)
        # Macro head: maps hidden state [batch_size, hidden_dim] to macro outputs [batch_size, macro_dim]
        self.macro_head = nn.Linear(in_features=hidden_units, out_features=macro_dim) 
    
    def forward(self, z):
        """
        Args:
            z (torch.Tensor): Latent vector of shape [batch_size, latent_dim]
            
        Returns:
            class_logits_seq (torch.Tensor): Sequence of class logits, shape [batch_size, T, num_classes]
            total_energy (torch.Tensor): Summed energy over T time steps, shape [batch_size, 1]
            total_macros (torch.Tensor): Accumulated macro outputs over T time steps, shape [batch_size, macro_dim]
            energies_tensor (torch.Tensor): Sequence of energy values, shape [batch_size, T, 1]
        """
        batch_size = z.size(0)
        T = 6  # Number of GRU time steps


        # Initialize hidden state for GRUCell with zeros, shape: [batch_size, hidden_dim]
        h1 = torch.zeros(size=(batch_size, self.hidden_units), device=z.device)
        h2 = torch.zeros(size=(batch_size, self.hidden_units), device=z.device)
        h_prev = h2

        # Project latent vector to hidden space (input for GRU at t=0), shape: [batch_size, hidden_dim]
        z_projected = self.latent_to_hidden(z)
         
        class_logits_seq = []
        energies_list = []
        # Initialize accumulation for macro outputs, shape: [batch_size, macro_dim]
        total_macros = torch.zeros(batch_size, self.macro_dim, device=z.device)

        for t in range(T):
            # For t=0, use the projected latent vector; for t>0, use previous hidden state as input.
            if t == 0:
                z = z_projected
            else:
                z = h_prev

            # GRUCell update: input z and previous hidden state h, both of shape [batch_size, hidden_dim]
            h1 = self.gru1(z, h1)
            h2 = self.gru2(h1, h2)

            # Compute outputs from the current hidden state
            logits = self.classifier(h2)    # Shape: [batch_size, num_classes]
            energy = self.energy_head(h2)     # Shape: [batch_size, 1]
            macros = self.macro_head(h2)      # Shape: [batch_size, macro_dim]

            class_logits_seq.append(logits)
            energies_list.append(energy)
            total_macros += macros  # Accumulate macro outputs over time steps

            h_prev = h2  # Save current hidden state for next iteration

        # Stack list of tensors along a new time dimension: [batch_size, T, num_classes] and [batch_size, T, 1]
        class_logits_seq = torch.stack(class_logits_seq, dim=1)
        energies_tensor = torch.stack(energies_list, dim=1)
        # Sum energy values over the time dimension, resulting in shape: [batch_size, 1]
        total_energy = energies_tensor.sum(dim=1)

        return class_logits_seq, total_energy, total_macros, energies_tensor

#### Summary

In [10]:
latent_dim = 16
hidden_units = 25
num_classes = 140
macro_dim = 5
batch_size = 50

decoder = Decoder(latent_dim, hidden_units, num_classes, macro_dim)

In [11]:
summary(decoder, (batch_size, latent_dim), col_names=["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
Decoder                                  [50, 16]                  [50, 6, 140]              --
├─Linear: 1-1                            [50, 16]                  [50, 25]                  425
├─GRUCell: 1-2                           [50, 25]                  [50, 25]                  3,900
├─GRUCell: 1-3                           [50, 25]                  [50, 25]                  3,900
├─Linear: 1-4                            [50, 25]                  [50, 140]                 3,640
├─Linear: 1-5                            [50, 25]                  [50, 1]                   26
├─Linear: 1-6                            [50, 25]                  [50, 5]                   130
├─GRUCell: 1-7                           [50, 25]                  [50, 25]                  (recursive)
├─GRUCell: 1-8                           [50, 25]                  [50, 25]                  (recursive)
├─Line

## Pre-Training

In [12]:
def adjust_meal_quantity(energies, batch_target_EI):
    pred_total_energy = energies.sum(dim=1)
    d = (batch_target_EI - pred_total_energy) / pred_total_energy
    d_expanded = d.unsqueeze(1)
    adjusted_energies = energies * (1 + d_expanded)
    new_total_energy = adjusted_energies.sum(dim=1)
    return adjusted_energies, new_total_energy   

In [13]:
def compute_L_macro(batch_min_macros, batch_max_macros, pred_macros):
    diff_min = torch.abs(batch_min_macros - pred_macros)
    diff_max = torch.abs(batch_max_macros - pred_macros)
    macro_penalty = diff_min + diff_max
    L_macro = macro_penalty.mean()
    return L_macro

In [14]:
def compute_L_energy(pred_energy, batch_target_EI):
    L_energy = F.mse_loss(pred_energy, batch_target_EI)
    return L_energy

In [15]:
def compute_KLD(mu, logvar, batch_size):
    KLD_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    KLD_loss = KLD_loss / batch_size
    return KLD_loss

In [16]:
def compute_L_MC(class_logits, batch_Y):
    T = class_logits.size(1)
    CE_loss = 0.0
    for t in range(T):
        CE_loss += F.cross_entropy(class_logits[:, t, :], batch_Y[:, t])
    CE_loss = CE_loss / T
    return CE_loss

In [17]:
def train_one_epoch(encoder, decoder, dataloader, optimizer, device='cpu', print_freq=20):
    """
    Trains the model for one epoch over the given dataloader.
    Returns the average loss for the epoch.
    """
    encoder.train()
    decoder.train()
    
    total_loss = 0.0

    for batch_idx, (X_features, Y_meals, target_EIs, min_macros, max_macros) in enumerate(dataloader):
        # Move data to device (if using GPU)
        X_features = X_features.to(device)
        Y_meals    = Y_meals.to(device)
        target_EIs = target_EIs.to(device)
        min_macros = min_macros.to(device)
        max_macros = max_macros.to(device)

        optimizer.zero_grad()
        
        # Forward pass through encoder
        mu, logvar = encoder(X_features)
        
        # Reparameterization trick:
        # std = e^(0.5 * logvar), then z = mu + std * epsilon
        epsilon = torch.randn_like(logvar)
        std = torch.exp(0.5 * logvar)
        z = mu + std * epsilon
        
        # Forward pass through decoder
        class_logits, pred_energy, pred_macros, energies_tensor = decoder(z)
        
        # Compute losses
        L_macro = compute_L_macro(min_macros, max_macros, pred_macros)
        L_energy = compute_L_energy(pred_energy, target_EIs)
        L_kld = compute_KLD(mu, logvar, X_features.size(0))
        L_mc = compute_L_MC(class_logits, Y_meals)

        # Total loss
        loss = L_macro + L_energy + L_kld + L_mc
        
        # Backprop
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if (batch_idx + 1) % print_freq == 0:
            print(f"Batch {batch_idx+1}/{len(dataloader)} | "
                  f"Loss: {loss.item():.4f} | "
                  f"Macro: {L_macro.item():.4f}, "
                  f"Energy: {L_energy.item():.4f}, "
                  f"KLD: {L_kld.item():.4f}, "
                  f"MC: {L_mc.item():.4f}")

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch complete. Average Loss: {avg_loss:.4f}\n")
    return avg_loss

In [18]:
def training(batch_size=64, 
             hidden_dim=256, 
             latent_dim=256, 
             hidden_units=512, 
             epochs=10, 
             csv_file=None, 
             device='cpu'):
    """
    Trains the VAE model for personalized meal planning.
    Returns the trained encoder, decoder, and a list of average epoch losses.
    """

    # 1) Dataset & DataLoader
    dataset = MealPlanningDataset(csv_file=csv_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # 2) Infer shape information from one batch
    sample_batch = next(iter(dataloader))
    X_sample, _, _, min_macro_sample, _ = sample_batch
    
    input_dim = X_sample.shape[1]
    macro_dim = min_macro_sample.shape[1]
    num_classes = 10  # Adjust if needed

    print(f"Model Dimensions: input_dim={input_dim}, num_classes={num_classes}, macro_dim={macro_dim}")

    # 3) Instantiate model parts & move to device
    encoder = Encoder(input_dim=input_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
    decoder = Decoder(latent_dim=latent_dim, hidden_units=hidden_units, 
                      num_classes=num_classes, macro_dim=macro_dim)
    encoder.to(device)
    decoder.to(device)

    # 4) Optimizer
    optimizer = Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)

    loss_seq = []

    # 5) Training loop
    for epoch in range(epochs):
        print(f"--> Epoch {epoch+1}/{epochs}")
        avg_loss = train_one_epoch(encoder, decoder, dataloader, optimizer, 
                                   device=device, print_freq=10)
        loss_seq.append(avg_loss)

    return encoder, decoder, loss_seq

## Training

In [19]:
# training(batch_size=64, hidden_dim=256, latent_dim=256, hidden_units=512, epochs=500)
encoder, decoder, loss_Seq = training(batch_size=64, hidden_dim=15, latent_dim=15, hidden_units=20, epochs=100)

Model Dimensions: input_dim=8, num_classes=10, macro_dim=4
--> Epoch 1/100
Batch 10/157 | Loss: 7.0052 | Macro: 2.3391, Energy: 1.9749, KLD: 0.3825, MC: 2.3087
Batch 20/157 | Loss: 6.7997 | Macro: 2.3970, Energy: 1.6957, KLD: 0.3883, MC: 2.3187
Batch 30/157 | Loss: 6.4991 | Macro: 2.3731, Energy: 1.4506, KLD: 0.3561, MC: 2.3193
Batch 40/157 | Loss: 6.1045 | Macro: 2.1359, Energy: 1.2634, KLD: 0.3982, MC: 2.3070
Batch 50/157 | Loss: 6.2859 | Macro: 2.2175, Energy: 1.4108, KLD: 0.3450, MC: 2.3126
Batch 60/157 | Loss: 6.3655 | Macro: 2.3051, Energy: 1.4225, KLD: 0.3145, MC: 2.3234
Batch 70/157 | Loss: 6.3272 | Macro: 2.3238, Energy: 1.3818, KLD: 0.3098, MC: 2.3119
Batch 80/157 | Loss: 6.3582 | Macro: 2.2886, Energy: 1.4406, KLD: 0.3140, MC: 2.3151
Batch 90/157 | Loss: 5.7860 | Macro: 2.0452, Energy: 1.1164, KLD: 0.3239, MC: 2.3004
Batch 100/157 | Loss: 5.9643 | Macro: 2.0717, Energy: 1.2571, KLD: 0.3164, MC: 2.3191
Batch 110/157 | Loss: 5.5331 | Macro: 2.0447, Energy: 0.8708, KLD: 0.3035,

In [20]:
loss_Seq

[6.007252659767297,
 5.126427705120888,
 4.9175153203830595,
 4.843828359227271,
 4.795502683918947,
 4.766226657636606,
 4.745477047695476,
 4.7359903299125135,
 4.721611308444078,
 4.675681867417256,
 4.55835644606572,
 4.423006226302712,
 4.379484371015221,
 4.378286674523809,
 4.351438165470293,
 4.34689928467866,
 4.359684666250922,
 4.331250224143836,
 4.3535873844365405,
 4.307073333460814,
 4.329466286738207,
 4.326593544832461,
 4.338094554888975,
 4.334859164657106,
 4.312631602499895,
 4.352790750515688,
 4.310657628782236,
 4.320601868781314,
 4.320387099199234,
 4.3132142094290185,
 4.3029625415802,
 4.330903921916986,
 4.313914543504168,
 4.284041246790795,
 4.282333753670857,
 4.288132037326789,
 4.30348576861582,
 4.269912531421443,
 4.245430976721891,
 4.2362565508314,
 4.214054818365984,
 4.189792548015618,
 4.160332155835097,
 4.126319434232773,
 4.128948115998773,
 4.1239130269190305,
 4.100152503153321,
 4.097928657653226,
 4.086794288295089,
 4.085154343562521,
 4