## Prepare the data

In [None]:
import pandas as pd
import numpy as np

In [97]:
df = pd.read_csv('../data/wiki_movie_plots_deduped.csv')
df.head()

Unnamed: 0,Release Year,Title,Origin/Ethnicity,Director,Cast,Genre,Wiki Page,Plot
0,1901,Kansas Saloon Smashers,American,Unknown,,unknown,https://en.wikipedia.org/wiki/Kansas_Saloon_Sm...,"A bartender is working at a saloon, serving dr..."
1,1901,Love by the Light of the Moon,American,Unknown,,unknown,https://en.wikipedia.org/wiki/Love_by_the_Ligh...,"The moon, painted with a smiling face hangs ov..."
2,1901,The Martyred Presidents,American,Unknown,,unknown,https://en.wikipedia.org/wiki/The_Martyred_Pre...,"The film, just over a minute long, is composed..."
3,1901,"Terrible Teddy, the Grizzly King",American,Unknown,,unknown,"https://en.wikipedia.org/wiki/Terrible_Teddy,_...",Lasting just 61 seconds and consisting of two ...
4,1902,Jack and the Beanstalk,American,"George S. Fleming, Edwin S. Porter",,unknown,https://en.wikipedia.org/wiki/Jack_and_the_Bea...,The earliest known adaptation of the classic f...


In [98]:
df = df.drop(['Director'], axis=1)

In [99]:
df = df.drop(['Cast'], axis=1)

In [100]:
df = df.drop(['Wiki Page'], axis=1)

In [101]:
df = df.drop(['Origin/Ethnicity'], axis=1)

In [102]:
df = df.rename(columns={"Release Year": "release_year", "Title": "title", "Genre": "genre", "Plot": "plot"})
df.head()

Unnamed: 0,release_year,title,genre,plot
0,1901,Kansas Saloon Smashers,unknown,"A bartender is working at a saloon, serving dr..."
1,1901,Love by the Light of the Moon,unknown,"The moon, painted with a smiling face hangs ov..."
2,1901,The Martyred Presidents,unknown,"The film, just over a minute long, is composed..."
3,1901,"Terrible Teddy, the Grizzly King",unknown,Lasting just 61 seconds and consisting of two ...
4,1902,Jack and the Beanstalk,unknown,The earliest known adaptation of the classic f...


In [103]:
category_counts = df['genre'].value_counts()
category_counts

unknown                          6083
drama                            5964
comedy                           4379
horror                           1167
action                           1098
                                 ... 
cbc-tv miniseries                   1
bio-drama                           1
national film board docudrama       1
cult drama                          1
horror romantic comedy              1
Name: genre, Length: 2265, dtype: int64

In [104]:
category_counts = df['release_year'].value_counts()
category_counts

2013    1021
2014     929
2012     874
2011     858
2010     825
        ... 
1906       3
1905       2
1903       2
1904       1
1902       1
Name: release_year, Length: 117, dtype: int64

In [105]:
# add new column so it can be used for the research

df['year_genre'] = list(zip(df['release_year'], df['genre']))
df.head()

Unnamed: 0,release_year,title,genre,plot,year_genre
0,1901,Kansas Saloon Smashers,unknown,"A bartender is working at a saloon, serving dr...","(1901, unknown)"
1,1901,Love by the Light of the Moon,unknown,"The moon, painted with a smiling face hangs ov...","(1901, unknown)"
2,1901,The Martyred Presidents,unknown,"The film, just over a minute long, is composed...","(1901, unknown)"
3,1901,"Terrible Teddy, the Grizzly King",unknown,Lasting just 61 seconds and consisting of two ...,"(1901, unknown)"
4,1902,Jack and the Beanstalk,unknown,The earliest known adaptation of the classic f...,"(1902, unknown)"


In [9]:
# text encoding steps

from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [107]:
from tqdm import tqdm

plots = list(df['plot'])  # A plot is a description of a movie

batch_size = 100  # i choose a batch of a size equal to 100 to avoid RAM insufissance
num_texts = len(plots)
encoded_texts = []
embeddings = []
for i in tqdm(range(0, num_texts, batch_size)):
    batch_texts = plots[i:i+batch_size]
    encoded_batch = [tokenizer.encode(text, max_length=125, padding='max_length', truncation=True, add_special_tokens=True) for text in batch_texts]
    encoded_texts.extend(encoded_batch)

100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 23.43it/s]


In [109]:
batch_size = 10 #i choose a batch size equal to 10 to avoid GPU insufissance


# Convert the encoded texts to tensors
input_ids = torch.tensor(encoded_texts)

embeddings = []
# Generate embeddings for the query and move it to the device
model.eval()
with torch.no_grad():
  for i in tqdm(range(0, num_texts, batch_size)):
      input_batch = input_ids[i:i+batch_size]
      outputs = model(input_batch)
      batch_embeddings = outputs.last_hidden_state.mean(dim=1)
      embeddings.extend(batch_embeddings)


