In [1]:
import os
import pickle
import sys

sys.path.append('../..')

import numpy as np
import polars as pl
from dotenv import load_dotenv
from replay.preprocessing.filters import MinCountFilter

from source.filters import ConsecutiveDuplicatesFilter

In [None]:
load_dotenv()

True

In [3]:
ZVUK_DATA_DIR = os.getenv('ZVUK_DATA_DIR')
!ls {ZVUK_DATA_DIR}

zvuk-interactions.parquet  zvuk-track_artist_embedding.parquet


In [4]:
interactions = pl.read_parquet(os.path.join(ZVUK_DATA_DIR, 'zvuk-interactions.parquet'))
print(f'{interactions.shape = }')
interactions.head()

interactions.shape = (244673551, 5)


user_id,session_id,datetime,track_id,play_duration
i32,i32,datetime[ms],i32,f32
1938823,2200336,2023-03-18 11:43:57.856,1242546,23.999992
1938823,2200336,2023-03-18 11:43:58.114,145031,2.0
1938823,2200336,2023-03-18 11:45:12.932,145031,34.0
1938823,2200336,2023-03-18 11:45:27.808,1501837,18.0
1938823,2200336,2023-03-18 11:45:42.891,1402020,2.0


In [5]:
# Rename columns
interactions = interactions.rename({'datetime': 'timestamp', 'track_id': 'item_id'})
interactions.head(1)

user_id,session_id,timestamp,item_id,play_duration
i32,i32,datetime[ms],i32,f32
1938823,2200336,2023-03-18 11:43:57.856,1242546,23.999992


In [6]:
# Keep positive interactions only
MIN_DURATION = 60  # seconds

interactions = interactions.filter(pl.col('play_duration') > MIN_DURATION).drop('play_duration')
len(interactions)

128229767

In [7]:
# Convert timestamp to UNIX time
interactions = interactions.with_columns(pl.col('timestamp').dt.epoch(time_unit='ms'))
interactions.head()

user_id,session_id,timestamp,item_id
i32,i32,i64,i32
4669961,2070579,1676423707362,1024676
4669961,2070579,1676423707367,1353716
4669961,2070579,1676423707371,667979
4669961,2070579,1676423707374,554027
4669961,2070579,1676423707377,305326


In [8]:
interactions.select(pl.col('user_id', 'item_id').n_unique())

user_id,item_id
u32,u32
350781,1220940


In [9]:
# Subsample NUM_USERS users
NUM_USERS = 10_000

interactions = interactions.filter(
    pl.col('user_id').is_in(
        interactions.get_column('user_id').unique().sort().sample(NUM_USERS, seed=42)
    )
)
print(f'{len(interactions) = }')
interactions.select(pl.col('user_id', 'item_id').n_unique())

len(interactions) = 3687800


user_id,item_id
u32,u32
10000,253038


In [10]:
# Remove consecutive duplicates
interactions = ConsecutiveDuplicatesFilter()(interactions)
len(interactions)

3409280

In [11]:
# Split data into train/val and test subsets
TIME_THRESHOLD_QUANTILE = 0.9
TIME_THRESHOLD = interactions.get_column('timestamp').quantile(TIME_THRESHOLD_QUANTILE)

condition = pl.col('timestamp') > TIME_THRESHOLD

train_val_interactions = interactions.filter(~condition)
test_interactions = interactions.filter(condition)

len(train_val_interactions), len(test_interactions)

(3068352, 340928)

In [12]:
# Apply N-core filtering to train/val data
N = 3
shape = None

while shape != train_val_interactions.shape:
    shape = train_val_interactions.shape

    train_val_interactions = MinCountFilter(N, groupby_column='user_id').transform(
        train_val_interactions
    )
    train_val_interactions = MinCountFilter(N, groupby_column='item_id').transform(
        train_val_interactions
    )

len(train_val_interactions)

2898290

In [13]:
# Assign tail interactions back to test users
test_users = test_interactions.get_column('user_id').unique()
test_interactions = pl.concat(
    [test_interactions, train_val_interactions.filter(pl.col('user_id').is_in(test_users))]
)
len(test_interactions)

2648693

In [14]:
# Split data into train and validation subsets
TRAIN_USERS_FRACTION = 0.9
TRAIN_USERS = (
    train_val_interactions.get_column('user_id')
    .unique()
    .sort()
    .sample(fraction=TRAIN_USERS_FRACTION, seed=42)
)

condition = pl.col('user_id').is_in(TRAIN_USERS)

train_interactions = train_val_interactions.filter(condition)
val_interactions = train_val_interactions.filter(~condition)

len(train_interactions), len(val_interactions)

(2621480, 276810)

In [15]:
# Encode warm item IDs
WARM_ITEMS = train_interactions.get_column('item_id').unique().sort()
print(f'{len(WARM_ITEMS) = }')

# Keep zero for padding
item2index_warm = {item: index + 1 for index, item in enumerate(WARM_ITEMS)}

train_interactions = train_interactions.with_columns(
    pl.col('item_id').replace_strict(item2index_warm)
)

# Filter out cold items from validation subset
val_interactions = val_interactions.filter(pl.col('item_id').is_in(WARM_ITEMS))
val_interactions = val_interactions.with_columns(pl.col('item_id').replace_strict(item2index_warm))

