In [1]:
import os
import sys
import json
import datetime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from tqdm import tqdm
from collections import defaultdict, Counter

%matplotlib inline

In [2]:
path_to_df = '../data/Beauty/ratings_Beauty.csv'
df = pd.read_csv(path_to_df, names=['raw_user_id', 'raw_item_id', 'rating', 'timestamp'])

In [None]:
df.head()

In [None]:
df.isnull().sum()

In [None]:
df.raw_user_id.max(), df.raw_user_id.unique().shape

In [None]:
df['user_id'] = pd.factorize(df.raw_user_id)[0] + 1
df.user_id.min(), df.user_id.max(), df.user_id.unique().shape

In [None]:
df['item_id'] = pd.factorize(df.raw_item_id)[0] + 1
df.item_id.min(), df.item_id.max(), df.item_id.unique().shape

In [None]:
df.head()

In [None]:
data = []

for _, row in tqdm(df.iterrows()):
    data.append({
        'user_id': int(row.user_id),
        'item_id': int(row.item_id),
        'timestamp': int(row.timestamp)
    })

print(len(data))

In [None]:
user_history = defaultdict(list)
item_history = defaultdict(list)

for row in tqdm(data):
    user_raw_id = row['user_id']
    item_raw_id = row['item_id']
    interaction_timestamp = row['timestamp']
    
    user_history[user_raw_id].append({'item_id': item_raw_id, 'timestamp': interaction_timestamp})
    item_history[item_raw_id].append({'user_id': user_raw_id, 'timestamp': interaction_timestamp})

is_changed = True
threshold = 5
good_users = set()
good_items = set()


while is_changed:
    old_state = (len(good_users), len(good_items))
    
    good_users = set()
    good_items = set()

    for user_id, history in user_history.items():
        if len(history) >= threshold:
            good_users.add(user_id)

    for item_id, history in item_history.items():
        if len(history) >= threshold:
            good_items.add(item_id)
    
    user_history = {
        user_id: list(filter(lambda x: x['item_id'] in good_items, history))
        for user_id, history in user_history.items()
    }
    
    item_history = {
        item_id: list(filter(lambda x: x['user_id'] in good_users, history))
        for item_id, history in item_history.items()
    }
    
    new_state = (len(good_users), len(good_items))
    is_changed = (old_state != new_state)
    print(old_state, new_state)

In [None]:
user_mapping = {}
item_mapping = {}
tmp_user_history = defaultdict(list)
tmp_item_history = defaultdict(list)

for user_id, history in tqdm(user_history.items()):
    processed_history = []

    for filtered_item in history:
        item_id = filtered_item['item_id']
        item_timestamp = filtered_item['timestamp']

        processed_item_id = item_mapping.get(item_id, len(item_mapping) + 1)
        item_mapping[item_id] = processed_item_id

        processed_history.append({'item_id': processed_item_id, 'timestamp': item_timestamp})
        
    if len(processed_history) >= threshold:
        processed_user_id = user_mapping.get(user_id, len(user_mapping) + 1)
        user_mapping[user_id] = processed_user_id

        tmp_user_history[processed_user_id] = sorted(processed_history, key=lambda x: x['timestamp'])

    
for item_id, history in tqdm(item_history.items()):
    processed_history = []

    for filtered_user in history:
        user_id = filtered_user['user_id']
        user_timestamp = filtered_user['timestamp']

        processed_user_id = user_mapping.get(user_id, len(user_mapping) + 1)
        user_mapping[user_id] = processed_user_id

        processed_history.append({'user_id': processed_user_id, 'timestamp': user_timestamp})

    if len(processed_history) >= threshold:
        processed_item_id = item_mapping.get(item_id, len(item_mapping) + 1)
        item_mapping[item_id] = processed_item_id

        tmp_item_history[processed_item_id] = sorted(processed_history, key=lambda x: x['timestamp'])

user_history = tmp_user_history
item_history = tmp_item_history

In [None]:
# Overall dataset statistics
print('Users count:', len(user_mapping))
print('Items count:', len(item_mapping))
print('Actions count:', sum(list(map(lambda x: len(x), user_history.values()))))
print('Avg user history len:', np.mean(list(map(lambda x: len(x), user_history.values()))))
print('Avg item history len:', np.mean(list(map(lambda x: len(x), item_history.values()))))

