In [28]:
import pandas as pd
import os
import random
import torch
import numpy as np

In [29]:
random.seed(2024)

# Load original dataset

### Selecting datasets

In [30]:
root = '/data/mjyin/dataset/AmazonRaw/'
output_path = './debug3/'
dataset_name_list = [
    'ratings_Sports_and_Outdoors',
    'ratings_Toys_and_Games',
]
output_dataset_name_list = [
    'sport',
    'toy',
]

### Load

In [31]:
dataset_list = []
for idx, dataset_name in enumerate(dataset_name_list):
    path = root + dataset_name + '.csv'
    dataset = pd.read_csv(path, header=None)
    dataset.columns = ['user_id', 'item_id', 'rating', 'timestamp']
    dataset = dataset[dataset.rating >= 3]
    dataset['domain'] = idx
    dataset_list.append(dataset)

# Filter

### filter by number of interactions

In [32]:
user_threshold = 5
item_threshold = 5
filtered_dataset_list = []
for dataset in dataset_list:
    filtered_dataset = dataset.copy()
    while(True):
        ori_len = len(filtered_dataset)
        filtered_dataset = filtered_dataset[filtered_dataset['user_id'].map(filtered_dataset['user_id'].value_counts()) >= user_threshold]
        filtered_dataset = filtered_dataset[filtered_dataset['item_id'].map(filtered_dataset['item_id'].value_counts()) >= item_threshold]
        if len(filtered_dataset) == ori_len:
            break
    print('done!')
    filtered_dataset_list.append(filtered_dataset)

done!
done!


# Map all ids

In [33]:
all_filtered_dataset = pd.concat(filtered_dataset_list)
all_user = all_filtered_dataset.user_id
all_item = all_filtered_dataset.item_id

In [34]:
user_id, user_token = pd.factorize(all_user)
item_id, item_token = pd.factorize(all_item)
user_mapping_dict = {_: idx for idx, _ in enumerate(user_token)}
item_mapping_dict = {_: idx for idx, _ in enumerate(item_token)}
print(user_token.shape)
print(item_token.shape)

(43422,)
(25091,)


In [35]:
all_filtered_dataset['user_id'] = all_filtered_dataset['user_id'].apply(lambda x: user_mapping_dict[x])
all_filtered_dataset['item_id'] = all_filtered_dataset['item_id'].apply(lambda x: item_mapping_dict[x])

In [36]:
mapped_dataset_list = [all_filtered_dataset[all_filtered_dataset['domain'] == idx] for idx in range(len(filtered_dataset_list))]

# Generate sequences

### mixed sequence

In [37]:
max_seq_len = 20
num_domain = len(mapped_dataset_list)
PAD = -1 # mask is set to the last position

def pad(_seq):
    cur_seq_len = len(_seq)
    return _seq + [PAD] * (max_seq_len * num_domain - cur_seq_len), cur_seq_len

all_mapped_dataset = pd.concat(mapped_dataset_list)
all_mapped_dataset = all_mapped_dataset.sort_values(by=['user_id', 'timestamp'])
item_id_group = all_mapped_dataset.groupby('user_id')['item_id'].apply(list)
domain_id_group = all_mapped_dataset.groupby('user_id')['domain'].apply(np.array)
train, val, test = [], [[] for _ in range(num_domain)], [[] for _ in range(num_domain)]

for user_id, user_seq in list(zip(item_id_group.index, item_id_group.tolist())):
    # get domain seq
    domain_seq = domain_id_group[user_id]
    # get item position of each domain
    domain_mask = [np.where(domain_seq == idx)[0] for idx in range(num_domain)]
    # truncate each sub-sequence in each domain
    truncated_domain_mask = [_[-max_seq_len:] for _ in domain_mask]
    # get the mixed seq/domain_seq and sort
    all_truncated_domain_mask = np.concatenate(truncated_domain_mask)
    all_truncated_domain_mask.sort()
    user_seq = np.array(user_seq)
    truncated_seq = user_seq[all_truncated_domain_mask]
    truncated_domain_seq = domain_seq[all_truncated_domain_mask]
    # re-compute item position of each domain
    domain_mask = [np.where(truncated_domain_seq == idx)[0] for idx in range(num_domain)]
    # compute test and val location given the domain mask
    test_location, val_location = {}, {}
    for idx, _ in enumerate(domain_mask):
        if len(_) > 0:
            test_location[idx] = _[-1]
            val_location[idx] = _[-2]

    for idx in range(num_domain):
        if val_location.get(idx, -1) == -1:
            continue
        # test
        current_test_position = test_location[idx]
        
        target_data = truncated_seq[current_test_position]

        new_seq = truncated_seq.copy()
        new_seq = new_seq[:current_test_position + 1]
        pad_seq, seq_len = pad(new_seq.tolist()[:-1])

        new_domain_seq = truncated_domain_seq.copy()
        new_domain_seq = new_domain_seq[:current_test_position + 1]
        pad_domain_seq, _ = pad(new_domain_seq.tolist()[:-1])
        
        label = 1
        test[idx].append([user_id, pad_seq, target_data, seq_len, label, pad_domain_seq])

        # val
        current_val_position = val_location[idx]
        
        target_data = truncated_seq[current_val_position]

        new_seq = truncated_seq.copy()
        new_seq = new_seq[:current_val_position + 1]
        pad_seq, seq_len = pad(new_seq.tolist()[:-1])

        new_domain_seq = truncated_domain_seq.copy()
        new_domain_seq = new_domain_seq[:current_val_position + 1]
        pad_domain_seq, _ = pad(new_domain_seq.tolist()[:-1])
        
        label = 1
        val[idx].append([user_id, pad_seq, target_data, seq_len, label, pad_domain_seq])

    # train
    new_seq = truncated_seq.copy()
    new_seq[list(test_location.values())] = PAD
    new_seq[list(val_location.values())] = PAD
    pad_seq, seq_len = pad(new_seq.tolist()[:-1])
    target_data, _ = pad(new_seq.tolist()[1:])

    new_domain_seq = truncated_domain_seq.copy()
    new_domain_seq[list(test_location.values())] = PAD
    new_domain_seq[list(val_location.values())] = PAD
    pad_domain_seq, _ = pad(new_domain_seq.tolist()[:-1])
    
    label = [1 if _ != -1 else 0 for _ in pad_domain_seq]
    train.append([user_id, pad_seq, target_data, seq_len, label, pad_domain_seq])

torch.save(train, os.path.join(output_path, 'train.pth'))
for idx in range(num_domain):
    torch.save(val[idx], os.path.join(output_path, output_dataset_name_list[idx], 'val.pth'))
    torch.save(test[idx], os.path.join(output_path, output_dataset_name_list[idx], 'test.pth'))

# Save

In [38]:
for d_name, d in zip(output_dataset_name_list, mapped_dataset_list):
    d.to_csv(os.path.join(output_path, d_name, 'inter.csv'), sep=',', index=None)

In [39]:
print(mapped_dataset_list[0]['user_id'].unique().shape)
print(mapped_dataset_list[0]['item_id'].unique().shape)

(29277,)
(15396,)


In [40]:
print(mapped_dataset_list[1]['user_id'].unique().shape)
print(mapped_dataset_list[1]['item_id'].unique().shape)

(15528,)
(9696,)
