In [None]:
import sys

sys.path.append("..")

In [None]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from collections import Counter, defaultdict
from tqdm import tqdm
from datasets import load_dataset
from scipy.stats import norm, multivariate_normal
import ast

from src.vae.model import BetaVAE

seed = 42
np.random.seed(seed)

sns.set_style('whitegrid')
# Increase font sizes for readability
sns.set_context('notebook', font_scale=1.4)
plt.rcParams.update({
    'figure.figsize': (18, 8),
    'font.size': 32,
    'axes.titlesize': 28,
    'axes.labelsize': 25,
    'xtick.labelsize': 23,
    'ytick.labelsize': 23,
    'legend.fontsize': 21
})

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Define Taxonomy

In [None]:
TAXONOMY = json.load(open("../data/concepts_to_tags.json", "r"))

CATEGORIES = list(TAXONOMY.keys())

# Reverse map for easy lookup (tag -> category)
TAG_TO_CATEGORY = {}
for cat, tags in TAXONOMY.items():
    for tag in tags:
        TAG_TO_CATEGORY[tag] = cat


In [None]:
tag_to_idx = {}
idx_to_tag = {}
cat_ranges = {} # Stores start/end index for each category

current_idx = 0
for cat in CATEGORIES:
    start = current_idx
    for tag in TAXONOMY[cat]:
        tag_to_idx[tag] = current_idx
        idx_to_tag[current_idx] = (cat, tag)
        current_idx += 1
    cat_ranges[cat] = (start, current_idx)

TOTAL_INPUT_DIM = current_idx
print(f"Total Input Dimension: {TOTAL_INPUT_DIM}")

## Sample tags using VAE

### Load MTG Jaemdo supporting dataset

In [None]:
df = pd.read_csv("../data/mtg_jamendo/autotagging_top50tags_processed_cleaned.csv")
df['aspect_list'] = df['aspect_list'].apply(ast.literal_eval)
df['instrument_tags'] = df['instrument_tags'].apply(ast.literal_eval)
df['genre_tags'] = df['genre_tags'].apply(ast.literal_eval)
df['mood_tags'] = df['mood_tags'].apply(ast.literal_eval)
df

### Load VAE and visualize latent space

In [None]:
input_dim = TOTAL_INPUT_DIM
latent_dim = 128
hidden_dim = 512
dropout_p = 0.25
use_batch_norm = False
beta = 0.25

model = BetaVAE(input_dim, latent_dim, hidden_dim, dropout_p, use_batch_norm, beta).to(device)
model.load_state_dict(torch.load("../models/vae_final.pth", map_location=device))
model.eval()

print("Best model loaded successfully")

In [None]:
def get_latent_representations(model, data):
    latents = []
    with torch.no_grad():
        for x in tqdm(data, desc="Encoding data to latent space"):
            x_tensor = torch.FloatTensor(x).unsqueeze(0).to(device)
            mu, logvar = model.encode(x_tensor)
            z = model.reparameterize(mu, logvar)
            latents.append(z.cpu().numpy().squeeze())
    return np.array(latents)

# Prepare data
data = []
for _, row in df.iterrows():
    x = np.zeros(TOTAL_INPUT_DIM, dtype=np.float32)
    for tag in row['aspect_list']:
        if tag in tag_to_idx:
            x[tag_to_idx[tag]] = 1.0
    data.append(x)
data = np.array(data)

latents = get_latent_representations(model, data)
print("Latent representations obtained.")

tsne = TSNE(n_components=2, random_state=seed)
latents_2d = tsne.fit_transform(latents)
print("t-SNE transformation completed.")

plt.figure(figsize=(12, 10))
plt.style.use('petroff10')
genre_styles = {
    'rock': {'marker': 'o'},
    'pop': {'marker': 's'},
    'jazz': {'marker': '^'},
    'classical': {'marker': 'D'}
}

for genre, style in genre_styles.items():
    indices = []
    for i, row in df.iterrows():
        if genre in row['genre_tags']:
            indices.append(i)
    
    if len(indices) > 0:
        plt.scatter(
            x=latents_2d[indices, 0], 
            y=latents_2d[indices, 1], 
            alpha=0.8, 
            marker=style['marker'],
            s=60,
            edgecolors='black',
            linewidths=0.3,
            label=genre
        )

plt.legend(loc='best', framealpha=0.8)
plt.title("t-SNE of VAE Latent Space Colored by Genre", fontweight='bold', pad=20)
plt.xlabel("")
plt.ylabel("")
plt.grid(alpha=0.3)
plt.savefig("../docs/assets/tsne_genre_clusters.pdf", bbox_inches='tight')
plt.show()

### Inference

In [None]:
def generate_tags_with_threshold(model, latent_vector, seed_tags=None, threshold=0.5, temp=1.0):
    model.eval()
    with torch.no_grad():
        latent_tensor = torch.FloatTensor(latent_vector).unsqueeze(0).to(device)
        recon_x = model.decode(latent_tensor)
        recon_x = torch.sigmoid(recon_x / temp).squeeze().cpu().numpy()
    
    if seed_tags:
        for tag in seed_tags:
            if tag in tag_to_idx:
                idx = tag_to_idx[tag]
                recon_x[idx] = 1.0  # Force seed tags to be present

    predicted_tags = []
    for idx, prob in enumerate(recon_x):
        if prob >= threshold:
            cat, tag = idx_to_tag[idx]
            predicted_tags.append(tag)
    
    return predicted_tags

