In [19]:
import torch
from torch_geometric.data import Data
from itertools import combinations

def create_ingredient_graph(dishes, num_ingredients):
    """
    Create a co-occurrence graph for ingredients.

    Parameters:
    - dishes: List of dishes, each dish is a list of ingredient indices.
    - num_ingredients: Total number of ingredients.

    Returns:
    - graph: PyTorch Geometric Data object representing the graph.
    """
    # Initialize adjacency matrix
    adjacency_matrix = torch.zeros((num_ingredients, num_ingredients))

    # Fill adjacency matrix based on co-occurrences
    for dish in dishes:
        for i, j in combinations(dish, 2):  # combinations avoids self-loops and duplicates
            adjacency_matrix[i, j] += 1
            adjacency_matrix[j, i] += 1


    # Extract edges and edge weights
    edge_index = torch.nonzero(adjacency_matrix, as_tuple=False).t()  # (2, num_edges)
    edge_weight = adjacency_matrix[edge_index[0], edge_index[1]]     # (num_edges,)

    # Create node features (optional: identity matrix for simple embeddings)
    node_features = torch.eye(num_ingredients)

    # Create the PyTorch Geometric graph object
    graph = Data(x=node_features, edge_index=edge_index, edge_attr=edge_weight)

    return graph

train_ingr_id = '../data/train_labels_ingr_id.csv'
val_ingr_id = '../data/val_labels_ingr_id.csv'
# test_ingr_id = './test_labels_ingr_id.csv'
import pandas as pd

train_df = pd.read_csv(train_ingr_id)
val_df = pd.read_csv(val_ingr_id)
train_df = pd.concat([train_df, val_df], ignore_index=True)
# val_df = pd.read_csv(val_ingr_id)
# test_df = pd.read_csv(test_ingr_id)

# combine the train, validation, and test DataFrames
# df = pd.concat([train_df, val_df, test_df], ignore_index=True)
df = train_df
print(len(df))

num_ingredients = len(df.columns) - 2  # Subtract 2 for 'dish_id' and 'image_path' columns

# df structure
# dish_id, ingredient 1, ingredient 2, ingredient 3, ingredient 4, ingredient 5, ingredient 6, ingredient 7, ingredient 8, ingredient 9, ingredient 10, ...
# x, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0
ingr_list = []
# for each dish, add the index of the ingredients where the value is 1 to the list
# e.g. dish_1 = [1, 3], dish_2 = [1, 4], dish_3 = [2, 3], ...
# ingr_list = [[1, 3], [1, 4], [2, 3], ...]
# Iterate over each row of the DataFrame
for _, row in df.iterrows():
    # We assume that the first column is 'dish_id' and the remaining columns represent ingredients.
    # Extract the values starting from the second column onward (i.e., ingredient columns)
    # last column is the image_path, so we exclude it
    ingredient_values = row.iloc[1:-1]
    
    # Find the indices (0-based) of the ingredients where the value is 1
    selected_indices = [i for i, val in enumerate(ingredient_values) if val == 1]
    
    # Append the list of selected ingredient indices for this dish
    ingr_list.append(selected_indices)

print(df.iloc[0])
# print the second row of the DataFrame
print(ingr_list[0])

2958
id                       dish_1562699612
brown rice                           1.0
quinoa                               1.0
olive oil                            1.0
carrot                               0.0
                            ...         
chilaquiles                          0.0
pasta salad                          0.0
balsamic vinegar                     0.0
toast                                0.0
img_indx            dish_1562699612.jpeg
Name: 0, Length: 201, dtype: object
[0, 1, 2]


In [20]:
import numpy as np
adj_matrix = np.zeros((num_ingredients, num_ingredients), dtype=np.int32)

for dish in ingr_list:
    for i, j in combinations(dish, 2):
        adj_matrix[i, j] += 1
        adj_matrix[j, i] += 1

adj_matrix

array([[  0,  12, 201, ...,   0,   0,   0],
       [ 12,   0, 124, ...,   0,   0,   0],
       [201, 124,   0, ...,   1,   1,   0],
       ...,
       [  0,   0,   1, ...,   0,   0,   0],
       [  0,   0,   1, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0]], dtype=int32)

In [21]:
# print out the top 10 paired ingredients in bert with the highest cosine similarity
# bert
import torch

# load the ingredient embeddings
ingredient_embeddings = torch.load("./ingredient_embeddings_gat_v2.pt")
bert_embeddings = torch.load("./ingredient_embeddings_bert.pt")

# inspect the embeddings
print(ingredient_embeddings)
print(bert_embeddings)

import pandas as pd
ing_id = "../data/test_labels_ingr_id.csv"

df = pd.read_csv(ing_id)
ingr_name = df.columns[1:-1].to_list()

cos = torch.nn.CosineSimilarity(dim=0)

