In [1]:
import ast
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import Counter, defaultdict
from datasets import load_dataset
from scipy.stats import norm, multivariate_normal

seed = 42
np.random.seed(seed)

ModuleNotFoundError: No module named 'tqdm'

## Define Tag Categories

Define all possible tags for each category based on the dataset.

In [2]:
KNOWN_TAGS = json.load(open("../data/concepts_to_tags.json", "r"))

# Reverse map for easy lookup (tag -> category)
TAG_TO_CATEGORY = {}
for cat, tags in KNOWN_TAGS.items():
    for tag in tags:
        TAG_TO_CATEGORY[tag] = cat


## Load MTG Jaemdo dataset as a baseline

In [26]:
out_buffer = []

with open("../data/mtg_jamendo/autotagging_top50tags.tsv", "r") as f:
    for line in f.readlines():
        strings = line.strip().split('\t')
        track_id = strings[0]
        tags = strings[5:]  # Assuming tags start from the 6th column
        out_buffer.append({
            "id": track_id,
            "tags": tags
        })
with open("../data/mtg_jamendo/autotagging_top50tags_processed.csv", "w") as f:
    f.write("id,tags\n")
    for item in out_buffer:
        f.write(f"{item['id']},\"{';'.join(item['tags'])}\"\n")

In [27]:
mtg_df = pd.read_csv("../data/mtg_jamendo/autotagging_top50tags_processed.csv", converters={
    'tags': lambda x: x.split(';')
})
mtg_df

Unnamed: 0,id,tags
0,TRACK_ID,[TAGS]
1,track_0000215,[genre---metal]
2,track_0000216,[genre---metal]
3,track_0000219,[genre---metal]
4,track_0000223,[genre---metal]
...,...,...
54376,track_1422056,"[genre---soundtrack, instrument---computer]"
54377,track_1422057,"[genre---soundtrack, instrument---computer]"
54378,track_1422058,"[genre---soundtrack, instrument---computer]"
54379,track_1422059,"[genre---soundtrack, instrument---computer]"


In [28]:
# Parse tags into categories
def parse_tags(tags):
    return pd.Series({
        'genre_tags': [t for t in tags if TAG_TO_CATEGORY[t] == 'genre'],
        'mood_tags': [t for t in tags if TAG_TO_CATEGORY[t] == 'mood'],
        'instrument_tags': [t for t in tags if TAG_TO_CATEGORY[t] == 'instrument']
    })

def clean_tags(tags):
    _tags = ast.literal_eval(str(tags))
    _tags = [t.split('---')[-1].strip() for t in _tags]
    return [t.lower() for t in _tags if t.lower() in TAG_TO_CATEGORY]

mtg_df[['genre_tags', 'mood_tags', 'instrument_tags']] = mtg_df['tags'].apply(clean_tags).apply(parse_tags)
mtg_df['aspect_list'] = mtg_df.apply(lambda row: list(set(
    row['genre_tags'] + row['mood_tags'] + row['instrument_tags']
)), axis=1)
mtg_df

Unnamed: 0,id,tags,genre_tags,mood_tags,instrument_tags,aspect_list
0,TRACK_ID,[TAGS],[],[],[],[]
1,track_0000215,[genre---metal],[metal],[],[],[metal]
2,track_0000216,[genre---metal],[metal],[],[],[metal]
3,track_0000219,[genre---metal],[metal],[],[],[metal]
4,track_0000223,[genre---metal],[metal],[],[],[metal]
...,...,...,...,...,...,...
54376,track_1422056,"[genre---soundtrack, instrument---computer]",[],[],[],[]
54377,track_1422057,"[genre---soundtrack, instrument---computer]",[],[],[],[]
54378,track_1422058,"[genre---soundtrack, instrument---computer]",[],[],[],[]
54379,track_1422059,"[genre---soundtrack, instrument---computer]",[],[],[],[]


In [None]:
mtg_df = mtg_df.where((mtg_df['genre_tags'].map(len) > 0) & (mtg_df['instrument_tags'].map(len) > 0)).dropna()
mtg_df