In [None]:
for i in range(5):
    random_latent = np.random.normal(0, 1, latent_dim)
    seed_tags = ['rock', 'guitar'] if i % 2 == 0 else ['classical', 'piano']
    temp = [0.7, 1.0, 1.2][i % 3]
    generated_tags = generate_tags_with_threshold(model, random_latent, seed_tags=seed_tags, threshold=0.52, temp=temp)
    print(f"Seed Tags: {seed_tags}, Temp: {temp} -> Generated Tags: {generated_tags}")

In [None]:
N_SAMPLES_TO_GENERATE = len(df)
latents = np.random.normal(0, 1, (N_SAMPLES_TO_GENERATE, latent_dim)).astype(np.float32)

In [None]:
threshold = 0.52
temperatures = [0.25, 0.5, 0.75, 1.0, 1.0, 1.0, 1.25, 1.5, 1.75, 2.0]
N_SAMPLES_TO_GENERATE = len(df)

results = []
for temp in tqdm(temperatures, desc="Temperatures"):
    for idx in tqdm(range(N_SAMPLES_TO_GENERATE), desc="Samples", leave=False):
        row = df.iloc[idx]
        seed_tags = []

        for category in ['genre', 'instrument', 'mood']:
            if len(row[f"{category}_tags"]) > 1:
                seed_tags.append(np.random.choice(row[f"{category}_tags"]))

        latent_vector = latents[idx:idx+1]
            
        generated_tags = generate_tags_with_threshold(model, latent_vector, seed_tags=seed_tags, threshold=threshold, temp=temp)

        df_entry = {
            'id': row['id'],
            'aspect_list': generated_tags,
            'original_aspect_list': row['aspect_list'],
            'temperature': temp,
        }
        results.append(df_entry)

# Add random latent vector generation for variety
for idx in tqdm(range(N_SAMPLES_TO_GENERATE * len(temperatures)), desc="Samples", leave=False):
    idx %= N_SAMPLES_TO_GENERATE  # Wrap around to existing data indices
    z = latents[idx:idx+1]  # Keep batch dimension
    
    generated_tags = generate_tags_with_threshold(model, z, seed_tags=[], threshold=threshold, temp=temp)
    
    df_entry = {
        'id': df.iloc[idx]['id'],
        'aspect_list': generated_tags,
        'original_aspect_list': [],
        'temperature': temp,
    }
    
    results.append(df_entry)

## Preprocess generated dataset

In [None]:
res_df = pd.DataFrame(results)
res_df

In [None]:
res_df['aspect_list'] = res_df['aspect_list'].apply(lambda x: sorted(list(set(x))))
res_df['instrument_tags'] = res_df['aspect_list'].apply(lambda tags: [tag for tag in tags if TAG_TO_CATEGORY.get(tag) == 'instrument'])
res_df['genre_tags'] = res_df['aspect_list'].apply(lambda tags: [tag for tag in tags if TAG_TO_CATEGORY.get(tag) == 'genre'])
res_df['mood_tags'] = res_df['aspect_list'].apply(lambda tags: [tag for tag in tags if TAG_TO_CATEGORY.get(tag) == 'mood'])
res_df['tempo_tags'] = res_df['aspect_list'].apply(lambda tags: [tag for tag in tags if TAG_TO_CATEGORY.get(tag) == 'tempo'])
# Remove samples with tags in less than 3 categories
res_df = res_df[
    res_df.apply(
        lambda row: sum(
            1 for cat in ['instrument_tags', 'genre_tags', 'mood_tags', 'tempo_tags'] if len(row[cat]) > 0
        ) >= 3,
        axis=1
    )
].reset_index(drop=True)
res_df

In [None]:
# Add surrogate key based on track_id, original_tags and temperature
import hashlib
def generate_surrogate_key(track_id: str, original_tags: str, temperature: float) -> str:
    key_str = f"{track_id}_{original_tags}_{temperature}"
    return hashlib.md5(key_str.encode()).hexdigest()

res_df['surrogate_key'] = res_df.apply(lambda row: generate_surrogate_key(row['id'], row['original_aspect_list'], row['temperature']), axis=1)
res_df.drop(columns=['id'], inplace=True)
res_df.rename(columns={'surrogate_key': 'id'}, inplace=True)
res_df

## Push to Hugginface Hub

In [None]:
from sklearn.model_selection import train_test_split

df_train, df_valid = train_test_split(res_df, test_size=0.1, random_state=42)
df_valid, df_test = train_test_split(df_valid, test_size=0.5, random_state=42)

In [None]:
from pathlib import Path

# Create output directory
output_dir = Path("../data/vae-tags-dataset")
output_dir.mkdir(parents=True, exist_ok=True)

df_train.to_csv(output_dir / "train.csv", index=False)
df_valid.to_csv(output_dir / "validation.csv", index=False)
df_test.to_csv(output_dir / "test.csv", index=False)
all_df = pd.concat([df_train, df_valid, df_test])
all_df.to_csv(output_dir / "all.csv", index=False)

In [None]:
data_files = {
    "train": str(output_dir / "train.csv"),
    "validation": str(output_dir / "validation.csv"),
    "test": str(output_dir / "test.csv")
}
dataset = load_dataset("csv", data_files=data_files)
dataset.push_to_hub("bsienkiewicz/vae-tags-dataset", private=True)