In [2]:
import os
import sys
import json
import datetime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from tqdm import tqdm
from collections import defaultdict, Counter

%matplotlib inline

In [23]:
print('Users count:', len(user_mapping))
print('Items count:', len(item_mapping))
print('Actions count:', sum(list(map(lambda x: len(x), user_history.values()))))
print('Max user history len:', np.max(list(map(lambda x: len(x), user_history.values()))))
print('Max item history len:', np.max(list(map(lambda x: len(x), item_history.values()))))

Users count: 22363
Items count: 12101
Actions count: 198502
Max user history len: 204
Max item history len: 431


In [20]:
data = pd.read_csv('../data/beauty/ratings_Beauty.csv')

user_history = defaultdict(list)
item_history = defaultdict(list)

for _, row in tqdm(data.iterrows()):
    user_raw_id = row['UserId']
    item_raw_id = row['ProductId']
    
    interaction_timestamp = int(row['Timestamp'])
    
    user_history[user_raw_id].append({'item_id': item_raw_id, 'timestamp': interaction_timestamp})
    item_history[item_raw_id].append({'user_id': user_raw_id, 'timestamp': interaction_timestamp})


is_changed = True
good_users = set()
good_items = set()

while is_changed:
    old_state = (len(good_users), len(good_items))
    
    good_users = set()
    good_items = set()

    for user_id, history in user_history.items():
        if len(history) >= 5:
            good_users.add(user_id)

    for item_id, history in item_history.items():
        if len(history) >= 5:
            good_items.add(item_id)
    
    user_history = {
        user_id: list(filter(lambda x: x['item_id'] in good_items, history))
        for user_id, history in user_history.items()
    }
    
    item_history = {
        item_id: list(filter(lambda x: x['user_id'] in good_users, history))
        for item_id, history in item_history.items()
    }
    
    new_state = (len(good_users), len(good_items))
    is_changed = (old_state != new_state)


user_mapping = {}
item_mapping = {}
tmp_user_history = defaultdict(list)
tmp_item_history = defaultdict(list)

min_history_len_thresold = 5

for user_id, history in tqdm(user_history.items()):
    if user_id in good_users:
        filtered_history = list(filter(lambda x: x['item_id'] in good_items, history))
        processed_history = []

        for filtered_item in filtered_history:
            item_id = filtered_item['item_id']
            item_timestamp = filtered_item['timestamp']

            processed_item_id = item_mapping.get(item_id, len(item_mapping) + 1)
            item_mapping[item_id] = processed_item_id

            processed_history.append({'item_id': processed_item_id, 'timestamp': item_timestamp})
        
        if len(processed_history) > 4:
            processed_user_id = user_mapping.get(user_id, len(user_mapping) + 1)
            user_mapping[user_id] = processed_user_id
            
            tmp_user_history[processed_user_id] = sorted(processed_history, key=lambda x: x['timestamp'])

            
for item_id, history in tqdm(item_history.items()):
    if item_id in good_items:
        filtered_history = list(filter(lambda x: x['user_id'] in good_users, history))
        processed_history = []

        for filtered_user in filtered_history:
            user_id = filtered_user['user_id']
            user_timestamp = filtered_user['timestamp']

            processed_user_id = user_mapping.get(user_id, len(user_mapping) + 1)
            user_mapping[user_id] = processed_user_id

            processed_history.append({'user_id': processed_user_id, 'timestamp': user_timestamp})
        
        if len(processed_history) > 4:
            processed_item_id = item_mapping.get(item_id, len(item_mapping) + 1)
            item_mapping[item_id] = processed_item_id
            
            tmp_item_history[processed_item_id] = sorted(processed_history, key=lambda x: x['timestamp'])

user_history = tmp_user_history
item_history = tmp_item_history

print('Users count:', len(user_mapping))
print('Items count:', len(item_mapping))
print('Actions count:', sum(list(map(lambda x: len(x), user_history.values()))))
print('Avg user history len:', np.mean(list(map(lambda x: len(x), user_history.values()))))
print('Avg item history len:', np.mean(list(map(lambda x: len(x), item_history.values()))))

2023070it [01:45, 19210.72it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1210271/1210271 [00:01<00:00, 741475.41it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 249274/249274 [00:00<00:00, 490480.57it/s]


Users count: 22363
Items count: 12101
Actions count: 198502
Avg user history len: 8.876358270357287
Avg item history len: 16.403768283612923


In [21]:
with open('../data/beauty/Beauty.txt', 'w') as f:
    for user_id, history in user_history.items():
        f.write(f'{user_id} {" ".join(list(map(lambda x: str(x["item_id"]), history)))}\n')

with open('../data/beauty/Beauty_timestamps.txt', 'w') as f:
    for user_id, history in user_history.items():
        f.write(f'{user_id} {" ".join(list(map(lambda x: str(x["timestamp"]), history)))}\n')

In [79]:
print('Users count:', len(user_mapping))
print('Items count:', len(item_mapping))
print('Interactions count:', len(lines))

Users count: 52374
Items count: 121291
Interactions count: 371345


In [53]:
from polara import get_movielens_data


In [54]:
mldata = get_movielens_data(include_time=True)

In [60]:
fst_timepoint = mldata['timestamp'].quantile(
    q=0.80, interpolation='nearest'
)

snd_timepoint = mldata['timestamp'].quantile(
    q=0.90, interpolation='nearest'
)

print(fst_timepoint, snd_timepoint)

975768738 978133367


In [61]:
train_data = mldata.query('timestamp < @fst_timepoint')
val_data = mldata.query('@fst_timepoint <= timestamp < @snd_timepoint')
test_data = mldata.query('@snd_timepoint <= timestamp')

In [62]:
train_data.shape, val_data.shape, test_data.shape

((800164, 4), (100023, 4), (100022, 4))

In [63]:
train_data.head()

Unnamed: 0,userid,movieid,rating,timestamp
94507,635,1251,4,975768620
94513,635,3948,4,975768294
94518,635,1270,4,975768106
94519,635,1279,5,975768520
94522,635,1286,4,975768106