len(WARM_ITEMS) = 107448


In [16]:
# Encode cold item IDs
test_interactions = test_interactions.with_columns(
    ~pl.col('item_id').is_in(WARM_ITEMS).alias('is_cold')
)

COLD_ITEMS = test_interactions.filter(pl.col('is_cold')).get_column('item_id').unique().sort()
print(f'{len(COLD_ITEMS) = }')

bias = max(item2index_warm.values()) + 1
item2index_cold = {item: index + bias for index, item in enumerate(COLD_ITEMS)}

test_interactions = test_interactions.with_columns(
    pl.col('item_id').replace_strict({**item2index_warm, **item2index_cold})
)

len(COLD_ITEMS) = 23637


In [17]:
# Separate ground truth from test interactions
test_interactions = test_interactions.with_columns(
    pl.col('user_id')
    .cum_count(reverse=True)
    .over('user_id', order_by='timestamp')
    .alias('position')
)

condition = pl.col('position') == 1

ground_truth = test_interactions.filter(condition)
test_interactions = test_interactions.filter(~condition)

In [18]:
len(train_interactions), len(val_interactions), len(test_interactions), len(ground_truth)

(2621480, 263305, 2644360, 4333)

In [19]:
# Apply some basic checks
assert (train_interactions.get_column('timestamp') <= TIME_THRESHOLD).all()
assert (val_interactions.get_column('timestamp') <= TIME_THRESHOLD).all()
assert (ground_truth.get_column('timestamp') > TIME_THRESHOLD).all()

In [20]:
# Save processed data
OUTPUT_DIR = '../../data/zvuk'

if not os.path.exists(os.path.join(OUTPUT_DIR, 'processed')):
    os.makedirs(os.path.join(OUTPUT_DIR, 'processed'))

train_interactions.write_parquet(os.path.join(OUTPUT_DIR, 'processed/train_interactions.parquet'))
val_interactions.write_parquet(os.path.join(OUTPUT_DIR, 'processed/val_interactions.parquet'))
test_interactions.write_parquet(os.path.join(OUTPUT_DIR, 'processed/test_interactions.parquet'))
ground_truth.write_parquet(os.path.join(OUTPUT_DIR, 'processed/ground_truth.parquet'))

with open(os.path.join(OUTPUT_DIR, 'processed/item2index_warm.pkl'), mode='wb') as file:
    pickle.dump(item2index_warm, file)

with open(os.path.join(OUTPUT_DIR, 'processed/item2index_cold.pkl'), mode='wb') as file:
    pickle.dump(item2index_cold, file)

In [21]:
metadata = pl.read_parquet(os.path.join(ZVUK_DATA_DIR, 'zvuk-track_artist_embedding.parquet'))
print(f'{metadata.shape = }')
metadata.head()

metadata.shape = (2199876, 4)


track_id,artist_id,cluster_id,vector
i32,i32,i32,list[f32]
1153192,123804,81,"[-0.133804, 0.030982, … -0.035863]"
1153192,53861,81,"[-0.133804, 0.030982, … -0.035863]"
1407103,119273,5,"[-0.202322, 0.02799, … 0.023902]"
767976,237822,4,"[-0.187276, -0.007839, … -0.026898]"
767976,125521,4,"[-0.187276, -0.007839, … -0.026898]"


In [22]:
# Process metadata
metadata = (
    metadata.rename({'track_id': 'item_id'})
    .filter(pl.col('item_id').is_in(pl.concat((WARM_ITEMS, COLD_ITEMS))))
    .with_columns(pl.col('item_id').replace_strict({**item2index_warm, **item2index_cold}))
    .unique(['item_id', 'vector'])
    .sort('item_id')
)
print(f'{metadata.shape = }')
metadata.head()

metadata.shape = (131085, 4)


item_id,artist_id,cluster_id,vector
i64,i32,i32,list[f32]
1,88862,93,"[-0.121946, 0.004496, … -0.01959]"
2,150708,40,"[-0.170344, 0.026975, … -0.020984]"
3,250819,68,"[-0.178707, -0.047905, … 0.038794]"
4,227608,42,"[-0.161459, 0.009666, … -0.043699]"
5,232146,69,"[-0.188571, 0.000846, … 0.029521]"


In [23]:
assert len(metadata) == len(item2index_warm) + len(item2index_cold)

In [24]:
item_embeddings = np.vstack(metadata.get_column('vector').to_list())
item_embeddings.shape

(131085, 128)

In [25]:
warm_item_embeddings = item_embeddings[: len(WARM_ITEMS)]
cold_item_embeddings = item_embeddings[len(WARM_ITEMS) :]

In [26]:
# Save item embeddings
if not os.path.exists(os.path.join(OUTPUT_DIR, 'item_embeddings')):
    os.makedirs(os.path.join(OUTPUT_DIR, 'item_embeddings'))

np.save(os.path.join(OUTPUT_DIR, 'item_embeddings/embeddings_warm.npy'), warm_item_embeddings)
np.save(os.path.join(OUTPUT_DIR, 'item_embeddings/embeddings_cold.npy'), cold_item_embeddings)