In [1]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter, defaultdict
from tqdm import tqdm
from datasets import load_dataset
from scipy.stats import norm, multivariate_normal
import ast

seed = 42
np.random.seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Define Tag Categories

Define all possible tags for each category based on the dataset.

In [2]:
TAXONOMY = json.load(open("../data/concepts_to_tags.json", "r"))

CATEGORIES = list(TAXONOMY.keys())

# Reverse map for easy lookup (tag -> category)
TAG_TO_CATEGORY = {}
for cat, tags in TAXONOMY.items():
    for tag in tags:
        TAG_TO_CATEGORY[tag] = cat


In [3]:
tag_to_idx = {}
idx_to_tag = {}
cat_ranges = {} # Stores start/end index for each category

current_idx = 0
for cat in CATEGORIES:
    start = current_idx
    for tag in TAXONOMY[cat]:
        tag_to_idx[tag] = current_idx
        idx_to_tag[current_idx] = (cat, tag)
        current_idx += 1
    cat_ranges[cat] = (start, current_idx)

TOTAL_INPUT_DIM = current_idx
print(f"Total Input Dimension: {TOTAL_INPUT_DIM}")

Total Input Dimension: 200


## Prepare dataset

In [4]:
def process_data_multilabel(df: pd.DataFrame) -> np.ndarray:
    """
    Creates a Multi-Hot vector for every song.
    Example: [0, 1, 0, 1, 1, ...] where 1 means the tag is present.
    """
    processed_data = []

    for _, row in df.iterrows():
        raw_tags = ast.literal_eval(row['aspect_list'])
        raw_tags = [t.lower() for t in raw_tags]
            
        # Create Zero Vector
        vector = np.zeros(TOTAL_INPUT_DIM, dtype=np.float32)
        has_data = False
        
        for tag in raw_tags:
            if tag in tag_to_idx:
                idx = tag_to_idx[tag]
                vector[idx] = 1.0
                has_data = True
        
        # Only keep records that have at least one valid tag
        if has_data:
            processed_data.append(vector)
            
    return np.array(processed_data)

In [5]:
df = load_dataset("google/MusicCaps", split="train").to_pandas()

In [6]:
data = process_data_multilabel(df)
print(f"Processed data shape: {data.shape}")

Processed data shape: (5046, 200)


## VAE

In [7]:
class MultiLabelVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=32, hidden_dim=128):
        super(MultiLabelVAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim) 
        
        # Dropout for the "Denoising" part (applied to input)
        self.input_dropout = nn.Dropout(p=0.3) 

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc2_mu(h1), self.fc2_logvar(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, temperature=1.0):
        h3 = F.relu(self.fc3(z))
        logits = self.fc4(h3)
        return torch.sigmoid(logits / temperature)

    def forward(self, x, temperature=1.0):
        # Apply dropout to inputs during training -> forces model to learn correlations
        x_noisy = self.input_dropout(x)
        mu, logvar = self.encode(x_noisy)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, temperature=temperature)
        return recon, mu, logvar

In [None]:
# Hyperparameter optimization setup
input_dim = TOTAL_INPUT_DIM

# Define hyperparameter combinations
hyperparameter_space = {
    'vae_small': {
        'latent_dim': 32,
        'hidden_dim': 128,
        'learning_rate': 5e-3,
        'epochs': 500
    },
    'vae_medium': {
        'latent_dim': 64,
        'hidden_dim': 256,
        'learning_rate': 1e-3,
        'epochs': 750
    },
    'vae_large': {
        'latent_dim': 128,
        'hidden_dim': 512,
        'learning_rate': 5e-4,
        'epochs': 1000
    }
}

batch_size = 32