## Leave-one-out split (last item for test, pre-last item for valid, the remaining part for train)

In [13]:
with open('../data/Beauty/all_data.txt', 'w') as f:
    for user_id, user_interractions in user_history.items():
        f.write(' '.join([str(user_id)] + [
            str(item_event['item_id']) for item_event in sorted(user_interractions, key=lambda x: x['timestamp'])
        ]))
        f.write('\n')

## Timestamp-based split (80% for train, 10% for valid, and 10% for test)

In [None]:
valid_portion = 0.1
test_portion = 0.1

all_events_timestamp = []
for user_id, user_interractions in user_history.items():
    for user_interraction in user_interractions:
        interractions_ts = user_interraction['timestamp']
        all_events_timestamp.append(interractions_ts)

all_events_timestamp = sorted(all_events_timestamp)

fst_threshold = all_events_timestamp[int(len(all_events_timestamp) * (1.0 - test_portion - valid_portion))]
snd_threshold = all_events_timestamp[int(len(all_events_timestamp) * (1.0 - test_portion))]

print(f'First train timestamp:\t{all_events_timestamp[0]}')
print(f'First valid timestamp:\t{fst_threshold}')
print(f'First test timestamp:\t{snd_threshold}')

In [15]:
train_samples = []
valid_samples = []
test_samples = []

for user_id, user_interactions in user_history.items():
    train_history = []
    history = []
    
    for user_interaction in user_interactions:
        if user_interaction['timestamp'] < fst_threshold: # train event
            assert len(history) == 0 or user_interaction['timestamp'] >= history[-1]['timestamp']
            train_history.append(user_interaction)
        elif user_interaction['timestamp'] < snd_threshold: # valid event
            assert user_interaction['timestamp'] >= fst_threshold
            if len(history) >= 5:  # remove cold-start users
                valid_samples.append({
                    'user_id': user_id,
                    'history': [x for x in history],
                    'next_interaction': user_interaction
                })
        else:  # test event
            assert user_interaction['timestamp'] >= snd_threshold
            if len(history) >= 5:  # remove cold-start users
                test_samples.append({
                    'user_id': user_id,
                    'history': [x for x in history],
                    'next_interaction': user_interaction
                })
        history.append(user_interaction)
    
    if len(train_history) >= 5:  # remove cold-start users
        train_samples.append({
            'user_id': user_id,
            'history': train_history
        })

In [None]:
len(train_samples), len(valid_samples), len(test_samples)

In [17]:
# train
with open('../data/Beauty/train.txt', 'w') as f:
    for train_sample in train_samples:
        f.write(' '.join([str(train_sample['user_id'])] + [
            str(user_interaction['item_id']) for user_interaction in sorted(train_sample['history'], key=lambda x: x['timestamp'])
        ]))
        f.write('\n')

# valid
with open('../data/Beauty/valid.txt', 'w') as f:
    for valid_sample in valid_samples:
        f.write(' '.join([str(valid_sample['user_id'])] + [
            str(user_interaction['item_id']) for user_interaction in sorted(valid_sample['history'], key=lambda x: x['timestamp'])
        ] + [str(valid_sample['next_interaction']['item_id'])]))
        f.write('\n')

# test
with open('../data/Beauty/test.txt', 'w') as f:
    for test_sample in test_samples:
        f.write(' '.join([str(test_sample['user_id'])] + [
            str(user_interaction['item_id']) for user_interaction in sorted(test_sample['history'], key=lambda x: x['timestamp'])
        ] + [str(test_sample['next_interaction']['item_id'])]))
        f.write('\n')

In [None]:
import torch
import pandas as pd

deduped_mapping = df.drop_duplicates(subset=['item_id', 'raw_item_id'])

embs = torch.load('../data/df_with_embs.pt')

merged = pd.merge(deduped_mapping, embs, 'inner', left_on='raw_item_id', right_on='asin')
merged['item_id'] = merged['item_id'].map(lambda x: item_mapping[x])
    
assert len(merged) == len(merged.item_id.unique())
merged = merged.set_index('item_id')

torch.save(merged, '../data/Beauty/data_full.pt')