In [1]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import lightning as pl
import numpy as np
import pandas as pd
from datasets import load_dataset
import ast

seed = 42
np.random.seed(seed)
pl.seed_everything(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  from pkg_resources import DistributionNotFound, get_distribution
Seed set to 42


Using device: cuda


## Define Tag Categories

Define all possible tags for each category based on the dataset.

In [2]:
TAXONOMY = json.load(open("../data/concepts_to_tags.json", "r"))

CATEGORIES = list(TAXONOMY.keys())

# Reverse map for easy lookup (tag -> category)
TAG_TO_CATEGORY = {}
for cat, tags in TAXONOMY.items():
    for tag in tags:
        TAG_TO_CATEGORY[tag] = cat


In [3]:
tag_to_idx = {}
idx_to_tag = {}
cat_ranges = {} # Stores start/end index for each category

current_idx = 0
for cat in CATEGORIES:
    start = current_idx
    for tag in TAXONOMY[cat]:
        tag_to_idx[tag] = current_idx
        idx_to_tag[current_idx] = (cat, tag)
        current_idx += 1
    cat_ranges[cat] = (start, current_idx)

TOTAL_INPUT_DIM = current_idx
print(f"Total Input Dimension: {TOTAL_INPUT_DIM}")

Total Input Dimension: 400


## Prepare dataset

In [10]:
def process_data_multilabel(df: pd.DataFrame) -> np.ndarray:
    """
    Creates a Multi-Hot vector for every song.
    Example: [0, 1, 0, 1, 1, ...] where 1 means the tag is present.
    """
    processed_data = []

    for _, row in df.iterrows():
        raw_tags = ast.literal_eval(row['aspect_list'])
        raw_tags = [t.lower() for t in raw_tags]
            
        # Create Zero Vector
        vector = np.zeros(TOTAL_INPUT_DIM, dtype=np.float32)
        has_data = False
        
        for tag in raw_tags:
            if tag in tag_to_idx:
                idx = tag_to_idx[tag]
                vector[idx] = 1.0
                has_data = True
        
        # Only keep records that have at least one valid tag
        if has_data:
            processed_data.append(vector)
            
    return np.array(processed_data)

In [8]:
df = load_dataset("google/MusicCaps", split="train").to_pandas()

In [11]:
data = process_data_multilabel(df)
print(f"Processed data shape: {data.shape}")

Processed data shape: (5163, 400)


## VAE

In [12]:
class MultiLabelVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=32, hidden_dim=128):
        super(MultiLabelVAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim) 
        
        # Dropout for the "Denoising" part (applied to input)
        self.input_dropout = nn.Dropout(p=0.3) 

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc2_mu(h1), self.fc2_logvar(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        # Apply dropout to inputs during training -> forces model to learn correlations
        x_noisy = self.input_dropout(x)
        mu, logvar = self.encode(x_noisy)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

In [20]:
# Hyperparameters
input_dim = TOTAL_INPUT_DIM
latent_dim = 32
hidden_dim = 128
batch_size = 64
num_epochs = 300
learning_rate = 1e-3

In [21]:
# Model, Optimizer, Loss Function
model = MultiLabelVAE(input_dim, latent_dim, hidden_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
bce_loss_fn = nn.BCELoss(reduction='sum')

In [22]:
def vae_loss(recon_x, x, mu, logvar):
    BCE = bce_loss_fn(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [23]:
# Prepare DataLoader
dataset = torch.utils.data.TensorDataset(torch.tensor(data))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [24]:
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        inputs = batch[0].to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(inputs)
        loss = vae_loss(recon_batch, inputs, mu, logvar)
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
    avg_loss = total_loss / len(dataloader.dataset)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
# Save the trained model
torch.save(model.state_dict(), "../models/multilabel_vae.pth")

Epoch 1, Loss: 99.8781
Epoch 2, Loss: 29.4878
Epoch 3, Loss: 27.2239
Epoch 4, Loss: 26.2956
Epoch 5, Loss: 25.1655
Epoch 6, Loss: 24.1259
Epoch 7, Loss: 23.2148
Epoch 8, Loss: 22.5671
Epoch 9, Loss: 22.1267
Epoch 10, Loss: 21.9945
Epoch 11, Loss: 21.8269
Epoch 12, Loss: 21.8291
Epoch 13, Loss: 21.7650
Epoch 14, Loss: 21.6979
Epoch 15, Loss: 21.6262
Epoch 16, Loss: 21.5047
Epoch 17, Loss: 21.4424
Epoch 18, Loss: 21.3595
Epoch 19, Loss: 21.2979
Epoch 20, Loss: 21.2361
Epoch 21, Loss: 21.1878
Epoch 22, Loss: 21.0612
Epoch 23, Loss: 21.0080
Epoch 24, Loss: 20.8816
Epoch 25, Loss: 20.8552
Epoch 26, Loss: 20.7075
Epoch 27, Loss: 20.5731
Epoch 28, Loss: 20.5440
Epoch 29, Loss: 20.4309
Epoch 30, Loss: 20.3341
Epoch 31, Loss: 20.2750
Epoch 32, Loss: 20.2374
Epoch 33, Loss: 20.1659
Epoch 34, Loss: 20.1663
Epoch 35, Loss: 20.0613
Epoch 36, Loss: 20.1469
Epoch 37, Loss: 20.0320
Epoch 38, Loss: 20.0033
Epoch 39, Loss: 19.9030
Epoch 40, Loss: 19.8433
Epoch 41, Loss: 19.8725
Epoch 42, Loss: 19.7929
E

## Generate tags

In [25]:
model = MultiLabelVAE(input_dim, latent_dim, hidden_dim).to(device)
model.load_state_dict(torch.load("../models/multilabel_vae.pth", map_location=device))
model.eval()

MultiLabelVAE(
  (fc1): Linear(in_features=400, out_features=128, bias=True)
  (fc2_mu): Linear(in_features=128, out_features=32, bias=True)
  (fc2_logvar): Linear(in_features=128, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=400, bias=True)
  (input_dropout): Dropout(p=0.3, inplace=False)
)

In [26]:
def generate_tags(model, seed_tags, requests):
    """
    seeds: List of tags we ALREADY have (e.g. ['rock', 'guitar'])
    requests: Dict of how many tags we want per category (e.g. {'instrument': 2, 'mood': 1})
    """
    model.eval()
    device = next(model.parameters()).device
    
    # 1. Build the Input Vector from Seeds
    input_vec = torch.zeros(1, TOTAL_INPUT_DIM).to(device)
    
    print(f"\n--- Context: {seed_tags} ---")
    
    # Fill in the knowns
    for tag in seed_tags:
        if tag in tag_to_idx:
            input_vec[0, tag_to_idx[tag]] = 1.0
        else:
            print(f"Warning: Seed tag '{tag}' not in taxonomy.")

    with torch.no_grad():
        # 2. Encode to get the Latent Vibe (z)
        # Note: We don't use dropout here; we want the model to use all clues we gave it.
        mu, logvar = model.encode(input_vec)
        z = model.reparameterize(mu, logvar)
        
        # 3. Decode to get probabilities for EVERYTHING
        # Output shape: [1, Total_Dim] (Values 0.0 to 1.0)
        probs = model.decode(z)[0] 
        
        # 4. Extract Top-K for requested categories
        results = {}
        
        for category, count in requests.items():
            start, end = cat_ranges[category]
            
            # Slice the probabilities relevant to this category
            cat_probs = probs[start:end]
            
            # Get Top K indices for this slice
            # We ask for count + len(seeds) just in case the model predicts the seed tag again
            top_k_vals, top_k_indices = torch.topk(cat_probs, k=count + 5)
            
            # Convert slice-indices back to global-indices, then to strings
            found_tags = []
            for i in range(len(top_k_indices)):
                local_idx = top_k_indices[i].item()
                global_idx = start + local_idx
                tag_name = idx_to_tag[global_idx][1]
                
                # Don't return tags we already provided as seeds
                if tag_name not in seed_tags:
                    found_tags.append(tag_name)
                
                if len(found_tags) == count:
                    break
            
            results[category] = found_tags
            print(f"Generated {count} {category}(s): {found_tags}")
            
    return results

In [27]:
my_seeds = ["rock", "guitar", "happy"]
my_wants = {
    "instrument": 2, # I want 2 MORE instruments
    "tempo": 1       # I want 1 tempo
}

generate_tags(model, my_seeds, my_wants)


--- Context: ['rock', 'guitar', 'happy'] ---
Generated 2 instrument(s): ['electric guitar', 'flat male vocal']
Generated 1 tempo(s): ['fast tempo']


{'instrument': ['electric guitar', 'flat male vocal'], 'tempo': ['fast tempo']}