100%|█████████████████████████████████████████████| 1/1 [00:01<00:00,  1.38s/it]


In [111]:
list_embeddings = []

for output in embeddings:
    list_embeddings.append(output.detach().numpy())

In [114]:
# add the encoded text into the data

df['encoded_text'] = list_embeddings

In [None]:
np.save('../data/wiki_movie_plots_deduped_encoded.npy', df)

## Load encoded dataset

In [3]:
# Load the .npy file
data = np.load('../data/wiki_movie_plots_deduped_encoded.npy', allow_pickle=True)

# Convert the data back to a dataframe
df = pd.DataFrame(data)

In [4]:
df.columns = ["release_year", "title", "genre", "plots", "year_genre", "encoded_text"]

In [5]:
df.head()

Unnamed: 0,release_year,title,genre,plots,year_genre,encoded_text
0,1901,Kansas Saloon Smashers,unknown,"A bartender is working at a saloon, serving dr...","(1901, unknown)","[-0.06952393, 0.21627466, 0.19784085, -0.29264..."
1,1901,Love by the Light of the Moon,unknown,"The moon, painted with a smiling face hangs ov...","(1901, unknown)","[-0.089487724, -0.17200117, 0.5052035, -0.0962..."
2,1901,The Martyred Presidents,unknown,"The film, just over a minute long, is composed...","(1901, unknown)","[-0.13208406, -0.05472481, 0.245642, -0.036767..."
3,1901,"Terrible Teddy, the Grizzly King",unknown,Lasting just 61 seconds and consisting of two ...,"(1901, unknown)","[-0.1028006, 0.3138553, 0.08195516, -0.1730907..."
4,1902,Jack and the Beanstalk,unknown,The earliest known adaptation of the classic f...,"(1902, unknown)","[-0.23109181, 0.056910552, 0.29452863, -0.2051..."


In [39]:
import numpy as np

from scipy.spatial.distance import cdist


def calculte_simalarity(filtered_df, query, k):
    
    desired_columns = ['title', 'release_year', 'genre']
    
    if query!= None:
    
        encoded_input = tokenizer.encode(query,  max_length=125, padding='max_length', truncation=True, return_tensors='pt')
        output = model(encoded_input)
        embedding_query = output.last_hidden_state.mean(dim=1).detach().numpy()
        
        # Convert text embeddings to a numpy array
        texts_encoded = list(filtered_df['encoded_text'])
        text_embeddings = np.array(texts_encoded)
        
        similarities = 1 - cdist(text_embeddings, embedding_query, metric='cosine').flatten()
        filtered_df['similarities'] = similarities
        filtered_df = filtered_df.nlargest(k, 'similarities')
    
        return filtered_df[desired_columns]
    
    else: 
        return filtered_df[desired_columns]
    

In [40]:
def filter_data(data = df, k=5, genre = None, release_year = None, query = None):
    
    list_years = list(data['release_year'].unique())
    list_genre = list(data['genre'].unique())
    
    if release_year != None:  
        if release_year in list_years:
            if genre != None:
                if genre in list_genre:
                    ### case one 1 : both genre and release_year present, we search (release_year, genre)
                    filtered_df = data[data['year_genre'] == (release_year, genre)]
                    return calculte_simalarity(filtered_df, query,  k)
                else:
                    raise Exception("Genre not found")
            else:
                ### case one 3 : genre absent and release_year present, we search (release_year, all)
                filtered_df = data[data['release_year'] == release_year]
                return calculte_simalarity(filtered_df, query,  k)   
        else:
            raise Exception("Year not found")   
    else:
        if genre != None:
            if genre in list_genre:
                ### case one 2 : genre present and release_year absent, we search (all, genre)
                filtered_df = data[data['genre'] == genre]
                return calculte_simalarity(filtered_df, query,  k)
            else:
                raise Exception("Genre not found")
        else:
            ### case one 4 : genre absent and release_year absent, we search (all, all)
            filtered_df = data
            return calculte_simalarity(filtered_df, query,  k) 


In [41]:
new_data = filter_data(data = df, k=4, query = "a movie about two people who have a toxic relationship")

In [42]:
import json

suggestions_dict = {}

for index, row in new_data.iterrows():
    suggestion_key = f"movie {index + 1}"
    suggestion_value = {
        "Title": row['title'],
        "year": row['release_year'],
        "genre": row['genre']
    }
    suggestions_dict[suggestion_key] = suggestion_value

json_data = json.dumps(suggestions_dict, indent=2)
suggestions_dict

{'movie 26': {'Title': 'The Lure of the Gown',
  'year': 1909,
  'genre': 'unknown'},
 'movie 25205': {'Title': 'Haathkadi',
  'year': 1982,
  'genre': 'family, thriller, drama'},
 'movie 17609': {'Title': 'Deadly', 'year': 1991, 'genre': 'crime'},
 'movie 7607': {'Title': 'Heaven and Earth Magic',
  'year': 1962,
  'genre': 'animated'}}