Unnamed: 0,id,tags,genre_tags,mood_tags,instrument_tags,aspect_list
607,track_0007391,"[genre---electronic, genre---pop, instrument--...","[electronic, pop]",[emotional],"[bass, drums, guitar, keyboard]","[drums, bass, guitar, electronic, emotional, p..."
1015,track_0015161,"[genre---instrumentalpop, genre---pop, genre--...","[pop, rock]",[emotional],"[bass, drums]","[drums, bass, rock, emotional, pop]"
1020,track_0015166,"[genre---dance, genre---electronic, genre---po...","[dance, electronic, pop, techno]",[emotional],[bass],"[bass, electronic, dance, techno, emotional, pop]"
1021,track_0015167,"[genre---chillout, genre---easylistening, genr...","[electronic, pop]",[emotional],"[bass, violin]","[bass, electronic, emotional, pop, violin]"
1023,track_0015169,"[genre---electronic, genre---instrumentalpop, ...","[electronic, pop]",[emotional],"[bass, drums]","[drums, bass, electronic, emotional, pop]"
...,...,...,...,...,...,...
54313,track_1420702,"[genre---dance, genre---easylistening, genre--...",[dance],"[funk, happy]","[bass, drums, keyboard]","[drums, bass, dance, funk, keyboard, happy]"
54314,track_1420704,"[genre---dance, genre---easylistening, instrum...",[dance],[happy],"[bass, drums, keyboard]","[drums, bass, dance, keyboard, happy]"
54315,track_1420705,"[genre---dance, genre---easylistening, instrum...",[dance],[happy],"[bass, drums, keyboard]","[drums, bass, dance, keyboard, happy]"
54316,track_1420706,"[genre---dance, genre---easylistening, instrum...",[dance],[happy],"[bass, drums, keyboard]","[drums, bass, dance, keyboard, happy]"


In [30]:
mtg_df.to_csv("../data/mtg_jamendo/autotagging_top50tags_processed_cleaned.csv", index=False)

## Compare with training dataset

In [None]:
df = pd.read_csv("../data/mtg_jamendo/autotagging_top50tags_processed_cleaned.csv")
df['aspect_list'] = df['aspect_list'].apply(ast.literal_eval)
df['instrument_tags'] = df['instrument_tags'].apply(ast.literal_eval)
df['genre_tags'] = df['genre_tags'].apply(ast.literal_eval)
df['mood_tags'] = df['mood_tags'].apply(ast.literal_eval)
df

In [None]:
df_train = pd.read_csv("../data/musiccaps_tags_to_description_dataset.csv")
df_train = df_train.reset_index(drop=True)
df_train = df_train.fillna('')
df_train

In [None]:

# Convert string representations to lists
df['aspect_list'] = df['aspect_list'].apply(ast.literal_eval)
df['instrument_tags'] = df['instrument_tags'].apply(ast.literal_eval)
df['genre_tags'] = df['genre_tags'].apply(ast.literal_eval)
df['mood_tags'] = df['mood_tags'].apply(ast.literal_eval)

df_train['aspect_list'] = df_train['aspect_list'].apply(lambda x: x.split(', ') if isinstance(x, str) else [])
df_train['instrument_tags'] = df_train['instrument_tags'].apply(lambda x: x.split(', ') if isinstance(x, str) else [])
df_train['genre_tags'] = df_train['genre_tags'].apply(lambda x: x.split(', ') if isinstance(x, str) else [])
df_train['mood_tags'] = df_train['mood_tags'].apply(lambda x: x.split(', ') if isinstance(x, str) else [])


## Tag Distribution Comparison


In [None]:

# Compare dataset sizes and basic statistics
import seaborn as sns
import matplotlib.pyplot as plt

print("=" * 60)
print("DATASET OVERVIEW")
print("=" * 60)
print(f"\nMusicCaps Training Dataset:")
print(f"  Total samples: {len(df_train)}")
print(f"  Total samples with captions: {len(df_train)}")

