In [1]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter, defaultdict
from tqdm import tqdm
from datasets import load_dataset
from scipy.stats import norm, multivariate_normal
import ast

seed = 42
np.random.seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Define Tag Categories

Define all possible tags for each category based on the dataset.

In [2]:
TAXONOMY = json.load(open("../data/concepts_to_tags.json", "r"))

CATEGORIES = list(TAXONOMY.keys())

# Reverse map for easy lookup (tag -> category)
TAG_TO_CATEGORY = {}
for cat, tags in TAXONOMY.items():
    for tag in tags:
        TAG_TO_CATEGORY[tag] = cat


In [3]:
tag_to_idx = {}
idx_to_tag = {}
cat_ranges = {} # Stores start/end index for each category

current_idx = 0
for cat in CATEGORIES:
    start = current_idx
    for tag in TAXONOMY[cat]:
        tag_to_idx[tag] = current_idx
        idx_to_tag[current_idx] = (cat, tag)
        current_idx += 1
    cat_ranges[cat] = (start, current_idx)

TOTAL_INPUT_DIM = current_idx
print(f"Total Input Dimension: {TOTAL_INPUT_DIM}")

Total Input Dimension: 200


## Prepare dataset

In [4]:
def process_data_multilabel(df: pd.DataFrame) -> np.ndarray:
    """
    Creates a Multi-Hot vector for every song.
    Example: [0, 1, 0, 1, 1, ...] where 1 means the tag is present.
    """
    processed_data = []

    for _, row in df.iterrows():
        raw_tags = ast.literal_eval(row['aspect_list'])
        raw_tags = [t.lower() for t in raw_tags]
            
        # Create Zero Vector
        vector = np.zeros(TOTAL_INPUT_DIM, dtype=np.float32)
        has_data = False
        
        for tag in raw_tags:
            if tag in tag_to_idx:
                idx = tag_to_idx[tag]
                vector[idx] = 1.0
                has_data = True
        
        # Only keep records that have at least one valid tag
        if has_data:
            processed_data.append(vector)
            
    return np.array(processed_data)

In [5]:
df = load_dataset("google/MusicCaps", split="train").to_pandas()

In [6]:
data = process_data_multilabel(df)
print(f"Processed data shape: {data.shape}")

Processed data shape: (5046, 200)


## VAE

In [8]:
class MultiLabelVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=32, hidden_dim=128):
        super(MultiLabelVAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim) 
        
        # Dropout for the "Denoising" part (applied to input)
        self.input_dropout = nn.Dropout(p=0.3) 

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc2_mu(h1), self.fc2_logvar(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, temperature=1.0):
        h3 = F.relu(self.fc3(z))
        logits = self.fc4(h3)
        return torch.sigmoid(logits / temperature)

    def forward(self, x, temperature=1.0):
        # Apply dropout to inputs during training -> forces model to learn correlations
        x_noisy = self.input_dropout(x)
        mu, logvar = self.encode(x_noisy)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, temperature=temperature)
        return recon, mu, logvar

In [9]:
# Hyperparameter optimization setup
input_dim = TOTAL_INPUT_DIM

# Define hyperparameter combinations
hyperparameter_space = {
    'vae_small': {
        'latent_dim': 32,
        'hidden_dim': 128,
        'learning_rate': 5e-3,
        'epochs': 300
    },
    'vae_medium': {
        'latent_dim': 64,
        'hidden_dim': 256,
        'learning_rate': 5e-3,
        'epochs': 300
    },
    'vae_large': {
        'latent_dim': 128,
        'hidden_dim': 512,
        'learning_rate': 5e-3,
        'epochs': 300
    }
}

batch_size = 32

In [10]:
def vae_loss(recon_x, x, mu, logvar):
    bce_loss_fn = nn.BCELoss(reduction='sum')
    BCE = bce_loss_fn(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
for model_name, params in hyperparameter_space.items():
    print(f"Training model: {model_name} with params: {params}")
    
    latent_dim = params['latent_dim']
    hidden_dim = params['hidden_dim']
    learning_rate = params['learning_rate']
    num_epochs = params['epochs']
    
    model = MultiLabelVAE(input_dim, latent_dim, hidden_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    data_tensor = torch.tensor(data).to(device)
    dataset = torch.utils.data.TensorDataset(data_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            inputs = batch[0]
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(inputs)
            loss = vae_loss(recon_batch, inputs, mu, logvar)
            loss.backward()
            total_loss += loss.item()
            optimizer.step()
        
        avg_loss = total_loss / len(data_tensor)
        if (epoch + 1) % 50 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
    
    # Save the trained model
    torch.save(model.state_dict(), f"../models/{model_name}.pth")
    print(f"Model {model_name} saved.\n")

Training model: vae_small with params: {'latent_dim': 32, 'hidden_dim': 128, 'learning_rate': 0.005, 'epochs': 300}
Epoch [50/300], Loss: 14.7081
Epoch [50/300], Loss: 14.7081
Epoch [100/300], Loss: 14.5597
Epoch [100/300], Loss: 14.5597
Epoch [150/300], Loss: 14.5648
Epoch [150/300], Loss: 14.5648
Epoch [200/300], Loss: 14.4779
Epoch [200/300], Loss: 14.4779
Epoch [250/300], Loss: 14.5191
Epoch [250/300], Loss: 14.5191
Epoch [300/300], Loss: 14.4475
Model vae_small saved.

Training model: vae_medium with params: {'latent_dim': 64, 'hidden_dim': 256, 'learning_rate': 0.005, 'epochs': 300}
Epoch [300/300], Loss: 14.4475
Model vae_small saved.

Training model: vae_medium with params: {'latent_dim': 64, 'hidden_dim': 256, 'learning_rate': 0.005, 'epochs': 300}
Epoch [50/300], Loss: 14.9363
Epoch [50/300], Loss: 14.9363
Epoch [100/300], Loss: 14.7020
Epoch [100/300], Loss: 14.7020
Epoch [150/300], Loss: 14.6514
Epoch [150/300], Loss: 14.6514
Epoch [200/300], Loss: 14.6524
Epoch [200/300], 

In [11]:
latent_dim = 32
hidden_dim = 128
learning_rate = 5e-4
num_epochs = 250
model_name = "vae_final"

model = MultiLabelVAE(input_dim, latent_dim, hidden_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

data_tensor = torch.tensor(data).to(device)
dataset = torch.utils.data.TensorDataset(data_tensor)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        inputs = batch[0]
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(inputs)
        loss = vae_loss(recon_batch, inputs, mu, logvar)
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
    
    avg_loss = total_loss / len(data_tensor)
    if (epoch + 1) % 50 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), f"../models/{model_name}.pth")
print(f"Model {model_name} saved.\n")

Epoch [50/250], Loss: 15.1943
Epoch [100/250], Loss: 14.5649
Epoch [150/250], Loss: 14.3618
Epoch [200/250], Loss: 14.2517
Epoch [250/250], Loss: 14.1880
Model vae_final saved.

