In [None]:
import os
import json
import ast
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
from datasets import load_dataset
from pathlib import Path

sns.set_style('whitegrid')
# Increase font sizes for readability
sns.set_context('notebook', font_scale=1.4)
plt.rcParams.update({
    'figure.figsize': (18, 8),
    'font.size': 32,
    'axes.titlesize': 28,
    'axes.labelsize': 25,
    'xtick.labelsize': 23,
    'ytick.labelsize': 23,
    'legend.fontsize': 21
})

## Load MusicCaps dataset

In [None]:
ds = load_dataset("google/MusicCaps")
df_train = ds['train'].to_pandas()
df_train['aspect_list_transformed'] = df_train['aspect_list'].apply(ast.literal_eval)

### Analyse tag counts for concept extraction

In [None]:
tag_counts = {}
for aspects in df_train['aspect_list_transformed']:
    for tag in aspects:
        if tag in tag_counts:
            tag_counts[tag] += 1
        else:
            tag_counts[tag] = 1
tag_counts_df = pd.DataFrame(list(tag_counts.items()), columns=['Tag', 'Count'])
tag_counts_df = tag_counts_df.sort_values(by='Count', ascending=False)

In [None]:
_tag_counts_df = tag_counts_df.head(20)
plt.figure(figsize=(18, 10))
sns.barplot(data=_tag_counts_df, x='Tag', y='Count', palette='viridis')
plt.xticks(rotation=45)
plt.title('Tag Frequency in MusicCaps Dataset')
plt.xlabel('Tags')
plt.ylabel('Count')
plt.tight_layout()
plt.show()

In [None]:
tag_counts_df.to_csv("../data/musiccaps_tag_frequencies.csv", index=False)

## Prepare concept dataset

In [None]:
def extract_tags(song_tags, concept_tags):
    res = []
    for c_tag in concept_tags:
        for s_tag in song_tags:
            if c_tag == s_tag:
                res.append(s_tag)
    return list(set(res))

In [None]:
concepts = json.load(open("../data/concepts_to_tags.json", "r"))

In [None]:
# remove 'low quality' from captions
df_train['caption'] = df_train['caption'].str.replace('low quality', '', case=False)

In [None]:
for concept, tags in concepts.items():
    df_train[concept + '_tags'] = df_train['aspect_list_transformed'].apply(
        lambda x: extract_tags(x, tags)
    )
df_train

In [None]:
# select rows with tags in at least 3 categories
def count_nonempty_tags(row):
    count = 0
    for col in ['tempo_tags', 'genre_tags', 'mood_tags', 'instrument_tags']:
        if row[col]:
            count += 1
    return count
df_train = df_train[df_train.apply(count_nonempty_tags, axis=1) >= 3]
df_train.reset_index(drop=True, inplace=True)
df_train

In [None]:
df_train = df_train[["caption", "aspect_list_transformed", "tempo_tags", "genre_tags", "mood_tags", "instrument_tags"]]
df_train["combined_tags"] = df_train["tempo_tags"] + df_train["genre_tags"] + df_train["mood_tags"] + df_train["instrument_tags"]
df_train["aspect_list"] = df_train["combined_tags"].apply(lambda x: ', '.join(x))
df_train["tempo_tags"] = df_train["tempo_tags"].apply(lambda x: ', '.join(x))
df_train["genre_tags"] = df_train["genre_tags"].apply(lambda x: ', '.join(x))
df_train["mood_tags"] = df_train["mood_tags"].apply(lambda x: ', '.join(x))
df_train["instrument_tags"] = df_train["instrument_tags"].apply(lambda x: ', '.join(x))
df_train = df_train[["caption", "aspect_list", "tempo_tags", "genre_tags", "mood_tags", "instrument_tags"]]
df_train

In [None]:
df_train.to_csv("../data/musiccaps_tags_to_description_dataset.csv", index=False)

## Analyse dataset

