<a href="https://colab.research.google.com/github/KevinTheRainmaker/Recommendation_Algorithms/blob/main/colab/fastcampus/Recommender_using_NCF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [62]:
import os
import pandas as pd
import numpy as np
from math import sqrt
from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

## Dataset

- ratings.csv > train/test

In [63]:
path = '/content/drive/MyDrive/data/movielens'
ratings_df = pd.read_csv(os.path.join(path, 'ratings.csv'), encoding='utf-8')

print(ratings_df.shape)
print(ratings_df.head())

(100836, 4)
   userId  movieId  rating  timestamp
0       1        1     4.0  964982703
1       1        3     4.0  964981247
2       1        6     4.0  964982224
3       1       47     5.0  964983815
4       1       50     5.0  964982931


In [64]:
train_df, test_df = train_test_split(ratings_df, test_size = 0.2)

print(train_df.shape)
print(test_df.shape)

(80668, 4)
(20168, 4)


In [65]:
train_df.head()

Unnamed: 0,userId,movieId,rating,timestamp
30515,212,183897,5.0,1532361617
94964,599,112552,3.0,1498589282
38948,268,1273,5.0,940183103
19154,123,79132,4.0,1447269091
55589,368,1590,2.0,975830196


## Sparse Matrix 만들기
- train_df의 sparse mateix를 만드는 과정 (2가지)


### 1. 연산량이 많아 시간이 오래 걸리는 방법


In [66]:
user_ids = sorted(list(set(train_df['userId'].values)))
movie_ids = sorted(list(set(train_df['movieId'].values)))

print(f' Number of Users: {len(user_ids)}', '\n', f'Number of Movies: {len(movie_ids)}')

 Number of Users: 610 
 Number of Movies: 9017


In [67]:
# Empty sparse matrix를 DataFrame 형태로 준비

sparse_matrix = pd.DataFrame(index = movie_ids, columns = user_ids)

sparse_matrix

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,...,571,572,573,574,575,576,577,578,579,580,581,582,583,584,585,586,587,588,589,590,591,592,593,594,595,596,597,598,599,600,601,602,603,604,605,606,607,608,609,610
1,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
5,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
193573,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193579,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193581,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193583,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [68]:
# train_df의 movieId를 기준으로 groupby
grouped = train_df.groupby('movieId')
grouped['userId'].count() # movie별 평점을 준 user 수

movieId
1         175
2          92
3          45
4           5
5          40
         ... 
193573      1
193579      1
193581      1
193583      1
193585      1
Name: userId, Length: 9017, dtype: int64

In [69]:
# sneak-pick the example
idx = 0
for movieId, group in grouped:
  if idx < 1:
    print(group)
    group_copied = group.transpose()
    group_copied.loc['userId'] = pd.to_numeric(group_copied.loc['userId'])
    print(group_copied)
    group_copied.columns = group_copied.loc['userId']
    print(group_copied)
    group_copied = group_copied.drop(['userId','movieId']).rename(index={'rating':movieId})
    print(group_copied)
    print(group_copied.columns)
    print(group_copied.index)
  else:
    continue
  idx += 1

       userId  movieId  rating   timestamp
17904     112        1     3.0  1442535639
18536     119        1     3.5  1435942468
79119     490        1     3.5  1328229305
41019     277        1     4.0   861812794
20336     135        1     4.0  1009691859
...       ...      ...     ...         ...
95101     600        1     2.5  1237764347
58096     381        1     3.5  1164383653
26092     182        1     4.0  1063289621
874         7        1     4.5  1106635946
38236     263        1     4.0   940384199

[175 rows x 4 columns]
                  17904         18536  ...         874          38236
userId     1.120000e+02  1.190000e+02  ...  7.000000e+00        263.0
movieId    1.000000e+00  1.000000e+00  ...  1.000000e+00          1.0
rating     3.000000e+00  3.500000e+00  ...  4.500000e+00          4.0
timestamp  1.442536e+09  1.435942e+09  ...  1.106636e+09  940384199.0

[4 rows x 175 columns]
userId            112.0         119.0  ...         7.0          263.0
userId     1.120

In [70]:
for movieId, group in tqdm(grouped):
  group_copied = group.transpose()
  group_copied.loc['userId'] = pd.to_numeric(group_copied.loc['userId'])
  group_copied.columns = group_copied.loc['userId']
  group_copied = group_copied.drop(['userId','movieId']).rename(index={'rating':movieId})

  sparse_matrix.update(group_copied)

  0%|          | 0/9017 [00:00<?, ?it/s]

