In [11]:
from gensim.models import Word2Vec
from pathlib import Path
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
from transformers import BertTokenizer, AutoModelForMaskedLM

import json
import torch

#### Load Pretrained Model

In [12]:
model_name = 'word2vec' # select between [bert, splade, word2vec]

bert_model_path = 'path/to/bert_model_folder'
splade_model_path = 'path/to/splade_model_folder'
word2vec_model_path = 'path/to/word2vec_model.bin'

# initiate model
if model_name == 'bert':
    model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path=bert_model_path)
elif model_name == 'splade':
    model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path=splade_model_path)
elif model_name == 'word2vec':
    model = Word2Vec.load(word2vec_model_path)

#### Generate Ingredient Embeddings

###### Helper Functions for BERT and SPLADE

In [13]:
def generate_dense_vector(text, tokenizer, model):
    tokens = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**tokens, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]
        dense_vector = last_hidden_state[:, 1, :].squeeze()
    
    return dense_vector.tolist()

def generate_sparse_vector(text, tokenizer, model):
    tokens = tokenizer(text, return_tensors='pt')
    output = model(**tokens)
    sparse_vector = torch.max(
        torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1), 
        dim=1,
    )[0].squeeze()
    
    return sparse_vector.tolist()

###### Generate Embeddings

In [14]:
GENERATING_EMBEDDING_MESSAGE = f"generating ingredient embeddings using '{model_name}'"

vocab_file = 'path/to/bert-base-cased-vocab.txt' # bert-base-cased-vocab file path
used_ingredients_file = 'path/to/all_ingredients.json' # cleaned ingredients file path

with Path(used_ingredients_file).open() as f:
    used_ingredients = json.load(f)

ingredient_names = []
ingredient_embeddings = []

if model_name == 'bert' or model_name == 'splade':
    tokenizer = BertTokenizer(
        vocab_file=vocab_file, 
        do_lower_case=False,
        max_len=128,
        never_split=used_ingredients
    )

    for ingredient in tqdm(used_ingredients, desc=GENERATING_EMBEDDING_MESSAGE):
        if model_name == 'bert':
            embedding = generate_dense_vector(text=ingredient, tokenizer=tokenizer, model=model)
        elif model_name == 'splade':
            embedding = generate_sparse_vector(text=ingredient, tokenizer=tokenizer, model=model)
        
        ingredient_names.append(ingredient)
        ingredient_embeddings.append(embedding)

elif model_name == 'word2vec':
    for ingredient in tqdm(used_ingredients, desc=GENERATING_EMBEDDING_MESSAGE):
        try:
            embedding = model.wv[ingredient]
            ingredient_names.append(ingredient)
            ingredient_embeddings.append(embedding)
        except:
            pass
    
preview_top_n = 5
for i in range(preview_top_n):
    print(f'{ingredient_names[i]}: {ingredient_embeddings[i]}')

generating ingredient embeddings using 'word2vec': 100%|██████████| 7006/7006 [00:00<00:00, 437606.76it/s]

salt: [ 5.290445    0.913867   -1.0863012   0.8623626  -0.4674062   1.4850291
 -4.449398    0.19326593  0.79840225 -3.7696817  -8.937332    5.396243
 -1.5825946   3.3089976   0.90930027  0.60248166  1.1192898   2.095074
 -0.53546184  1.9280357   6.0514913  -0.2811899  -6.5449524   1.2315501
 -2.2797735   2.9715595  -5.8430834   1.6902156   2.910913    5.693375
 -0.930088   -2.8069067  -0.13160013 -6.9677563  -0.65470976 -1.4642285
 -6.432981    0.97264045 -2.9176748  -7.265186   -4.413367   -4.1855135
  0.36481628 -4.518372   -3.1976647  -1.7128602   0.5032585   5.3611813
 -3.8773909  -0.6008178   3.7304778  -3.2806964   3.2925324  -5.261976
  1.495242   -0.7740897   3.0662649  -3.4892735  -2.044665    5.276503
 -0.08387772  1.8251109   1.8781763  -6.6783423  -3.9808152   0.18981819
  1.9181858  -1.3713095  -3.5619245  -1.8910278   2.3206096   1.0558263
 -0.85360146 -0.86032116 -0.766023    4.740972   -4.9475346  -0.9059874
 -2.213351    1.6304389  -2.3757186  -7.375565    3.9031565   




#### Search Similarity Between Ingredients

In [None]:
result_export_path = 'path/to/export/result.json'

n_neighbors = 10
neighbors = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=-1)
neighbors.fit(ingredient_embeddings)

