In [109]:
import torch
from torch_geometric.data import Data
import torch.nn.functional as F
import pandas as pd
from collections import defaultdict
from itertools import combinations, permutations
from torch_geometric.nn import GCNConv, GATConv, BatchNorm

In [110]:
df = pd.read_csv("../utils/data/train_labels_ingr_id.csv")
df = pd.DataFrame.drop(df, labels="img_indx", axis=1)
df

Unnamed: 0,id,brown rice,quinoa,olive oil,carrot,watermelon,raspberries,berries,cantaloupe,pineapple,...,pepperoni,orange with peel,mozzarella cheese,baby carrots,banana with peel,wheat bread,chilaquiles,pasta salad,balsamic vinegar,toast
0,dish_1562699612,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,dish_1558722322,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,dish_1561406762,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,dish_1562007739,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,dish_1562689548,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2750,dish_1559933793,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2751,dish_1563389153,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2752,dish_1558725253,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2753,dish_1561749348,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [111]:
num_rows, num_cols = df.shape
print(f"Number of rows: {num_rows}, Number of columns: {num_cols - 1}")
zero_columns = df.columns[(df == 0).all()]
print(zero_columns.tolist())

Number of rows: 2755, Number of columns: 199
['bread', 'tuna salad']


In [112]:
co_occurrence = defaultdict(int) # dict of tuples of ingredients and their co-occurrence, key: (ing1, ing2), value: count
ing_cols = df.columns.drop("id") # length should be 199
for _, r in df.iterrows():
    mask = r[ing_cols] == 1 # mask will be true for all the ingredients present in the matrix, false otherwise
    ing_present = ing_cols[mask].tolist() # ing_present is a list of ingredient present in the above matrix
    for first, second in permutations(ing_present, 2):
        food_pair = tuple([first, second])
        co_occurrence[food_pair] += 1

In [113]:
all_ingredients = set()
for ing1, ing2 in co_occurrence.keys():
    all_ingredients.add(ing1)
    all_ingredients.add(ing2)

all_ingredients = list(all_ingredients) # list of all the ingredients
print(f"Number of ingredients: {len(all_ingredients)}")

ingredient_to_index = {ing: i for i, ing in enumerate(all_ingredients)}
print(ingredient_to_index)
print("co-oc", co_occurrence)
edge_list = []
for (ing1, ing2), weight in co_occurrence.items():
    if weight >= 8:
        edge_list.append((ingredient_to_index[ing1], ingredient_to_index[ing2], weight))
print(edge_list) # list of tuples of ingredients and their co-occurrence
edges_df = pd.DataFrame(edge_list, columns=["source", "target", "weight"]) # bidirectional edges
edges_df

Number of ingredients: 193
{'cantaloupe': 0, 'mozzarella cheese': 1, 'flour': 2, 'white wine': 3, 'tomatoes': 4, 'chicken': 5, 'watermelon': 6, 'asparagus': 7, 'pesto': 8, 'waffles': 9, 'cauliflower': 10, 'shallots': 11, 'sun dried tomatoes': 12, 'eggs': 13, 'almonds': 14, 'bok choy': 15, 'blue cheese': 16, 'brown rice': 17, 'ketchup': 18, 'chard': 19, 'kale': 20, 'blackberries': 21, 'pepperoni': 22, 'lentils': 23, 'chickpeas': 24, 'cereal': 25, 'cod': 26, 'cheese': 27, 'chicken breast': 28, 'snow peas': 29, 'lemon juice': 30, 'rosemary': 31, 'basil': 32, 'rice noodles': 33, 'walnuts': 34, 'steak': 35, 'chilaquiles': 36, 'avocado': 37, 'fried rice': 38, 'brussels sprouts': 39, 'pears': 40, 'raspberries': 41, 'granola': 42, 'carrot': 43, 'chia seeds': 44, 'sweet potato': 45, 'chicken thighs': 46, 'black beans': 47, 'soy sauce': 48, 'cookies': 49, 'hominy': 50, 'bulgur': 51, 'mayonnaise': 52, 'chayote squash': 53, 'grapes': 54, 'country rice': 55, 'cherry tomatoes': 56, 'sour cream': 57,

Unnamed: 0,source,target,weight
0,17,98,11
1,17,115,187
2,98,17,11
3,98,115,119
4,115,17,187
...,...,...,...
5859,178,134,9
5860,139,114,8
5861,114,139,8
5862,154,102,8


In [114]:
weight_min = edges_df['weight'].min()
weight_max = edges_df['weight'].max()
edges_df['weight_normalized'] = (edges_df['weight'] - weight_min) / (weight_max - weight_min)

print(edges_df)


      source  target  weight  weight_normalized
0         17      98      11           0.002613
1         17     115     187           0.155923
2         98      17      11           0.002613
3         98     115     119           0.096690
4        115      17     187           0.155923
...      ...     ...     ...                ...
5859     178     134       9           0.000871
5860     139     114       8           0.000000
5861     114     139       8           0.000000
5862     154     102       8           0.000000
5863     102     154       8           0.000000

[5864 rows x 4 columns]


In [115]:
edge_index = torch.tensor([edges_df["source"].values, edges_df["target"].values], dtype=torch.long)
edge_weight = torch.tensor(edges_df["weight_normalized"].values, dtype=torch.float)
edge_index.shape, len(edge_weight)

(torch.Size([2, 5864]), 5864)

In [116]:
ing_num_total = len(all_ingredients)
x = torch.eye(ing_num_total)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_weight)
data