In [71]:
sparse_matrix

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,...,571,572,573,574,575,576,577,578,579,580,581,582,583,584,585,586,587,588,589,590,591,592,593,594,595,596,597,598,599,600,601,602,603,604,605,606,607,608,609,610
1,4,,,,4,,4.5,,,,,,,,2.5,,,3.5,,,3.5,,,,,,3,,,,5,3,3,,,,,,,,...,,4,5,,,,,,4,3,,,,5,,,5,,,4,,,,,,,,,3,2.5,4,,4,3,4,2.5,4,,,5
2,,,,,,4,,4,,,,,,,,,,,3,3,3.5,,,,,,4,,,,,,,,,,,,,,...,,,4.5,,,,,,,,,,,,,4,,,,2.5,,4,,4,,,,,2.5,4,,4,,5,3.5,,,2,,
3,4,,,,,,,,,,,,,,,,,,3,,,,,,,,,,,,,3,,,,,,,,,...,,,,,,,,,,,,,,,,,,3,,3,,,,4,,,,,1.5,,,,,,,,,2,,
4,,,,,,,,,,,,,,3,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.5,,,,,,,,,,
5,,,,,,5,,,,,,,,,,,,,,,,,,,,,,,,,3,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,2,,,,,,,,,,,,,,3,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
193573,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193579,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193581,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193583,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


### 2. 간단하고 효율적인 방법 - Unstack (계층적 인덱싱)

In [72]:
sparse_matrix = train_df.groupby('movieId').apply(lambda x: pd.Series(x['rating'].values, index = x['userId'])).unstack()
sparse_matrix.index.name = 'movieId'

sparse_matrix

userId,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,...,571,572,573,574,575,576,577,578,579,580,581,582,583,584,585,586,587,588,589,590,591,592,593,594,595,596,597,598,599,600,601,602,603,604,605,606,607,608,609,610
movieId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1
1,4.0,,,,4.0,,4.5,,,,,,,,2.5,,,3.5,,,3.5,,,,,,3.0,,,,5.0,3.0,3.0,,,,,,,,...,,4.0,5.0,,,,,,4.0,3.0,,,,5.0,,,5.0,,,4.0,,,,,,,,,3.0,2.5,4.0,,4.0,3.0,4.0,2.5,4.0,,,5.0
2,,,,,,4.0,,4.0,,,,,,,,,,,3.0,3.0,3.5,,,,,,4.0,,,,,,,,,,,,,,...,,,4.5,,,,,,,,,,,,,4.0,,,,2.5,,4.0,,4.0,,,,,2.5,4.0,,4.0,,5.0,3.5,,,2.0,,
3,4.0,,,,,,,,,,,,,,,,,,3.0,,,,,,,,,,,,,3.0,,,,,,,,,...,,,,,,,,,,,,,,,,,,3.0,,3.0,,,,4.0,,,,,1.5,,,,,,,,,2.0,,
4,,,,,,,,,,,,,,3.0,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.5,,,,,,,,,,
5,,,,,,5.0,,,,,,,,,,,,,,,,,,,,,,,,,3.0,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,2.0,,,,,,,,,,,,,,3.0,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
193573,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193579,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193581,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193583,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


## Cosine Similarity 활용


In [73]:
from sklearn.metrics.pairwise import cosine_similarity

def cossim_matrix(a, b):
  cossim_values = cosine_similarity(a.values, b.values)
  cossim_df = pd.DataFrame(data = cossim_values, columns = a.index.values, index = a.index)

  return cossim_df

## Neighborhood-based Collaborative Filtering 추천 점수 계산

### Item-based

In [74]:
item_sparse_matrix = sparse_matrix.fillna(0) # 0이 아닌 다른 값으로 채우는 아이디어도 사용 가능

In [75]:
item_cossim_df = cossim_matrix(item_sparse_matrix, item_sparse_matrix)
item_cossim_df

