In [1]:
import torch
import pickle
import pandas as pd
import numpy as np
from src.train import *
from src.processing import *
from src.models import *
from src.inference import *
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertModel
from transformers import DistilBertModel, DistilBertTokenizer
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

torch.cuda.empty_cache()

Using device: cuda


In [3]:
vocab = Vocabulary()
tokenizer = BertTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = BertModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

ratings_df, movie_descriptions, movies_metadata = create_ratings_df(
    number_of_movies=7500,
    links_path='CLIP4Rec/archive/links.csv',
    movies_metadata_path='CLIP4Rec/archive/movies_metadata.csv',
    ratings_path='CLIP4Rec/archive/ratings.csv'
    )
sequences = get_sequences(ratings_df)
vocab.build_vocab(sequences)

train_sentences, val_sentences = train_test_split(sequences, test_size=0.2, random_state=42)
train_data, film_descriptions_encoded = prepare_dataset(
    train_sentences, movie_descriptions, tokenizer, vocab, encode_descriptions=True
)
val_data = prepare_dataset(
    val_sentences, movie_descriptions, tokenizer, vocab
)

train_dataset = FilmRecommendationDataset(train_data, film_descriptions_encoded)
val_dataset = FilmRecommendationDataset(val_data, film_descriptions_encoded)

epochs = 7
batch_size = 32
lr = 0.0001

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

  movies_metadata = pd.read_csv(movies_metadata_path)
100%|██████████| 865083/865083 [00:02<00:00, 294762.88it/s]
100%|██████████| 7315/7315 [00:07<00:00, 939.63it/s]
100%|██████████| 216271/216271 [00:00<00:00, 423374.40it/s]


In [8]:
film_encoder = SASFilmEncoder(item_num=len(vocab.word_to_index), seq_len=seq_len, embed_dim=384, device=device)
text_encoder = TextEncoder(model, output_dim=384, add_fc_layer=False)

In [7]:
train_clip(film_encoder, text_encoder, train_loader, val_loader, 
           epochs=epochs, lr=lr, device=device, iter_verbose=1000, folder='CLIP4Rec/artifacts')

  7%|▋         | 1000/13517 [05:29<1:08:50,  3.03it/s]

Epoch 1, Batch 1000
Accuracy: 0.0077
Agreggated loss: 2.9500
Classification loss: 7.3598
Contrastive loss: 1.2889


 15%|█▍        | 2000/13517 [10:59<1:02:44,  3.06it/s]

Epoch 1, Batch 2000
Accuracy: 0.0161
Agreggated loss: 2.9035
Classification loss: 7.1817
Contrastive loss: 1.2326


 22%|██▏       | 3000/13517 [16:28<57:43,  3.04it/s]  

Epoch 1, Batch 3000
Accuracy: 0.0199
Agreggated loss: 2.8762
Classification loss: 7.0619
Contrastive loss: 1.2048


 30%|██▉       | 4000/13517 [21:57<51:58,  3.05it/s]

Epoch 1, Batch 4000
Accuracy: 0.0230
Agreggated loss: 2.8570
Classification loss: 6.9722
Contrastive loss: 1.1873


 37%|███▋      | 5000/13517 [27:26<46:51,  3.03it/s]

Epoch 1, Batch 5000
Accuracy: 0.0242
Agreggated loss: 2.8423
Classification loss: 6.9006
Contrastive loss: 1.1750


 44%|████▍     | 6000/13517 [32:57<41:26,  3.02it/s]  

Epoch 1, Batch 6000
Accuracy: 0.0249
Agreggated loss: 2.8308
Classification loss: 6.8434
Contrastive loss: 1.1659


 52%|█████▏    | 7000/13517 [38:27<35:49,  3.03it/s]

Epoch 1, Batch 7000
Accuracy: 0.0256
Agreggated loss: 2.8214
Classification loss: 6.7953
Contrastive loss: 1.1586


 59%|█████▉    | 8000/13517 [43:59<30:14,  3.04it/s]

Epoch 1, Batch 8000
Accuracy: 0.0275
Agreggated loss: 2.8136
Classification loss: 6.7548
Contrastive loss: 1.1530


 67%|██████▋   | 9000/13517 [49:32<24:59,  3.01it/s]