Data(x=[193, 193], edge_index=[2, 5864], edge_attr=[5864])

In [117]:
class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight)
        return x

class GAT(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, edge_dim):
        super(GAT, self).__init__()
        self.gat1 = GATConv(input_dim, hidden_dim, edge_dim=edge_dim, concat=True)
        self.dropout = torch.nn.Dropout(0.3)
        self.gat2 = GATConv(hidden_dim, output_dim, edge_dim=edge_dim, concat=False)
        self.batch_norm = BatchNorm(hidden_dim)
        self.batch_norm2 = BatchNorm(output_dim)
        self.residual = input_dim == output_dim

    def forward(self, x, edge_index, edge_weight=None):
        x1 = self.gat1(x, edge_index, edge_weight)
        x1 = self.batch_norm(x1)
        x1 = F.relu(x1)
        x1 = self.dropout(x1)
        x2 = self.gat2(x1, edge_index, edge_weight)
        x2 = self.batch_norm2(x2)
        if self.residual:
            x2 = x2 + x
        return x2

def graph_embedding_training(model, graph, epochs=100, lr=0.0005):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = torch.nn.MSELoss()

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        out = model(graph.x, graph.edge_index, graph.edge_attr)

        edge_pred = torch.sigmoid((out[graph.edge_index[0]] * out[graph.edge_index[1]]).sum(dim=1))

        # Compute loss
        loss = loss_fn(edge_pred, graph.edge_attr)
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")
    return model

In [118]:
edge_dim = data.edge_attr.size(1) if data.edge_attr is not None and data.edge_attr.dim() > 1 else 1
model = GAT(ing_num_total, 128, 128, edge_dim)
model = graph_embedding_training(model, data, epochs=75, lr=0.01)

Epoch 1, Loss: 0.9397634863853455
Epoch 2, Loss: 0.6781497001647949
Epoch 3, Loss: 0.5462038516998291
Epoch 4, Loss: 0.5927304029464722
Epoch 5, Loss: 0.5665622353553772
Epoch 6, Loss: 0.5500088930130005
Epoch 7, Loss: 0.5644076466560364
Epoch 8, Loss: 0.5493362545967102
Epoch 9, Loss: 0.5495089292526245
Epoch 10, Loss: 0.5585336089134216
Epoch 11, Loss: 0.5578573942184448
Epoch 12, Loss: 0.5516193509101868
Epoch 13, Loss: 0.6358314752578735
Epoch 14, Loss: 0.6194373369216919
Epoch 15, Loss: 0.5415550470352173
Epoch 16, Loss: 0.5511419773101807
Epoch 17, Loss: 0.5536244511604309
Epoch 18, Loss: 0.5490732192993164
Epoch 19, Loss: 0.5530809760093689
Epoch 20, Loss: 0.5476882457733154
Epoch 21, Loss: 0.5511658191680908
Epoch 22, Loss: 0.5552700161933899
Epoch 23, Loss: 0.5562862157821655
Epoch 24, Loss: 0.582330584526062
Epoch 25, Loss: 0.5829263925552368
Epoch 26, Loss: 0.5782021880149841
Epoch 27, Loss: 0.5742967128753662
Epoch 28, Loss: 0.6128886342048645
Epoch 29, Loss: 0.555380225181

In [119]:
model.eval()
with torch.no_grad():
    embeddings = model(data.x, data.edge_index)

print(f"Embeddings shape: {embeddings.shape}")

Embeddings shape: torch.Size([193, 128])