Unnamed: 0_level_0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,34,36,38,39,41,42,43,44,...,184931,184997,185029,185033,185135,185435,185473,185585,186587,187031,187541,187593,187595,187717,188189,188301,188675,188751,188797,188833,189043,189111,189333,189381,189547,189713,190183,190207,190209,190213,190215,191005,193565,193567,193571,193573,193579,193581,193583,193585
movieId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1
1,1.000000,0.356204,0.258644,0.051685,0.285916,0.298757,0.232935,0.113923,0.203869,0.319414,0.274204,0.209344,0.069130,0.140127,0.054982,0.235879,0.268742,0.166156,0.309659,0.137701,0.261251,0.213009,0.140108,0.185738,0.265818,0.154338,0.086183,0.090115,0.112044,0.073297,0.155728,0.403401,0.387761,0.232824,0.130024,0.269731,0.095629,0.081655,0.071161,0.161762,...,0.075552,0.0,0.048085,0.075552,0.0,0.000000,0.075552,0.058041,0.075552,0.040043,0.052461,0.036144,0.062243,0.028332,0.028332,0.085545,0.000000,0.0,0.075552,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.028332,0.056664,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.356204,1.000000,0.263372,0.082033,0.245457,0.239511,0.220680,0.114164,0.042964,0.305756,0.264038,0.144051,0.107763,0.117917,0.209696,0.217603,0.158971,0.186561,0.413840,0.130783,0.242969,0.187446,0.122403,0.172323,0.156860,0.163100,0.089564,0.099554,0.080534,0.028565,0.243564,0.301026,0.339713,0.100177,0.151559,0.349095,0.145205,0.032919,0.103279,0.230335,...,0.000000,0.0,0.084327,0.000000,0.0,0.117774,0.000000,0.065973,0.000000,0.000000,0.000000,0.171534,0.113821,0.000000,0.000000,0.091966,0.103053,0.0,0.000000,0.103053,0.0,0.0,0.000000,0.103053,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.258644,0.263372,1.000000,0.000000,0.353171,0.241337,0.255900,0.282738,0.205649,0.194419,0.156695,0.181288,0.000000,0.174004,0.094878,0.205935,0.181728,0.216029,0.227122,0.103435,0.244612,0.106924,0.178234,0.274930,0.196407,0.115314,0.156762,0.123700,0.083605,0.000000,0.155389,0.264884,0.214173,0.195459,0.098621,0.223697,0.121760,0.069141,0.077472,0.152575,...,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.078381,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.051685,0.082033,0.000000,1.000000,0.087317,0.000000,0.215731,0.000000,0.000000,0.082116,0.082446,0.000000,0.000000,0.156786,0.000000,0.000000,0.154445,0.000000,0.062904,0.000000,0.073358,0.066827,0.000000,0.000000,0.097415,0.000000,0.000000,0.082562,0.000000,0.000000,0.185953,0.075023,0.076579,0.147998,0.000000,0.085974,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,0.285916,0.245457,0.353171,0.087317,1.000000,0.253858,0.394374,0.240036,0.216803,0.214573,0.245731,0.194428,0.144973,0.230450,0.148802,0.138232,0.267693,0.188939,0.207004,0.056490,0.161736,0.109140,0.192735,0.237529,0.217959,0.148996,0.240981,0.021237,0.028132,0.000000,0.235807,0.241127,0.208234,0.203031,0.108844,0.198912,0.160893,0.000000,0.114003,0.092269,...,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.043253,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
193573,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.752577,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
193579,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.752577,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
193581,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.752577,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
193583,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.752577,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [76]:
# train_df에 포함된 userId를 계산에 반영
userId_grouped = train_df.groupby('userId')

# index: userId, columns: total movieId
item_prediction_result_df = pd.DataFrame(index=list(userId_grouped.indices.keys()), columns=item_sparse_matrix.index)
item_prediction_result_df

movieId,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,34,36,38,39,41,42,43,44,...,184931,184997,185029,185033,185135,185435,185473,185585,186587,187031,187541,187593,187595,187717,188189,188301,188675,188751,188797,188833,189043,189111,189333,189381,189547,189713,190183,190207,190209,190213,190215,191005,193565,193567,193571,193573,193579,193581,193583,193585
1,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
5,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
606,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
607,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
608,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
609,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [77]:
for userId, group in tqdm(userId_grouped):
  # user가 rating한 movieId * 전체 movieId
  user_sim = item_cossim_df.loc[group['movieId']]

  # user가 rating한 movieId * 1
  user_rating = group['rating']

  # 전체 movieId * 1
  sim_sum = user_sim.sum(axis = 0)

  # userId의 전체 rating predictions
  pred_ratings = np.matmul(user_sim.T.to_numpy(), user_rating) / (sim_sum + 1)
  item_prediction_result_df.loc[userId] = pred_ratings

  0%|          | 0/610 [00:00<?, ?it/s]

In [78]:
item_prediction_result_df.head(10)