Epoch 1, Batch 9000
Accuracy: 0.0268
Agreggated loss: 2.8068
Classification loss: 6.7195
Contrastive loss: 1.1483


 74%|███████▍  | 10000/13517 [55:07<19:20,  3.03it/s]

Epoch 1, Batch 10000
Accuracy: 0.0288
Agreggated loss: 2.8010
Classification loss: 6.6888
Contrastive loss: 1.1443


 81%|████████▏ | 11000/13517 [1:00:36<13:48,  3.04it/s]

Epoch 1, Batch 11000
Accuracy: 0.0291
Agreggated loss: 2.7959
Classification loss: 6.6623
Contrastive loss: 1.1408


 89%|████████▉ | 12000/13517 [1:06:06<08:17,  3.05it/s]

Epoch 1, Batch 12000
Accuracy: 0.0307
Agreggated loss: 2.7911
Classification loss: 6.6360
Contrastive loss: 1.1378


 96%|█████████▌| 13000/13517 [1:11:35<02:51,  3.02it/s]

Epoch 1, Batch 13000
Accuracy: 0.0295
Agreggated loss: 2.7869
Classification loss: 6.6135
Contrastive loss: 1.1352


100%|██████████| 13517/13517 [1:14:25<00:00,  3.03it/s]


Epoch 1, Batch 13517
Accuracy: 0.0319
Agreggated loss: 2.7849
Classification loss: 6.6025
Contrastive loss: 1.1339


100%|██████████| 3380/3380 [15:00<00:00,  3.75it/s]


Epoch 1: Val Loss: 2.6989, Val Accuracy: 0.0313
Val Classification loss: 6.3123
Val Contrastive loss: 1.0340


  7%|▋         | 1000/13517 [05:29<1:08:23,  3.05it/s]

Epoch 2, Batch 1000
Accuracy: 0.0344
Agreggated loss: 2.7205
Classification loss: 6.2296
Contrastive loss: 1.1019


 15%|█▍        | 2000/13517 [11:00<1:04:59,  2.95it/s]

Epoch 2, Batch 2000
Accuracy: 0.0350
Agreggated loss: 2.7192
Classification loss: 6.2238
Contrastive loss: 1.1010


 22%|██▏       | 3000/13517 [16:31<59:40,  2.94it/s]  

Epoch 2, Batch 3000
Accuracy: 0.0342
Agreggated loss: 2.7189
Classification loss: 6.2240
Contrastive loss: 1.1004


 30%|██▉       | 4000/13517 [22:07<52:21,  3.03it/s]  

Epoch 2, Batch 4000
Accuracy: 0.0353
Agreggated loss: 2.7179
Classification loss: 6.2196
Contrastive loss: 1.0994


 37%|███▋      | 5000/13517 [27:37<46:53,  3.03it/s]

Epoch 2, Batch 5000
Accuracy: 0.0355
Agreggated loss: 2.7171
Classification loss: 6.2163
Contrastive loss: 1.0988


 44%|████▍     | 6000/13517 [33:07<41:04,  3.05it/s]

Epoch 2, Batch 6000
Accuracy: 0.0353
Agreggated loss: 2.7164
Classification loss: 6.2137
Contrastive loss: 1.0980


 52%|█████▏    | 7000/13517 [38:36<35:47,  3.03it/s]

Epoch 2, Batch 7000
Accuracy: 0.0369
Agreggated loss: 2.7155
Classification loss: 6.2094
Contrastive loss: 1.0974


 59%|█████▉    | 8000/13517 [44:05<30:23,  3.03it/s]

Epoch 2, Batch 8000
Accuracy: 0.0383
Agreggated loss: 2.7147
Classification loss: 6.2052
Contrastive loss: 1.0969


 67%|██████▋   | 9000/13517 [49:39<25:54,  2.91it/s]

Epoch 2, Batch 9000
Accuracy: 0.0373
Agreggated loss: 2.7138
Classification loss: 6.2012
Contrastive loss: 1.0963


 74%|███████▍  | 10000/13517 [55:17<19:24,  3.02it/s]

Epoch 2, Batch 10000
Accuracy: 0.0383
Agreggated loss: 2.7132
Classification loss: 6.1978
Contrastive loss: 1.0959


 81%|████████▏ | 11000/13517 [1:00:56<14:26,  2.90it/s]