print(f"\nMTG Jamendo Dataset:")
print(f"  Total samples: {len(df)}")

# Tag count statistics
tag_columns = ['aspect_list', 'genre_tags', 'mood_tags', 'instrument_tags']

print("\n" + "=" * 60)
print("TAG COUNT STATISTICS")
print("=" * 60)

for col in tag_columns:
    mtg_counts = df[col].apply(len)
    train_counts = df_train[col].apply(len)
    
    print(f"\n{col.upper()}:")
    print(f"  MusicCaps - Mean: {train_counts.mean():.2f}, Median: {train_counts.median():.0f}, Std: {train_counts.std():.2f}")
    print(f"  MTG Jamendo - Mean: {mtg_counts.mean():.2f}, Median: {mtg_counts.median():.0f}, Std: {mtg_counts.std():.2f}")


In [None]:

# Visualize tag count distributions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for idx, col in enumerate(tag_columns):
    ax = axes[idx]
    
    mtg_counts = df[col].apply(len)
    train_counts = df_train[col].apply(len)
    
    # Create side-by-side histograms
    bins = np.arange(0, max(mtg_counts.max(), train_counts.max()) + 2)
    
    ax.hist(train_counts, bins=bins, alpha=0.6, label='MusicCaps', color='blue', edgecolor='black')
    ax.hist(mtg_counts, bins=bins, alpha=0.6, label='MTG Jamendo', color='orange', edgecolor='black')
    
    ax.set_xlabel('Number of Tags')
    ax.set_ylabel('Frequency')
    ax.set_title(f'Distribution: {col.replace("_", " ").title()}')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


In [None]:

# Extract unique tags and compare vocabularies
def get_tag_vocabulary(series):
    all_tags = []
    for tags in series:
        all_tags.extend(tags)
    return Counter(all_tags)

print("\n" + "=" * 60)
print("TAG VOCABULARY COMPARISON")
print("=" * 60)

vocab_categories = ['genre_tags', 'mood_tags', 'instrument_tags']

for col in vocab_categories:
    mtg_vocab = get_tag_vocabulary(df[col])
    train_vocab = get_tag_vocabulary(df_train[col])
    
    mtg_tags = set(mtg_vocab.keys())
    train_tags = set(train_vocab.keys())
    
    common_tags = mtg_tags & train_tags
    mtg_only = mtg_tags - train_tags
    train_only = train_tags - mtg_tags
    
    print(f"\n{col.upper()}:")
    print(f"  MusicCaps vocabulary size: {len(train_tags)}")
    print(f"  MTG Jamendo vocabulary size: {len(mtg_tags)}")
    print(f"  Common tags: {len(common_tags)}")
    print(f"  MTG-only tags: {len(mtg_only)}")
    print(f"  MusicCaps-only tags: {len(train_only)}")
    
    if mtg_only:
        print(f"    MTG-only: {', '.join(sorted(list(mtg_only))[:10])}{'...' if len(mtg_only) > 10 else ''}")
    if train_only:
        print(f"    MusicCaps-only: {', '.join(sorted(list(train_only))[:10])}{'...' if len(train_only) > 10 else ''}")


In [None]:

# Compare top tags in each category
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

