In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from collections import Counter, defaultdict
import itertools
import random
from datasets import load_dataset

## Load Existing MusicCaps Data

Load the existing tags dataset to analyze tag co-occurrence patterns.

In [2]:
ds = load_dataset("google/MusicCaps")
df = ds['train'].to_pandas()
df['aspect_list_transformed'] = df['aspect_list'].apply(lambda x: x.strip("[]").replace("'", ""))
df['aspect_list_transformed'] = df['aspect_list_transformed'].apply(lambda x: x.split(', '))

README.md: 0.00B [00:00, ?B/s]

musiccaps-public.csv:   0%|          | 0.00/2.94M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5521 [00:00<?, ? examples/s]

## Define Tag Categories

Define all possible tags for each category based on the dataset.

In [4]:
tempo_tags = [
    "fast tempo", "medium tempo", "slow tempo", "moderate tempo", "uptempo",
    "medium fast tempo", "slower tempo", "medium to uptempo", "mid-tempo",
    "quick tempo", "accelerated tempo", "steady tempo", "rapid tempo",
    "slow music", "very fast tempo", "slow to medium tempo", "medium-to-high pitch singing",
    "steady drumming rhythm", "dance rhythm", "various tempos", "tempo changes",
    "fast paced", "slow song", "mid tempo", "steady beat", "pulsating beats",
    "groovy rhythm", "4 on the floor kick pattern", "normal tempo", "fast beat"
]

genre_tags = [
    "rock", "pop", "jazz", "classical", "folk", "blues", "hip hop", "reggae",
    "metal", "country", "r&b", "edm", "trance", "techno", "dance music",
    "electronic dance music", "gospel", "ambient", "soul", "funk",
    "alternative rock", "ballad", "hip-hop", "techno pop", "world music",
    "disco", "trap", "punk rock", "latin pop", "house", "bluegrass",
    "indie rock", "new age", "grunge", "industrial", "dubstep",
    "carnatic music", "bossa nova", "baroque music", "surf rock",
    "ska", "lo-fi", "symphonic", "orchestral", "fusion music", "raga",
    "bollywood music", "afrobeat", "folk song", "christian rock", "soundtrack"
]

mood_tags = [
    "emotional", "passionate", "happy", "melancholic", "relaxing", "calming",
    "upbeat", "exciting", "mellow", "sentimental", "soothing", "joyful",
    "intense", "peaceful", "dreamy", "romantic mood", "ominous", "suspenseful",
    "haunting", "energetic", "chill", "cheerful", "nostalgic", "fun",
    "cool", "ethereal", "sad", "spooky", "hopeful", "playful",
    "mystical", "dark", "solemn", "festive", "inspirational", "sentimental",
    "powerful", "serene", "mysterious", "emphatic", "tranquil", "passionate singing",
    "ominous music", "romantic", "meditative", "joyous", "heartfelt", "uplifting",
    "enthusiastic", "melancholy", "emotional voice", "soothing melody", "heavenly", 
    "fearful", "vibrant", "soulful", "excited", "energetic drums", "charming"
]

instrument_tags = [
    "piano", "drums", "guitar", "bass guitar", "electric guitar", "acoustic guitar",
    "flute", "violin", "cello", "trumpet", "saxophone", "tambourine",
    "synth", "harmonica", "organ", "harp", "clarinet", "string section",
    "percussion", "banjo", "trombone", "didgeridoo", "mandolin", "tabla",
    "ukulele", "accordion", "xylophone", "viola", "timpani", "congas",
    "bongo", "triangle", "oboe", "bagpipes", "steel drums", "marimba",
    "dj mixer", "drum machine", "brass section", "horn", "sitar",
    "strings", "keyboard", "double bass", "synth bass", "guitar solo",
    "electric piano", "acoustic piano", "woodwind", "cymbals", "bells",
    "vibraphone", "hand claps", "snare", "hi-hat", "kick drum", 
    "conga", "tabla percussion", "theremin", "church organ", "trumpets",
    "bass drum", "djembe", "steel guitar", "harpsichord", "choir"
]

In [5]:
def extract_tags(song_tags, concept_tags):
    res = []
    for c_tag in concept_tags:
        for s_tag in song_tags:
            if c_tag in s_tag:
                res.append(s_tag)
    return list(set(res))

In [6]:
concepts = {
    "tempo": tempo_tags,
    "genre": genre_tags,
    "mood": mood_tags,
    "instrument": instrument_tags
}

for concept, tags in concepts.items():
    df[concept + '_tags'] = df['aspect_list_transformed'].apply(
        lambda x: extract_tags(x, tags)
    )