movieId,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,34,36,38,39,41,42,43,44,...,184931,184997,185029,185033,185135,185435,185473,185585,186587,187031,187541,187593,187595,187717,188189,188301,188675,188751,188797,188833,189043,189111,189333,189381,189547,189713,190183,190207,190209,190213,190215,191005,193565,193567,193571,193573,193579,193581,193583,193585
1,4.32992,4.29179,4.28919,3.64766,4.14181,4.31105,4.16272,4.2078,3.7878,4.27895,4.22555,4.13018,3.93975,3.92599,4.03956,4.31785,4.21498,4.24895,4.27264,4.18117,4.26741,4.21221,4.13756,4.23733,4.23424,4.17231,4.04424,3.8488,4.24423,3.92746,4.21512,4.32012,4.29671,4.12238,4.12977,4.24818,4.13904,4.09257,4.09878,4.19169,...,3.87174,3.25049,3.97122,3.87174,3.05371,3.76506,3.87174,4.02677,3.87174,3.38568,2.34364,4.08202,4.22639,3.36079,3.36079,4.07117,3.86541,2.93155,3.87174,3.86541,0.691272,0.691272,1.48763,3.86541,1.59966,3.36079,2.99982,0.691272,0.691272,0.691272,0.691272,0.419039,0.419039,0.419039,0.419039,0.419039,0.419039,0.419039,0.419039,0.419039
2,3.1684,3.12065,2.56872,0.484224,2.52962,3.04264,2.23454,2.18612,1.23657,2.89337,2.35578,2.24278,1.07938,0.758591,1.73998,3.16577,2.35823,2.79701,3.04824,2.71834,2.4903,2.27628,2.08238,2.62447,2.2844,2.43057,1.66461,1.35489,2.02984,0.272026,2.486,3.12028,2.88508,1.94403,2.31415,2.79317,1.73194,1.87267,1.88939,2.80505,...,1.92247,1.95651,2.71039,1.92247,2.53362,2.65163,1.92247,2.58961,1.92247,1.30653,1.60334,3.02558,2.84736,0.880975,0.880975,2.23268,2.19906,2.11586,1.92247,2.19906,0.191715,0.191715,1.11623,2.19906,0.374292,0.880975,1.436,0.191715,0.191715,0.191715,0.191715,1.25282,1.25282,1.25282,1.25282,1.25282,1.25282,1.25282,1.25282,1.25282
3,1.42816,1.30741,1.46397,0.230648,1.02848,1.59921,1.0519,0.931664,0.999505,1.43301,1.10924,1.3726,0.227519,0.435764,0.738918,1.17468,0.831898,1.16678,1.25327,1.12842,1.23708,1.16165,0.790596,1.17251,1.06324,0.930928,0.610863,0.29247,1.19459,1.14356,1.19561,1.5984,1.44093,0.998397,0.838927,1.12574,0.883757,1.23427,0.694154,1.56666,...,1.04399,0.700188,0.769108,1.04399,0.0,0.108114,1.04399,1.20742,1.04399,0.712029,0.0265147,0.831125,1.0768,0.0,0.0,1.07329,0.764034,0.0429988,1.04399,0.764034,0.0,0.0,0.0,0.764034,0.0,0.0,0.107577,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,3.44369,3.37624,3.35247,3.36092,3.26156,3.37072,3.34597,3.16056,2.92225,3.32361,3.35812,3.16651,2.76695,3.29657,3.1435,3.33636,3.42294,3.33799,3.34255,3.17814,3.4441,3.27246,3.16949,3.27882,3.38048,3.30558,3.19285,3.23816,3.43234,3.52403,3.29482,3.42258,3.42041,3.4482,3.14403,3.40019,3.26303,2.99599,3.36968,3.2645,...,3.02535,2.62626,3.06117,3.02535,1.63581,2.77777,3.02535,3.06515,3.02535,2.55581,1.6576,2.95656,3.18063,3.16044,3.16044,3.05052,3.00986,2.21062,3.02535,3.00986,0.859219,0.859219,0.679781,3.00986,0.460633,3.16044,1.93141,0.859219,0.859219,0.859219,0.859219,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,3.29697,3.15652,3.04231,2.62688,3.01145,3.16102,3.12043,2.88946,2.52415,3.1251,3.18329,2.43697,2.59128,3.00056,2.86761,3.15518,3.14436,2.90536,3.09381,2.73913,3.19765,3.02655,2.85253,2.88981,3.17188,3.01644,2.64946,2.58461,3.02163,1.83854,3.0557,3.22671,3.28871,3.28271,2.38213,3.20382,2.99449,2.22363,2.98901,2.99448,...,1.92534,1.81485,2.17943,1.92534,0.869301,1.83998,1.92534,2.17901,1.92534,1.4931,0.787378,2.1802,2.42128,1.86004,1.86004,2.06657,1.38666,0.613038,1.92534,1.38666,0.794259,0.794259,0.113065,1.38666,0.0,1.86004,0.513646,0.794259,0.794259,0.794259,0.794259,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,3.58166,3.61047,3.50327,3.41615,3.50925,3.53985,3.51339,3.48373,3.25128,3.60221,3.59034,3.33572,3.42878,3.43161,3.47203,3.5408,3.55322,3.40224,3.5686,3.3959,3.56956,3.57649,3.55574,3.46989,3.49757,3.49809,3.42804,3.31909,3.38605,3.0236,3.52077,3.56093,3.64367,3.50992,3.26009,3.6093,3.47923,3.23749,3.41993,3.53615,...,2.14231,2.14423,2.91327,2.14231,1.19909,2.87569,2.14231,2.77099,2.14231,1.49325,0.971503,3.14613,3.31962,1.78648,1.78648,3.10079,2.83447,1.07408,2.14231,2.83447,0.369246,0.369246,0.282661,2.83447,0.214548,1.78648,1.25745,0.369246,0.369246,0.369246,0.369246,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,3.333,3.22236,3.16904,2.51125,3.22974,3.24676,3.19241,2.94452,2.78382,3.29935,3.22106,3.13031,2.391,2.92817,2.86941,3.19507,3.15713,3.07977,3.19346,3.13541,3.22754,3.09799,3.05293,3.0698,3.24684,3.09832,2.88717,2.83724,3.1745,2.56381,3.14874,3.26658,3.23123,3.23543,3.15312,3.21424,3.0582,2.76945,2.97257,3.184,...,2.54898,2.57664,2.83323,2.54898,2.55667,2.73868,2.54898,2.99462,2.54898,2.43484,1.8643,2.99757,2.95358,2.58919,2.58919,2.95599,2.66982,2.45619,2.54898,2.66982,0.586965,0.586965,1.26462,2.66982,2.07552,2.58919,1.70783,0.586965,0.586965,0.586965,0.586965,0.715045,0.715045,0.715045,0.715045,0.715045,0.715045,0.715045,0.715045,0.715045
8,3.31754,3.3331,3.11629,2.64528,3.07496,3.2714,3.12754,2.71211,2.46938,3.237,3.31202,2.58957,2.53466,2.91389,2.90033,3.23361,3.18265,2.91577,3.24241,2.90792,3.30931,3.1515,2.95835,3.04976,3.18265,3.04812,2.63435,2.57804,2.95639,1.95232,3.11045,3.31512,3.30064,3.139,2.33639,3.27403,3.03322,2.35486,2.83178,3.07466,...,1.61655,1.5947,2.29374,1.61655,0.683251,2.16523,1.61655,2.14302,1.61655,1.06974,0.89844,2.50751,2.37608,1.66194,1.66194,2.18021,1.81487,0.869418,1.61655,1.81487,0.661513,0.661513,0.119241,1.81487,0.0,1.66194,0.554915,0.661513,0.661513,0.661513,0.661513,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,3.30484,3.06001,2.83319,1.7678,2.65701,3.04183,2.6542,2.32958,1.56263,3.05278,2.79899,2.16297,1.45075,2.17148,2.15027,3.11765,2.88666,2.8429,2.9826,2.42994,2.95267,2.70786,2.20732,2.65769,2.95973,2.61607,2.25209,1.89165,2.98933,1.67393,2.68129,3.21536,3.09892,2.59347,2.61444,2.95157,2.48526,1.67371,2.37196,2.63346,...,1.27423,1.60527,2.00398,1.27423,2.07598,1.89869,1.27423,1.80806,1.27423,1.21474,1.02706,2.46055,2.62991,1.1972,1.1972,2.13421,1.65232,0.826632,1.27423,1.65232,0.458989,0.458989,0.514487,1.65232,0.964163,1.1972,0.999102,0.458989,0.458989,0.458989,0.458989,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
10,3.15111,3.15772,2.91205,1.6121,2.95244,3.04969,2.81823,2.82582,1.6406,3.11654,2.84395,2.73597,2.66321,1.92069,2.66321,3.09853,2.93681,2.94291,3.09593,2.95616,2.86691,2.82545,2.79082,2.83949,2.75101,2.87417,2.8793,2.98063,2.7175,1.34024,3.05876,3.06831,3.05616,2.43116,3.04076,3.09337,2.39966,2.09784,2.59459,3.00931,...,2.12665,2.5379,2.8413,2.12665,3.00059,2.90881,2.12665,2.98827,2.12665,2.1115,2.30714,3.28907,3.14657,1.44945,1.44945,2.81322,2.52951,2.08113,2.12665,2.52951,0.769388,0.769388,1.47624,2.52951,1.58372,1.44945,1.74231,0.769388,0.769388,0.769388,0.769388,1.42192,1.42192,1.42192,1.42192,1.42192,1.42192,1.42192,1.42192,1.42192


