In [5]:
import pandas as pd
import numpy as np
from math import sqrt

In [6]:
RATING_DATA_PATH = './data/ratings.csv'

np.set_printoptions(precision=2)

In [7]:
def distance(
    user_1 : np.array,
    user_2 : np.array
    ) -> float:
    
    return sqrt(np.nansum((user_1 - user_2) ** 2))

In [8]:
def filter_users_without_movie(
    rating_data : np.array,
    movie_id : int
    ) -> np.array:
    """movie_id 번째 영화를 평가하지 않은 유저 제외"""
    return rating_data[~np.isnan(rating_data[:,movie_id])]

In [9]:
def fill_nan_with_user_mean(
    rating_data : np.array
    ) -> np.array:
    
    filled_data = np.copy(rating_data)
    row_mean = np.nanmean(filled_data, axis=1)
    inds = np.where(np.isnan(filled_data))  ## empty indices
    filled_data[inds] = np.take(row_mean, inds[0])
    
    return filled_data

In [15]:
def get_k_neighbors(user_id, rating_data, k):
    
    distance_data = np.copy(rating_data)
    distance_data = np.append(distance_data, np.zeros((distance_data.shape[0], 1)), axis=1)
    
    for i in range(distance_data.shape[0]):
        if i == user_id:
            distance_data[i, -1] = np.inf
            continue
        distance_data[i, -1] = distance(distance_data[user_id, :-1], distance_data[i, :-1])
    
    ## sort by distance
    distance_data = distance_data[np.argsort(distance_data[:, -1])]
    
    return distance_data[:k, :-1]

In [16]:
rating_data = pd.read_csv(RATING_DATA_PATH, index_col='user_id').values
filtered_data = filter_users_without_movie(rating_data, 3)
filled_data = fill_nan_with_user_mean(filtered_data)
user_0_neighbors = get_k_neighbors(0, filled_data, 5)
user_0_neighbors

array([[3.18, 3.18, 3.18, 5.  , 3.18, 3.18, 2.  , 2.  , 2.  , 3.18, 3.  ,
        4.  , 2.  , 5.  , 4.  , 3.18, 3.18, 3.18, 4.  , 2.  ],
       [3.36, 5.  , 3.36, 5.  , 3.  , 3.36, 3.36, 3.  , 2.  , 4.  , 2.  ,
        3.36, 4.  , 4.  , 5.  , 4.  , 2.  , 3.36, 1.  , 3.  ],
       [2.71, 2.71, 2.  , 5.  , 2.71, 2.71, 2.71, 2.71, 2.71, 2.71, 2.71,
        2.71, 1.  , 2.71, 2.71, 2.71, 3.  , 1.  , 5.  , 2.  ],
       [2.78, 5.  , 1.  , 4.  , 2.78, 2.78, 2.78, 3.  , 1.  , 2.78, 1.  ,
        2.78, 2.78, 4.  , 2.78, 2.78, 2.  , 2.78, 2.78, 4.  ],
       [3.  , 3.  , 3.  , 5.  , 4.  , 3.  , 3.  , 4.  , 5.  , 3.  , 3.  ,
        1.  , 2.  , 1.  , 1.  , 3.  , 3.  , 3.  , 4.  , 3.  ]])