In [None]:
# Add new cell after tag statistics analysis

# Calculate conditional probabilities (co-occurrence)
def build_conditional_probs(df, category_cols):
    """Build conditional probability matrix for tags.
    
    P(tag_j | tag_i) = count(tag_i AND tag_j) / count(tag_i)
    """
    conditional_probs = defaultdict(lambda: defaultdict(float))
    tag_counts = defaultdict(int)
    
    for category in category_cols:
        for _, row in df.iterrows():
            tags = row[category]
            if not tags:
                continue
                
            # Count co-occurrences
            for tag_i in tags:
                tag_counts[tag_i] += 1
                for tag_j in tags:
                    if tag_i != tag_j:
                        conditional_probs[tag_i][tag_j] += 1
    
    # Normalize to probabilities
    for tag_i in conditional_probs:
        total = tag_counts[tag_i]
        for tag_j in conditional_probs[tag_i]:
            conditional_probs[tag_i][tag_j] /= total
    
    return dict(conditional_probs), dict(tag_counts)

# Build conditional probability matrices
print("Building conditional probability matrices...")
category_cols = ['instrument_tags', 'mood_tags', 'genre_tags', 'tempo_tags']
conditional_probs, tag_counts = build_conditional_probs(df, category_cols)

print(f"Built conditional probabilities for {len(conditional_probs)} tags")

# Example: show what tags commonly appear with "piano"
if "piano" in conditional_probs:
    print("\nTop tags that appear with 'piano':")
    piano_cooccur = sorted(conditional_probs["piano"].items(), 
                          key=lambda x: x[1], reverse=True)[:10]
    for tag, prob in piano_cooccur:
        print(f"  {tag}: {prob:.3f}")

Building conditional probability matrices...
Built conditional probabilities for 3995 tags

Top tags that appear with 'piano':
  acoustic drums: 0.238
  bass guitar: 0.146
  strings: 0.113
  e-guitar: 0.079
  acoustic guitar: 0.059
  percussion: 0.059
  saxophone: 0.054
  electric guitar: 0.050
  electronic drums: 0.046
  drums: 0.046
  violin: 0.046
  trumpet: 0.042
  flute: 0.042
  trumpets: 0.038
  digital drums: 0.033
  simple percussion: 0.033
  double bass: 0.029
  no percussion: 0.025
  organ: 0.025
  keyboard: 0.025
  brass section: 0.025
  trombone: 0.025
  guitar: 0.025
  synth: 0.021
  french horn: 0.021
  latin percussion: 0.021
  clarinet: 0.021
  choir: 0.017
  synth strings: 0.017
  cello: 0.017


## Load MTG Jaemdo dataset as a baseline

In [20]:
out_buffer = []

with open("autotagging_top50tags.tsv", "r") as f:
    for line in f.readlines():
        strings = line.strip().split('\t')
        track_id = strings[0]
        tags = strings[5:]  # Assuming tags start from the 6th column
        out_buffer.append({
            "id": track_id,
            "tags": tags
        })
with open("autotagging_top50tags_processed.csv", "w") as f:
    f.write("id,tags\n")
    for item in out_buffer:
        f.write(f"{item['id']},\"{';'.join(item['tags'])}\"\n")

In [25]:
mtg_df = pd.read_csv("autotagging_top50tags_processed.csv", converters={
    'tags': lambda x: x.split(';')
})
mtg_df

Unnamed: 0,id,tags
0,TRACK_ID,[TAGS]
1,track_0000215,[genre---metal]
2,track_0000216,[genre---metal]
3,track_0000219,[genre---metal]
4,track_0000223,[genre---metal]
...,...,...
54376,track_1422056,"[genre---soundtrack, instrument---computer]"
54377,track_1422057,"[genre---soundtrack, instrument---computer]"
54378,track_1422058,"[genre---soundtrack, instrument---computer]"
54379,track_1422059,"[genre---soundtrack, instrument---computer]"


In [26]:
# Parse tags into categories
def parse_tags(tags):
    genre = extract_tags(tags, genre_tags)
    mood = extract_tags(tags, mood_tags)
    instrument = extract_tags(tags, instrument_tags)
    return pd.Series({
        'genre_tags': genre,
        'mood_tags': mood,
        'instrument_tags': instrument
    })

mtg_df[['genre_tags', 'mood_tags', 'instrument_tags']] = mtg_df['tags'].apply(parse_tags)

In [27]:
mtg_df = mtg_df.where((mtg_df['mood_tags'].map(len) > 0) & (mtg_df['genre_tags'].map(len) > 0) & (mtg_df['instrument_tags'].map(len) > 0)).dropna()