Epoch 2, Batch 11000
Accuracy: 0.0377
Agreggated loss: 2.7124
Classification loss: 6.1939
Contrastive loss: 1.0955


 89%|████████▉ | 12000/13517 [1:06:32<08:38,  2.92it/s]

Epoch 2, Batch 12000
Accuracy: 0.0393
Agreggated loss: 2.7115
Classification loss: 6.1895
Contrastive loss: 1.0949


 96%|█████████▌| 13000/13517 [1:12:11<02:52,  3.00it/s]

Epoch 2, Batch 13000
Accuracy: 0.0374
Agreggated loss: 2.7107
Classification loss: 6.1855
Contrastive loss: 1.0945


100%|██████████| 13517/13517 [1:15:02<00:00,  3.00it/s]


Epoch 2, Batch 13517
Accuracy: 0.0376
Agreggated loss: 2.7104
Classification loss: 6.1838
Contrastive loss: 1.0942


100%|██████████| 3380/3380 [15:00<00:00,  3.76it/s]


Epoch 2: Val Loss: 2.6702, Val Accuracy: 0.0380
Val Classification loss: 6.1515
Val Contrastive loss: 1.0208


  7%|▋         | 1000/13517 [05:30<1:08:28,  3.05it/s]

Epoch 3, Batch 1000
Accuracy: 0.0434
Agreggated loss: 2.6789
Classification loss: 5.9857
Contrastive loss: 1.0869


 15%|█▍        | 2000/13517 [10:58<1:02:49,  3.06it/s]

Epoch 3, Batch 2000
Accuracy: 0.0442
Agreggated loss: 2.6788
Classification loss: 5.9881
Contrastive loss: 1.0860


 22%|██▏       | 3000/13517 [16:29<58:01,  3.02it/s]  

Epoch 3, Batch 3000
Accuracy: 0.0427
Agreggated loss: 2.6796
Classification loss: 5.9938
Contrastive loss: 1.0860


 30%|██▉       | 4000/13517 [21:59<52:13,  3.04it/s]

Epoch 3, Batch 4000
Accuracy: 0.0442
Agreggated loss: 2.6798
Classification loss: 5.9963
Contrastive loss: 1.0856


 37%|███▋      | 5000/13517 [27:29<46:31,  3.05it/s]

Epoch 3, Batch 5000
Accuracy: 0.0439
Agreggated loss: 2.6799
Classification loss: 5.9990
Contrastive loss: 1.0850


 44%|████▍     | 6000/13517 [33:07<42:46,  2.93it/s]

Epoch 3, Batch 6000
Accuracy: 0.0459
Agreggated loss: 2.6799
Classification loss: 6.0002
Contrastive loss: 1.0848


 52%|█████▏    | 7000/13517 [38:47<37:11,  2.92it/s]

Epoch 3, Batch 7000
Accuracy: 0.0452
Agreggated loss: 2.6797
Classification loss: 5.9996
Contrastive loss: 1.0844


 59%|█████▉    | 8000/13517 [44:21<30:12,  3.04it/s]

Epoch 3, Batch 8000
Accuracy: 0.0444
Agreggated loss: 2.6796
Classification loss: 6.0004
Contrastive loss: 1.0839


 67%|██████▋   | 9000/13517 [49:56<25:47,  2.92it/s]

Epoch 3, Batch 9000
Accuracy: 0.0455
Agreggated loss: 2.6793
Classification loss: 6.0001
Contrastive loss: 1.0835


 74%|███████▍  | 10000/13517 [55:35<20:03,  2.92it/s]

Epoch 3, Batch 10000
Accuracy: 0.0435
Agreggated loss: 2.6794
Classification loss: 6.0016
Contrastive loss: 1.0832


 81%|████████▏ | 11000/13517 [1:01:14<14:13,  2.95it/s]

Epoch 3, Batch 11000
Accuracy: 0.0453
Agreggated loss: 2.6793
Classification loss: 6.0020
Contrastive loss: 1.0828


 89%|████████▉ | 12000/13517 [1:06:54<08:40,  2.91it/s]

Epoch 3, Batch 12000
Accuracy: 0.0440
Agreggated loss: 2.6791
Classification loss: 6.0021
Contrastive loss: 1.0825


 96%|█████████▌| 13000/13517 [1:12:32<02:49,  3.06it/s]

Epoch 3, Batch 13000
Accuracy: 0.0442
Agreggated loss: 2.6790
Classification loss: 6.0022
Contrastive loss: 1.0822


