In [None]:
from collections import defaultdict
import os
import wget
import zipfile
from tqdm.notebook import tqdm

In [None]:
url = "https://files.grouplens.org/datasets/movielens/ml-1m.zip"
zip_path = "ml-1m.zip"
raw_dir ="ml-1m"
items_path = os.path.join(raw_dir,"movies.dat")
ratings_path = os.path.join(raw_dir,"ratings.dat")

# download raw dataset
if not os.path.exists(zip_path):
    wget.download(url)

if not os.path.exists(raw_dir):
    with zipfile.ZipFile(zip_path,"r") as zip_ref:
        zip_ref.extractall()

In [None]:
ITEM_FREQ_MIN = 5
REVIEWS_REMOVE_LESS_THAN = 5

SEP="::"
INTERNAL_SEP="|"

out_path = "ml-1m.txt"
id_to_title_map_path = "ml-1m-titles.txt"
train_item_freq_path = "ml-1m-train_item_freq.txt"


In [None]:
# load items data
items = dict()
with open(items_path, "r", encoding ='ISO-8859-1') as f:
    for line in f:
        item_id, title, genres  = line.split(SEP)
        items[item_id] = title

print(f"Found {len(items)} items")

# load rewview data
reviews = defaultdict(list)
item_freq = defaultdict(int)
skipped = 0
with open(ratings_path, "r", encoding="utf-8") as f:
    for line in f:
        user_id, item_id, rating, timestemp = line.split(SEP)
        if item_id in items:
            reviews[user_id].append((item_id, int(timestemp)))
            item_freq[item_id] += 1
        else:
            skipped += 1

print(f"Found {len(reviews)} users")
print(f"Found {sum(item_freq.values())} reviews")
print(f"Skipepd {skipped} item reviews without metadata")
      
item_freq = {k: v for k, v in item_freq.items() if v >= ITEM_FREQ_MIN}

# remove user with less than K reviews
removed_users_less_than = 0
removed_users_item_less_than = 0
removed_items = 0
updated_items = set()
for user_id in list(reviews.keys()):
    if len(reviews[user_id]) < REVIEWS_REMOVE_LESS_THAN:
        del reviews[user_id]
        removed_users_less_than += 1
    else:
        len_before = len(reviews[user_id])
        reviews[user_id] = [item for item in reviews[user_id] if item[0] in item_freq]
        updated_items.update([t[0] for t in reviews[user_id]])
        removed_items += len_before - len(reviews[user_id])
        if len(reviews[user_id]) <= 0:
            del reviews[user_id]
            removed_users_item_less_than += 1
print(f"Removed {removed_items} reviews of items that appear less than {ITEM_FREQ_MIN} in total")
print(f"Removed {removed_users_less_than} users with less than {REVIEWS_REMOVE_LESS_THAN} actions")
print(f"Removed {removed_users_item_less_than} users with only item count less than {REVIEWS_REMOVE_LESS_THAN}")

# calculate item frequencey again 
original_item_freq = item_freq
item_freq = defaultdict(int)
for user_id, rating_list in reviews.items():
    for item, timestamp in rating_list:
        item_freq[item] += 1
        
item_freq = dict(sorted(item_freq.items()))
print(f"Total of {sum(item_freq.values())} reviews")

# remove "unused" items
new_items = {}
new_item_freq = {}
new_original_item_freq = {}
for asin in tqdm(updated_items):
    new_items[asin] = items[asin]
    new_item_freq[asin] = item_freq[asin]
    new_original_item_freq[asin] = original_item_freq[asin]
print(f"Removed {len(items) - len(new_items)} items that are not been reviewd")
item_freq = new_item_freq
items = new_items
original_item_freq = new_original_item_freq


print()
print(f"Items   Reviews   Users")
print(f"{len(items):<4}   {sum(len(v) for v in reviews.values()):<7}   {len(reviews):<5}")

# fix user id
user_id_mapping = dict()
i = 0
for original_user_id in reviews:
    user_id_mapping[original_user_id] = i
    i += 1

# fix items ids
item_id_mapping = dict()
i = 0
for asin in items:
    item_id_mapping[asin] = i
    i += 1

train_item_freq = {k: 0 for k in item_freq.keys()}
val_item_freq = {k: 0 for k in item_freq.keys()}
test_item_freq = {k: 0 for k in item_freq.keys()}
for user_id, rating_list in reviews.items():
    sorted_list = list(map(lambda t: t[0], sorted(rating_list, key=lambda t: t[1])))
    if len(sorted_list) < 3:
        train_list = sorted_list
    else:
        train_list = sorted_list[1:-2]
        val_item_freq[sorted_list[-2]] += 1
        test_item_freq[sorted_list[-1]] += 1    
    for asin in train_list:
        train_item_freq[asin] += 1

with open(out_path, "w") as f:
    for user_id, rating_list in reviews.items():
        sorted_list = sorted(rating_list, key=lambda t: t[1])
        for item_id, timestamp in sorted_list:
            f.write(f"{user_id_mapping[user_id] + 1} {item_id_mapping[item_id] + 1}\n") # start user id from 1 to match original SASRec paper,reserve the 0 index for padding

with open(id_to_title_map_path, "w") as f:
    for asin, title in items.items():
        f.write(f'{item_id_mapping[asin]} "{title}"\n')

with open(train_item_freq_path, "w") as f:
    for asin, count in train_item_freq.items():
        f.write(f'{item_id_mapping[asin]} {count}\n')