In [None]:
# Clean the tag names
def clean_tag(tag):
    return tag.split('---')[-1].strip()
mtg_df['genre_tags'] = mtg_df['genre_tags'].apply(lambda tags: [clean_tag(t) for t in tags])
mtg_df['mood_tags'] = mtg_df['mood_tags'].apply(lambda tags: [clean_tag(t) for t in tags])
mtg_df['instrument_tags'] = mtg_df['instrument_tags'].apply(lambda tags: [clean_tag(t) for t in tags])
mtg_df['tags'] = mtg_df.apply(lambda row: row['genre_tags'] + row['mood_tags'] + row['instrument_tags'], axis=1)
mtg_df

In [31]:
mtg_df.to_csv("autotagging_top50tags_processed_cleaned.csv", index=False)

## Create new dataset with shuffled tags

In [32]:
def sample_with_temperature(items, probs, temperature=1.0):
    """Apply temperature to probability distribution and sample."""
    if not items:
        return None
    
    prob_array = np.array([probs.get(item, 1e-10) for item in items])
    prob_array = prob_array ** (1 / temperature)
    prob_array = prob_array / prob_array.sum()
    
    return np.random.choice(items, p=prob_array)

In [48]:
def generate_sample_causal(
    existing_row,
    variety_factor=0.5,
    base_temp=1.0,
):
    """Generate a sample with causal (conditional) tag selection.
    
    Args:
        variety_factor: Controls randomness (0 = deterministic, 1 = very random)
        base_temp: Base temperature for sampling
    """
    temp = base_temp + variety_factor * 2.0
    
    selected_instruments = []
    for instrument in existing_row['instrument_tags']:
        sampled_tag = sample_with_temperature(
            list(conditional_probs.get(instrument, {}).keys()),
            conditional_probs.get(instrument, {}),
            temperature=temp
        )
        if sampled_tag:
            selected_instruments.append(sampled_tag)
    selected_instruments += existing_row['instrument_tags']
    
    selected_moods = []
    for mood in existing_row['mood_tags']:
        sampled_tag = sample_with_temperature(
            list(conditional_probs.get(mood, {}).keys()),
            conditional_probs.get(mood, {}),
            temperature=temp
        )
        if sampled_tag:
            selected_moods.append(sampled_tag)
    selected_moods += existing_row['mood_tags']

    selected_genres = []
    for genre in existing_row['genre_tags']:
        sampled_tag = sample_with_temperature(
            list(conditional_probs.get(genre, {}).keys()),
            conditional_probs.get(genre, {}),
            temperature=temp
        )
        if sampled_tag:
            selected_genres.append(sampled_tag)
    selected_genres += existing_row['genre_tags']

    # Select tempo tags (not in MTG dataset)
    selected_tempos = []
    sampled_tag = sample_with_temperature(
        tempo_tags,
        {tag: 1.0 for tag in tempo_tags},
        temperature=temp
    )
    selected_tempos.append(sampled_tag)
    for _ in range(random.randint(1, 3)):
        sampled_tag = sample_with_temperature(
            list(conditional_probs.get(sampled_tag, {}).keys()),
            conditional_probs.get(sampled_tag, {}),
            temperature=temp
        )
        if sampled_tag and sampled_tag not in selected_tempos:
            selected_tempos.append(sampled_tag)

    selected_tags = selected_instruments + selected_moods + selected_genres + selected_tempos
    
    return {
        'instrument_tags': ', '.join(selected_instruments),
        'mood_tags': ', '.join(selected_moods),
        'genre_tags': ', '.join(selected_genres),
        'tempo_tags': ', '.join(selected_tempos),
        'aspect_list': ', '.join(selected_tags)
    }

In [52]:
def generate_dataset_causal_tags(df, temperature=1.0):
    """Generate a new dataset with causal tags."""
    new_data = []
    for _, row in df.iterrows():
        new_row = generate_sample_causal(row, base_temp=temperature)
        new_row['id'] = row['id']
        new_data.append(new_row)
    return pd.DataFrame(new_data)

In [53]:
new_mtg_causal_df = generate_dataset_causal_tags(mtg_df, temperature=1.0)
new_mtg_causal_df

