## 0. Imports

In [None]:
import utils.dataset_functions as df
import utils.user_features as uf
import utils.two_towers as ttn
import pandas as pd
import ast
import torch
import numpy as np
from threading import Thread
from pathlib import Path
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm as progress_bar


(Path("dataset") / "processed").mkdir(parents=True, exist_ok=True)
data_dir = Path("dataset") / "unprocessed"
data_dir.mkdir(parents=True, exist_ok=True)

## 1. Download and write locally to CSV's

In [None]:
# Write files locally
dataset_types = ["likes", "listens", "dislikes", "unlikes", "undislikes"]
dataset = df.YambdaDataset('flat', '50m')
for dt in dataset_types:
    df.download_df(dataset=dataset, dataset_type=dt)


if not (data_dir / "embeddings.csv").exists():
    embeddings = dataset.audio_embeddings().to_pandas()
    embeddings.to_csv(data_dir / "embeddings.csv", index=False)
    del embeddings

## 2. Load Dataframes

In [None]:
# User like-dislike interactions
likes = pd.read_csv(data_dir / "likes.csv", usecols=['uid', 'timestamp', 'item_id'])
dislikes = pd.read_csv(data_dir / "dislikes.csv", usecols=['uid', 'timestamp', 'item_id'])
unlikes = pd.read_csv(data_dir / "unlikes.csv", usecols=['uid', 'timestamp', 'item_id'])
undislikes = pd.read_csv(data_dir / "dislikes.csv", usecols=['uid', 'timestamp', 'item_id'])

# If not done before
if (Path("dataset") / "processed" / "merged.csv").exists():
    user_item_data = pd.read_csv(Path("dataset") / "processed" / "merged.csv", index_col=False)
    user_item_data["normalized_embed"] = user_item_data["normalized_embed"].apply(df.parse_embedding)
    #user_item_data['labels'] = user_item_data['labels'].apply(ast.literal_eval)

else:
   

    # User listen interactions
    listens = pd.read_csv(data_dir / "listens.csv", index_col=False)
    listens.drop(columns=['is_organic'])

    # due to computational limitations, we constrain our dataset to users to have between 500 and 5000 timestamps.
    listens = listens.groupby('uid').filter(lambda x: 500 <= len(x) <= 5000)

    # Embeddings
    embeddings = pd.read_csv(data_dir/'embeddings.csv', usecols=['item_id', 'normalized_embed'], index_col=False)
    embeddings["normalized_embed"] = embeddings["normalized_embed"].apply(df.parse_embedding)

    # Merge the song embeddings and user listens dataset 
    user_item_data = pd.merge(listens, embeddings, on='item_id', how='inner')
    
    
    # save memory
    del listens
    del embeddings

    
    # # Determine the labels under different conditions using this function.
    user_item_data[["labels", "net_interactions"]] = user_item_data.apply(
    uf.get_song_label_and_user_interacton,
    axis=1,
    args=(likes, user_item_data, unlikes, dislikes, undislikes),
    result_type="expand") 


    

    # Save our processed dataset.
    user_item_data.to_csv(Path("dataset") / "processed" / "merged.csv", index=False)


user_item_data

## 3. Create and save user features
We do this in train/val/test splits

In [None]:
users = user_item_data['uid'].unique()

__train_val_set, test_set = train_test_split(
    users,
    test_size=0.10,   # 10 % is test data
    random_state=42,        # reproducible shuffling
    shuffle=True
)

train_set, val_set = train_test_split(
    __train_val_set,
    test_size=0.22,   # ~20% validation
    random_state=42,
    shuffle=True
)


print("train:", len(train_set), "users")
print("val  :", len(val_set), "users")
print("test :", len(test_set), "users")



# It is HIGHLY recommended to use more than 1 thread per set
# You can split the data equally over threads like so:
num_threads = 1
k, m = divmod(len(train_set), num_threads)
train_split = [train_set[i*k + min(i, m) : (i+1)*k + min(i+1, m)] for i in range(num_threads)]


# Multithread it to make it somewhat time managable
t1 = Thread(target=uf.extract_and_save_features, args=(train_split[0], user_item_data, 'train', Path("dataset")/ "processed" / "train", likes, user_item_data, unlikes, dislikes, undislikes))
t2 = Thread(target=uf.extract_and_save_features, args=(val_set, user_item_data, 'val', Path("dataset")/ "processed" / "val", likes, user_item_data, unlikes, dislikes, undislikes))
t3 = Thread(target=uf.extract_and_save_features, args=(test_set, user_item_data, 'test', Path("dataset")/ "processed" / "test", likes, user_item_data, unlikes, dislikes, undislikes))


t1.start()
t2.start()
t3.start()

t1.join()
t2.join()
t3.join()

In [None]:
# free up memory, we don't need this anymore
del user_item_data 
del likes
del dislikes
del unlikes
del undislikes

### 3.1: merge the seperate user files