### User-based

In [79]:
user_sparse_matrix = sparse_matrix.fillna(0).T

In [80]:
user_cossim_df = cossim_matrix(user_sparse_matrix, user_sparse_matrix)
user_cossim_df

Unnamed: 0_level_0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,...,571,572,573,574,575,576,577,578,579,580,581,582,583,584,585,586,587,588,589,590,591,592,593,594,595,596,597,598,599,600,601,602,603,604,605,606,607,608,609,610
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1
1,1.000000,0.034857,0.055309,0.194390,0.083175,0.083208,0.112971,0.084667,0.051237,0.011950,0.047249,0.010386,0.032461,0.069802,0.132847,0.108030,0.204495,0.172505,0.249567,0.143729,0.127975,0.052194,0.083810,0.130117,0.076680,0.019520,0.209908,0.170889,0.118550,0.062829,0.159041,0.093267,0.118504,0.069706,0.077260,0.067386,0.033086,0.088154,0.267092,0.065831,...,0.092401,0.079005,0.202313,0.051579,0.018294,0.039362,0.204364,0.000000,0.092823,0.172888,0.026653,0.049938,0.045918,0.049240,0.094104,0.111643,0.114386,0.110874,0.077203,0.270338,0.091691,0.124369,0.158362,0.121675,0.065360,0.100213,0.283914,0.016027,0.224128,0.226709,0.072302,0.116583,0.166690,0.066705,0.128476,0.134676,0.225536,0.220481,0.038276,0.110058
2,0.034857,1.000000,0.000000,0.004772,0.021056,0.032824,0.008290,0.035723,0.000000,0.075552,0.060331,0.000000,0.000000,0.021391,0.093872,0.069096,0.088130,0.128164,0.016408,0.018292,0.060040,0.144061,0.017538,0.120364,0.151436,0.000000,0.000000,0.031753,0.104834,0.127749,0.000000,0.023818,0.041071,0.034294,0.000000,0.068622,0.040135,0.028516,0.000000,0.022182,...,0.000000,0.042218,0.159203,0.043907,0.000000,0.000000,0.019636,0.000000,0.000000,0.027982,0.109927,0.181966,0.021737,0.000000,0.109949,0.084108,0.000000,0.072668,0.034367,0.039215,0.046344,0.039249,0.079988,0.017030,0.000000,0.101049,0.015398,0.000000,0.083770,0.030061,0.136265,0.021186,0.007047,0.000000,0.000000,0.021553,0.016556,0.050458,0.034553,0.076118
3,0.055309,0.000000,1.000000,0.002901,0.000000,0.003592,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.018262,0.038240,0.003404,0.025192,0.016160,0.009731,0.003341,0.004016,0.003998,0.000000,0.000000,0.000000,0.003027,0.026880,0.003412,0.000000,0.000000,0.017376,0.005862,0.000000,0.000000,0.000000,0.000000,0.000000,0.004560,0.000000,...,0.000000,0.005544,0.000000,0.000000,0.000000,0.035828,0.042974,0.000000,0.002916,0.037850,0.005307,0.000000,0.000000,0.000000,0.000000,0.000000,0.002563,0.000000,0.000000,0.015462,0.000000,0.000000,0.001397,0.000000,0.000000,0.017692,0.023215,0.000000,0.034376,0.005204,0.002868,0.002318,0.017136,0.000000,0.013861,0.009751,0.019928,0.024822,0.000000,0.017417
4,0.194390,0.004772,0.002901,1.000000,0.085367,0.060679,0.086793,0.068982,0.014170,0.002908,0.031773,0.034303,0.041635,0.025928,0.035050,0.123387,0.095425,0.100947,0.157395,0.077603,0.052826,0.008974,0.071469,0.078880,0.044197,0.059380,0.097798,0.122451,0.046619,0.034949,0.114881,0.142758,0.182264,0.025478,0.043940,0.086552,0.045292,0.075087,0.259689,0.046172,...,0.022931,0.143144,0.091783,0.022022,0.066780,0.000000,0.140341,0.000000,0.082206,0.127921,0.031622,0.037889,0.047698,0.034808,0.063236,0.065959,0.124211,0.029981,0.024132,0.159938,0.074962,0.049830,0.152046,0.070648,0.101119,0.081193,0.247649,0.020175,0.156585,0.179267,0.086746,0.109451,0.251135,0.056384,0.069884,0.154908,0.129538,0.113305,0.034660,0.080890
5,0.083175,0.021056,0.000000,0.085367,1.000000,0.186807,0.088868,0.316256,0.000000,0.012249,0.213644,0.073692,0.021198,0.226853,0.085413,0.063279,0.132520,0.082416,0.087157,0.086874,0.050675,0.022847,0.061255,0.065970,0.054175,0.131014,0.091820,0.103427,0.045930,0.087713,0.083355,0.173960,0.217371,0.019212,0.302113,0.026802,0.249827,0.356580,0.080699,0.217242,...,0.024481,0.100914,0.099465,0.330398,0.000000,0.000000,0.073879,0.000000,0.151883,0.087090,0.026834,0.000000,0.078176,0.324824,0.000000,0.039438,0.114674,0.231518,0.224384,0.118065,0.000000,0.352898,0.186448,0.145272,0.050136,0.105673,0.096557,0.000000,0.088699,0.126579,0.066335,0.400914,0.059351,0.225565,0.118145,0.094739,0.173132,0.098357,0.207910,0.055896
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
606,0.134676,0.021553,0.009751,0.154908,0.094739,0.085087,0.142701,0.079756,0.062645,0.077093,0.049480,0.037107,0.063764,0.080611,0.128789,0.148320,0.165816,0.173162,0.136884,0.102316,0.121364,0.091795,0.133096,0.086467,0.053534,0.042480,0.083763,0.169587,0.077042,0.051096,0.061986,0.097943,0.155912,0.093313,0.048967,0.085868,0.014001,0.089974,0.096951,0.079718,...,0.059806,0.156005,0.109242,0.033759,0.041202,0.024086,0.115636,0.022160,0.062839,0.223151,0.123651,0.066647,0.041886,0.060650,0.065441,0.075954,0.136745,0.081306,0.083383,0.243438,0.071859,0.071812,0.169237,0.107907,0.091602,0.140472,0.180539,0.042917,0.240520,0.233715,0.145069,0.099308,0.229610,0.055161,0.134493,1.000000,0.139929,0.184873,0.050920,0.148169
607,0.225536,0.016556,0.019928,0.129538,0.173132,0.131458,0.173107,0.130547,0.014749,0.011282,0.241735,0.031007,0.011112,0.154481,0.156550,0.078363,0.203251,0.143684,0.216508,0.069832,0.106464,0.024431,0.053504,0.120339,0.066450,0.146792,0.129410,0.154041,0.067140,0.105950,0.132416,0.151079,0.149626,0.105742,0.048927,0.073757,0.205597,0.151845,0.246916,0.110013,...,0.140128,0.166185,0.209162,0.161599,0.046339,0.006410,0.232697,0.000000,0.130783,0.157952,0.070047,0.036461,0.011821,0.166900,0.010674,0.107540,0.130204,0.192743,0.179416,0.232491,0.072582,0.190816,0.165589,0.213936,0.014191,0.129457,0.197480,0.003500,0.194938,0.207607,0.071822,0.199090,0.176991,0.103188,0.114070,0.139929,1.000000,0.229346,0.109358,0.126082
608,0.220481,0.050458,0.024822,0.113305,0.098357,0.139637,0.234982,0.137146,0.077805,0.056784,0.119112,0.033497,0.082025,0.114233,0.188187,0.110166,0.156495,0.238423,0.283562,0.216120,0.187927,0.163746,0.134803,0.124814,0.032587,0.062216,0.122180,0.264685,0.066133,0.062810,0.078090,0.117691,0.145583,0.214195,0.077622,0.063932,0.120102,0.120922,0.193284,0.116731,...,0.072551,0.109452,0.278531,0.113124,0.063814,0.060737,0.125401,0.000000,0.122349,0.325837,0.140761,0.040780,0.081044,0.134688,0.141961,0.150192,0.097555,0.126778,0.073746,0.311168,0.133439,0.200707,0.199644,0.186916,0.014880,0.221967,0.186108,0.026227,0.337042,0.299991,0.096468,0.137315,0.170640,0.128936,0.153072,0.184873,0.229346,1.000000,0.095832,0.258657
609,0.038276,0.034553,0.000000,0.034660,0.207910,0.174348,0.057862,0.313814,0.000000,0.025844,0.195150,0.000000,0.000000,0.230640,0.100859,0.068145,0.134855,0.073705,0.102239,0.000000,0.045736,0.074982,0.020900,0.095883,0.032004,0.382952,0.050853,0.057526,0.079616,0.072363,0.033498,0.084342,0.126467,0.081444,0.238631,0.043981,0.314303,0.273799,0.068105,0.126886,...,0.040172,0.105799,0.109613,0.293014,0.060445,0.000000,0.060171,0.000000,0.063519,0.065105,0.044034,0.000000,0.039472,0.323846,0.037127,0.053877,0.081331,0.316065,0.236371,0.077519,0.037870,0.307368,0.119080,0.083114,0.000000,0.047893,0.061515,0.000000,0.089097,0.068698,0.000000,0.272679,0.067181,0.168249,0.054219,0.050920,0.109358,0.095832,1.000000,0.030525