Unnamed: 0,instrument_tags,mood_tags,genre_tags,tempo_tags,aspect_list,id
0,"synthpop/rock, synthesizer",chillout,"regional pop, ambient",very fast tempo,"synthpop/rock, synthesizer, chillout, regional...",track_0000946
1,"acoustic drums, synthesizer",chillout,"house, techno","various tempos, groovy rhythm, latin medium tempo","acoustic drums, synthesizer, chillout, house, ...",track_0000950
2,"digital choir sample, synthesizer",chillout,"techno music, ambient","uptempo, uptempo hi hats","digital choir sample, synthesizer, chillout, t...",track_0000953
3,"finger cymbals, synthesizer",chillout,"techno, ambient","slow tempo, steady drumming rhythm","finger cymbals, synthesizer, chillout, techno,...",track_0000954
4,"drums, synthesizer",chillout,"acoustic jazz drums, jazz","slow tempo, steady drumming rhythm, groovy dan...","drums, synthesizer, chillout, acoustic jazz dr...",track_0000955
...,...,...,...,...,...,...
4052,"e-guitar, piano chords, electric guitar, synth...","happy song, chillout, happy","disco music, house",normal tempo,"e-guitar, piano chords, electric guitar, synth...",track_1420700
4053,"piano/opera, electric guitar, organ legends, s...","playful music, happy","disco music, house","uptempo, fast paced, uptempo hi hats","piano/opera, electric guitar, organ legends, s...",track_1420701
4054,"e-guitar, guitar, guitar solo, synthesizer, dr...","synth funk, mellow bells melody, funk, happy","funky guitar playing, funk","fast paced, uptempo, fast paced percussion, tr...","e-guitar, guitar, guitar solo, synthesizer, dr...",track_1420702
4055,"synthpop/rock, electric guitar, snare drum, sy...","passionate girls vocals, chillout, happy","techno, house","uptempo, fast paced percussion","synthpop/rock, electric guitar, snare drum, sy...",track_1420707


In [51]:
# Analyze tag distributions in the new dataset

categories = ['instrument_tags', 'mood_tags', 'genre_tags', 'tempo_tags']

for category in categories:
    all_tags = list(itertools.chain.from_iterable(new_mtg_causal_df[category].apply(lambda x: x.split(', '))))
    tag_counter = Counter(all_tags)
    most_common_tags = tag_counter.most_common(10)
    print(f"Most common tags in category '{category}':")
    for tag, count in most_common_tags:
        print(f"  {tag}: {count}")
    print()

Most common tags in category 'instrument_tags':
  piano: 1838
  synthesizer: 1702
  drums: 1558
  electricguitar: 865
  keyboard: 852
  guitar: 832
  violin: 492
  electricpiano: 388
  strings: 328
  acousticguitar: 327

Most common tags in category 'mood_tags':
  chillout: 1354
  emotional: 778
  relaxing: 681
  happy: 664
  energetic: 641
  funk: 613
  fun: 78
  passionate: 74
  upbeat: 65
  enthusiastic: 61

Most common tags in category 'genre_tags':
  ambient: 1048
  soundtrack: 956
  pop: 815
  classical: 739
  funk: 626
  jazz: 562
  rock: 529
  house: 372
  folk: 235
  orchestral: 205

Most common tags in category 'tempo_tags':
  fast tempo: 797
  medium tempo: 793
  groovy rhythm: 558
  steady drumming rhythm: 510
  dance rhythm: 483
  uptempo: 425
  slow tempo: 341
  quick tempo: 296
  fast beat: 296
  medium fast tempo: 208



In [56]:
train_df = new_mtg_causal_df.sample(frac=0.8, random_state=42)
val_df = new_mtg_causal_df.drop(train_df.index)
test_df = val_df.sample(frac=0.5, random_state=42)
val_df = val_df.drop(test_df.index)

In [2]:
output_dir = Path("../data/mtg_causal_tags")
output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
train_df.to_csv(output_dir / "train.csv", index=False)
val_df.to_csv(output_dir / "validation.csv", index=False)
test_df.to_csv(output_dir / "test.csv", index=False)

In [3]:
data_files = {
    "train": str(output_dir / "train.csv"),
    "validation": str(output_dir / "validation.csv"),
    "test": str(output_dir / "test.csv")
}
dataset = load_dataset("csv", data_files=data_files)
dataset.push_to_hub("bsienkiewicz/mtg_causal_tags_dataset", private=True)

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

CommitInfo(commit_url='https://huggingface.co/datasets/bsienkiewicz/mtg_causal_tags_dataset/commit/c83f15ed52a0c1aab0882ebc2c4fe4defc9d1d64', commit_message='Upload dataset', commit_description='', oid='c83f15ed52a0c1aab0882ebc2c4fe4defc9d1d64', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/bsienkiewicz/mtg_causal_tags_dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='bsienkiewicz/mtg_causal_tags_dataset'), pr_revision=None, pr_num=None)