ingr_name_dict = {}
ingr_id_dict = {}

for i in range(10):
    for j in range(i+1, len(bert_embeddings)):
        sim = cos(bert_embeddings[i], bert_embeddings[j])
        ingr_name_dict[(ingr_name[i], ingr_name[j])] = sim.item()
        ingr_id_dict[(i, j)] = sim.item()
        
sorted_sim_dict = sorted(ingr_name_dict.items(), key=lambda x: x[1], reverse=True)
sorted_id_dict = sorted(ingr_id_dict.items(), key=lambda x: x[1], reverse=True)

top_100 = sorted_id_dict[:100]
for i in range(100):
    print(sorted_sim_dict[i])

tensor([[-3.5067e-02,  4.0401e-02,  6.2585e-02,  ...,  8.7442e-02,
         -3.4544e-02,  2.6393e-02],
        [ 2.8923e-03,  3.5581e-02,  4.6746e-02,  ...,  6.9635e-02,
         -3.1150e-02, -2.4894e-03],
        [-3.8004e-02,  5.1686e-02,  1.3883e-01,  ...,  1.2125e-01,
         -1.7026e-02,  8.5565e-03],
        ...,
        [ 2.5114e+00, -5.5451e-02, -4.2354e-01,  ...,  3.6905e-02,
          1.6226e+00,  7.9568e-01],
        [-2.0718e-02, -1.0674e+00,  1.5850e+00,  ...,  1.0442e+00,
          1.6812e+00,  8.0268e-01],
        [ 2.1651e-01,  4.2835e-01, -1.3640e+00,  ...,  6.5846e-02,
          1.0917e+00, -3.6186e-01]])
tensor([[-0.7719,  0.4169, -0.7402,  ...,  0.2603,  0.3389,  0.4507],
        [-0.3830, -0.0660, -0.2323,  ..., -0.1751,  0.4549,  0.2395],
        [-0.4955,  0.4092, -1.1320,  ..., -0.4718, -0.1319,  0.4905],
        ...,
        [-0.4084, -0.0712, -0.4128,  ..., -0.2269, -0.0887,  0.3565],
        [-0.8203,  0.4091, -0.4271,  ..., -0.3878,  0.0322,  0.7871],
     

  ingredient_embeddings = torch.load("./ingredient_embeddings_gat_v2.pt")
  bert_embeddings = torch.load("./ingredient_embeddings_bert.pt")


In [22]:
# for the top 100 paired ingredients, print out the number of co-occurrences
never_appear = 0

for pair in top_100:
    if adj_matrix[pair[0][0], pair[0][1]] == 0:
        never_appear += 1
        print(ingr_name[pair[0][0]], ingr_name[pair[0][1]], pair[1], top_100.index(pair))

print(never_appear)

quinoa jicama 0.9840413331985474 0
berries corn 0.9838709831237793 1
berries wine 0.9818021655082703 3
berries sugar 0.9809611439704895 5
berries tuna 0.979519784450531 7
berries squash 0.9789713025093079 9
berries chicken 0.977954089641571 11
berries butter 0.9775931239128113 14
berries syrup 0.975771963596344 17
quinoa pesto 0.9757385849952698 18
berries chili 0.9755913615226746 19
berries cookies 0.9731311202049255 23
carrot milk 0.9730768799781799 24
berries basil 0.9722170233726501 26
berries apple 0.9709581136703491 30
berries rosemary 0.9709539413452148 31
berries white wine 0.970842182636261 32
berries brown sugar 0.9703260660171509 33
berries hominy 0.9684455394744873 35
berries ginger 0.9684008955955505 36
carrot cookies 0.9680901169776917 38
berries mushroom 0.9663424491882324 47
berries steak 0.9659769535064697 49
berries tofu 0.9654885530471802 51
berries sandwiches 0.9639341235160828 56
berries orange with peel 0.9638068675994873 57
berries pesto 0.9634993076324463 59
ber

In [23]:
ingr_name_dict = {}
ingr_id_dict = {}

for i in range(10):
    for j in range(i+1, len(ingredient_embeddings)):
        sim = cos(ingredient_embeddings[i], ingredient_embeddings[j])
        ingr_name_dict[(ingr_name[i], ingr_name[j])] = sim.item()
        ingr_id_dict[(i, j)] = sim.item()
        
sorted_sim_dict = sorted(ingr_name_dict.items(), key=lambda x: x[1], reverse=True)
sorted_id_dict = sorted(ingr_id_dict.items(), key=lambda x: x[1], reverse=True)

top_100 = sorted_id_dict[:100]

# for the top 100 paired ingredients, print out the number of co-occurrences
never_appear = 0

for pair in top_100:
    if adj_matrix[pair[0][0], pair[0][1]] < 10:
        never_appear += 1

print(never_appear)

5
