In [2]:
import zipfile
import pandas as pd
import movie_utils
from tqdm import tqdm
import numpy as np

%load_ext autoreload
%autoreload 2

In [51]:
# data from: https://grouplens.org/datasets/movielens/

files = {}

with zipfile.ZipFile("ml-32m.zip", 'r') as zip:
    zip_contents = zip.namelist()
    for file_name in zip_contents:
        if file_name.endswith('.csv'):
            print("Downloading " + file_name + "...")
            with zip.open(file_name) as file:
                df = pd.read_csv(file)
                files[file_name[7:-4]] = df

Downloading ml-32m/tags.csv...
Downloading ml-32m/links.csv...
Downloading ml-32m/ratings.csv...


ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [None]:
files["links"] = files["links"][["movieId", "tmdbId"]]
files["links"].head()

Unnamed: 0,movieId,tmdbId
0,1,862.0
1,2,8844.0
2,3,15602.0
3,4,31357.0
4,5,11862.0


In [None]:
files["ratings"] = files["ratings"][["userId", "movieId", "rating"]]
files["ratings"].head()

Unnamed: 0,userId,movieId,rating
0,1,17,4.0
1,1,25,1.0
2,1,29,2.0
3,1,30,5.0
4,1,32,5.0


In [None]:
files["movies"] = files["movies"].merge(files["links"], left_on="movieId", right_on="movieId", how="inner")
files["movies"].head()

Unnamed: 0,movieId,title,genres,tmdbId
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy,862.0
1,2,Jumanji (1995),Adventure|Children|Fantasy,8844.0
2,3,Grumpier Old Men (1995),Comedy|Romance,15602.0
3,4,Waiting to Exhale (1995),Comedy|Drama|Romance,31357.0
4,5,Father of the Bride Part II (1995),Comedy,11862.0


In [None]:
movies = {}

for movie in files["movies"].iterrows():
    movie = movie[1]
    movies[movie.movieId] = movie_utils.Movie(movie.movieId, movie.tmdbId, movie.title, movie.genres)

In [None]:
user = files["ratings"][files["ratings"]["userId"] == 2024][:5]
for movie in user.iterrows():
    if movie[1].movieId in movies.keys():
        print("-----------------------------------------------------------------------------------------------")
        print(movies[movie[1].movieId])
        print("Rating: " + str(movie[1].rating))

-----------------------------------------------------------------------------------------------
Clerks 	1994
Comedy
Rating: 1.0
-----------------------------------------------------------------------------------------------
Exotica 	1994
Drama
Rating: 2.0
-----------------------------------------------------------------------------------------------
Red Firecracker, Green Firecracker (Pao Da Shuang Deng) 	1994
Drama
Rating: 1.0
-----------------------------------------------------------------------------------------------
Maltese Falcon, The 	1941
Film-Noir | Mystery
Rating: 5.0
-----------------------------------------------------------------------------------------------
Gone with the Wind 	1939
Drama | Romance | War
Rating: 5.0


In [None]:
popular_movies = files["ratings"][["movieId", "userId"]].groupby("movieId").count()
popular_movies = popular_movies.sort_values(by=["userId"], ascending=False)
popular_movies = popular_movies.index.values.tolist()

num_movies = 1000
top_popular_movies = set(popular_movies[:num_movies])

In [None]:
total = len(files["ratings"])
ratings = np.empty((total, 3), dtype=np.float32)

i = 0
for rating in tqdm(files["ratings"].iterrows(), total=total):
    if rating[1].movieId in top_popular_movies:
        ratings[i, 0] = rating[1].userId
        ratings[i, 1] = rating[1].movieId
        ratings[i, 2] = rating[1].rating
        i += 1

ratings = ratings[:i]
np.savez_compressed("ratings.npz", ratings=ratings)

100%|██████████| 32000204/32000204 [31:23<00:00, 16992.68it/s]


In [3]:
ratings = np.load("ratings.npz")["ratings"]

In [4]:
ratings_df = pd.DataFrame(ratings, columns=["userId", "movieId", "rating"], index=[i for i in range(len(ratings))])
ratings_df["userId"] = ratings_df["userId"].astype(int)
ratings_df["movieId"] = ratings_df["movieId"].astype(int)
ratings_df.head()

Unnamed: 0,userId,movieId,rating
0,1,17,4.0
1,1,25,1.0
2,1,29,2.0
3,1,32,5.0
4,1,34,2.0


In [5]:
ratings_df = ratings_df.pivot_table(index="userId", columns="movieId", values="rating")
ratings_df