100%|██████████| 13517/13517 [1:15:22<00:00,  2.99it/s]


Epoch 3, Batch 13517
Accuracy: 0.0443
Agreggated loss: 2.6788
Classification loss: 6.0019
Contrastive loss: 1.0820


100%|██████████| 3380/3380 [14:57<00:00,  3.76it/s]


Epoch 3: Val Loss: 2.6577, Val Accuracy: 0.0424
Val Classification loss: 6.0962
Val Contrastive loss: 1.0116


  7%|▋         | 1000/13517 [05:29<1:08:31,  3.04it/s]

Epoch 4, Batch 1000
Accuracy: 0.0514
Agreggated loss: 2.6477
Classification loss: 5.8080
Contrastive loss: 1.0756


 15%|█▍        | 2000/13517 [10:57<1:03:02,  3.04it/s]

Epoch 4, Batch 2000
Accuracy: 0.0518
Agreggated loss: 2.6473
Classification loss: 5.8072
Contrastive loss: 1.0750


 22%|██▏       | 3000/13517 [16:25<57:39,  3.04it/s]  

Epoch 4, Batch 3000
Accuracy: 0.0514
Agreggated loss: 2.6483
Classification loss: 5.8135
Contrastive loss: 1.0752


 30%|██▉       | 4000/13517 [21:53<51:58,  3.05it/s]

Epoch 4, Batch 4000
Accuracy: 0.0508
Agreggated loss: 2.6490
Classification loss: 5.8192
Contrastive loss: 1.0750


 37%|███▋      | 5000/13517 [27:22<46:53,  3.03it/s]

Epoch 4, Batch 5000
Accuracy: 0.0511
Agreggated loss: 2.6493
Classification loss: 5.8222
Contrastive loss: 1.0747


 44%|████▍     | 6000/13517 [32:50<40:58,  3.06it/s]

Epoch 4, Batch 6000
Accuracy: 0.0505
Agreggated loss: 2.6497
Classification loss: 5.8261
Contrastive loss: 1.0744


 52%|█████▏    | 7000/13517 [38:19<35:32,  3.06it/s]  

Epoch 4, Batch 7000
Accuracy: 0.0503
Agreggated loss: 2.6505
Classification loss: 5.8311
Contrastive loss: 1.0744


 59%|█████▉    | 8000/13517 [43:54<30:22,  3.03it/s]

Epoch 4, Batch 8000
Accuracy: 0.0497
Agreggated loss: 2.6506
Classification loss: 5.8329
Contrastive loss: 1.0740


 67%|██████▋   | 9000/13517 [49:23<24:48,  3.03it/s]

Epoch 4, Batch 9000
Accuracy: 0.0515
Agreggated loss: 2.6509
Classification loss: 5.8358
Contrastive loss: 1.0738


 74%|███████▍  | 10000/13517 [54:52<19:30,  3.01it/s]

Epoch 4, Batch 10000
Accuracy: 0.0496
Agreggated loss: 2.6511
Classification loss: 5.8380
Contrastive loss: 1.0736


 81%|████████▏ | 11000/13517 [1:00:21<13:47,  3.04it/s]

Epoch 4, Batch 11000
Accuracy: 0.0502
Agreggated loss: 2.6513
Classification loss: 5.8404
Contrastive loss: 1.0732


 89%|████████▉ | 12000/13517 [1:05:50<08:16,  3.06it/s]

Epoch 4, Batch 12000
Accuracy: 0.0494
Agreggated loss: 2.6517
Classification loss: 5.8442
Contrastive loss: 1.0729


 96%|█████████▌| 13000/13517 [1:11:24<02:49,  3.05it/s]

Epoch 4, Batch 13000
Accuracy: 0.0499
Agreggated loss: 2.6518
Classification loss: 5.8460
Contrastive loss: 1.0727


100%|██████████| 13517/13517 [1:14:14<00:00,  3.03it/s]


Epoch 4, Batch 13517
Accuracy: 0.0495
Agreggated loss: 2.6518
Classification loss: 5.8468
Contrastive loss: 1.0725


100%|██████████| 3380/3380 [14:58<00:00,  3.76it/s]


Epoch 4: Val Loss: 2.6520, Val Accuracy: 0.0445
Val Classification loss: 6.0817
Val Contrastive loss: 1.0043


  7%|▋         | 1000/13517 [05:28<1:08:34,  3.04it/s]

