In [11]:
from typing import List, Dict
import pandas as pd


def get_dataframe_from_dataset(
    rating_files: List[str]
    ) -> pd.DataFrame:
    """This function constructs a single csv file that represents all of the
    movie ratings made by the users from a series of files. The generated final
    csv file has the following columns:
        movie_id - user_id - score - date
    rating_files: all the documents where data has to be extracted from.
    dest_path: path (and file) where the final document has to be saved.
    show_progress (default False): wether to print the progress to terminal.
    progress_step (default 100): print progress every this many lines."""
    ratings: List[Dict[str, int]] = []
    for current_file_idx, rating_file in enumerate(rating_files):
        with open(rating_file) as ratings_data:
            for ln_idx, ln in enumerate(ratings_data):
                if ":" in ln:
                    movie_id = int(ln.split(":")[0])
                else:
                    splt_ln = ln.split(",")
                    user_id, score = int(splt_ln[0]), int(splt_ln[1])
                    ratings.append({
                        "user_id": user_id,
                        "movie_id": movie_id,
                        "score": score
                    })
    return pd.DataFrame(ratings)

In [13]:
dataset = get_dataframe_from_dataset(['../data/toy_dataset/raw/ratings.txt'])

In [14]:
dataset

Unnamed: 0,user_id,movie_id,score
0,7952,0,4
1,3413,0,1
2,7974,0,3
3,6213,0,1
4,8259,0,2
...,...,...,...
248634,5341,499,1
248635,5349,499,5
248636,6614,499,5
248637,1,499,2


In [15]:
users = dataset['user_id'].unique()
len(users)

10000

In [16]:
import pandas as pd
import math 
from functools import reduce
import time

# user 사이의 weight 값 계산하는거.
def user(x,y):
    i = single_rate[x]
    j = single_rate[y]
    ij = len(set(movie_lists[x]).intersection(set(movie_lists[y])))
    if ij == 0 : return 0 - math.log(i) - math.log(j) + math.log(d)
    else: return math.log(ij) - math.log(i) - math.log(j) + math.log(d)


def construct_user_graph(dataset, threshold=3):
    global user_lists
    global single_rate
    global d
    global movies
    global movie_lists
    
    
    movie_lists = []
    user_lists = []
    users = dataset['user_id'].unique()
#     movies = dataset['movie_id'].unique()

    for i in users:
        movie_lists_per_users = dataset.loc[(dataset['user_id']== i) & (dataset['score']>=threshold), 'movie_id'].values
        movie_lists.append(movie_lists_per_users)
        
    
    single_rate = [len(dataset[(dataset['user_id']==i) & (dataset['score'] >= threshold)]) for i in users]
    d = reduce(lambda x,y: x+y, single_rate)

    user_comb = [[x, y] for x in users for y in users if y > x]
            
    user_graph = pd.DataFrame(user_comb, columns=['x','y'])
    user_graph['weight'] = user_graph.apply(lambda row: user(row['x'],row['y']), axis=1)

    return user_graph

In [17]:
import time

start = time.time()
user_graph = construct_user_graph(dataset, threshold=3)
print(time.time() - start)

# sort 하기 전

745.8226537704468


In [18]:
user_graph

Unnamed: 0,x,y,weight
0,7952,7974,5.881533
1,7952,8259,6.015064
2,7952,9976,6.089172
3,7952,9250,6.089172
4,7952,7990,6.015064
...,...,...,...
49994995,2324,8812,6.328722
49994996,2324,3033,6.439948
49994997,2324,9994,6.788254
49994998,2324,7708,6.565111


In [20]:
import seaborn as sns
import matplotlib.pyplot as plt

In [21]:
# sns.histplot(user_graph['weight'])
# plt.show()

In [22]:
threshold =3
users = dataset['user_id'].unique()

In [23]:
single_rate = [len(dataset[(dataset['user_id']==i) & (dataset['score'] >= threshold)]) for i in users]
d = reduce(lambda x,y: x+y, single_rate)
print(d)
print(single_rate)

149077
[11, 17, 12, 14, 15, 10, 14, 20, 10, 11, 11, 7, 10, 12, 16, 12, 18, 17, 10, 21, 12, 14, 19, 22, 12, 17, 14, 14, 12, 15, 19, 15, 12, 11, 13, 13, 14, 17, 13, 14, 18, 27, 21, 11, 16, 12, 16, 11, 12, 14, 19, 11, 16, 20, 14, 11, 15, 16, 15, 18, 12, 11, 19, 21, 19, 19, 15, 22, 11, 12, 16, 20, 15, 16, 18, 15, 17, 19, 25, 14, 13, 9, 15, 15, 15, 24, 14, 16, 20, 13, 6, 12, 20, 15, 22, 13, 17, 12, 11, 12, 20, 9, 23, 16, 13, 15, 19, 18, 14, 15, 17, 21, 14, 13, 17, 22, 17, 13, 13, 16, 16, 21, 12, 18, 10, 21, 13, 14, 14, 18, 13, 21, 12, 13, 14, 15, 13, 11, 19, 25, 15, 10, 15, 10, 18, 21, 18, 9, 19, 18, 18, 12, 10, 14, 14, 18, 23, 14, 10, 17, 20, 16, 18, 12, 16, 10, 15, 13, 15, 14, 7, 22, 16, 21, 17, 8, 10, 18, 19, 14, 14, 12, 6, 17, 13, 13, 16, 16, 12, 16, 10, 19, 19, 16, 19, 12, 8, 16, 12, 19, 11, 17, 17, 15, 9, 21, 14, 12, 22, 11, 15, 21, 17, 22, 22, 15, 18, 11, 14, 23, 17, 16, 14, 7, 20, 18, 17, 15, 15, 21, 16, 14, 15, 24, 19, 22, 21, 20, 15, 20, 12, 15, 14, 15, 15, 18, 11, 13, 22, 14, 11,

In [25]:
user_graph.to_csv('../data/toy_dataset/raw/user_graph.csv', index=False)