## 0. Imports

In [None]:
import utils.dataset_functions as df
import utils.user_features as uf
import utils.two_towers as ttn
import pandas as pd
import torch
from threading import Thread
from pathlib import Path
from tqdm.notebook import tqdm as progress_bar


(Path("dataset") / "processed").mkdir(parents=True, exist_ok=True)
data_dir = Path("dataset") / "unprocessed"
data_dir.mkdir(parents=True, exist_ok=True)

## 1. Download and write locally to CSV's

In [None]:
# Write files locally
dataset_types = ["likes", "listens", "dislikes", "unlikes", "undislikes"]
dataset = df.YambdaDataset('flat', '50m')
for dt in dataset_types:
    df.download_df(dataset=dataset, dataset_type=dt)


if not (data_dir / "embeddings.csv").exists():
    embeddings = dataset.audio_embeddings().to_pandas()
    embeddings.to_csv(data_dir / "embeddings.csv", index=False)
    del embeddings

## 2. Load Dataframes

In [None]:
# User like-dislike interactions
likes = pd.read_csv(data_dir / "likes.csv", usecols=['uid', 'timestamp', 'item_id'])
dislikes = pd.read_csv(data_dir / "dislikes.csv", usecols=['uid', 'timestamp', 'item_id'])
unlikes = pd.read_csv(data_dir / "unlikes.csv", usecols=['uid', 'timestamp', 'item_id'])
undislikes = pd.read_csv(data_dir / "dislikes.csv", usecols=['uid', 'timestamp', 'item_id'])

# If not done before
if (Path("dataset") / "processed" / "merged.csv").exists():
    user_item_data = pd.read_csv(Path("dataset") / "processed" / "merged.csv", index_col=False)
    user_item_data["normalized_embed"] = user_item_data["normalized_embed"].apply(df.parse_embedding)

else:
    # User listen interactions
    listens = pd.read_csv(data_dir / "listens.csv", index_col=False)
    listens[listens["is_organic"] == 1]
    listens.drop("is_organic", axis=1, inplace=True)

    # due to computational limitations, we constrain our dataset to users to have between 500 and 5000 timestamps.
    listens = listens.groupby('uid').filter(lambda x: 100 <= len(x) <=5000)
    # We only take data containing songs that appear at least a 1000 times in the interaction dataset
    cut_off = 500
    counts = listens['item_id'].value_counts()
    temp = []
    for id, count in counts.items():
        if count >= cut_off:
            temp.append(id)
    print(len(temp))
    listens  = listens[listens['item_id'].isin(temp)]
    

    # Embeddings
    embeddings = pd.read_csv(data_dir/'embeddings.csv', usecols=['item_id', 'normalized_embed'], index_col=False)
    embeddings["normalized_embed"] = embeddings["normalized_embed"].apply(df.parse_embedding)

    # Merge the song embeddings and user listens dataset 
    user_item_data = pd.merge(listens, embeddings, on='item_id', how='inner')

    valid_items = user_item_data['item_id'].unique()
    likes  = likes[likes['item_id'].isin(valid_items)]
    dislikes  = dislikes[dislikes['item_id'].isin(valid_items)]
    unlikes  = unlikes[unlikes['item_id'].isin(valid_items)]
    undislikes  = undislikes[undislikes['item_id'].isin(valid_items)]


    # save memory
    del listens
    del embeddings
    del temp


    uid_map = {}

    for id, uid in enumerate(user_item_data['uid'].unique()):
        uid_map[uid] = id

    item_id_map = {}
    for id, sid in enumerate(user_item_data['item_id'].unique()):
        item_id_map[sid] = id
    
    
    
    user_item_data['uid'] = user_item_data["uid"].replace(uid_map)
    user_item_data['item_id'] = user_item_data["item_id"].replace(item_id_map)

    likes['uid'] = likes["uid"].replace(uid_map)
    likes['item_id'] = likes["item_id"].replace(item_id_map)

    dislikes['uid'] = dislikes["uid"].replace(uid_map)
    dislikes['item_id'] = dislikes["item_id"].replace(item_id_map)

    unlikes['uid'] = unlikes["uid"].replace(uid_map)
    unlikes['item_id'] = unlikes["item_id"].replace(item_id_map)

    undislikes['uid'] = undislikes["uid"].replace(uid_map)
    undislikes['item_id'] = undislikes["item_id"].replace(item_id_map)


    # Save our processed dataset.
    user_item_data.to_csv(Path("dataset") / "processed" / "merged.csv", index=False)
    likes.to_csv(Path("dataset") / "unprocessed" / "likes.csv", index=False)
    dislikes.to_csv(Path("dataset") / "unprocessed" / "dislikes.csv", index=False)
    unlikes.to_csv(Path("dataset") / "unprocessed" / "unlikes.csv", index=False)
    undislikes.to_csv(Path("dataset") / "unprocessed" / "undislikes.csv", index=False)

user_item_data

## 3. Create and save user features
We do this in train/val/test splits

In [None]:
users = user_item_data['uid'].unique()

# It is HIGHLY recommended to use more than 1 thread per set
# You can split the data equally over threads like so:

todo_users = [u for u in users if not (Path("dataset")/ "processed" / "users"/ f"{u}.pt").exists()]
num_threads = 3
k, m = divmod(len(todo_users), num_threads)
user_split = [todo_users[i*k + min(i, m) : (i+1)*k + min(i+1, m)] for i in range(num_threads)]