Epoch 5, Batch 1000
Accuracy: 0.0620
Agreggated loss: 2.6144
Classification loss: 5.6090
Contrastive loss: 1.0681


 15%|█▍        | 2000/13517 [10:57<1:03:03,  3.04it/s]

Epoch 5, Batch 2000
Accuracy: 0.0611
Agreggated loss: 2.6172
Classification loss: 5.6280
Contrastive loss: 1.0681


 22%|██▏       | 3000/13517 [16:27<57:23,  3.05it/s]  

Epoch 5, Batch 3000
Accuracy: 0.0603
Agreggated loss: 2.6185
Classification loss: 5.6383
Contrastive loss: 1.0674


 30%|██▉       | 4000/13517 [21:56<52:04,  3.05it/s]  

Epoch 5, Batch 4000
Accuracy: 0.0576
Agreggated loss: 2.6200
Classification loss: 5.6486
Contrastive loss: 1.0673


 37%|███▋      | 5000/13517 [27:25<46:33,  3.05it/s]

Epoch 5, Batch 5000
Accuracy: 0.0591
Agreggated loss: 2.6206
Classification loss: 5.6543
Contrastive loss: 1.0669


 44%|████▍     | 6000/13517 [32:53<41:04,  3.05it/s]

Epoch 5, Batch 6000
Accuracy: 0.0585
Agreggated loss: 2.6214
Classification loss: 5.6599
Contrastive loss: 1.0668


 52%|█████▏    | 7000/13517 [38:21<35:39,  3.05it/s]

Epoch 5, Batch 7000
Accuracy: 0.0585
Agreggated loss: 2.6218
Classification loss: 5.6636
Contrastive loss: 1.0665


 59%|█████▉    | 8000/13517 [43:50<30:17,  3.04it/s]

Epoch 5, Batch 8000
Accuracy: 0.0555
Agreggated loss: 2.6228
Classification loss: 5.6707
Contrastive loss: 1.0663


 67%|██████▋   | 9000/13517 [49:18<24:39,  3.05it/s]

Epoch 5, Batch 9000
Accuracy: 0.0568
Agreggated loss: 2.6233
Classification loss: 5.6747
Contrastive loss: 1.0661


 74%|███████▍  | 10000/13517 [54:52<19:33,  3.00it/s]

Epoch 5, Batch 10000
Accuracy: 0.0577
Agreggated loss: 2.6239
Classification loss: 5.6795
Contrastive loss: 1.0658


 81%|████████▏ | 11000/13517 [1:00:25<13:52,  3.02it/s]

Epoch 5, Batch 11000
Accuracy: 0.0561
Agreggated loss: 2.6244
Classification loss: 5.6838
Contrastive loss: 1.0656


 89%|████████▉ | 12000/13517 [1:05:54<08:18,  3.04it/s]

Epoch 5, Batch 12000
Accuracy: 0.0555
Agreggated loss: 2.6248
Classification loss: 5.6876
Contrastive loss: 1.0653


 96%|█████████▌| 13000/13517 [1:11:23<02:50,  3.04it/s]

Epoch 5, Batch 13000
Accuracy: 0.0562
Agreggated loss: 2.6254
Classification loss: 5.6919
Contrastive loss: 1.0651


100%|██████████| 13517/13517 [1:14:13<00:00,  3.04it/s]


Epoch 5, Batch 13517
Accuracy: 0.0556
Agreggated loss: 2.6255
Classification loss: 5.6933
Contrastive loss: 1.0649


100%|██████████| 3380/3380 [14:58<00:00,  3.76it/s]


Epoch 5: Val Loss: 2.6517, Val Accuracy: 0.0452
Val Classification loss: 6.1001
Val Contrastive loss: 0.9987


In [8]:
# torch.save(film_encoder.state_dict(), 'artifacts/film_encoder_weights_final.pth')
# torch.save(text_encoder.state_dict(), 'artifacts/text_encoder_weights_final.pth')

torch.save(train_dataset, 'CLIP4Rec/artifacts/train_dataset.pt')
torch.save(val_dataset, 'CLIP4Rec/artifacts/val_dataset.pt')

with open('CLIP4Rec/artifacts/ratings_df.pickle', 'wb') as f:
  pickle.dump(ratings_df, f)

