In [1]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import lightning as pl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter, defaultdict
from tqdm import tqdm
from datasets import load_dataset
from scipy.stats import norm, multivariate_normal
import ast

seed = 42
np.random.seed(seed)
pl.seed_everything(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  from pkg_resources import DistributionNotFound, get_distribution
Seed set to 42


Using device: cuda


## Define Tag Categories

Define all possible tags for each category based on the dataset.

In [2]:
TAXONOMY = json.load(open("../data/concepts_to_tags.json", "r"))

CATEGORIES = list(TAXONOMY.keys())

# Reverse map for easy lookup (tag -> category)
TAG_TO_CATEGORY = {}
for cat, tags in TAXONOMY.items():
    for tag in tags:
        TAG_TO_CATEGORY[tag] = cat


In [3]:
tag_to_idx = {}
idx_to_tag = {}
cat_ranges = {} # Stores start/end index for each category

current_idx = 0
for cat in CATEGORIES:
    start = current_idx
    for tag in TAXONOMY[cat]:
        tag_to_idx[tag] = current_idx
        idx_to_tag[current_idx] = (cat, tag)
        current_idx += 1
    cat_ranges[cat] = (start, current_idx)

TOTAL_INPUT_DIM = current_idx
print(f"Total Input Dimension: {TOTAL_INPUT_DIM}")

Total Input Dimension: 385


## Prepare dataset

In [4]:
def process_data_multilabel(df: pd.DataFrame) -> np.ndarray:
    """
    Creates a Multi-Hot vector for every song.
    Example: [0, 1, 0, 1, 1, ...] where 1 means the tag is present.
    """
    processed_data = []

    for _, row in df.iterrows():
        raw_tags = ast.literal_eval(row['aspect_list'])
        raw_tags = [t.lower() for t in raw_tags]
            
        # Create Zero Vector
        vector = np.zeros(TOTAL_INPUT_DIM, dtype=np.float32)
        has_data = False
        
        for tag in raw_tags:
            if tag in tag_to_idx:
                idx = tag_to_idx[tag]
                vector[idx] = 1.0
                has_data = True
        
        # Only keep records that have at least one valid tag
        if has_data:
            processed_data.append(vector)
            
    return np.array(processed_data)

In [5]:
df = load_dataset("google/MusicCaps", split="train").to_pandas()

In [6]:
data = process_data_multilabel(df)
print(f"Processed data shape: {data.shape}")

Processed data shape: (5161, 385)


## VAE

In [7]:
class MultiLabelVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=32, hidden_dim=128):
        super(MultiLabelVAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim) 
        
        # Dropout for the "Denoising" part (applied to input)
        self.input_dropout = nn.Dropout(p=0.3) 

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc2_mu(h1), self.fc2_logvar(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, temperature=1.0):
        h3 = F.relu(self.fc3(z))
        logits = self.fc4(h3)
        return torch.sigmoid(logits / temperature)

    def forward(self, x, temperature=1.0):
        # Apply dropout to inputs during training -> forces model to learn correlations
        x_noisy = self.input_dropout(x)
        mu, logvar = self.encode(x_noisy)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, temperature=temperature)
        return recon, mu, logvar

In [8]:
# Hyperparameters
input_dim = TOTAL_INPUT_DIM
latent_dim = 64
hidden_dim = 256
batch_size = 64
num_epochs = 500
learning_rate = 5e-4

