# Building a Multi-Output Model to Predict Macronutrient Profile

In this notebook we are going to improve on the current model to predict calores, and instead we will predict all 3 macronutrients. Here is how it will work:

- The user will input a recipe name.
- Then a list of suggested ingredients will appear, and the user can optionally edit this by either adding/subtracting ingredients or modifying quantities
    - This will be obtained for now by just running the cosine similarity with the recipe names in the data, and taking the ingredients from the top 5 or so results.
    - Quantities may be ignored if model can't handle numeric inputs, but LLMs can if going that route.
    - Being able to edit this could be a paid feature, i.e. changing quantities or adding/subtracting ingredients.
- The model will then take as input, the recipe name concatenated with the ingredients list. Perhaps with quantities as well if using an LLM.
    - Consider adding dietery type/preference here as well.
- Finally, a multi-output regression model will be used to predict the macronutrients, and we can take ratios by dividing each macro by the total calorie count. The aim here is that during training, the model will learn to optimize its parameters to minimize the error between the predicted and ground truth ratios derived from the normalized macronutrient values.

In [1]:
#keeping all imports at the top
import pandas as pd
import ast

for module_name in ['pandas',]:
    module = __import__(module_name)
    print(f"{module_name}: {module.__version__}")

pandas: 2.2.0


## EDA and Preprocessing

In [2]:
df = pd.read_csv('../recipes.csv')
df.columns

Index(['uri', 'label', 'image', 'source', 'url', 'shareAs', 'yield',
       'dietLabels', 'healthLabels', 'cautions', 'ingredientLines',
       'ingredients', 'calories', 'totalWeight', 'totalTime', 'cuisineType',
       'mealType', 'dishType', 'totalNutrients', 'totalDaily', 'digest',
       'tags'],
      dtype='object')

In [3]:
df['ingredientLines']