results = {}
for i in tqdm(range(len(ingredient_embeddings)), desc='finding closest ingredients...'):
    distance, indices = neighbors.kneighbors(
        [ingredient_embeddings[i]], 
        n_neighbors + 1,
        return_distance=True
    )

    substitutes_and_scores = []
    # map nearest neighbor indices to ingredient names and create substitution results
    for j, idx in enumerate(indices[0]):
        if ingredient_names[i] != ingredient_names[idx]:
            substitutes_and_scores.append(((ingredient_names[idx], distance[0][j])))
    
    results[ingredient_names[i]] = substitutes_and_scores[:n_neighbors]

preview_top_n = 5
for i in range(preview_top_n):
    print(f'{ingredient_names[i]}: {results[ingredient_names[i]]}')

# save first-stage results
with Path(result_export_path).open('w') as file:
    json.dump(results, file, indent=2)

finding closest ingredients...:   0%|          | 0/6988 [00:00<?, ?it/s]

finding closest ingredients...: 100%|██████████| 6988/6988 [00:13<00:00, 537.14it/s]


salt: [('kosher_salt', 19.65060806274414), ('salt_black', 21.303600311279297), ('sea_salt', 21.42864227294922), ('garlic_salt', 22.803667068481445), ('coarse_salt', 25.142929077148438), ('butter_salt', 25.47328758239746), ('rosemary_salt', 25.72801399230957), ('seasoning_salt', 25.76896095275879), ('pepper_salt', 26.266878128051758), ('onion_salt', 26.509910583496094)]
butter: [('margarine', 17.67350196838379), ('unsalted_butter', 20.01671600341797), ('oleo', 20.376522064208984), ('softened_butter', 20.44024658203125), ('shortening', 20.800329208374023), ('butter_oil', 21.83688735961914), ('coconut_oil', 22.075387954711914), ('melted_butter', 22.215559005737305), ('crisco', 23.806915283203125), ('lard', 23.95610809326172)]
sugar: [('granulated_sugar', 18.866891860961914), ('brown_sugar', 18.93625259399414), ('white_sugar', 22.164108276367188), ('confectioner_sugar', 24.371112823486328), ('honey', 24.801908493041992), ('caster_sugar', 25.418663024902344), ('powdered_sugar', 25.850288391

#### Fusion 2 Results And Filter Thai Ingredients

In [16]:
def reciprocal_rank_fusion(list1, list2):
    reciprocal_ranks = {}

    for sublist in [list1, list2]:
        for i, item in enumerate(sublist):
            rank = i + 1
            reciprocal_rank = 1 / rank
            if item[0] in reciprocal_ranks:
                reciprocal_ranks[item[0]] += reciprocal_rank
            else:
                reciprocal_ranks[item[0]] = reciprocal_rank
    
    for item in reciprocal_ranks.keys():
        reciprocal_ranks[item] /= 2

    merged_list = sorted(reciprocal_ranks.items(), key=lambda x: x[1], reverse=True)

    return merged_list

In [19]:
result_1_path = 'path/to/result.json'
result_2_path = 'path/to/result.json'
merged_result_export_path = 'path/to/export/merged_result.json'

thai_ingredients_path = 'path/to/thai_ingredients.json'

with Path(result_1_path).open() as f:
    result_1 = json.load(f)

with Path(result_2_path).open() as f:
    result_2 = json.load(f)

with Path(thai_ingredients_path).open() as file:
    thai_ingredients = json.load(file)
    thai_ingredients_set = set(thai_ingredients)

top_k = 5
subtitute_pairs = set()
for key in tqdm(result_1.keys(), desc='merging results...'):
    if key in result_1 and key in result_2:
        # fusion 2 results into a single result
        merged_list = reciprocal_rank_fusion(result_1[key], result_2[key])
        # filter only Thai ingredients in the final result
        for item, score in merged_list[:top_k]:
            if item.replace('_', ' ') in thai_ingredients_set:
                subtitute_pairs.add((key.replace('_', ' '), item.replace('_', ' ')))

subtitute_pairs = list(sorted(subtitute_pairs))

preview_top_n = 5
for i in range(preview_top_n):
    print(f'{subtitute_pairs[i]}')

# save the final substitution results
with open(merged_result_export_path, 'w') as file:
    json.dump(subtitute_pairs, file)

merging results...: 100%|██████████| 7006/7006 [00:00<00:00, 75720.66it/s]

('accompaniment', 'orange')
('accompaniment', 'shallot')
('ada', 'cassava')
('alcoholic_beverage', 'nightshade')
('allspice', 'cardamom')



