In [1]:
import openai
from dotenv import dotenv_values
import pandas as pd
import numpy as np
from tenacity import retry, wait_random_exponential, stop_after_attempt
import pickle
from nomic import atlas
from openai.embeddings_utils import distances_from_embeddings, indices_of_nearest_neighbors_from_distances

In [2]:
config = dotenv_values(".env")
openai.api_key = config["API"]

In [3]:
#import data set from https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots
dataset_path = "./wiki_movie_plots.csv"
source_df = pd.read_csv(dataset_path)

In [4]:
movies = source_df[source_df["Origin/Ethnicity"]=="American"].sort_values("Release Year", ascending=False).head(5000)

In [5]:
#get embedding function
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6))
def get_embedding(text, model="text-embedding-ada-002"):
    text=text.replace("\n", " ")
    return openai.Embedding.create(input=text, model=model)["data"][0]["embedding"]

In [6]:
#establishing cache of embeddings to reduce cost and time
embedding_cache_path = "movie_embeddings_cache.pkl"
try:
    embedding_cache = pd.read_pickle(embedding_cache_path)
except FileNotFoundError:
    embedding_cache = {}
with open(embedding_cache_path, "wb") as embedding_cache_file:
    pickle.dump(embedding_cache, embedding_cache_file)

In [7]:
#define a function to retrieve embeddings from the cache if present, 
#otherwise request via API

def embedding_from_cache_or_API(
    string, 
    model="text-embedding-ada-002",
    embedding_cache=embedding_cache
):
    if (string, model) not in embedding_cache.keys():
        embedding_cache[(string, model)]=get_embedding(string, model)
        print("I have just got embeddings from openai for you!")
        with open(embedding_cache_path, "wb") as embedding_cache_file:
            pickle.dump(embedding_cache, embedding_cache_file)
    return embedding_cache[(string, model)]

In [8]:
# generate embeddings for movie plots
plot_embeddings = [embedding_from_cache_or_API(plot) for plot in movies["Plot"].values]

In [9]:
#atlas part with visualising results

atlas_map = atlas.map_embeddings(
    embeddings=np.array(plot_embeddings),
    data=movies[["Title", "Genre"]].to_dict(orient="records")
)

[32m2023-08-28 13:29:10.695[0m | [1mINFO    [0m | [36mnomic.project[0m:[36m_create_project[0m:[36m779[0m - [1mCreating project `harsh-being` in organization `wojtczakmart`[0m
[32m2023-08-28 13:29:11.759[0m | [1mINFO    [0m | [36mnomic.atlas[0m:[36mmap_embeddings[0m:[36m107[0m - [1mUploading embeddings to Atlas.[0m
4it [00:04,  1.03s/it]                       
[32m2023-08-28 13:29:16.110[0m | [1mINFO    [0m | [36mnomic.project[0m:[36m_add_data[0m:[36m1411[0m - [1mUpload succeeded.[0m
[32m2023-08-28 13:29:16.113[0m | [1mINFO    [0m | [36mnomic.atlas[0m:[36mmap_embeddings[0m:[36m126[0m - [1mEmbedding upload succeeded.[0m
[32m2023-08-28 13:29:17.808[0m | [1mINFO    [0m | [36mnomic.project[0m:[36mcreate_index[0m:[36m1121[0m - [1mCreated map `harsh-being` in project `harsh-being`: https://atlas.nomic.ai/map/69b3aa10-a6bf-4814-9452-09c306fc9fde/2bb1c4e0-3c29-4a3b-906f-fa44d0c782d9[0m
[32m2023-08-28 13:29:17.809[0m | [1mINFO    [0

In [10]:
#basic movie recommendations

def recommendations(
    movie_title, 
    k_nearest_neighbours=3
):
    if  movie_title in movies["Title"].values:
        movie_plot = movies[movies["Title"]==movie_title]["Plot"].values[0]
    else:
        return "no movie in database"
    movie_embedding = embedding_cache[(movie_plot, "text-embedding-ada-002")]
    movie_index = plot_embeddings.index(movie_embedding)
    distances = distances_from_embeddings(movie_embedding, plot_embeddings)
    indices = indices_of_nearest_neighbors_from_distances(distances)

    match_counter = 0
    matching_movies = []
    for i in indices:
        if i == movie_index:
            continue
        if match_counter >= k_nearest_neighbours:
            break
        match_counter += 1
        matching_movies += [movies[movies["Plot"]==list(embedding_cache.keys())[list(embedding_cache.values()).index(plot_embeddings[i])][0]]["Title"].values[0]]
    return matching_movies


In [12]:
recommendations('Wonder Woman')

['Batman v Superman: Dawn of Justice',
 'Hercules',
 'Professor Marston and the Wonder Women']