movieId,1,2,3,5,6,7,10,11,16,17,...,168252,171763,174055,176371,177765,187593,195159,202439,204698,207313
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,,,,,,,,,,4.0,...,,,,,,,,,,
2,,,,,,,,,,,...,,,,,,,,,,
3,,3.5,,,,,4.0,4.0,,5.0,...,,,,,,,,,,
4,,,,,,,,,,,...,,,,,,,,,,
5,,,,,,,4.0,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
200944,4.0,,,,,,,,,,...,,,,,,,,,,
200945,,,,,,,,,,,...,,,0.5,,,,,,,
200946,,,,,4.0,,5.0,5.0,,4.0,...,,,,,,,,,,
200947,4.0,,,,,,,,,,...,,,,,,,,,,


$$X = \argmin_{M} ||M||_* \ \text{s.t.} \ M_{i,j} = X_{i,j} \ \forall i,j \in \Omega$$

In [13]:
def isvt(X, r=100, max_iter=1000, epsilon=1e-3):
    # Create a mask for the known entries in X
    mask = X > 0  
    prevM = np.zeros(X.shape)
    
    for _ in tqdm(range(max_iter)):
        currM = np.copy(prevM)
        
        # Fill in known values in currM from X
        currM[mask] = X[mask]
        
        # Perform Singular Value Decomposition
        U, S, Vt = np.linalg.svd(currM, full_matrices=False)
        
        # Threshold the singular values
        S[r:] = 0  # Keep only the largest r singular values
        currM = U @ np.diag(S) @ Vt
        
        # Check for convergence
        if np.linalg.norm(currM - prevM, ord='fro') < epsilon:
            break
        
        prevM = currM
    
    return currM

M = isvt(ratings_df[:2000].to_numpy())

100%|██████████| 1000/1000 [08:18<00:00,  2.01it/s]


In [18]:
M_df = pd.DataFrame(M, columns=ratings_df.columns, index=ratings_df.index[:2000])
M_df

movieId,1,2,3,5,6,7,10,11,16,17,...,168252,171763,174055,176371,177765,187593,195159,202439,204698,207313
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,2.164781,-0.736521,-0.986958,-0.519609,0.313017,-2.986201,2.676430,-2.345338,2.831801,3.996643,...,-1.300362,0.707777,-3.160997,-1.167746,0.427714,0.027859,0.304869,-4.409649,-4.619329,-0.149587
2,0.518351,2.010121,1.683100,1.863158,-1.274878,1.905920,2.452264,3.593043,-0.993920,3.013841,...,0.085609,-1.475960,-0.989313,0.415424,1.156267,-0.406539,-1.336204,-0.811925,-1.152012,1.480641
3,1.065025,3.440601,0.962399,-0.213119,-0.163261,2.329854,3.991673,4.000907,-1.655516,4.973500,...,-0.760822,3.284715,0.439682,-2.541574,-0.566239,0.771754,-0.236175,0.686113,0.214525,0.278901
4,-1.842804,-0.551203,-0.031548,-1.439202,-0.168703,-0.036295,0.404704,0.466012,1.098388,-1.985857,...,0.817266,2.071612,0.693172,-0.281403,0.465751,1.539132,0.057819,-0.655612,0.046354,-0.832026
5,2.011834,2.726890,-1.262735,0.148570,0.881450,-1.634832,4.003888,0.952687,0.419154,-1.470957,...,-0.728689,1.207003,0.254113,0.922615,-0.458205,0.100580,-0.147686,0.288957,0.872301,1.506501
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1996,0.649980,0.267547,1.823791,0.636039,3.998799,3.723862,-1.052397,1.669182,2.009943,-0.030105,...,1.675939,1.182616,0.158811,1.440807,0.238469,2.641301,1.694531,2.099258,1.168780,1.580974
1997,3.482977,2.482660,2.273181,4.209023,3.285669,-3.199558,-3.389736,-1.410767,0.368105,3.974215,...,-1.422370,-2.460344,0.812725,1.326985,-1.701539,-3.766004,-0.842246,-0.618657,-2.308247,0.432442
1998,1.773803,1.473892,-0.241790,-0.676113,1.761199,0.616230,-0.242179,1.062913,-0.411605,-0.954007,...,-0.170257,0.254678,0.551301,0.002127,1.078059,0.369173,0.467259,0.937266,-0.452304,0.507646
1999,3.041702,-0.369703,0.349384,1.479712,2.002457,-0.763222,-0.284395,0.660517,-0.315693,-1.350113,...,1.159091,1.180774,2.076171,-0.181055,-0.144716,-0.644902,2.031424,3.879508,1.460747,1.233817