In [9]:
# Model, Optimizer, Loss Function
model = MultiLabelVAE(input_dim, latent_dim, hidden_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
bce_loss_fn = nn.BCELoss(reduction='sum')

In [10]:
def vae_loss(recon_x, x, mu, logvar):
    BCE = bce_loss_fn(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [11]:
# Prepare DataLoader
dataset = torch.utils.data.TensorDataset(torch.tensor(data))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [12]:
model.train()
for epoch in tqdm(range(num_epochs), desc="Training VAE"):
    total_loss = 0
    for batch in dataloader:
        inputs = batch[0].to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(inputs)
        loss = vae_loss(recon_batch, inputs, mu, logvar)
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
    avg_loss = total_loss / len(dataloader.dataset)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
# Save the trained model
torch.save(model.state_dict(), "../models/multilabel_vae.pth")

Training VAE:   2%|▏         | 11/500 [00:01<01:12,  6.72it/s]

Epoch [10/500], Loss: 21.7369


Training VAE:   4%|▍         | 21/500 [00:03<01:09,  6.90it/s]

Epoch [20/500], Loss: 21.0426


Training VAE:   6%|▌         | 31/500 [00:04<01:07,  6.91it/s]

Epoch [30/500], Loss: 20.3683


Training VAE:   8%|▊         | 41/500 [00:06<01:06,  6.86it/s]

Epoch [40/500], Loss: 19.8552


Training VAE:  10%|█         | 51/500 [00:07<01:04,  6.98it/s]

Epoch [50/500], Loss: 19.6638


Training VAE:  12%|█▏        | 61/500 [00:09<01:02,  7.01it/s]

Epoch [60/500], Loss: 19.4317


Training VAE:  14%|█▍        | 71/500 [00:10<01:01,  6.95it/s]

Epoch [70/500], Loss: 19.3387


Training VAE:  16%|█▌        | 81/500 [00:11<00:59,  6.99it/s]

Epoch [80/500], Loss: 19.1012


Training VAE:  18%|█▊        | 91/500 [00:13<00:58,  6.96it/s]

Epoch [90/500], Loss: 18.8847


Training VAE:  20%|██        | 101/500 [00:14<00:57,  6.93it/s]

Epoch [100/500], Loss: 18.7474


Training VAE:  22%|██▏       | 111/500 [00:16<00:55,  6.98it/s]

Epoch [110/500], Loss: 18.5129


Training VAE:  24%|██▍       | 121/500 [00:17<00:54,  7.00it/s]

Epoch [120/500], Loss: 18.3633


Training VAE:  26%|██▌       | 131/500 [00:19<00:53,  6.94it/s]

Epoch [130/500], Loss: 18.2089


Training VAE:  28%|██▊       | 141/500 [00:20<00:51,  6.99it/s]

Epoch [140/500], Loss: 18.1076


Training VAE:  30%|███       | 151/500 [00:21<00:49,  7.02it/s]

Epoch [150/500], Loss: 18.0161


Training VAE:  32%|███▏      | 161/500 [00:23<00:48,  7.00it/s]

Epoch [160/500], Loss: 17.9602


Training VAE:  34%|███▍      | 171/500 [00:24<00:47,  6.89it/s]

Epoch [170/500], Loss: 17.8939


Training VAE:  36%|███▌      | 181/500 [00:26<00:46,  6.90it/s]

Epoch [180/500], Loss: 17.8195


Training VAE:  38%|███▊      | 191/500 [00:27<00:44,  6.98it/s]

Epoch [190/500], Loss: 17.7991


Training VAE:  40%|████      | 201/500 [00:29<00:42,  6.96it/s]

Epoch [200/500], Loss: 17.7380


Training VAE:  42%|████▏     | 211/500 [00:30<00:42,  6.80it/s]

Epoch [210/500], Loss: 17.6187


Training VAE:  44%|████▍     | 221/500 [00:32<00:40,  6.96it/s]

Epoch [220/500], Loss: 17.6065


Training VAE:  46%|████▌     | 231/500 [00:33<00:38,  7.06it/s]

Epoch [230/500], Loss: 17.5537


Training VAE:  48%|████▊     | 241/500 [00:34<00:36,  7.06it/s]

Epoch [240/500], Loss: 17.4980


Training VAE:  50%|█████     | 251/500 [00:36<00:36,  6.89it/s]

Epoch [250/500], Loss: 17.5589


Training VAE:  52%|█████▏    | 261/500 [00:37<00:33,  7.18it/s]

Epoch [260/500], Loss: 17.4694


Training VAE:  54%|█████▍    | 271/500 [00:39<00:32,  7.00it/s]

Epoch [270/500], Loss: 17.4238


Training VAE:  56%|█████▌    | 281/500 [00:40<00:31,  6.95it/s]

Epoch [280/500], Loss: 17.4454


Training VAE:  58%|█████▊    | 291/500 [00:42<00:29,  7.01it/s]

Epoch [290/500], Loss: 17.4125


Training VAE:  60%|██████    | 301/500 [00:43<00:27,  7.17it/s]

Epoch [300/500], Loss: 17.4223


Training VAE:  62%|██████▏   | 311/500 [00:44<00:27,  6.98it/s]

Epoch [310/500], Loss: 17.3028


Training VAE:  64%|██████▍   | 321/500 [00:46<00:25,  7.03it/s]

Epoch [320/500], Loss: 17.3377


Training VAE:  66%|██████▌   | 331/500 [00:47<00:24,  7.01it/s]

Epoch [330/500], Loss: 17.2577


Training VAE:  68%|██████▊   | 341/500 [00:49<00:23,  6.91it/s]

Epoch [340/500], Loss: 17.2829


Training VAE:  70%|███████   | 351/500 [00:50<00:21,  6.92it/s]

Epoch [350/500], Loss: 17.3658


Training VAE:  72%|███████▏  | 361/500 [00:52<00:19,  7.01it/s]

Epoch [360/500], Loss: 17.3475


Training VAE:  74%|███████▍  | 371/500 [00:53<00:18,  7.00it/s]

Epoch [370/500], Loss: 17.2063


Training VAE:  76%|███████▌  | 381/500 [00:54<00:16,  7.02it/s]

Epoch [380/500], Loss: 17.1884


Training VAE:  78%|███████▊  | 391/500 [00:56<00:15,  7.02it/s]

Epoch [390/500], Loss: 17.2610


Training VAE:  80%|████████  | 401/500 [00:57<00:14,  7.00it/s]

Epoch [400/500], Loss: 17.2695


Training VAE:  82%|████████▏ | 411/500 [00:59<00:12,  6.92it/s]

Epoch [410/500], Loss: 17.1772


Training VAE:  84%|████████▍ | 421/500 [01:00<00:11,  7.02it/s]

Epoch [420/500], Loss: 17.2150


Training VAE:  86%|████████▌ | 431/500 [01:02<00:09,  7.00it/s]

Epoch [430/500], Loss: 17.1695


Training VAE:  88%|████████▊ | 441/500 [01:03<00:08,  6.93it/s]

Epoch [440/500], Loss: 17.1379


Training VAE:  90%|█████████ | 451/500 [01:04<00:07,  6.91it/s]

Epoch [450/500], Loss: 17.0366


Training VAE:  92%|█████████▏| 461/500 [01:06<00:05,  6.91it/s]

Epoch [460/500], Loss: 17.0595


Training VAE:  94%|█████████▍| 471/500 [01:07<00:04,  6.96it/s]

Epoch [470/500], Loss: 17.0253


Training VAE:  96%|█████████▌| 481/500 [01:09<00:02,  6.97it/s]

Epoch [480/500], Loss: 17.1156


Training VAE:  98%|█████████▊| 491/500 [01:10<00:01,  6.95it/s]

Epoch [490/500], Loss: 17.1447


Training VAE: 100%|██████████| 500/500 [01:12<00:00,  6.94it/s]

Epoch [500/500], Loss: 17.1305





## Generate tags

In [13]:
df = pd.read_csv("../data/mtg_jamendo/autotagging_top50tags_processed_cleaned.csv")
df['aspect_list'] = df['aspect_list'].apply(ast.literal_eval)
df['instrument_tags'] = df['instrument_tags'].apply(ast.literal_eval)
df['genre_tags'] = df['genre_tags'].apply(ast.literal_eval)
df['mood_tags'] = df['mood_tags'].apply(ast.literal_eval)
df

Unnamed: 0,id,tags,genre_tags,mood_tags,instrument_tags,aspect_list
0,track_0007391,"['genre---electronic', 'genre---pop', 'instrum...","[electronic, pop]",[emotional],"[bass, drums, guitar, keyboard]","[drums, bass, guitar, electronic, emotional, p..."
1,track_0015161,"['genre---instrumentalpop', 'genre---pop', 'ge...","[pop, rock]",[emotional],"[bass, drums]","[drums, bass, rock, emotional, pop]"
2,track_0015166,"['genre---dance', 'genre---electronic', 'genre...","[dance, electronic, pop, techno]",[emotional],[bass],"[bass, electronic, dance, techno, emotional, pop]"
3,track_0015167,"['genre---chillout', 'genre---easylistening', ...","[electronic, pop]",[emotional],"[bass, violin]","[bass, electronic, emotional, pop, violin]"
4,track_0015169,"['genre---electronic', 'genre---instrumentalpo...","[electronic, pop]",[emotional],"[bass, drums]","[drums, bass, electronic, emotional, pop]"
...,...,...,...,...,...,...
2036,track_1420702,"['genre---dance', 'genre---easylistening', 'ge...",[dance],"[funk, happy]","[bass, drums, keyboard]","[drums, bass, dance, funk, keyboard, happy]"
2037,track_1420704,"['genre---dance', 'genre---easylistening', 'in...",[dance],[happy],"[bass, drums, keyboard]","[drums, bass, dance, keyboard, happy]"
2038,track_1420705,"['genre---dance', 'genre---easylistening', 'in...",[dance],[happy],"[bass, drums, keyboard]","[drums, bass, dance, keyboard, happy]"
2039,track_1420706,"['genre---dance', 'genre---easylistening', 'in...",[dance],[happy],"[bass, drums, keyboard]","[drums, bass, dance, keyboard, happy]"


In [14]:
model = MultiLabelVAE(input_dim, latent_dim, hidden_dim).to(device)
model.load_state_dict(torch.load("../models/multilabel_vae.pth", map_location=device))
model.eval()

MultiLabelVAE(
  (fc1): Linear(in_features=385, out_features=256, bias=True)
  (fc2_mu): Linear(in_features=256, out_features=64, bias=True)
  (fc2_logvar): Linear(in_features=256, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=385, bias=True)
  (input_dropout): Dropout(p=0.3, inplace=False)
)

In [15]:
def generate_tags(model, seed_tags, requests, temperature=1.0):
    """
    seeds: List of tags we ALREADY have (e.g. ['rock', 'guitar'])
    requests: Dict of how many tags we want per category (e.g. {'instrument': 2, 'mood': 1})
    """
    model.eval()
    device = next(model.parameters()).device
    
    # 1. Build the Input Vector from Seeds
    input_vec = torch.zeros(1, TOTAL_INPUT_DIM).to(device)
    
    
    # Fill in the knowns
    for tag in seed_tags:
        if tag in tag_to_idx:
            input_vec[0, tag_to_idx[tag]] = 1.0
        else:
            print(f"Warning: Seed tag '{tag}' not in taxonomy.")

    with torch.no_grad():
        # 2. Encode to get the Latent Vibe (z)
        # Note: We don't use dropout here; we want the model to use all clues we gave it.
        mu, logvar = model.encode(input_vec)
        z = model.reparameterize(mu, logvar)
        
        # 3. Decode to get probabilities for EVERYTHING
        # Output shape: [1, Total_Dim] (Values 0.0 to 1.0)
        probs = model.decode(z, temperature=temperature)[0] 
        
        # 4. Extract Top-K for requested categories
        results = {}
        
        for category, count in requests.items():
            if count <= 0:
                results[f"generated_{category}_tags"] = []
                continue

            start, end = cat_ranges[category]
            
            # Slice the probabilities relevant to this category
            cat_probs = probs[start:end]
            
            # Get Top K indices for this slice
            # We ask for count + len(seeds) just in case the model predicts the seed tag again
            top_k_vals, top_k_indices = torch.topk(cat_probs, k=count + 5)
            
            # Convert slice-indices back to global-indices, then to strings
            found_tags = []
            for i in range(len(top_k_indices)):
                local_idx = top_k_indices[i].item()
                global_idx = start + local_idx
                tag_name = idx_to_tag[global_idx][1]
                
                # Don't return tags we already provided as seeds
                if tag_name not in seed_tags:
                    found_tags.append(tag_name)
                
                if len(found_tags) == count:
                    break
            
            results[f"generated_{category}_tags"] = found_tags
            
    return results

In [16]:
def generate_df(idx: int, tags_per_category: dict[str, int], temperature=1.0):
    row = df.iloc[idx]
    seed_tags = []
    for category in ['genre', 'instrument', 'mood']:
        if len(row[f"{category}_tags"]) > 1:
            seed_tags.append(np.random.choice(row[f"{category}_tags"]))
            tags_per_category[category] = tags_per_category.get(category, 1) - 1
        
    generated_tags = generate_tags(model, seed_tags, tags_per_category)
    _generated_tags = []
    for gtags in generated_tags.values():
        _generated_tags.extend(gtags)

    res_entry = {
        'id': row['id'],
        'original_aspect_list': row['aspect_list'],
        'aspect_list': seed_tags + _generated_tags,
        **generated_tags
    }
    return pd.DataFrame([res_entry])

In [17]:
CATEGORIES = [
    "tempo",
    "genre",
    "mood",
    "instrument"
]
N_CATEGORIES = len(CATEGORIES)
N_SAMPLES_TO_GENERATE = len(df)

# --- 1. SYNTHETIC DATA GENERATION (Replace with your actual data) ---
# We simulate a dataset where tag counts are discrete and correlated.
# Max counts are defined for simulation purposes.
MAX_COUNTS = {
    "tempo": 6,
    "genre": 7, 
    "mood": 8, 
    "instrument": 8, 
}
MEANS = {
    "tempo": 1.21,
    "genre": 1.44, 
    "mood": 1.66, 
    "instrument": 2.49,
}
VARIANCES = {
    "tempo": 0.57,
    "genre": 0.86,
    "mood": 1.06,
    "instrument": 1.37,
}

In [18]:
def generate_synthetic_correlated_data(n_records):
    """
    Creates synthetic discrete count data that serves as the 'real' dataset.
    This step is highly important: it determines the statistics (R and ECDFs)
    that the Copula will try to match.
    """
    print("--- 1. Generating Synthetic Data ---")

    # Define the desired correlation matrix (e.g., high correlation between Genre and Instrument)
    # This represents your calculated correlation matrix R.
    correlation_matrix = np.array([
        [1.0, 0.22, 0.087, 0.05],  # Tempo
        [0.22, 1.0, 0.11, 0.01],  # Genre
        [0.087, 0.11, 1.0, -0.05],  # Mood
        [0.05, 0.01, -0.05, 1.0]   # Instrument
    ])

    # Generate correlated continuous data (Multivariate Normal)
    mean = np.zeros(N_CATEGORIES)
    z_continuous = multivariate_normal.rvs(mean=mean, cov=correlation_matrix, size=n_records)

    data = np.zeros((n_records, N_CATEGORIES), dtype=int)
    
    # Transform continuous data into discrete counts based on desired marginals
    # (using inverse CDF of an arbitrary discrete distribution for simulation)
    # This simulates your real-world data having specific tag count distributions
    for i, cat in enumerate(CATEGORIES):
        max_c = MAX_COUNTS[cat]
        # Simulate log normal-like distribution for counts
        mu = np.log(MEANS[cat]**2 / np.sqrt(MEANS[cat]**2 + VARIANCES[cat]))
        sigma = np.sqrt(np.log(1 + VARIANCES[cat] / MEANS[cat]**2))
        # Create discrete probability distribution
        x = np.arange(1, max_c + 1)
        p = (1 / (x * sigma * np.sqrt(2 * np.pi)))
        p *= np.exp(- (np.log(x) - mu)**2 / (2 * sigma**2))
        p /= p.sum()  # Normalize to sum to 1
        
        # Convert continuous z (uniform quantile) to discrete count (inverse CDF)
        uniform_quantiles = norm.cdf(z_continuous[:, i])
        
        # Quantile mapping for a simple discrete distribution
        counts = np.digitize(uniform_quantiles, np.cumsum(p[:-1])) + 1
        data[:, i] = np.clip(counts, 1, max_c)

    print(f"Synthetic Data Shape: {data.shape}")
    print(f"Calculated Correlation of Synthetic Data:\n{np.corrcoef(data.T).round(2)}")
    return data, correlation_matrix

In [19]:
data, _ = generate_synthetic_correlated_data(N_SAMPLES_TO_GENERATE)
print(data)

--- 1. Generating Synthetic Data ---
Synthetic Data Shape: (2041, 4)
Calculated Correlation of Synthetic Data:
[[ 1.    0.15 -0.    0.06]
 [ 0.15  1.    0.06 -0.01]
 [-0.    0.06  1.   -0.  ]
 [ 0.06 -0.01 -0.    1.  ]]
[[1 1 2 3]
 [1 1 3 3]
 [2 1 1 2]
 ...
 [1 1 4 2]
 [1 1 2 2]
 [1 1 2 4]]


In [20]:
temperatures = [0.5, 0.8, 1.0, 1.25, 1.5]

res_df = pd.DataFrame()

for temp in tqdm(temperatures):
    for idx in tqdm(range(N_SAMPLES_TO_GENERATE), leave=False):
        num_tags_for_category = {
            "tempo": data[idx, 0],
            "genre": data[idx, 1],
            "mood": data[idx, 2],
            "instrument": data[idx, 3],
        }
        temp_df = generate_df(idx, tags_per_category=num_tags_for_category, temperature=temp)
        temp_df['temperature'] = temp
        res_df = pd.concat([res_df, temp_df], ignore_index=True)
res_df

100%|██████████| 5/5 [00:14<00:00,  2.87s/it]


Unnamed: 0,id,original_aspect_list,aspect_list,generated_tempo_tags,generated_genre_tags,generated_mood_tags,generated_instrument_tags,temperature
0,track_0007391,"[drums, bass, guitar, electronic, emotional, p...","[pop, guitar, uptempo, uplifting energy, uplif...",[uptempo],[],"[uplifting energy, uplifting mood]","[e-bass, acoustic drums]",0.5
1,track_0015161,"[drums, bass, rock, emotional, pop]","[rock, drums, uptempo, energetic, fun, uplifti...",[uptempo],[],"[energetic, fun, uplifting]","[flat male vocal, e-guitar]",0.5
2,track_0015166,"[bass, electronic, dance, techno, emotional, pop]","[techno, fast tempo, uptempo, energetic, elect...","[fast tempo, uptempo]",[],[energetic],"[electronic drums, groovy bass]",0.5
3,track_0015167,"[bass, electronic, emotional, pop, violin]","[electronic, bass, slow to medium tempo, calming]",[slow to medium tempo],[],[calming],[],0.5
4,track_0015169,"[drums, bass, electronic, emotional, pop]","[pop, bass, slower tempo, uptempo, medium temp...","[slower tempo, uptempo, medium tempo]",[],[melancholic],[e-bass],0.5
...,...,...,...,...,...,...,...,...
10200,track_1420702,"[drums, bass, dance, funk, keyboard, happy]","[bass, funk, moderate tempo, dance song, male ...",[moderate tempo],[dance song],[],"[male voice, percussion]",1.5
10201,track_1420704,"[drums, bass, dance, keyboard, happy]","[keyboard, medium tempo, pop song, dreamy, emo...",[medium tempo],[pop song],"[dreamy, emotional]",[piano],1.5
10202,track_1420705,"[drums, bass, dance, keyboard, happy]","[keyboard, moderate tempo, blues, playful, hap...",[moderate tempo],[blues],"[playful, happy mood, funk, energetic]",[electric guitar],1.5
10203,track_1420706,"[drums, bass, dance, keyboard, happy]","[keyboard, slow tempo, comedy, dreamy, spiritu...",[slow tempo],[comedy],"[dreamy, spiritual]",[female vocal],1.5


In [26]:
# Sort aspect list column and deduplicate tag combinations
res_df['aspect_list'] = res_df['aspect_list'].apply(lambda x: sorted(list(set(x))))
res_df = res_df.drop_duplicates(subset=['aspect_list']).reset_index(drop=True)
res_df

Unnamed: 0,original_aspect_list,aspect_list,generated_tempo_tags,generated_genre_tags,generated_mood_tags,generated_instrument_tags,temperature,id
0,"[drums, bass, guitar, electronic, emotional, p...","[acoustic drums, e-bass, guitar, pop, upliftin...",[uptempo],[],"[uplifting energy, uplifting mood]","[e-bass, acoustic drums]",0.5,aa379a3c8de830b15b69dddfe2f66298
1,"[drums, bass, rock, emotional, pop]","[drums, e-guitar, energetic, flat male vocal, ...",[uptempo],[],"[energetic, fun, uplifting]","[flat male vocal, e-guitar]",0.5,e0e21d59d73008353f38b0ffa71567b1
2,"[bass, electronic, dance, techno, emotional, pop]","[electronic drums, energetic, fast tempo, groo...","[fast tempo, uptempo]",[],[energetic],"[electronic drums, groovy bass]",0.5,31dec79b9d5cfc73d5ec85563035fae2
3,"[bass, electronic, emotional, pop, violin]","[bass, calming, electronic, slow to medium tempo]",[slow to medium tempo],[],[calming],[],0.5,b3dd4f8feaa730001c841359fee9070e
4,"[drums, bass, electronic, emotional, pop]","[bass, e-bass, medium tempo, melancholic, pop,...","[slower tempo, uptempo, medium tempo]",[],[melancholic],[e-bass],0.5,9a7103effe784c71d063483887cbc5f1
...,...,...,...,...,...,...,...,...
10034,"[drums, bass, dance, funk, keyboard, happy]","[bass, dance song, funk, male voice, moderate ...",[moderate tempo],[dance song],[],"[male voice, percussion]",1.5,005e5ab6eb9dfabc9abce44db0b2fcc6
10035,"[drums, bass, dance, keyboard, happy]","[dreamy, emotional, keyboard, medium tempo, pi...",[medium tempo],[pop song],"[dreamy, emotional]",[piano],1.5,046e4a2e721f47f170241c469117ffee
10036,"[drums, bass, dance, keyboard, happy]","[blues, electric guitar, energetic, funk, happ...",[moderate tempo],[blues],"[playful, happy mood, funk, energetic]",[electric guitar],1.5,f6c27a7ca592434c2aa4f5141e79b6a0
10037,"[drums, bass, dance, keyboard, happy]","[comedy, dreamy, female vocal, keyboard, slow ...",[slow tempo],[comedy],"[dreamy, spiritual]",[female vocal],1.5,632d96ac71bb8f5b7622196154193aec


In [27]:
# Add surrogate key based on track_id and temperature
import hashlib
def generate_surrogate_key(track_id: str, temperature: float) -> str:
    key_str = f"{track_id}_{temperature}"
    return hashlib.md5(key_str.encode()).hexdigest()

res_df['surrogate_key'] = res_df.apply(lambda row: generate_surrogate_key(row['id'], row['temperature']), axis=1)
res_df.drop(columns=['id'], inplace=True)
res_df.rename(columns={'surrogate_key': 'id'}, inplace=True)
res_df

Unnamed: 0,original_aspect_list,aspect_list,generated_tempo_tags,generated_genre_tags,generated_mood_tags,generated_instrument_tags,temperature,id
0,"[drums, bass, guitar, electronic, emotional, p...","[acoustic drums, e-bass, guitar, pop, upliftin...",[uptempo],[],"[uplifting energy, uplifting mood]","[e-bass, acoustic drums]",0.5,bdf8f2bfda071e3eb158a91c0374c904
1,"[drums, bass, rock, emotional, pop]","[drums, e-guitar, energetic, flat male vocal, ...",[uptempo],[],"[energetic, fun, uplifting]","[flat male vocal, e-guitar]",0.5,0d2ccfab6bd435f54ebe93f101fff91f
2,"[bass, electronic, dance, techno, emotional, pop]","[electronic drums, energetic, fast tempo, groo...","[fast tempo, uptempo]",[],[energetic],"[electronic drums, groovy bass]",0.5,97d69e5a5e61500f2aeefa19f228758e
3,"[bass, electronic, emotional, pop, violin]","[bass, calming, electronic, slow to medium tempo]",[slow to medium tempo],[],[calming],[],0.5,64287deef63223e9a2e251e442eab8e6
4,"[drums, bass, electronic, emotional, pop]","[bass, e-bass, medium tempo, melancholic, pop,...","[slower tempo, uptempo, medium tempo]",[],[melancholic],[e-bass],0.5,2d1efa32fc3f2c65b3534579af2da055
...,...,...,...,...,...,...,...,...
10034,"[drums, bass, dance, funk, keyboard, happy]","[bass, dance song, funk, male voice, moderate ...",[moderate tempo],[dance song],[],"[male voice, percussion]",1.5,1d4d33255cec0afbef62a1d1301ac2b5
10035,"[drums, bass, dance, keyboard, happy]","[dreamy, emotional, keyboard, medium tempo, pi...",[medium tempo],[pop song],"[dreamy, emotional]",[piano],1.5,c8b3d9f648e8746bc2444f1d65aa6c15
10036,"[drums, bass, dance, keyboard, happy]","[blues, electric guitar, energetic, funk, happ...",[moderate tempo],[blues],"[playful, happy mood, funk, energetic]",[electric guitar],1.5,8ad2ac66452e90e50e42538874a32dfe
10037,"[drums, bass, dance, keyboard, happy]","[comedy, dreamy, female vocal, keyboard, slow ...",[slow tempo],[comedy],"[dreamy, spiritual]",[female vocal],1.5,223d4c13fdc6a07d0c77b5e61e08e4bc


## Push to Hugginface Hub

In [28]:
from sklearn.model_selection import train_test_split

df_train, df_valid = train_test_split(res_df, test_size=0.1, random_state=42)
df_valid, df_test = train_test_split(df_valid, test_size=0.5, random_state=42)

In [29]:
from pathlib import Path

# Create output directory
output_dir = Path("../data/vae_mtg_tags")
output_dir.mkdir(parents=True, exist_ok=True)

df_train.to_csv(output_dir / "train.csv", index=False)
df_valid.to_csv(output_dir / "validation.csv", index=False)
df_test.to_csv(output_dir / "test.csv", index=False)
all_df = pd.concat([df_train, df_valid, df_test])
all_df.to_csv(output_dir / "all.csv", index=False)

In [30]:
data_files = {
    "train": str(output_dir / "train.csv"),
    "validation": str(output_dir / "validation.csv"),
    "test": str(output_dir / "test.csv")
}
dataset = load_dataset("csv", data_files=data_files)
dataset.push_to_hub("bsienkiewicz/mtg_vae_tags_dataset", private=True)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

README.md:   0%|          | 0.00/813 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/bsienkiewicz/mtg_vae_tags_dataset/commit/8cd79382506d89963bb2a60e576e17c431d1c8d5', commit_message='Upload dataset', commit_description='', oid='8cd79382506d89963bb2a60e576e17c431d1c8d5', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/bsienkiewicz/mtg_vae_tags_dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='bsienkiewicz/mtg_vae_tags_dataset'), pr_revision=None, pr_num=None)