for idx, col in enumerate(vocab_categories):
    ax = axes[idx]
    
    mtg_vocab = get_tag_vocabulary(df[col])
    train_vocab = get_tag_vocabulary(df_train[col])
    
    # Get top 15 tags
    mtg_top = dict(mtg_vocab.most_common(15))
    train_top = dict(train_vocab.most_common(15))
    
    all_tags = set(mtg_top.keys()) | set(train_top.keys())
    
    mtg_counts = [mtg_top.get(tag, 0) for tag in sorted(all_tags)]
    train_counts = [train_top.get(tag, 0) for tag in sorted(all_tags)]
    
    x = np.arange(len(all_tags))
    width = 0.35
    
    ax.bar(x - width/2, train_counts, width, label='MusicCaps', color='blue', alpha=0.7)
    ax.bar(x + width/2, mtg_counts, width, label='MTG Jamendo', color='orange', alpha=0.7)
    
    ax.set_xlabel('Tags')
    ax.set_ylabel('Frequency')
    ax.set_title(f'Top Tags: {col.replace("_", " ").title()}')
    ax.set_xticks(x)
    ax.set_xticklabels(sorted(all_tags), rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()


In [None]:

# Compare tag coverage and overlap
print("\n" + "=" * 60)
print("TAG COVERAGE ANALYSIS")
print("=" * 60)

for col in vocab_categories:
    mtg_vocab = get_tag_vocabulary(df[col])
    train_vocab = get_tag_vocabulary(df_train[col])
    
    mtg_tags = set(mtg_vocab.keys())
    train_tags = set(train_vocab.keys())
    
    common_tags = mtg_tags & train_tags
    
    # Calculate coverage: how much of each dataset's tags appear in the other
    mtg_coverage_in_train = len(common_tags) / len(mtg_tags) * 100 if mtg_tags else 0
    train_coverage_in_mtg = len(common_tags) / len(train_tags) * 100 if train_tags else 0
    
    print(f"\n{col.upper()}:")
    print(f"  MTG tags appearing in MusicCaps: {mtg_coverage_in_train:.1f}%")
    print(f"  MusicCaps tags appearing in MTG: {train_coverage_in_mtg:.1f}%")
    
    # Jaccard similarity
    jaccard = len(common_tags) / len(mtg_tags | train_tags)
    print(f"  Jaccard similarity: {jaccard:.3f}")


In [None]:

# Create correlation heatmap comparing tag counts across categories
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# MusicCaps correlation
tag_count_data_train = {}
for col in vocab_categories:
    tag_count_data_train[col.replace('_tags', '')] = df_train[col].apply(len)
tag_count_df_train = pd.DataFrame(tag_count_data_train)

sns.heatmap(tag_count_df_train.corr(), annot=True, fmt='.2f', cmap='coolwarm', 
            vmin=-1, vmax=1, ax=axes[0], cbar_kws={'label': 'Correlation'})
axes[0].set_title('MusicCaps: Tag Count Correlation')

# MTG Jamendo correlation
tag_count_data_mtg = {}
for col in vocab_categories:
    tag_count_data_mtg[col.replace('_tags', '')] = df[col].apply(len)
tag_count_df_mtg = pd.DataFrame(tag_count_data_mtg)

sns.heatmap(tag_count_df_mtg.corr(), annot=True, fmt='.2f', cmap='coolwarm', 
            vmin=-1, vmax=1, ax=axes[1], cbar_kws={'label': 'Correlation'})
axes[1].set_title('MTG Jamendo: Tag Count Correlation')

plt.tight_layout()
plt.show()


In [None]:

# Create a comprehensive summary table
print("\n" + "=" * 80)
print("COMPREHENSIVE DATASET COMPARISON SUMMARY")
print("=" * 80)

comparison_data = []

for col in vocab_categories:
    mtg_vocab = get_tag_vocabulary(df[col])
    train_vocab = get_tag_vocabulary(df_train[col])
    
    mtg_tags = set(mtg_vocab.keys())
    train_tags = set(train_vocab.keys())
    
    common_tags = mtg_tags & train_tags
    mtg_counts = df[col].apply(len)
    train_counts = df_train[col].apply(len)
    
    comparison_data.append({
        'Category': col.replace('_tags', '').title(),
        'MusicCaps Samples': len(df_train),
        'MTG Samples': len(df),
        'MC Vocab Size': len(train_tags),
        'MTG Vocab Size': len(mtg_tags),
        'Common Tags': len(common_tags),
        'MC Avg Tags': f"{train_counts.mean():.2f}",
        'MTG Avg Tags': f"{mtg_counts.mean():.2f}",
        'Jaccard Sim': f"{len(common_tags) / len(mtg_tags | train_tags):.3f}"
    })

comparison_df = pd.DataFrame(comparison_data)
print("\n" + comparison_df.to_string(index=False))
