In [29]:
import json
import numpy as np
import pandas as pd
import pickle
import copy
from tqdm import tqdm
import csv

In [1]:
dataset = 'ML20M'

#### Show statistics of dataset

In [30]:
rating_df = pd.read_csv('./ratings.csv', sep=',', names=["userId", "itemId", "rating", "timestamp"], skiprows=1)
rating_df.drop(columns=['timestamp'], inplace=True)
# Print the number of users, items and interactions
print("Dataset statistics: ")
print(f"> No. of users: {len(rating_df['userId'].unique())}")
print(f"> No. of items: {len(rating_df['itemId'].unique())}")
print(f"> No. of interactions: {rating_df.shape[0]}")

Dataset statistics: 
> No. of users: 138493
> No. of items: 26744
> No. of interactions: 20000263


#### Data filter

In [31]:
# Drop ratings less than 4
for i in tqdm(range(len(rating_df))):
    if rating_df.at[i, 'rating'] >= 4:
        rating_df.at[i, 'rating'] = 1
    else: 
        rating_df.at[i, 'rating'] = 0
rating_df.drop(rating_df.index[rating_df['rating'] == 0], axis=0, inplace=True)
# Drop the column of 'rating' and duplicate records
rating_df.drop('rating', axis=1, inplace=True)
rating_df.drop_duplicates(subset =['userId', 'itemId'], keep = 'first', inplace = True)

100%|██████████| 20000263/20000263 [05:39<00:00, 58958.31it/s]


In [32]:
# Copy rating_df to rdf
rdf = copy.copy(rating_df)
# Calculate the total number of interactions of every user and item
rdf['user_freq'] = rdf.groupby('userId')['userId'].transform('count')
rdf['item_freq'] = rdf.groupby('itemId')['itemId'].transform('count')
print(rdf)

          userId  itemId  user_freq  item_freq
6              1     151         88       6305
7              1     223         88      15538
8              1     253         88      12782
9              1     260         88      42612
10             1     293         88      19079
...          ...     ...        ...        ...
20000256  138493   66762        301         41
20000257  138493   68319        301       1070
20000258  138493   68954        301       6621
20000259  138493   69526        301        591
20000261  138493   70286        301       5229

[9995410 rows x 4 columns]


In [33]:
# Thresholds for user and item
user_threshold = 50
item_threshold = 25

In [34]:
# Remove users and items where their interactions less than threshold
while (rdf['user_freq'].min() < user_threshold or rdf['item_freq'].min() < item_threshold) :
    rdf.drop(rdf.index[rdf['user_freq'] < user_threshold], inplace=True)
    rdf['item_freq'] = rdf.groupby('itemId')['itemId'].transform('count')
    rdf.drop(rdf.index[rdf['item_freq'] < item_threshold], inplace=True)
    rdf['user_freq'] = rdf.groupby('userId')['userId'].transform('count')
    rdf['item_freq'] = rdf.groupby('itemId')['itemId'].transform('count')   

In [35]:
# Show the number of users, items and the sparsity after preprocessing
usercnt = len(rdf['userId'].unique())
itemcnt = len(rdf['itemId'].unique())
print("total user: ", usercnt)
print("total item: ", itemcnt)
print('sparsity: ' + str(len(rdf) * 1.0 / (usercnt * itemcnt)))
# Drop the column of 'user_freq' and 'item_freq'
rdf.drop('user_freq', axis=1, inplace=True)
rdf.drop('item_freq', axis=1, inplace=True)
rdf.reset_index(drop=True, inplace=True)

total user:  55845
total item:  8783
sparsity: 0.0163016511143061


#### Renumber users and items

In [36]:
user_dic = dict()
item_dic = dict()

user_idx = 0
item_idx = 0

for row in tqdm(rdf.iterrows(), total=rdf.shape[0]):
  if row[1][0] not in user_dic.keys():
    user_dic[row[1][0]] = user_idx
    user_idx += 1
  # add a new book id with an index
  if row[1][1] not in item_dic.keys():
    item_dic[row[1][1]] = item_idx
    item_idx += 1

100%|██████████| 7995742/7995742 [03:08<00:00, 42347.61it/s]


In [37]:
header = ['userId', 'itemId']
with open(f'../../mod_data/{dataset}/{dataset}.csv', 'w', encoding='utf-8') as fp:
    writer = csv.writer(fp)
    writer.writerow(header)
    for row in tqdm(rdf.iterrows(), total=rdf.shape[0]):
        try:
            writer.writerow([user_dic[row[1][0]], item_dic[row[1][1]]])
        except KeyError as e:
            print(e, row[0])

100%|██████████| 7995742/7995742 [03:17<00:00, 40457.84it/s]