In [81]:
movieId_grouped = train_df.groupby('movieId')
user_prediction_result_df = pd.DataFrame(index=list(movieId_grouped.indices.keys()), columns = user_sparse_matrix.index)

user_prediction_result_df

userId,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,...,571,572,573,574,575,576,577,578,579,580,581,582,583,584,585,586,587,588,589,590,591,592,593,594,595,596,597,598,599,600,601,602,603,604,605,606,607,608,609,610
1,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
5,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
193573,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193579,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193581,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
193583,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,...,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [82]:
for movieId, group in tqdm(movieId_grouped):
  user_sim = user_cossim_df.loc[group['userId']]
  user_rating = group['rating']
  sim_sum = user_sim.sum(axis = 0)
  pred_ratings = np.matmul(user_sim.T.to_numpy(), user_rating) / (sim_sum + 1)
  user_prediction_result_df.loc[movieId] = pred_ratings

  0%|          | 0/9017 [00:00<?, ?it/s]

In [83]:
print(item_prediction_result_df.shape)

(610, 9017)


In [84]:
# 전체 userId가 모든 movieId에 매긴 평점
print(item_prediction_result_df.head())
print(user_prediction_result_df.head())

movieId   1        2        3       ...    193581    193583    193585
1        4.32992  4.29179  4.28919  ...  0.419039  0.419039  0.419039
2         3.1684  3.12065  2.56872  ...   1.25282   1.25282   1.25282
3        1.42816  1.30741  1.46397  ...         0         0         0
4        3.44369  3.37624  3.35247  ...         0         0         0
5        3.29697  3.15652  3.04231  ...         0         0         0