# Multithread it to make it somewhat time managable
t1 = Thread(target=uf.extract_and_save_features, args=(user_split[0], Path("dataset")/ "processed" / "users", user_item_data, likes, unlikes, dislikes, undislikes))
t2 = Thread(target=uf.extract_and_save_features, args=(user_split[1], Path("dataset")/ "processed" / "users", user_item_data, likes, unlikes, dislikes, undislikes))
t3 = Thread(target=uf.extract_and_save_features, args=(user_split[2], Path("dataset")/ "processed" / "users", user_item_data, likes, unlikes, dislikes, undislikes))

t1.start()
t2.start()
t3.start()


t1.join()
t2.join()
t3.join()


In [None]:
# free up memory, we don't need this anymore
del user_item_data 
del likes
del dislikes
del unlikes
del undislikes

### 3.1: merge the seperate user files

In [None]:
files = Path("dataset") / "processed" / "users"
train = []
val = []
test = []


train_feats = []
train_ids = []
train_songids = []
train_embeds = []
train_labels = []
train_interactions = []

val_feats = []
val_ids = []
val_songids = []
val_embeds = []
val_labels = []
val_interactions = []

test_feats = []
test_ids = []
test_songids = []
test_embeds = []
test_labels = []
test_interactions = []

files = list(files.glob("*.pt"))

for file in progress_bar(files, desc="Processing users"):
    data = torch.load(file, map_location="cpu")
    feats = data['user_feats']         # shape: [N, F]
    user_ids = data['user_ids']        # shape: [N]
    song_ids = data['song_ids']        # shape: [N]
    embeds = data['song_embeds']       # shape: [N, E]
    labels = data['labels']            # shape: [N, L]
    interactions = data['interactions'] # shape: [N]

    N = feats.shape[0]

    # Split indices
    train_end = int(N * 0.70)
    val_end   = int(N * 0.85)

    # Slice per user
    feats_train = feats[:train_end]
    feats_val   = feats[train_end:val_end]
    feats_test  = feats[val_end:]

    ids_train = user_ids[:train_end]
    ids_val   = user_ids[train_end:val_end]
    ids_test  = user_ids[val_end:]

    songids_train = song_ids[:train_end]
    songids_val   = song_ids[train_end:val_end]
    songids_test  = song_ids[val_end:]

    embeds_train = embeds[:train_end]
    embeds_val   = embeds[train_end:val_end]
    embeds_test  = embeds[val_end:]

    labels_train = labels[:train_end]
    labels_val   = labels[train_end:val_end]
    labels_test  = labels[val_end:]

    inter_train = interactions[:train_end]
    inter_val   = interactions[train_end:val_end]
    inter_test  = interactions[val_end:]

    # Append to global lists
    train_feats.append(feats_train)
    train_ids.append(ids_train)
    train_songids.append(songids_train)
    train_embeds.append(embeds_train)
    train_labels.append(labels_train)
    train_interactions.append(inter_train)

    val_feats.append(feats_val)
    val_ids.append(ids_val)
    val_songids.append(songids_val)
    val_embeds.append(embeds_val)
    val_labels.append(labels_val)
    val_interactions.append(inter_val)

    test_feats.append(feats_test)
    test_ids.append(ids_test)
    test_songids.append(songids_test)
    test_embeds.append(embeds_test)
    test_labels.append(labels_test)
    test_interactions.append(inter_test)

# Final merge train 
train = {
    "user_feats": torch.cat(train_feats, dim=0),
    "user_ids": torch.cat(train_ids, dim=0),
    "song_ids": torch.cat(train_songids, dim=0),
    "song_embeds": torch.cat(train_embeds, dim=0),
    "labels": torch.cat(train_labels, dim=0),
    "interactions": torch.cat(train_interactions, dim=0),
}

val = {
    "user_feats": torch.cat(val_feats, dim=0),
    "user_ids": torch.cat(val_ids, dim=0),
    "song_ids": torch.cat(val_songids, dim=0),
    "song_embeds": torch.cat(val_embeds, dim=0),
    "labels": torch.cat(val_labels, dim=0),
    "interactions": torch.cat(val_interactions, dim=0),
}

test = {
    "user_feats": torch.cat(test_feats, dim=0),
    "user_ids": torch.cat(test_ids, dim=0),
    "song_ids": torch.cat(test_songids, dim=0),
    "song_embeds": torch.cat(test_embeds, dim=0),
    "labels": torch.cat(test_labels, dim=0),
    "interactions": torch.cat(test_interactions, dim=0),
}

print("Train size:", train["user_feats"].shape[0])
print("Val size:  ", val["user_feats"].shape[0])
print("Test size: ", test["user_feats"].shape[0])

torch.save(train, Path("dataset") / "processed" / "train.pt")
torch.save(val, Path("dataset") / "processed" / "val.pt")
torch.save(test, Path("dataset") / "processed" / "test.pt")

## 4. Training the model

### 4.1 set the parameters

In [None]:
# Dimensions of the tower FFN
output_dim      = 16
hidden_dim      = 256
id_dim          = 16


# Training
num_epochs      = 50
learning_rate   = 1e-3
batch_size      = 256
patience        = 5
device='cuda' if torch.cuda.is_available() else 'cpu'
print("Training models on:", device)


### 4.2 Training our models

In [None]:
models = ["binary_label", "continueous_label"]

for label_id, model_name in enumerate(models):
    train_set = df.load_tensor_dataloader("train", Path("dataset")/"processed", batch_size, label_id)
    val_set = df.load_tensor_dataloader("val", Path("dataset")/"processed", batch_size, label_id)


    model = ttn.DualAugmentedTwoTower(model_name, hidden_dim, output_dim, id_dim)
    optimiser = torch.optim.Adam(model.parameters(), lr = learning_rate)
    ttn.train_model(model, train_set, val_set, optimiser, patience, num_epochs=50, device=device)