In [None]:
df_train = pd.read_csv("../data/musiccaps_tags_to_description_dataset.csv")
df_train = df_train.reset_index(drop=True)
df_train = df_train.fillna('')
df_train

In [None]:
# Display basic statistics
tag_columns = ["tempo_tags", "genre_tags", "mood_tags", "instrument_tags", "aspect_list"]
for col in tag_columns:
    df_train[col + '_count'] = df_train[col].apply(lambda x: len(x.split(', ')))
display(df_train[[col + '_count' for col in tag_columns]].describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99]))

In [None]:
# calculate tag len distribution
tag_len_counts = df_train['aspect_list'].map(lambda x: len(x.split(', '))).value_counts().sort_index()
plt.bar(tag_len_counts.index, tag_len_counts.values)
plt.xlabel("Number of Tags")
plt.ylabel("Number of Samples")
plt.title("Distribution of Tag Counts in Training dataset", fontweight='bold')
plt.savefig("../docs/assets/tag_count_distribution_musiccaps.pdf", bbox_inches='tight')
plt.show()

In [None]:
# Plot distribution of number of tags per category
tag_columns = ["tempo_tags", "genre_tags", "mood_tags", "instrument_tags"]
plt.figure(figsize=(20, 12))
titles = ['Tempo Tags', 'Genre Tags', 'Mood Tags', 'Instrument Tags']
for i, col in enumerate(tag_columns, 1):
    plt.subplot(2, 2, i)
    tag_counts = df_train[col].apply(lambda x: len(x.split(', ')))
    
    # Create histogram with centered bin labels
    n, bins, patches = plt.hist(tag_counts, bins=range(1, tag_counts.max() + 2), edgecolor='black')
    bin_centers = (bins[:-1] + bins[1:]) / 2
    plt.xticks(bin_centers, [int(x) for x in bin_centers], rotation=0)
    
    plt.title(titles[i-1], fontweight='bold')
    plt.xlabel('Number of Tags')
    plt.ylabel('Frequency')
plt.suptitle('Distribution of Number of Tags per Category\nin Training Dataset', fontweight='bold')
plt.tight_layout(rect=[0, 0.01, 1, 0.95])
plt.savefig("../docs/assets/tag_count_distribution_per_category.pdf")
plt.show()

In [None]:
# Display distribution of caption lengths
df_train['caption_length'] = df_train['caption'].apply(lambda x: len(x.split()))
plt.figure(figsize=(10, 6))
sns.histplot(df_train['caption_length'], bins=30, kde=True, color='skyblue')
plt.title('Distribution of Caption Lengths\nin Training dataset', fontweight='bold')
plt.xlabel('Number of Words')
plt.ylabel('Frequency')
plt.tight_layout()
plt.savefig("../docs/assets/caption_length_distribution_musiccaps.pdf", bbox_inches='tight')
plt.show()

## Save final dataset

In [None]:
from sklearn.model_selection import train_test_split

df_train, df_valid = train_test_split(df_train, test_size=0.1, random_state=42)
df_valid, df_test = train_test_split(df_valid, test_size=0.5, random_state=42)

In [None]:
# Create output directory
output_dir = Path("../data/distilled-musiccaps")
output_dir.mkdir(parents=True, exist_ok=True)

df_train.to_csv(output_dir / "train.csv", index=False)
df_valid.to_csv(output_dir / "validation.csv", index=False)
df_test.to_csv(output_dir / "test.csv", index=False)
all_df = pd.concat([df_train, df_valid, df_test])
all_df.to_csv(output_dir / "all.csv", index=False)

In [None]:
data_files = {
    "train": str(output_dir / "train.csv"),
    "validation": str(output_dir / "validation.csv"),
    "test": str(output_dir / "test.csv")
}
dataset = load_dataset("csv", data_files=data_files)
dataset.push_to_hub("bsienkiewicz/distilled-musiccaps", private=True)