[5 rows x 9017 columns]
userId       1         2          3    ...       608      609      610
1        3.77897   3.32357     2.4039  ...   3.78951  3.73595  3.73468
2        3.07692   2.65969    1.57707  ...   3.14277  3.27099  3.09372
3        2.80768   1.78114     1.2822  ...   2.73084  2.77165  2.51431
4       0.749143  0.190966  0.0402651  ...  0.830998  1.03403  0.57445
5        2.46791   1.61012   0.713347  ...   2.59087  2.65842  2.29286

[5 rows x 610 columns]


## RMSE로 Evaluate

In [85]:
user_prediction_result_df = user_prediction_result_df.T

In [86]:
def evaluate(test_df, prediction_result_df):
  groups_with_movie_ids = test_df.groupby(by='movieId')
  groups_with_user_ids = test_df.groupby(by='userId')
  intersection_movie_ids = sorted(list(set(list(prediction_result_df.columns)).intersection(set(list(groups_with_movie_ids.indices.keys())))))
  intersection_user_ids = sorted(list(set(list(prediction_result_df.index)).intersection(set(groups_with_user_ids.indices.keys()))))

  print(len(intersection_movie_ids))
  print(len(intersection_user_ids))

  compressed_prediction_df = prediction_result_df.loc[intersection_user_ids][intersection_movie_ids]

  # test_df에 대해서 RMSE 계산
  grouped = test_df.groupby(by='userId')
  result_df = pd.DataFrame(columns=['rmse'])
  for userId, group in tqdm(grouped):
      if userId in intersection_user_ids:
          pred_ratings = compressed_prediction_df.loc[userId][compressed_prediction_df.loc[userId].index.intersection(list(group['movieId'].values))]
          pred_ratings = pred_ratings.to_frame(name='rating').reset_index().rename(columns={'index':'movieId','rating':'pred_rating'})
          actual_ratings = group[['rating', 'movieId']].rename(columns={'rating':'actual_rating'})

          final_df = pd.merge(actual_ratings, pred_ratings, how='inner', on=['movieId'])
          final_df = final_df.round(4) # 반올림

          # if not final_df.empty:
          #     rmse = sqrt(mean_squared_error(final_df['rating_actual'], final_df['rating_pred']))
          #     result_df.loc[userId] = rmse
          #     # print(userId, rmse)
    
  return final_df