with open('CLIP4Rec/artifacts/movie_descriptions.pickle', 'wb') as f:
  pickle.dump(movie_descriptions, f)

with open('CLIP4Rec/artifacts/sequences.pickle', 'wb') as f:
  pickle.dump(sequences, f)

with open('CLIP4Rec/artifacts/vocab.pickle', 'wb') as f:
  pickle.dump(vocab, f)

with open('CLIP4Rec/artifacts/film_descriptions_encoded.pickle', 'wb') as f:
  pickle.dump(film_descriptions_encoded, f)

with open('CLIP4Rec/artifacts/movies_metadata.pickle', 'wb') as f:
  pickle.dump(movies_metadata, f)

In [2]:
# list_movies = ["Only Lovers Left Alive",
#                "The Twilight Saga: Eclipse",
#                "Me Before You",
#                "(500) Days of Summer"]

list_movies = ["Minions",
               "Zootopia",
               "Shrek",
               "Kung Fu Panda"]

In [3]:
vocab = pd.read_pickle('artifacts/vocab.pickle')
movies_metadata = pd.read_pickle('artifacts/movies_metadata.pickle')
film_descriptions_encoded = pd.read_pickle('artifacts/film_descriptions_encoded.pickle')
bert_model = DistilBertModel.from_pretrained('distilbert-base-uncased')
bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

dim = 384
num_trees=10
search_type='euclidean'



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

In [4]:
# build and save

inference = Inference(
    film_encoder_path = 'artifacts/film_encoder_weights_final_4.pth',
    text_encoder_path = 'artifacts/text_encoder_weights_final_4.pth',
    vocab=vocab,
    dim=dim,
    movies_metadata=movies_metadata,
    seq_len=seq_len,
    device=device,
    bert_model=bert_model,
    bert_tokenizer=bert_tokenizer,
)

film_embeddings, text_embeddings = inference.get_embeddings(film_descriptions_encoded, batch_size=32)

annoy_model = AnnoySearchEngine(
    dim=dim,
    num_trees=num_trees,
    search_type=search_type,
)
annoy_model.build_trees(film_embeddings, text_embeddings)
annoy_model.save_indexes('artifacts/text_index.ann', 'artifacts/film_index.ann', 'artifacts/idx_to_movieId.pickle')
inference.init_annoy_model('artifacts/text_index.ann', 'artifacts/film_index.ann', 'artifacts/idx_to_movieId.pickle', num_trees=10)

100%|██████████| 229/229 [00:08<00:00, 28.22it/s]
100%|██████████| 7314/7314 [00:01<00:00, 7024.77it/s]


In [5]:
overview = movies_metadata.query('title=="Kung Fu Panda"')['overview'].values[0]
inference.search_text(overview, in_films=True)

['Kung Fu Panda',
 'Kung Fu Panda 2',
 'The Mermaid',
 'Kung Fu Dunk',
 'The Man with the Iron Fists 2',
 'Girls Against Boys',
 'Shanghai Knights',
 'Clean',
 'Rise: Blood Hunter',
 'Saving Mr. Wu']

In [11]:
# load and init

inference = Inference(
    film_encoder_path = 'artifacts/film_encoder_weights_final_4.pth',
    text_encoder_path = 'artifacts/text_encoder_weights_final_4.pth',
    vocab=vocab,
    dim=384,
    movies_metadata=movies_metadata,
    seq_len=seq_len,
    device=device,
    bert_model=bert_model,
    bert_tokenizer=bert_tokenizer,
)

inference.init_annoy_model('artifacts/text_index.ann', 'artifacts/film_index.ann', 'artifacts/idx_to_movieId.pickle', num_trees=10)

In [8]:
overview = movies_metadata.query('title=="Megamind"')['overview'].values[0]
inference.search_text(overview, in_films=True)

['Megamind',
 'Despicable Me 2',
 'Teenage Mutant Ninja Turtles: Out of the Shadows',
 'The Lego Movie',
 'Teen Titans: Trouble in Tokyo',
 'LEGO DC Comics Super Heroes: Justice League: Attack of the Legion of Doom!',
 'Superman/Shazam!: The Return of Black Adam',
 'The SpongeBob SquarePants Movie',
 'In the Name of the King 2: Two Worlds',
 'Fantastic 4: Rise of the Silver Surfer']