In [14]:
import os
import pandas as pd
import gradio as gr

import torch

from utils import utils

In [10]:
import warnings
warnings.filterwarnings("ignore")

In [12]:
model_path = os.path.join(".", 'models', 'binaries')
models = [model for model in os.listdir(model_path) if model.endswith(".pth")]
models.sort()

final_model = models[-1]
final_model_path = os.path.join(model_path, final_model)

del models

In [13]:
embedding_state = torch.load(final_model_path)

n_users = embedding_state['n_users']
n_movies = embedding_state['n_movies']
n_factors = embedding_state['n_factors']

movie_bias = embedding_state['state_dict']['movie_bias.weight']
movie_factors = embedding_state['state_dict']['movie_factors.weight']

userLabel_to_Id = embedding_state["userLabel_to_Id"]
movieLabel_to_Id = embedding_state["movieLabel_to_Id"]

del embedding_state, final_model_path

In [15]:
resource_path = os.path.join(".", 'data', 'ml-25m')
movies_csv = os.path.join(resource_path, 'movies.csv')
links_csv = os.path.join(resource_path, 'links.csv')

movies = pd.read_csv(movies_csv)
movies['genres'] = movies.genres.apply(lambda genres: genres.split('|'))

links = pd.read_csv(links_csv, dtype = {'movieId': int, 'imdbId': str, 'tmdbId': str})
movies = movies.merge(links)
del links

In [16]:
def get_movie_id(label):
    id = movieLabel_to_Id([label])
    return id.item(0)

movies_dict = {
    'label': [],
    'title': [],
    'genres':  [],
    'imdbId': []
}

for label in range(n_movies):
    movies_dict['label'].append(label)
    movies_dict['title'].append(movies[movies['movieId']==get_movie_id(label)]['title'].item())
    movies_dict['genres'].append(movies[movies['movieId']==get_movie_id(label)]['genres'].item())
    movies_dict['imdbId'].append(movies[movies['movieId']==get_movie_id(label)]['imdbId'].item())

movies = pd.DataFrame(movies_dict)
del movies_dict
movies

Unnamed: 0,label,title,genres,imdbId
0,0,Toy Story (1995),"[Adventure, Animation, Children, Comedy, Fantasy]",0114709
1,1,Jumanji (1995),"[Adventure, Children, Fantasy]",0113497
2,2,Heat (1995),"[Action, Crime, Thriller]",0113277
3,3,GoldenEye (1995),"[Action, Adventure, Thriller]",0113189
4,4,Ace Ventura: When Nature Calls (1995),[Comedy],0112281
...,...,...,...,...
195,195,Up (2009),"[Adventure, Animation, Children, Drama]",1049413
196,196,Avatar (2009),"[Action, Adventure, Sci-Fi, IMAX]",0499549
197,197,Inception (2010),"[Action, Crime, Drama, Mystery, Sci-Fi, Thrill...",1375666
198,198,Django Unchained (2012),"[Action, Drama, Western]",1853728


In [17]:
def get_movie_name(label):
    return movies[movies['label']==label]['title'].item()

In [56]:
user_movies = [(30,4.8), (2, 6.8)]

In [60]:
def disp_movies(movies):
    with gr.Blocks() as block:
        gr.Markdown('<h2 style="text-align: center;">My Movies</h2>')

        for movie in movies:
            gr.Image(
                value="assets/posters/Shawshank_Redemption.jpg",
                label=f"The Shawshank Redemption | 4.8"
                     )

    return block

def add_movie():
    pass

In [61]:
movie_list = sorted(movies.title.values.tolist())

In [62]:
with gr.Blocks() as app:
    gr.Markdown('<h1 style="text-align: center;">Movie Recommendation System</h1>')
    gr.Markdown('<p style="text-align: center;">Project by: Alex Prateek Shankar (2248302) and Sneh Shah (2248318)</p>')

    with gr.Blocks():
        choice = gr.Dropdown(
            choices=movie_list,
            label="Select Movie"
            )
        add_btn = gr.Button(
            "Add Movie",
            label="Click here to add movies to your list"
            )

    user = disp_movies(user_movies)
    add_btn.click(fn=add_movie, inputs=choice, outputs=user)


    # gr.Markdown("## Please enter the keywords that must be used below")
    # words = []
    # counter = 1
    # for i in range(2):
    #     with gr.Row():
    #         for j in range(3):
    #             words.append(gr.Textbox(label=f"Word {counter}"))
    #             counter += 1

    # gr.Markdown("<hr>")

    # # Generating Story
    # story_btn = gr.Button("Generate Story")
    # story = gr.Textbox(label="Generated Story", interactive=False)
    # story_btn.click(fn=generate_story,
    #           inputs=[genre,
    #                   words[0],
    #                   words[1],
    #                   words[2],
    #                   words[3],
    #                   words[4],
    #                   words[5]],
    #           outputs=story)
    
    # gr.Markdown("<hr>")

    # # Generating Comic
    # gr.Markdown("Please note that there are restrictions on the Dalle API limiting it to 5 images per minute. It also follows certain security restrictions which stop it from generating images related to certain topics. If the Comic fails to generate due to any reason, you could try using different prompts or generate comics for a story that is 5 sentences or lesser")
    # comic_btn = gr.Button("Generate Comic")
    # gr.Markdown("Comic generation may take about 2-3 minutes. Please be patient")
    # comic = gr.Image(label="Generated Comic")

    # comic_btn.click(fn=generate_comic,
    #           inputs=[story, genre],
    #           outputs=comic)

app.title = "Neural Collaborative Filtering"
app.launch()

Running on local URL:  http://127.0.0.1:7878

To create a public link, set `share=True` in `launch()`.