0        ['1 pound green beans, trimmed', '1 tablespoon...
1        ['1 1/2 lb green beans, stem ends trimmed', '1...
2        ['1 stick (8 tbsp.) unsalted, cultured butter'...
3        ['2 teaspoons walnut oil', '1 pound green bean...
4        ['1 pound green beans, trimmed', '2 teaspoons ...
                               ...                        
13267    ['* 2tablespoons olive oil', '* 1 large red be...
13268    ['Two 6-ounce cans white meat tuna packed in w...
13269    ['16 ounces low-sodium chunk light tuna, drain...
13270    ['1 can (3 ounces) tuna, drained', '1 tablespo...
13271    ['1 can (3 ounces) chunk light tuna in water, ...
Name: ingredientLines, Length: 13272, dtype: object

Note the ingredientLines column is all we need if we concatenate ingredients with the recipe name, but since we want to adjust the ingredients and quantity for the input, we will need to get the information from the ingredients column manually instead.

In [4]:
ingredients_col = df['ingredients'].apply(ast.literal_eval)

In [5]:
ingredients_col[0]

[{'text': '1 pound green beans, trimmed',
  'quantity': 1.0,
  'measure': 'pound',
  'food': 'green beans',
  'weight': 453.59237,
  'foodCategory': 'vegetables',
  'foodId': 'food_aceucvpau4a8v6atkx5eabxyoqdn',
  'image': 'https://www.edamam.com/food-img/891/89135f10639878a2360e6a33c9af3d91.jpg'},
 {'text': '1 tablespoon butter, (optional)',
  'quantity': 1.0,
  'measure': 'tablespoon',
  'food': 'butter',
  'weight': 14.2,
  'foodCategory': 'Dairy',
  'foodId': 'food_awz3iefajbk1fwahq9logahmgltj',
  'image': 'https://www.edamam.com/food-img/713/71397239b670d88c04faa8d05035cab4.jpg'},
 {'text': 'Coarse salt and ground pepper',
  'quantity': 0.0,
  'measure': None,
  'food': 'Coarse salt',
  'weight': 2.80675422,
  'foodCategory': 'Condiments and sauces',
  'foodId': 'food_a1vgrj1bs8rd1majvmd9ubz8ttkg',
  'image': 'https://www.edamam.com/food-img/694/6943ea510918c6025795e8dc6e6eaaeb.jpg'},
 {'text': 'Coarse salt and ground pepper',
  'quantity': 0.0,
  'measure': None,
  'food': 'groun

In [6]:
def get_ingredient_aspect(row, aspect):
    lst = []
    for j in range(len(row)):
        lst.append(row[j][aspect])
    return lst

food_ingredients = ingredients_col.apply(lambda row: get_ingredient_aspect(row, 'food')).rename('foodItem')
quantity_ingredients = ingredients_col.apply(lambda row: get_ingredient_aspect(row, 'quantity')).rename('quantity')
measure_ingredients = ingredients_col.apply(lambda row: get_ingredient_aspect(row, 'measure')).rename('measurementUnit') # will need this to understand quantity

The `healthLabel` is the dietary type/preference like vegan, pescaterian, etc. We will only select few options though as it is a multilabel column, and the user can only select 1 for now, from: ['Mediterranean', 'Vegetarian', 'Vegan', 'Red-Meat-Free', 'Paleo', 'Pescatarian']. When reducing the healthLabels column from multilabel to categorical, we need to define a priority order, and if none of these are there then the dish is balanced, so we will add that as an option. In the future, some analysis on this column should be done to improve the priority order, rather than relying on domain knowledge. Alternatively, and option to select multiple could be implemented instead.

In [7]:
health_labels = df['healthLabels'].apply(ast.literal_eval)

Let's take a look at all unique values of health labels. If we had more data it might be worth it to just keep all of these health labels. 

In [8]:
unique_health_labels = []
for lst in health_labels:
    for health in lst:
        unique_health_labels.append(health)

print(set(unique_health_labels))


{'Sesame-Free', 'Tree-Nut-Free', 'Soy-Free', 'Kidney-Friendly', 'Egg-Free', 'Mollusk-Free', 'Low Sugar', 'Alcohol-Cocktail', 'Vegan', 'Paleo', 'Gluten-Free', 'Peanut-Free', 'Crustacean-Free', 'Keto-Friendly', 'Dairy-Free', 'Shellfish-Free', 'Lupine-Free', 'Sugar-Conscious', 'Kosher', 'Immuno-Supportive', 'Wheat-Free', 'Sulfite-Free', 'Mustard-Free', 'Red-Meat-Free', 'Pork-Free', 'Vegetarian', 'Fish-Free', 'Low Potassium', 'DASH', 'FODMAP-Free', 'Celery-Free', 'Pescatarian', 'Mediterranean', 'Alcohol-Free', 'No oil added'}


In [9]:
priority_order = ['Vegan', 'Vegetarian', 'Pescatarian', 'Paleo', 'Red-Meat-Free', 'Mediterranean']

In [10]:
def replace_with_priority(labels):
    for label in priority_order:
        if label in labels:
            return label
    return 'Balanced'  # Handle case where no label matches priority_order, in which case the diet is balanced

# Apply function to the multilabels series
priority_health_labels = health_labels.apply(replace_with_priority)

### Target Variable

Now let's get the macros and calories can be calcualted from that.

In [11]:
df.columns

Index(['uri', 'label', 'image', 'source', 'url', 'shareAs', 'yield',
       'dietLabels', 'healthLabels', 'cautions', 'ingredientLines',
       'ingredients', 'calories', 'totalWeight', 'totalTime', 'cuisineType',
       'mealType', 'dishType', 'totalNutrients', 'totalDaily', 'digest',
       'tags'],
      dtype='object')

In [12]:
nutrients = df['totalNutrients'].apply(ast.literal_eval)

In [13]:
nutrients[0].keys()

dict_keys(['ENERC_KCAL', 'FAT', 'FASAT', 'FATRN', 'FAMS', 'FAPU', 'CHOCDF', 'CHOCDF.net', 'FIBTG', 'SUGAR', 'PROCNT', 'CHOLE', 'NA', 'CA', 'MG', 'K', 'FE', 'ZN', 'P', 'VITA_RAE', 'VITC', 'THIA', 'RIBF', 'NIA', 'VITB6A', 'FOLDFE', 'FOLFD', 'FOLAC', 'VITB12', 'VITD', 'TOCPHA', 'VITK1', 'WATER'])

In [14]:
for nutrient in nutrients[1].keys():
    print(nutrients[1][nutrient])

    #just want to look at these more closely, but we don't need to look at the micronutrients
    if nutrients[0][nutrient]['label'] == 'Cholesterol':
        break

{'label': 'Energy', 'quantity': 331.96545205, 'unit': 'kcal'}
{'label': 'Fat', 'quantity': 15.008954821, 'unit': 'g'}
{'label': 'Saturated', 'quantity': 2.2059442775, 'unit': 'g'}
{'label': 'Trans', 'quantity': 0.0, 'unit': 'g'}
{'label': 'Monounsaturated', 'quantity': 9.9235888555, 'unit': 'g'}
{'label': 'Polyunsaturated', 'quantity': 2.19255406715, 'unit': 'g'}
{'label': 'Carbs', 'quantity': 47.8064322835, 'unit': 'g'}
{'label': 'Carbohydrates (net)', 'quantity': 29.2874412985, 'unit': 'g'}
{'label': 'Fiber', 'quantity': 18.518990985000002, 'unit': 'g'}
{'label': 'Sugars', 'quantity': 22.359966893000003, 'unit': 'g'}
{'label': 'Protein', 'quantity': 12.551760556500001, 'unit': 'g'}
{'label': 'Cholesterol', 'quantity': 0.0, 'unit': 'mg'}


Calories is fat x 9 + protein x 4 + carbs x 4 + fiber x 2

In [15]:
fat = 2.2059442775 + 9.9235888555 + 2.19255406715 #adding up all the different types of fat doesn't result in the total fat for some reason
(15.008954821*9 + 29.2874412985*4 + 12.551760556500001*4) + 18.518990985000002*2

339.47538277900003

In [16]:
#net carbs + fiber
18.518990985000002 + 29.2874412985

47.8064322835

For some reason when you add this up the calories isn't an exact match with the recorded calories in the calories column (which is the same as the Energy label here). And carbs is just net carbs + fiber, and since fiber is insoluble we won't be counting/predicting it. Also the discrepency doesn't seem to be just from counting or not counting Fiber. Thus, we will just use the calculation for now instead, counting only net carbs, protein, and fat.

In [17]:
def get_macros(nutrients_row):
    macros_dct = {}

    for nutrient in nutrients_row.keys():
        if nutrients_row[nutrient]['label'] == 'Fat':
            macros_dct['fat'] = nutrients_row[nutrient]['quantity']
        elif nutrients_row[nutrient]['label'] == 'Protein':
            macros_dct['protein'] = nutrients_row[nutrient]['quantity']
        elif nutrients_row[nutrient]['label'] == 'Carbohydrates (net)':
            macros_dct['carbs'] = nutrients_row[nutrient]['quantity']

    return macros_dct

In [18]:
macros_df = pd.DataFrame(list(nutrients.apply(lambda row: get_macros(row))))

#macros_df['calories'] = 9*macros_df['fat'] + 4*macros_df['protein'] + 4*macros_df['carbs']
macros_df.head(3)

Unnamed: 0,fat,carbs,protein
0,12.559853,19.920021,8.567392
1,15.008955,29.287441,12.551761
2,93.626455,29.120751,13.416711


In [19]:
recipe_name = df['label'].rename('recipeName')

relevant_cols_df = pd.concat([priority_health_labels, recipe_name, food_ingredients, quantity_ingredients, measure_ingredients, macros_df], axis=1)
relevant_cols_df.head()

Unnamed: 0,healthLabels,recipeName,foodItem,quantity,measurementUnit,fat,carbs,protein
0,Vegetarian,Green Beans,"[green beans, butter, Coarse salt, ground pepper]","[1.0, 1.0, 0.0, 0.0]","[pound, tablespoon, None, None]",12.559853,19.920021,8.567392
1,Vegan,Sauteed Green Beans,"[green beans, olive oil, green-bean]","[1.5, 1.0, 1.0]","[pound, tablespoon, <unit>]",15.008955,29.287441,12.551761
2,Vegetarian,Caramelized Green Beans,"[butter, green beans]","[8.0, 1.5]","[tablespoon, pound]",93.626455,29.120751,13.416711
3,Vegan,Sautéed Fresh Green Beans,"[walnut oil, green beans]","[2.0, 1.0]","[teaspoon, pound]",9.997903,19.368394,8.30074
4,Vegan,Fancy Green Beans,"[green beans, vegan margarine, sesame seeds, S...","[1.0, 2.0, 2.0, 0.0, 0.0]","[pound, teaspoon, teaspoon, None, None]",11.359097,20.608817,9.509045


In [20]:
ex_row = relevant_cols_df.iloc[1]

for i in range(len(ex_row['foodItem'])):
    quantity = ex_row['quantity'][i]
    unit = ex_row['measurementUnit'][i]
    food = ex_row['foodItem'][i]

    if quantity > 0:
        print(quantity, unit, food)
    else:
        print(food)


1.5 pound green beans
1.0 tablespoon olive oil
1.0 <unit> green-bean


In [21]:
idxs = []
for i, units in enumerate(relevant_cols_df['measurementUnit']):
    for u in units:
        if u == '<unit>':
            idxs.append(i)
            break

In [22]:
relevant_cols_df.iloc[idxs,:]

Unnamed: 0,healthLabels,recipeName,foodItem,quantity,measurementUnit,fat,carbs,protein
1,Vegan,Sauteed Green Beans,"[green beans, olive oil, green-bean]","[1.5, 1.0, 1.0]","[pound, tablespoon, <unit>]",15.008955,29.287441,12.551761
10,Vegan,Frenched Green Beans,"[green beans, salt, black pepper, Sherry vineg...","[2.0, 0.75, 0.25, 2.0, 1.5, 1.0]","[pound, teaspoon, teaspoon, teaspoon, tablespo...",22.281541,39.256213,16.777531
28,Vegan,Lemon Green Beans recipes,"[lemon, salt, green beans, olive oil]","[1.0, 0.0, 1.0, 1.5]","[<unit>, None, pound, tablespoon]",21.499903,24.845194,9.224740
29,Balanced,Southern Green Beans,"[green beans, bacon, onion, red wine vinegar, ...","[1.25, 2.0, 1.0, 2.0, 2.0]","[pound, slice, <unit>, tablespoon, teaspoon]",22.835379,38.038953,19.103845
32,Vegetarian,Crispy Green Beans,"[flour, egg, buttermilk, panko breadcrumbs, ca...","[0.6666666666666666, 1.0, 0.6666666666666666, ...","[cup, <unit>, cup, cup, tablespoon, None, None...",95.584178,170.094762,39.403952
...,...,...,...,...,...,...,...,...
13265,Pescatarian,Tuna melt pizza baguettes,"[baguette, red pepper, green pepper, sweetcorn...","[2.0, 1.0, 1.0, 198.0, 225.0, 100.0, 1.0]","[<unit>, <unit>, <unit>, gram, gram, gram, tab...",52.192541,343.546831,151.500834
13267,Pescatarian,Red Pepper Farro With Tuna,"[olive oil, red bell pepper, farro, Salt, Pepp...","[2.0, 1.0, 1.0, 0.0, 0.0, 10.0]","[tablespoon, <unit>, cup, None, None, ounce]",34.505177,110.728103,82.228005
13269,Pescatarian,Savory Tuna Sandwich,"[light tuna, black olives, capers, white vineg...","[16.0, 6.0, 2.0, 0.25, 1.0, 8.0, 2.0, 4.0]","[ounce, <unit>, teaspoon, cup, tablespoon, sli...",30.103793,164.929173,127.987906
13270,Pescatarian,Tiny Tuna Melts recipes,"[tuna, mayonnaise, lemon juice, coarse salt, g...","[3.0, 1.0, 2.0, 0.0, 0.0, 16.0, 0.25]","[ounce, tablespoon, teaspoon, None, None, <uni...",31.519834,24.483712,25.609022


In [23]:
df['ingredientLines'][13265]

"['2 part-baked baguette', '1 red pepper, diced', '1 green pepper, diced', '198g can sweetcorn, drained', '225g jar tuna', '100g cheddar, grated', '1 tbsp tomato purée']"

In [24]:
relevant_cols_df.iloc[13265]['measurementUnit']

['<unit>', '<unit>', '<unit>', 'gram', 'gram', 'gram', 'tablespoon']

These <unit> values are essentially missing data, even though it is in the ingredientLines column. We could cross reference to fill them in, but since we are going to concatenate all these foodItem, quantity, and measurementUnit columns to get the ingredientLines anyway, let's just use it directly. Then, to allow users to adjust the quantity, we will have to go through the ingredients and pull out the quantity instead, and put it in it's own column.

In [25]:
ingredientLines = df['ingredientLines'].apply(ast.literal_eval)

In [26]:
relevant_df = pd.concat([priority_health_labels, recipe_name, ingredientLines, macros_df], axis=1)
relevant_df.head()

Unnamed: 0,healthLabels,recipeName,ingredientLines,fat,carbs,protein
0,Vegetarian,Green Beans,"[1 pound green beans, trimmed, 1 tablespoon bu...",12.559853,19.920021,8.567392
1,Vegan,Sauteed Green Beans,"[1 1/2 lb green beans, stem ends trimmed, 1 ta...",15.008955,29.287441,12.551761
2,Vegetarian,Caramelized Green Beans,"[1 stick (8 tbsp.) unsalted, cultured butter, ...",93.626455,29.120751,13.416711
3,Vegan,Sautéed Fresh Green Beans,"[2 teaspoons walnut oil, 1 pound green beans, ...",9.997903,19.368394,8.30074
4,Vegan,Fancy Green Beans,"[1 pound green beans, trimmed, 2 teaspoons veg...",11.359097,20.608817,9.509045


In [27]:
#we are going to join/concatenate the list of ingredients in the ingredientLines column with a comma, 
#so let's get rid of the existing commas in the individual ingredients now
def comma_to_bracket(ingredient_list):
    """
    Input: ingredient_list (str): a list of strings, like ingredients of a recipe.
    Output: recipe (str): commas in individual elements from input string are removed, then they are all joined together with a comma, so commas seperate each ingredient now.
    """
    processed_ingredients = []
    for ingredient in ingredient_list:
        parts = ingredient.split(',', 1)  # Split at the first comma
        if len(parts) > 1:  # Check if there is a comma
            # Check if the part after the comma is already in brackets
            if '(' not in parts[1] and ')' not in parts[1]:
                parts[1] = f'({parts[1].strip()})'  # Put it in brackets
        processed_ingredients.append(' '.join(parts))

    # Join the processed strings with a comma and space now that we removed the commas in the individual strings
    recipe = ', '.join(processed_ingredients)

    return recipe

In [28]:
relevant_df['ingredientLines'] = relevant_df['ingredientLines'].apply(comma_to_bracket)

In [29]:
relevant_df.head(3)

Unnamed: 0,healthLabels,recipeName,ingredientLines,fat,carbs,protein
0,Vegetarian,Green Beans,"1 pound green beans (trimmed), 1 tablespoon bu...",12.559853,19.920021,8.567392
1,Vegan,Sauteed Green Beans,"1 1/2 lb green beans (stem ends trimmed), 1 ta...",15.008955,29.287441,12.551761
2,Vegetarian,Caramelized Green Beans,"1 stick (8 tbsp.) unsalted (cultured butter), ...",93.626455,29.120751,13.416711


## Modeling

Now we can build a model to predict the macros. But first we need to preprocess the input by concatenating the strings into just one.

In [68]:
import torch
from datasets import Dataset
from transformers import BertModel, BertConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split

import torch, random, datasets
from transformers.file_utils import is_tf_available, is_torch_available
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

from sklearn.metrics import mean_squared_error, r2_score, mean_squared_error, mean_absolute_error
import numpy as np

In [31]:
X = relevant_df['healthLabels'] + ' ' + relevant_df['recipeName'] + ' ' + relevant_df['ingredientLines']
X = X.rename('fullRecipeInput')
y = relevant_df[['fat', 'carbs', 'protein']]

dataset = Dataset.from_pandas(pd.concat([X, y], axis=1))

In [32]:
dataset.train_test_split(test_size=0.2)

DatasetDict({
    train: Dataset({
        features: ['fullRecipeInput', 'fat', 'carbs', 'protein'],
        num_rows: 10617
    })
    test: Dataset({
        features: ['fullRecipeInput', 'fat', 'carbs', 'protein'],
        num_rows: 2655
    })
})

In [70]:
X_train, X_test, y_train, y_test = train_test_split(X.tolist(), y['fat'], test_size=0.25)

In [71]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

max_length = 100

# Encode the text
train_encodings = tokenizer(X_train, truncation=True, padding=True, max_length=max_length)
valid_encodings = tokenizer(X_test, truncation=True, padding=True, max_length=max_length)

In [73]:
class MakeTorchData(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        label = {k: torch.tensor(v[idx], dtype=torch.float32) for k, v in self.labels.items()}
        item.update(label)
        return item

    def __len__(self):
        return len(self.labels)

# convert our tokenized data into a torch Dataset
train_dataset = MakeTorchData(train_encodings, y_train)
valid_dataset = MakeTorchData(valid_encodings, y_test)

In [77]:
from transformers import BertForSequenceClassification, BertConfig
import torch

# Load pre-trained BERT model configuration
config = BertConfig.from_pretrained('bert-base-uncased')
# Modify the number of labels to 3 (for the 3 regression targets)
config.num_labels = 3

# Instantiate the model with modified configuration
model = BertForSequenceClassification(config).to("cuda")

In [79]:
class CustomLoss(nn.Module):
    def __init__(self, weights):
        super(CustomLoss, self).__init__()
        self.weights = weights

    def forward(self, predictions, targets):
        loss = 0.0
        for i in range(len(predictions)):
            loss += self.weights[i] * torch.mean((predictions[i] - targets[:, i])**2)
        return loss

loss_fn = CustomLoss(weights=[1.0, 1.0, 1.0])

# Define TrainingArguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=20,
    weight_decay=0.01,
    learning_rate=2e-5,
    logging_dir='./logs',
    save_total_limit=10,
    load_best_model_at_end=True,
    metric_for_best_model='rmse',
    evaluation_strategy="epoch",
    save_strategy="epoch",
)

# Create Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    compute_metrics=loss_fn,
)

# Train the model
trainer.train()

# Evaluate the model
results = trainer.evaluate()
print(results)

  0%|          | 0/312 [00:00<?, ?it/s]

KeyError: 7450

In [74]:

model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels = 3).to("cuda")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [61]:
def compute_metrics_for_regression(eval_pred):
    logits, labels = eval_pred
    labels = labels.reshape(-1, 1)

    mse = mean_squared_error(labels, logits)
    rmse = mean_squared_error(labels, logits, squared=False)
    mae = mean_absolute_error(labels, logits)
    r2 = r2_score(labels, logits)
    smape = 1/len(labels) * np.sum(2 * np.abs(logits-labels) / (np.abs(labels) + np.abs(logits))*100)

    return {"mse": mse, "rmse": rmse, "mae": mae, "r2": r2, "smape": smape}

In [75]:
training_args = TrainingArguments(
    output_dir ='./results',          
    num_train_epochs = 1,     
    per_device_train_batch_size = 32,   
    per_device_eval_batch_size = 20,   
    weight_decay = 0.01,               
    learning_rate = 2e-5,
    logging_dir = './logs',            
    save_total_limit = 10,
    load_best_model_at_end = True,     
    metric_for_best_model = 'rmse',    
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
) 

# Call the Trainer
trainer = Trainer(
    model = model,                         
    args = training_args,                  
    train_dataset = train_dataset,         
    eval_dataset = valid_dataset,          
    compute_metrics = compute_metrics_for_regression,     
)

# Train the model
trainer.train()

# Call the summary
trainer.evaluate()

  0%|          | 0/312 [00:00<?, ?it/s]

KeyError: 7450

In [33]:
class MultiOutputRegressionModel(nn.Module):
    def __init__(self, num_outputs):
        super(MultiOutputRegressionModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(768, 256)
        self.fc2 = nn.Linear(256, num_outputs)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        x = torch.relu(self.fc1(pooled_output))
        x = self.fc2(x)
        return x
    
    

In [34]:
num_outputs = 3  # Number of output dimensions
model = MultiOutputRegressionModel(num_outputs=num_outputs)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=2e-5)