In [1]:
import json
import logging
import os
import pathlib as Path
import random
import numpy as np
import polars as pl
import pickle as pkl
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader

# from src_dataloadre import Train

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from legacy_dataloeders import TrainDataset, collate_fn, EvalDatasetGTS

In [3]:
domains = ['mega', 'zvuk']
max_seq_len = 200

data_path = Path.Path('/home/jovyan/Samra/fl-recsys-adaptation/data/mega_zvuk-overlap50-minuser10-positives/')


In [4]:
def read_domain_data(data_path: Path):

    train_df = (
        pl.scan_parquet(data_path / 'train.parquet')
        .group_by('uid')
        .agg(pl.col('item_id'), pl.col('timestamp'))
        .collect()
    )

    val_df = (
        pl.scan_parquet(data_path / 'val.parquet')
        .group_by('uid')
        .agg(pl.col('item_id'), pl.col('timestamp'))
        .collect()
    )
    test_df = (
        pl.scan_parquet(data_path / 'test.parquet')
        .group_by('uid')
        .agg(pl.col('item_id'), pl.col('timestamp'))
        .collect()
    ).with_columns(pl.col('uid').cast(pl.Int32))
    train_val_df = (pl.concat([
        pl.scan_parquet(data_path / 'train.parquet'),
        pl.scan_parquet(data_path / 'val.parquet'),
        ])
        .group_by('uid')
        .agg(pl.col('item_id'), pl.col('timestamp'))
        .collect()
    ).with_columns(pl.col('uid').cast(pl.Int32))

    with open(data_path / 'item_id_to_idx.pkl', 'rb') as f:
        item_id_to_idx = pkl.load(f)

    num_items = len(item_id_to_idx)

    return train_df, val_df, train_val_df, test_df, num_items


In [5]:
train_dfs, val_dfs, train_val_dfs, test_dfs, num_items = [], [], [], [], []
for domain in domains: 
    train_df, val_df, train_val_df, test_df, num_item = read_domain_data(data_path / domain)
    train_dfs.append(train_df)
    test_dfs.append(test_df)
    train_val_dfs.append(train_val_df)
    val_dfs.append(val_df)
    num_items.append(num_item)


In [6]:
all_users = np.union1d(train_dfs[0]['uid'].unique().to_numpy(), train_dfs[1]['uid'].unique().to_numpy())
print(len(all_users))

143928


In [7]:
uid_2_index = dict(zip(all_users, np.arange(len(all_users))))

In [8]:
train_dfs = [df.with_columns(pl.col('uid').replace_strict(uid_2_index)) for df in train_dfs]
test_dfs = [df.with_columns(pl.col('uid').replace_strict(uid_2_index)) for df in test_dfs]
val_dfs = [df.with_columns(pl.col('uid').replace_strict(uid_2_index)) for df in val_dfs]

In [10]:
for i, domain in enumerate(domains):
    os.makedirs(f'data/{domain}', exist_ok=True)

    train_dataset = TrainDataset(dataset=train_dfs[i], num_items=num_items[i],
                                    max_seq_len=max_seq_len, num_neg_items=1)

    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=1,
        collate_fn=collate_fn,
        drop_last=True,
        shuffle=True,
        num_workers=2,
        prefetch_factor=4,
    )
    with open(f"data/{domain}/train_data.txt", "w") as f:
        for batch in tqdm(train_dataloader):
            user_id = int(batch["user.ids"][0])
            items = batch["item.ids"].tolist()

            f.write(str(user_id))
            for item in items:
                f.write(f"\t{item}")
            f.write("\n")



  0%|          | 0/95952 [00:00<?, ?it/s]

100%|██████████| 95952/95952 [04:15<00:00, 375.21it/s]
100%|██████████| 95952/95952 [04:15<00:00, 376.13it/s]


In [11]:
def get_eval_dataloader(train_df, eval_df, max_seq_len, batch_size, seed=42, eval_mode='random'):
    
    eval_df = train_df.join(eval_df, on='uid', how='inner', suffix='_valid').select(
        pl.col('uid'), pl.col('item_id').alias('item_id_train'), pl.col('item_id_valid')
    ).sort('uid')
    eval_dataset = EvalDatasetGTS(dataset=eval_df, max_seq_len=max_seq_len, seed=seed, mode=eval_mode)

    eval_dataloader = DataLoader(
        dataset=eval_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        drop_last=False,
        num_workers=8,
        shuffle=False,
    )

    return eval_df, eval_dataloader

In [12]:
for i, domain in enumerate(domains):
    os.makedirs(f'data/{domain}', exist_ok=True)

    eval_df, eval_dataloader = get_eval_dataloader(train_dfs[i], val_dfs[i], max_seq_len, 1, eval_mode='last')
    with open(f"data/{domain}/valid_data.txt", "w") as f:
        for batch in tqdm(eval_dataloader):
            user_id = int(batch["user.ids"][0])
            items = batch["item.ids"].tolist()

            f.write(str(user_id))
            for item in items:
                f.write(f"\t{item}")
            f.write("\n")


  1%|▏         | 387/28600 [00:01<01:20, 350.89it/s]

100%|██████████| 28600/28600 [01:18<00:00, 365.16it/s]
100%|██████████| 44587/44587 [02:03<00:00, 360.84it/s]


In [27]:
for i, domain in enumerate(domains):
    os.makedirs(f'data/{domain}', exist_ok=True)
    with open(f"data/{domain}/test_data.txt", "w") as f:
        for seed in [42, 41, 43]:
            eval_df, eval_dataloader = get_eval_dataloader(train_val_dfs[i], test_dfs[i], max_seq_len, batch_size=1, eval_mode='random', seed=seed)
            for batch in tqdm(eval_dataloader):
                user_id = int(batch["user.ids"][0])
                items = batch["item.ids"].tolist()

                f.write(str(user_id))
                for item in items:
                    f.write(f"\t{item}")
                f.write("\n")



  0%|          | 0/24799 [00:00<?, ?it/s]

100%|██████████| 24799/24799 [01:08<00:00, 360.24it/s]
100%|██████████| 24799/24799 [01:07<00:00, 368.57it/s]
100%|██████████| 24799/24799 [01:07<00:00, 369.28it/s]
100%|██████████| 42625/42625 [01:58<00:00, 359.53it/s]
100%|██████████| 42625/42625 [01:57<00:00, 362.92it/s]
100%|██████████| 42625/42625 [01:59<00:00, 357.85it/s]


In [10]:
for i, domain in enumerate(domains):
    os.makedirs(f'data/{domain}', exist_ok=True)
    with open(f"data/{domain}/num_items.txt", "w") as f:
        f.write(f'{num_items[i]}')
    with open(f"data/{domain}/item_users.txt", "w") as f:
        f.write(f'{len(train_dfs[i])}')
    
        