In [40]:
import openai
from dotenv import dotenv_values
import pandas as pd
import numpy as np
from tenacity import retry, wait_random_exponential, stop_after_attempt
import pickle

In [30]:
config = dotenv_values(".env")
openai.api_key = config["API"]

In [36]:
#import data set from https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots
dataset_path = "./wiki_movie_plots.csv"
source_df = pd.read_csv(dataset_path)

In [32]:
movies = source_df[source_df["Origin/Ethnicity"]=="American"].sort_values("Release Year", ascending=False).head(5000)

In [39]:
#get embedding function
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6))
def get_embedding(text, model="text-embedding-ada-002"):
    text=text.replace("\n", " ")
    return openai.Embedding.create(input=text, model=model)["data"][0]["embedding"]

In [53]:
#establishing cache of embeddings to reduce cost and time
embedding_cache_path = "movie_embeddings_cache.pkl"
try:
    embedding_cache = pd.read_pickle(embedding_cache_path)
except FileNotFoundError:
    embedding_cache = {}
with open(embedding_cache_path, "wb") as embedding_cache_file:
    pickle.dump(embedding_cache, embedding_cache_file)

In [54]:
embedding_cache

{}

In [55]:
#define a function to retrieve embeddings from the cache if present, 
#otherwise request via API

def embedding_from_cache_or_API(
    string, 
    model="text-embedding-ada-002",
    embedding_cache=embedding_cache
):
    if (string, model) not in embedding_cache.keys():
        embedding_cache[(string, model)]=get_embedding(string,model)
        print("I have just got embeddings from openai for you!")
        with open(embedding_cache_path, "wb") as embedding_cache_file:
            pickle.dump(embedding_cache, embedding_cache_file)
        return embedding_cache[(string, model)]

In [None]:
# generate embeddings for movie plots
plot_embeddings = [embedding_from_cache_or_API(plot) for plot in movies["Plot"].values]

In [58]:
plot_embeddings[0]

[-0.0008034348138608038,
 -0.01881888136267662,
 -0.02088065631687641,
 0.0013638499658554792,
 -0.003950581420212984,
 0.014902893453836441,
 0.004258463624864817,
 -0.004092414863407612,
 -0.024256985634565353,
 -0.025488514453172684,
 -0.0013119596987962723,
 0.0014131458010524511,
 0.011990118771791458,
 -0.01803014986217022,
 -0.011464296840131283,
 -0.0011917471420019865,
 0.02735656499862671,
 0.010807019658386707,
 0.011574995703995228,
 -0.02927996590733528,
 -0.011180629953742027,
 0.013000249862670898,
 0.006939462386071682,
 -0.012052386999130249,
 -0.009229554794728756,
 0.0018057824345305562,
 0.021890787407755852,
 -0.015719301998615265,
 0.015318016521632671,
 -0.01403113640844822,
 0.0027830495964735746,
 -0.001439955784007907,
 -0.00337459915317595,
 -0.012965655885636806,
 -0.039464302361011505,
 -0.030442308634519577,
 0.005355078727006912,
 -0.011000743135809898,
 0.017033854499459267,
 0.01400346215814352,
 0.005977762397378683,
 0.005313566420227289,
 -0.02124042