In [24]:
def vae_loss(recon_x, x, mu, logvar):
    bce_loss_fn = nn.BCELoss(reduction='sum')
    BCE = bce_loss_fn(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
for model_name, params in hyperparameter_space.items():
    print(f"Training model: {model_name} with params: {params}")
    
    latent_dim = params['latent_dim']
    hidden_dim = params['hidden_dim']
    learning_rate = params['learning_rate']
    num_epochs = params['epochs']
    
    model = MultiLabelVAE(input_dim, latent_dim, hidden_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    data_tensor = torch.tensor(data).to(device)
    dataset = torch.utils.data.TensorDataset(data_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            inputs = batch[0]
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(inputs)
            loss = vae_loss(recon_batch, inputs, mu, logvar)
            loss.backward()
            total_loss += loss.item()
            optimizer.step()
        
        avg_loss = total_loss / len(data_tensor)
        if (epoch + 1) % 50 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
    
    # Save the trained model
    torch.save(model.state_dict(), f"../models/{model_name}.pth")
    print(f"Model {model_name} saved.\n")

Training model: vae_small with params: {'latent_dim': 32, 'hidden_dim': 128, 'learning_rate': 0.005, 'epochs': 500}
Epoch [50/500], Loss: 14.6774
Epoch [100/500], Loss: 14.5252
Epoch [150/500], Loss: 14.4066
Epoch [200/500], Loss: 14.4271
Epoch [250/500], Loss: 14.4566
Epoch [300/500], Loss: 14.4598
Epoch [350/500], Loss: 14.4606
Epoch [400/500], Loss: 14.4501
Epoch [450/500], Loss: 14.4999
Epoch [500/500], Loss: 14.4179
Model vae_small saved.

Training model: vae_medium with params: {'latent_dim': 64, 'hidden_dim': 256, 'learning_rate': 0.001, 'epochs': 750}
Epoch [50/750], Loss: 14.8965
Epoch [100/750], Loss: 14.4378
Epoch [150/750], Loss: 14.3522
Epoch [200/750], Loss: 14.1601
Epoch [250/750], Loss: 14.0742
Epoch [300/750], Loss: 14.0289
Epoch [350/750], Loss: 13.9840
Epoch [400/750], Loss: 14.0265
Epoch [450/750], Loss: 13.9347
Epoch [500/750], Loss: 13.9971
Epoch [550/750], Loss: 13.9352
Epoch [600/750], Loss: 13.8855
Epoch [650/750], Loss: 13.8598
Epoch [700/750], Loss: 13.8357
E

## Generate tags

In [26]:
df = pd.read_csv("../data/mtg_jamendo/autotagging_top50tags_processed_cleaned.csv")
df['aspect_list'] = df['aspect_list'].apply(ast.literal_eval)
df['instrument_tags'] = df['instrument_tags'].apply(ast.literal_eval)
df['genre_tags'] = df['genre_tags'].apply(ast.literal_eval)
df['mood_tags'] = df['mood_tags'].apply(ast.literal_eval)
df

Unnamed: 0,id,tags,genre_tags,mood_tags,instrument_tags,aspect_list
0,track_0007391,"['genre---electronic', 'genre---pop', 'instrum...","[electronic, pop]",[emotional],"[bass, drums, guitar, keyboard]","[drums, bass, guitar, electronic, emotional, p..."
1,track_0015161,"['genre---instrumentalpop', 'genre---pop', 'ge...","[pop, rock]",[emotional],"[bass, drums]","[drums, bass, rock, emotional, pop]"
2,track_0015166,"['genre---dance', 'genre---electronic', 'genre...","[dance, electronic, pop, techno]",[emotional],[bass],"[bass, electronic, dance, techno, emotional, pop]"
3,track_0015167,"['genre---chillout', 'genre---easylistening', ...","[electronic, pop]",[emotional],"[bass, violin]","[bass, electronic, emotional, pop, violin]"
4,track_0015169,"['genre---electronic', 'genre---instrumentalpo...","[electronic, pop]",[emotional],"[bass, drums]","[drums, bass, electronic, emotional, pop]"
...,...,...,...,...,...,...
2036,track_1420702,"['genre---dance', 'genre---easylistening', 'ge...",[dance],"[funk, happy]","[bass, drums, keyboard]","[drums, bass, dance, funk, keyboard, happy]"
2037,track_1420704,"['genre---dance', 'genre---easylistening', 'in...",[dance],[happy],"[bass, drums, keyboard]","[drums, bass, dance, keyboard, happy]"
2038,track_1420705,"['genre---dance', 'genre---easylistening', 'in...",[dance],[happy],"[bass, drums, keyboard]","[drums, bass, dance, keyboard, happy]"
2039,track_1420706,"['genre---dance', 'genre---easylistening', 'in...",[dance],[happy],"[bass, drums, keyboard]","[drums, bass, dance, keyboard, happy]"


In [29]:
latent_dim = 128
hidden_dim = 512

model = MultiLabelVAE(input_dim, latent_dim, hidden_dim).to(device)
model.load_state_dict(torch.load("../models/vae_large.pth", map_location=device))
model.eval()

print("✓ Best model loaded successfully")

✓ Best model loaded successfully


In [30]:
def generate_tags(model, seed_tags, requests, temperature=1.0):
    """
    seeds: List of tags we ALREADY have (e.g. ['rock', 'guitar'])
    requests: Dict of how many tags we want per category (e.g. {'instrument': 2, 'mood': 1})
    """
    model.eval()
    device = next(model.parameters()).device
    
    # 1. Build the Input Vector from Seeds
    input_vec = torch.zeros(1, TOTAL_INPUT_DIM).to(device)
    
    
    # Fill in the knowns
    for tag in seed_tags:
        if tag in tag_to_idx:
            input_vec[0, tag_to_idx[tag]] = 1.0
        else:
            print(f"Warning: Seed tag '{tag}' not in taxonomy.")

    with torch.no_grad():
        # 2. Encode to get the Latent Vibe (z)
        # Note: We don't use dropout here; we want the model to use all clues we gave it.
        mu, logvar = model.encode(input_vec)
        z = model.reparameterize(mu, logvar)
        
        # 3. Decode to get probabilities for EVERYTHING
        # Output shape: [1, Total_Dim] (Values 0.0 to 1.0)
        probs = model.decode(z, temperature=temperature)[0] 
        
        # 4. Extract Top-K for requested categories
        results = {}
        
        for category, count in requests.items():
            if count <= 0:
                results[f"generated_{category}_tags"] = []
                continue

            start, end = cat_ranges[category]
            
            # Slice the probabilities relevant to this category
            cat_probs = probs[start:end]
            
            # Get Top K indices for this slice
            # We ask for count + len(seeds) just in case the model predicts the seed tag again
            top_k_vals, top_k_indices = torch.topk(cat_probs, k=count + 5)
            
            # Convert slice-indices back to global-indices, then to strings
            found_tags = []
            for i in range(len(top_k_indices)):
                local_idx = top_k_indices[i].item()
                global_idx = start + local_idx
                tag_name = idx_to_tag[global_idx][1]
                
                # Don't return tags we already provided as seeds
                if tag_name not in seed_tags:
                    found_tags.append(tag_name)
                
                if len(found_tags) == count:
                    break
            
            results[f"generated_{category}_tags"] = found_tags
            
    return results

In [31]:
def generate_df(idx: int, tags_per_category: dict[str, int], temperature=1.0):
    row = df.iloc[idx]
    seed_tags = []
    for category in ['genre', 'instrument', 'mood']:
        if len(row[f"{category}_tags"]) > 1:
            seed_tags.append(np.random.choice(row[f"{category}_tags"]))
            tags_per_category[category] = tags_per_category.get(category, 1) - 1
        
    generated_tags = generate_tags(model, seed_tags, tags_per_category)
    _generated_tags = []
    for gtags in generated_tags.values():
        _generated_tags.extend(gtags)

    res_entry = {
        'id': row['id'],
        'original_aspect_list': row['aspect_list'],
        'aspect_list': seed_tags + _generated_tags,
        **generated_tags
    }
    return pd.DataFrame([res_entry])

In [32]:
CATEGORIES = [
    "tempo",
    "genre",
    "mood",
    "instrument"
]
N_CATEGORIES = len(CATEGORIES)
N_SAMPLES_TO_GENERATE = len(df)

# --- 1. SYNTHETIC DATA GENERATION (Replace with your actual data) ---
# We simulate a dataset where tag counts are discrete and correlated.
# Max counts are defined for simulation purposes.
MAX_COUNTS = {
    "tempo": 6,
    "genre": 6, 
    "mood": 7, 
    "instrument": 6, 
}
MEANS = {
    "tempo": 1.18,
    "genre": 1.29, 
    "mood": 1.52, 
    "instrument": 2.08,
}
VARIANCES = {
    "tempo": 0.54,
    "genre": 0.65,
    "mood": 0.91,
    "instrument": 1.14,
}

In [33]:
def generate_synthetic_correlated_data(n_records):
    """
    Creates synthetic discrete count data that serves as the 'real' dataset.
    This step is highly important: it determines the statistics (R and ECDFs)
    that the Copula will try to match.
    """
    print("--- 1. Generating Synthetic Data ---")

    # Define the desired correlation matrix (e.g., high correlation between Genre and Instrument)
    # This represents your calculated correlation matrix R.
    correlation_matrix = np.array([
        [1.0, 0.17, 0.081, -0.035],  # Tempo
        [0.17, 1.0, 0.55, -0.045],  # Genre
        [0.081, 0.55, 1.0, -0.05],  # Mood
        [-0.035, -0.045, -0.077, 1.0]   # Instrument
    ])

    # Generate correlated continuous data (Multivariate Normal)
    mean = np.zeros(N_CATEGORIES)
    z_continuous = multivariate_normal.rvs(mean=mean, cov=correlation_matrix, size=n_records)

    data = np.zeros((n_records, N_CATEGORIES), dtype=int)
    
    # Transform continuous data into discrete counts based on desired marginals
    # (using inverse CDF of an arbitrary discrete distribution for simulation)
    # This simulates your real-world data having specific tag count distributions
    for i, cat in enumerate(CATEGORIES):
        max_c = MAX_COUNTS[cat]
        # Simulate log normal-like distribution for counts
        mu = np.log(MEANS[cat]**2 / np.sqrt(MEANS[cat]**2 + VARIANCES[cat]))
        sigma = np.sqrt(np.log(1 + VARIANCES[cat] / MEANS[cat]**2))
        # Create discrete probability distribution
        x = np.arange(1, max_c + 1)
        p = (1 / (x * sigma * np.sqrt(2 * np.pi)))
        p *= np.exp(- (np.log(x) - mu)**2 / (2 * sigma**2))
        p /= p.sum()  # Normalize to sum to 1
        
        # Convert continuous z (uniform quantile) to discrete count (inverse CDF)
        uniform_quantiles = norm.cdf(z_continuous[:, i])
        
        # Quantile mapping for a simple discrete distribution
        counts = np.digitize(uniform_quantiles, np.cumsum(p[:-1])) + 1
        data[:, i] = np.clip(counts, 1, max_c)

    print(f"Synthetic Data Shape: {data.shape}")
    print(f"Calculated Correlation of Synthetic Data:\n{np.corrcoef(data.T).round(2)}")
    return data, correlation_matrix

In [34]:
data, _ = generate_synthetic_correlated_data(N_SAMPLES_TO_GENERATE)
print(data)

--- 1. Generating Synthetic Data ---
Synthetic Data Shape: (2041, 4)
Calculated Correlation of Synthetic Data:
[[ 1.    0.09  0.07 -0.04]
 [ 0.09  1.    0.41 -0.04]
 [ 0.07  0.41  1.   -0.06]
 [-0.04 -0.04 -0.06  1.  ]]
[[1 1 1 2]
 [1 2 1 2]
 [2 1 2 1]
 ...
 [1 1 2 3]
 [1 1 2 3]
 [1 1 1 1]]


  out = random_state.multivariate_normal(mean, cov, size)


In [35]:
temperatures = [0.8, 1.0, 1.25, 1.5, 1.75]

res_df = pd.DataFrame()

for temp in tqdm(temperatures):
    for idx in tqdm(range(N_SAMPLES_TO_GENERATE), leave=False):
        num_tags_for_category = {
            "tempo": data[idx, 0],
            "genre": data[idx, 1],
            "mood": data[idx, 2],
            "instrument": data[idx, 3],
        }
        temp_df = generate_df(idx, tags_per_category=num_tags_for_category, temperature=temp)
        temp_df['temperature'] = temp
        res_df = pd.concat([res_df, temp_df], ignore_index=True)
res_df

100%|██████████| 5/5 [00:18<00:00,  3.61s/it]


Unnamed: 0,id,original_aspect_list,aspect_list,generated_tempo_tags,generated_genre_tags,generated_mood_tags,generated_instrument_tags,temperature
0,track_0007391,"[drums, bass, guitar, electronic, emotional, p...","[pop, guitar, medium tempo, passionate, flat m...",[medium tempo],[],[passionate],[flat male vocal],0.80
1,track_0015161,"[drums, bass, rock, emotional, pop]","[rock, drums, moderate tempo, rapping, energet...",[moderate tempo],[rapping],[energetic],[electronic drums],0.80
2,track_0015166,"[bass, electronic, dance, techno, emotional, pop]","[techno, medium tempo, fast tempo, energetic, ...","[medium tempo, fast tempo]",[],"[energetic, passionate]",[male singer],0.80
3,track_0015167,"[bass, electronic, emotional, pop, violin]","[electronic, bass, moderate tempo, slow tempo,...","[moderate tempo, slow tempo]",[],[happy mood],"[male voice, no voices, percussion, no voice]",0.80
4,track_0015169,"[drums, bass, electronic, emotional, pop]","[pop, bass, slow tempo, medium to uptempo, dar...","[slow tempo, medium to uptempo]",[],"[dark, sad, emotional]",[digital drums],0.80
...,...,...,...,...,...,...,...,...
10200,track_1420702,"[drums, bass, dance, funk, keyboard, happy]","[bass, funk, medium to uptempo, rap, digital d...",[medium to uptempo],[rap],[],[digital drums],1.75
10201,track_1420704,"[drums, bass, dance, keyboard, happy]","[keyboard, medium tempo, classical, calming, a...",[medium tempo],[classical],[calming],[acoustic guitar],1.75
10202,track_1420705,"[drums, bass, dance, keyboard, happy]","[keyboard, groovy, rapping, energetic, aggress...",[groovy],[rapping],"[energetic, aggressive]","[electronic drums, male vocal]",1.75
10203,track_1420706,"[drums, bass, dance, keyboard, happy]","[keyboard, groovy, chill, energetic, funky, ba...",[groovy],[chill],"[energetic, funky]","[bass guitar, electric guitar]",1.75


In [36]:
# Sort aspect list column and deduplicate tag combinations
res_df['aspect_list'] = res_df['aspect_list'].apply(lambda x: sorted(list(set(x))))
res_df = res_df.drop_duplicates(subset=['aspect_list']).reset_index(drop=True)
res_df

Unnamed: 0,id,original_aspect_list,aspect_list,generated_tempo_tags,generated_genre_tags,generated_mood_tags,generated_instrument_tags,temperature
0,track_0007391,"[drums, bass, guitar, electronic, emotional, p...","[flat male vocal, guitar, medium tempo, passio...",[medium tempo],[],[passionate],[flat male vocal],0.80
1,track_0015161,"[drums, bass, rock, emotional, pop]","[drums, electronic drums, energetic, moderate ...",[moderate tempo],[rapping],[energetic],[electronic drums],0.80
2,track_0015166,"[bass, electronic, dance, techno, emotional, pop]","[energetic, fast tempo, male singer, medium te...","[medium tempo, fast tempo]",[],"[energetic, passionate]",[male singer],0.80
3,track_0015167,"[bass, electronic, emotional, pop, violin]","[bass, electronic, happy mood, male voice, mod...","[moderate tempo, slow tempo]",[],[happy mood],"[male voice, no voices, percussion, no voice]",0.80
4,track_0015169,"[drums, bass, electronic, emotional, pop]","[bass, dark, digital drums, emotional, medium ...","[slow tempo, medium to uptempo]",[],"[dark, sad, emotional]",[digital drums],0.80
...,...,...,...,...,...,...,...,...
9228,track_1420700,"[drums, bass, dance, keyboard, happy, house]","[bass, electronic drums, energetic, fast tempo...","[uptempo, fast tempo, upbeat]",[techno],[energetic],[electronic drums],1.75
9229,track_1420701,"[drums, bass, dance, keyboard, happy, house]","[bass, dance, electronic drums, energetic, gro...","[upbeat, groovy]",[],"[energetic, playful]","[electronic drums, male vocal]",1.75
9230,track_1420702,"[drums, bass, dance, funk, keyboard, happy]","[bass, digital drums, funk, medium to uptempo,...",[medium to uptempo],[rap],[],[digital drums],1.75
9231,track_1420705,"[drums, bass, dance, keyboard, happy]","[aggressive, electronic drums, energetic, groo...",[groovy],[rapping],"[energetic, aggressive]","[electronic drums, male vocal]",1.75


In [37]:
# Add surrogate key based on track_id and temperature
import hashlib
def generate_surrogate_key(track_id: str, temperature: float) -> str:
    key_str = f"{track_id}_{temperature}"
    return hashlib.md5(key_str.encode()).hexdigest()

res_df['surrogate_key'] = res_df.apply(lambda row: generate_surrogate_key(row['id'], row['temperature']), axis=1)
res_df.drop(columns=['id'], inplace=True)
res_df.rename(columns={'surrogate_key': 'id'}, inplace=True)
res_df

Unnamed: 0,original_aspect_list,aspect_list,generated_tempo_tags,generated_genre_tags,generated_mood_tags,generated_instrument_tags,temperature,id
0,"[drums, bass, guitar, electronic, emotional, p...","[flat male vocal, guitar, medium tempo, passio...",[medium tempo],[],[passionate],[flat male vocal],0.80,fcbb8aa0b93ef7b6da549497a1e4676f
1,"[drums, bass, rock, emotional, pop]","[drums, electronic drums, energetic, moderate ...",[moderate tempo],[rapping],[energetic],[electronic drums],0.80,831702271c06dba496b65c120cbe2fbc
2,"[bass, electronic, dance, techno, emotional, pop]","[energetic, fast tempo, male singer, medium te...","[medium tempo, fast tempo]",[],"[energetic, passionate]",[male singer],0.80,42b34a3f75ac2b295df782e4bf93b9a8
3,"[bass, electronic, emotional, pop, violin]","[bass, electronic, happy mood, male voice, mod...","[moderate tempo, slow tempo]",[],[happy mood],"[male voice, no voices, percussion, no voice]",0.80,e465878054dcfdb69c60b0f460e05c07
4,"[drums, bass, electronic, emotional, pop]","[bass, dark, digital drums, emotional, medium ...","[slow tempo, medium to uptempo]",[],"[dark, sad, emotional]",[digital drums],0.80,dc206d3ba40cf46b9b34846ced92dab0
...,...,...,...,...,...,...,...,...
9228,"[drums, bass, dance, keyboard, happy, house]","[bass, electronic drums, energetic, fast tempo...","[uptempo, fast tempo, upbeat]",[techno],[energetic],[electronic drums],1.75,6d79bbb03a0be5de8b8c20652c5e7188
9229,"[drums, bass, dance, keyboard, happy, house]","[bass, dance, electronic drums, energetic, gro...","[upbeat, groovy]",[],"[energetic, playful]","[electronic drums, male vocal]",1.75,c59d37138843e6f8f44b950ab63e632a
9230,"[drums, bass, dance, funk, keyboard, happy]","[bass, digital drums, funk, medium to uptempo,...",[medium to uptempo],[rap],[],[digital drums],1.75,2e4184f227dd05c3046300cbcd925dbb
9231,"[drums, bass, dance, keyboard, happy]","[aggressive, electronic drums, energetic, groo...",[groovy],[rapping],"[energetic, aggressive]","[electronic drums, male vocal]",1.75,0e148b77cb0c370cb090efbffc18bdf2


## Push to Hugginface Hub

In [38]:
from sklearn.model_selection import train_test_split

df_train, df_valid = train_test_split(res_df, test_size=0.1, random_state=42)
df_valid, df_test = train_test_split(df_valid, test_size=0.5, random_state=42)

In [39]:
from pathlib import Path

# Create output directory
output_dir = Path("../data/vae_mtg_tags")
output_dir.mkdir(parents=True, exist_ok=True)

df_train.to_csv(output_dir / "train.csv", index=False)
df_valid.to_csv(output_dir / "validation.csv", index=False)
df_test.to_csv(output_dir / "test.csv", index=False)
all_df = pd.concat([df_train, df_valid, df_test])
all_df.to_csv(output_dir / "all.csv", index=False)

In [40]:
data_files = {
    "train": str(output_dir / "train.csv"),
    "validation": str(output_dir / "validation.csv"),
    "test": str(output_dir / "test.csv")
}
dataset = load_dataset("csv", data_files=data_files)
dataset.push_to_hub("bsienkiewicz/mtg_vae_tags_dataset", private=True)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

README.md:   0%|          | 0.00/813 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/bsienkiewicz/mtg_vae_tags_dataset/commit/c45157bfe397e1b876214aa2f562bbda2b3a4d0b', commit_message='Upload dataset', commit_description='', oid='c45157bfe397e1b876214aa2f562bbda2b3a4d0b', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/bsienkiewicz/mtg_vae_tags_dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='bsienkiewicz/mtg_vae_tags_dataset'), pr_revision=None, pr_num=None)