In [87]:
evaluate(test_df, user_prediction_result_df)

4384
610


  0%|          | 0/610 [00:00<?, ?it/s]

Unnamed: 0,actual_rating,movieId,pred_rating
0,3.0,44555,3.32669
1,4.0,8366,2.65694
2,3.0,3744,1.98661
3,4.5,750,3.94206
4,4.5,1300,1.59748
...,...,...,...
208,3.5,7894,1.6818
209,3.5,60684,3.47371
210,3.5,3683,2.81933
211,3.5,5989,3.75249


In [88]:
evaluate(test_df, item_prediction_result_df)

4384
610


  0%|          | 0/610 [00:00<?, ?it/s]

Unnamed: 0,actual_rating,movieId,pred_rating
0,3.0,44555,3.87023
1,4.0,8366,3.72467
2,3.0,3744,3.59309
3,4.5,750,3.85124
4,4.5,1300,4.01945
...,...,...,...
208,3.5,7894,3.63077
209,3.5,60684,3.70714
210,3.5,3683,3.89082
211,3.5,5989,3.78607


In [89]:
result_df = evaluate(test_df, user_prediction_result_df)
print(result_df)
print(f"RMSE: {sqrt(mean_squared_error(result_df['actual_rating'].values, result_df['pred_rating'].values))}")

4384
610


  0%|          | 0/610 [00:00<?, ?it/s]

     actual_rating  movieId pred_rating
0              3.0    44555     3.32669
1              4.0     8366     2.65694
2              3.0     3744     1.98661
3              4.5      750     3.94206
4              4.5     1300     1.59748
..             ...      ...         ...
208            3.5     7894      1.6818
209            3.5    60684     3.47371
210            3.5     3683     2.81933
211            3.5     5989     3.75249
212            3.5   111663    0.320937

[213 rows x 3 columns]
RMSE: 1.7366604322166128


In [90]:
result_df = evaluate(test_df, item_prediction_result_df)
print(result_df)
print(f"RMSE: {sqrt(mean_squared_error(result_df['actual_rating'].values, result_df['pred_rating'].values))}")

4384
610


  0%|          | 0/610 [00:00<?, ?it/s]

     actual_rating  movieId pred_rating
0              3.0    44555     3.87023
1              4.0     8366     3.72467
2              3.0     3744     3.59309
3              4.5      750     3.85124
4              4.5     1300     4.01945
..             ...      ...         ...
208            3.5     7894     3.63077
209            3.5    60684     3.70714
210            3.5     3683     3.89082
211            3.5     5989     3.78607
212            3.5   111663     3.55169

[213 rows x 3 columns]
RMSE: 0.7643428860080879