In [None]:
for data_type in ['train', 'val', 'test']:
    combined_file = Path("dataset") / "processed" / f'{data_type}.pt' 
    if not combined_file.exists():
        files = Path("dataset") / "processed" / data_type
        user_feats_all = []
        label_specific_feats_all = []
        user_ids_all = []
        song_embeds_all = []
        song_labels_all = []
        interactions_all = []


        for file in progress_bar(list(files.glob("*.pt")), desc=f"loading and merging {data_type} files"):
            loaded = torch.load(file, map_location="cpu")

            user_feats_all.append(loaded["user_feats"])
            label_specific_feats_all.append(
                torch.cat(loaded["label_specific_feats"], dim=0)
                if isinstance(loaded["label_specific_feats"], (list, tuple))
                else loaded["label_specific_feats"]
            )
            user_ids_all.append(
                torch.cat(loaded["user_ids"], dim=0)
                if isinstance(loaded["user_ids"], (list, tuple))
                else loaded["user_ids"]
            )
            song_embeds_all.append(
                torch.cat(loaded["song_embeds"], dim=0)
                if isinstance(loaded["song_embeds"], (list, tuple))
                else loaded["song_embeds"]
            )
            song_labels_all.append(
                torch.cat(loaded["labels"], dim=0)
                if isinstance(loaded["labels"], (list, tuple))
                else loaded["labels"]
            )
            interactions_all.append(
                torch.cat(loaded["interactions"], dim=0)
                if isinstance(loaded["interactions"], (list, tuple))
                else loaded["interactions"]
            )


        user_feats = torch.cat(user_feats_all, dim=0)
        label_specific_feats = torch.cat(label_specific_feats_all, dim=0)
        user_ids = torch.cat(user_ids_all, dim=0)
        song_embeds = torch.cat(song_embeds_all, dim=0)
        song_labels = torch.cat(song_labels_all, dim=0)
        interactions = torch.cat(interactions_all, dim=0)


        torch.save({
        "user_feats": user_feats,
        "label_specific_feats": label_specific_feats,
        "user_ids": user_ids,
        "song_embeds": song_embeds,
        "labels": song_labels,
        "interactions": interactions,
    }, combined_file)

In [None]:
train_set = df.load_tensor_dataloader("train", Path("dataset")/"processed", batch_size=1)
for user_feats, label_specific_feats, song_embeds, labels, interactions in train_set:
    print(user_feats.shape)
    print(label_specific_feats.shape)
    print(song_embeds.shape)
    print(labels.shape)
    print(interactions.shape)
    break

## 4. Training the model

### 4.1 set the parameters

In [1]:
import utils.dataset_functions as df
import utils.user_features as uf
import utils.two_towers as ttn
import pandas as pd
import ast
import torch
import numpy as np
from threading import Thread
from pathlib import Path
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm as progress_bar


(Path("dataset") / "processed").mkdir(parents=True, exist_ok=True)
data_dir = Path("dataset") / "unprocessed"
data_dir.mkdir(parents=True, exist_ok=True)
# Input dimensions
user_dim        = 5 #+ 129  # user_features + label specific features
item_dim        = 128
aug_dim         = 32

# Dimensions of the tower FFN
hidden_dim      = 64
embed_dim       = 32

# lambda1 for loss_u & lambda2 for loss_V
lambda1         = 1
lambda2         = 1

# Training
num_epochs      = 50
learning_rate   = 1e-3
batch_size      = 128
patience        = 5
device='cuda' if torch.cuda.is_available() else 'cpu'
print("Training models on:", device)

Training models on: cuda


### 4.2 Training our models

In [2]:
train_set = df.load_tensor_dataloader("train", Path("dataset")/"processed", batch_size, 0)
val_set = df.load_tensor_dataloader("val", Path("dataset")/"processed", batch_size, 0)


In [None]:
models = ["interactions_model", "multiple_listens_model", "pct_100_model", "pct_80_model"]

for label_id, model_name in enumerate(models[1:]):
    train_set = df.load_tensor_dataloader("train", Path("dataset")/"processed", batch_size, label_id)
    val_set = df.load_tensor_dataloader("val", Path("dataset")/"processed", batch_size, label_id)


    model = ttn.DualAugmentedTwoTower(model_name, user_dim, item_dim, hidden_dim, aug_dim)
    optimiser = torch.optim.Adam(model.parameters(), lr = learning_rate)
    ttn.train_model(model, train_set, val_set, optimiser, patience, num_epochs=num_epochs, device=device)



Epoch: [9/50]
early stopped


In [6]:
test_set = df.load_tensor_dataloader("test", Path("dataset")/"processed", batch_size, 0)

model.eval()
total_correct = 0
total_samples = 0
total_loss = 0

with torch.no_grad():
    for user_features, song_embedding, labels, interactions in test_set:
        # Move to device
        user_features = user_features.to(device)
        song_embedding = song_embedding.to(device)
        labels = labels.to(device)

        # Forward pass
        scores, pu, pv = model(user_features, song_embedding)

        # Loss
        loss = model.loss(scores, pu, pv, labels, lambda1, lambda2)
        total_loss += loss.item() * labels.size(0)

        # ----- ACCURACY -----
        # If labels are 0/1 → typical classification
        preds = (scores >= 0.95).long()      # threshold for binary classification
        correct = (preds.squeeze() == labels.long()).sum().item()

        total_correct += correct
        total_samples += labels.size(0)

# Final metrics
avg_loss = total_loss / total_samples
accuracy = total_correct / total_samples

print(f"Test Loss: {avg_loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")


Test Loss: 0.3037
Test Accuracy: 0.8443
