# Stratified data splitting:

## Problem statement

Each recording has metadata:
- pianist (20 classes)
- setting ∈ {solo, trio}
- album_id (group)
- composition_id (group)

And we need:
- 8/1/1 ratio
- Solo/trio proportion preserved
- Each pianist in every split
- No album leakage
- No composition leakage

In [42]:
import json
import re
import math
from pathlib import Path

import networkx as nx
import pandas as pd
import pulp
import unidecode
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

from deep_pianist_identification import utils

In [3]:
SIMILARITY_THRESH = 0.9


## Metadata curation

Given that the track names in JTD and PiJAMA are not resolved to their unique compositions (e.g., "A Night in Tunisia", "Night in Tunisia", "An Night in Tunisia (Alternative Take)" can all appear!), we use a human-in-the-loop pipeline to resolve these:
- Grab "raw" track names from JTD and PiJAMA, convert these to lowercase, and remove punctuation
- Clean up "junk" words and phrases with hardcoded regular expressions (e.g., 'Alternate Take', 'Interrupted', 'Live at the Maybeck Recital Hall')
    - Already, these two stages reduce the number of unique compositions from over 3000 to 2450
- Next, we map "contrafact" pieces onto their original composition, using a list of 450 composition pairs curated manually by a professional jazz musician
    - E.g., the Sonny Rollins tune "Oleo" is mapped onto the Gershwin composition "I Got Rhythm", which uses the same chords
    - We chose to do this as the performances in JTD are "solos only", meaning the melody (the only thing different between "Oleo" and "I Got Rhythm") is not likely not present
- Using a SentenceBERT model pretrained on 1B sentences, we extract embeddings for every cleaned title
    - We then compute the cosine similarity pairwise between every sentence, and use these scores to construct a graph: when the cosine similarity of two titles exceeds a threshold, we connect these two titles along an edge.
    - We tune the threshold manually and set it to 0.9 following preliminary experiments
 - After this process, we are left with 2348 unique compositions
    - We then manually correct any remaining misidentified compositions.
        - These are usually the results of spelling mistakes in the original track listings that were not resolved by the SentenceBERT clustering, e.g. Miserioso -> Misterioso (Monk tune), Sloliloquy -> Soliloquy.
    - This leads to a final canonic list of 2216 unique compositions

This process is significantly more advanced that the process followed previously by Edwards et al. (2023) to map compositions in PiJAMA to jazz standards, which used only a small number of hardcoded regular expressions. We release our finegrained annotations of both the complete JTD and PiJAMA datasets to the community as part of this paper, and anticipate that they may find use in cover song and jazz standard identification work.

In [4]:
# load up all metadata JSON files
js_files = list((Path(utils.get_project_root()) / "data").rglob("**/*.json"))


In [10]:
# read all JSONs, create metadata df
all_metadata = []
columns_to_keep = ["track_name", "album_name", "bandleader", "pianist", "recording_year", "mbz_id"]
for js_file in js_files:
    with open(js_file) as f:
        metadata = json.load(f)
    res = {k: v for k, v in metadata.items() if k in columns_to_keep}
    all_metadata.append(res)
df = pd.DataFrame(all_metadata)

### Remove "junk"

Start by removing junk words/phrases and fillers that never belong to the core title of a track

In [11]:
META_WORDS = {
    "take", "alt", "alternate", "version", "edit",
    "live", "remastered", "remaster", "mono", "stereo",
    "demo", "session", "rehearsal", "outtake",
    "bonus", "track", "disc", "cd",
    "instrumental", "vocal",
    "complete", "interrupted",
    "excerpt", "outro",
    "reprise", "medley", "intro", "recorded",
    "sequence", "first", "second"
}
ARTICLES = {"a", "an", "the"}


def remove_junk(title: str) -> str:
    """
    We need to remove 'junk' information from titles, like take information, 'remastered', etc.
    """
    # 1. Basic cleanup
    t = title.lower()
    t = unidecode.unidecode(t)
    # Normalize quotes/apostrophes
    t = t.replace("’", "'").replace("`", "'")
    # Replace "&" with "and"
    t = t.replace("&", "and").replace("+", "and")

    # 2. Remove parentheticals and brackets
    # (live at ...) [alt take] {demo}
    t = re.sub(r"[\(\[\{].*?[\)\]\}]", " ", t)

    # If all numeric, return original
    if all(t_.isdigit() for t_ in t.split()):
        return t

    # 3. Remove common recording annotations
    patterns = [
        # Take numbers
        r"\btake\s*\d+\b",
        r"\btk\s*\d+\b",
        # Alternate versions
        r"\balt(ernate)?\b",
        r"\bversion\b",
        r"\bedited?\b",
        # Live / venue
        r"\blive\b.*$",
        # r"\bat\s+[a-z0-9\s]+\b",
        # Dates
        r"\b\d{4}\b",
        r"\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b",
        # Disc / track numbers
        r"\b(cd|disc|track)\s*\d+\b",
        # Remaster info
        r"\b(remaster(ed)?|mono|stereo)\b",
        # Session info
        r"\b(session|rehearsal|outtake|demo)\b",
        # Arrangement info
        r"\b(instrumental|vocal)\b",
        # From soundtrack / album
        # r"\bfrom\s+.*$"
    ]
    for p in patterns:
        t = re.sub(p, " ", t)

    # 4. Remove punctuation
    t = re.sub(r"[^\w\s]", " ", t)

    # 5. Token cleanup
    tokens = t.split()
    cleaned = []
    for w in tokens:
        # Remove pure numbers
        if w.isdigit():
            continue
        # Remove articles
        # if w in ARTICLES:
        #     continue
        # Remove meta words
        if w in META_WORDS:
            continue
        cleaned.append(w)

    # 6. Rejoin and collapse spaces
    t = " ".join(cleaned)

    if t == "":
        t = re.sub(r"[^\w\s]", " ", title.lower())
        return re.sub(r"\s+", " ", t).strip()

    return re.sub(r"\s+", " ", t).strip()


In [12]:
dejunked = {t: remove_junk(t) for t in df["track_name"].unique()}

In [13]:
# now, map dejunked/contrafacted titles back to original titles
df["dejunked_track_name"] = df["track_name"].map(dejunked)

In [15]:
print(df["dejunked_track_name"].nunique(), df["track_name"].nunique())

2450 3044


### Map contrafacts to original compositions

We have a list of jazz contrafacts saved in `references/jazz_contrafacts.csv`: we can map compositions using this

In [6]:
contrafacts_path = Path(utils.get_project_root()) / "references/jazz_contrafacts.csv"
contrafacts_df = pd.read_csv(contrafacts_path)

# apply dejunking
contrafacts_df["contrafact_dj"] = contrafacts_df["Contrafact"].apply(remove_junk)
contrafacts_df["original_dj"] = contrafacts_df["Original Song"].apply(remove_junk)

contrafacts_dj = contrafacts_df[["contrafact_dj", "original_dj"]]
contrafacts_mapping = {c: d for c, d in contrafacts_dj.to_numpy()}

In [7]:
# remove contrafacts from dejunked list
for k, c in dejunked.items():
    if c in contrafacts_mapping:
        dejunked[k] = contrafacts_mapping[c]


### Compute similarity

Use a pretrained SentenceBERT model to compute inter-composition similarity. We set a `SIMILARITY_THRESH` and use this to add edges between a graph connecting compositions together.

In [9]:
bert = SentenceTransformer("all-MiniLM-L6-v2")
titles = df["dejunked_track_name"].tolist()
embeddings = bert.encode(titles, normalize_embeddings=True)



Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


In [10]:
sim = cosine_similarity(embeddings)

G = nx.Graph()
for i in range(len(titles)):
    for j in range(i + 1, len(titles)):
        if sim[i, j] > SIMILARITY_THRESH:
            G.add_edge(i, j)

clusters = list(nx.connected_components(G))

Now we can map compositions back onto their named entity

In [11]:
cluster_mapping = {}
for clust in clusters:
    resolved = df["dejunked_track_name"].iloc[list(clust)].to_list()
    # use the shortest title as the correct one
    smallest = min(resolved, key=len)
    for r in resolved:
        cluster_mapping[r] = smallest

df["clustered_track_name"] = df["dejunked_track_name"].map(cluster_mapping).fillna(df["dejunked_track_name"])

### Human-in-the-loop

Now, we manually look through this dataframe and resolve any final track names manually. This creates a df, which we've saved already as `references/canonical_track_names.csv`.

Finally, we iterate through this DF and update the `metadata.json` for each track

In [4]:
# We've already done this, so update the metadata with the canonic track names
canonic_df = pd.read_csv(Path(utils.get_project_root()) / "references/canonical_track_names.csv")
canonic_mapping = {k: v for (k, v) in canonic_df[["mbz_id", "final_track_name"]].to_numpy()}

In [18]:
for i, js in enumerate(js_files):
    with open(js, "r") as js_in:
        js_read = json.load(js_in)

    js_read["composition_canonic"] = canonic_mapping[js_read["mbz_id"]]

    with open(js, "w") as js_out:
        json.dump(js_read, js_out, indent=4, ensure_ascii=False)

## Stratified Splitting

First, load up the data to be splitted. We only want to keep the top 20 pianists, again.

In [5]:
DESIRED_PIANISTS = [
    "Abdullah Ibrahim",
    "Ahmad Jamal",
    "Bill Evans",
    "Brad Mehldau",
    "Cedar Walton",
    "Chick Corea",
    "Gene Harris",
    "Geri Allen",
    "Hank Jones",
    "John Hicks",
    "Junior Mance",
    "Keith Jarrett",
    "Kenny Barron",
    "Kenny Drew",
    "McCoy Tyner",
    "Oscar Peterson",
    "Stanley Cowell",
    "Teddy Wilson",
    "Thelonious Monk",
    "Tommy Flanagan",
]

In [6]:
split_in = []
split_cols = ["track_name", "album_name", "pianist", "composition_canonic"]

for js in js_files:
    with open(js, "r") as js_in:
        js_read = json.load(js_in)

    if js_read["pianist"] in DESIRED_PIANISTS:
        js_res = {k: v for k, v in js_read.items() if k in split_cols}

        # add the pianist name to the album name
        #  reason here is that two pianists could release an album called "Solo"
        #  but these are two unique albums!
        js_res["album_name"] = f"{js_res['pianist']} ({js_res['album_name']})"

        js_res["setting"] = "trio" if "data/raw/jtd/" in str(js) else "solo"

        split_in.append(js_res)

split_df = pd.DataFrame(split_in)

In [51]:
def lp_group_split(df, group_col='album_name', train_frac=0.8, val_frac=0.1, test_frac=0.1):
    """
    LP-based group split that:
    - Keeps all recordings of a group (album/composition) together
    - Approximates 80/10/10 split by number of recordings
    - Optionally balances SOLO/TRIO proportions
    - Can limit solve time and relative gap
    - Assigns split labels in a new column 'split'

    Parameters:
        df : pandas.DataFrame
        group_col : str, column to group by ('album_name' or 'composition_canonic')
        train_frac, val_frac, test_frac : float, desired split fractions

    Returns:
        df : pandas.DataFrame with new column 'split'
    """
    groups = df[group_col].unique()
    total_records = len(df)
    targets = {'train': total_records * train_frac, 'val': total_records * val_frac, 'test': total_records * test_frac}

    stratify_column = "setting"

    # Records per group
    group_sizes = df.groupby(group_col).size().to_dict()

    # Stratify counts per group
    group_strat_counts = df.groupby(group_col)[stratify_column].value_counts().unstack(fill_value=0).to_dict(orient='index')
    strat_levels = df[stratify_column].unique()
    global_strat_frac = df[stratify_column].value_counts(normalize=True).to_dict()  # global fraction

    # Desired number of stratify types per split
    strat_targets = {s: {level: targets[s]*global_strat_frac[level] for level in strat_levels} for s in targets.keys()}

    # LP problem
    prob = pulp.LpProblem("GroupSplit", pulp.LpMinimize)

    # Decision variables: x[group, split] binary
    x = pulp.LpVariable.dicts("x", ((g,s) for g in groups for s in targets.keys()), cat='Binary')

    # Deviation variables for total records per split
    d_total = pulp.LpVariable.dicts("d_total", targets.keys(), lowBound=0)

    # Deviation variables for stratify counts per split
    d_strat = {s: {level: pulp.LpVariable(f"d_{s}_{level}", lowBound=0) for level in strat_levels} for s in targets.keys()}

    # 1. Each group assigned to exactly one split
    for g in groups:
        prob += pulp.lpSum([x[g,s] for s in targets.keys()]) == 1

    # 2. Total records deviation constraints
    for s in targets.keys():
        prob += pulp.lpSum([group_sizes[g]*x[g,s] for g in groups]) - targets[s] <= d_total[s]
        prob += targets[s] - pulp.lpSum([group_sizes[g]*x[g,s] for g in groups]) <= d_total[s]

    # 3. Stratification deviation constraints
    for s in targets.keys():
        for level in strat_levels:
            prob += pulp.lpSum([group_strat_counts[g].get(level,0)*x[g,s] for g in groups]) - strat_targets[s][level] <= d_strat[s][level]
            prob += strat_targets[s][level] - pulp.lpSum([group_strat_counts[g].get(level,0)*x[g,s] for g in groups]) <= d_strat[s][level]

    # 4. Pianist coverage constraints
    pianists = df['pianist'].unique()
    for s in targets.keys():
        for p in pianists:
            # groups that contain at least one recording by pianist p
            groups_with_p = [g for g in groups if p in df[df[group_col]==g]['pianist'].unique()]
            # enforce at least one of these groups assigned to split s
            prob += pulp.lpSum([x[g,s] for g in groups_with_p]) >= 1

    # 5. Objective function: minimize total deviation
    prob += pulp.lpSum([d_total[s] for s in targets.keys()]) + pulp.lpSum([d_strat[s][level] for s in targets.keys() for level in strat_levels])

    # 6. Solve
    solver = pulp.PULP_CBC_CMD(timeLimit=120, gapRel=0.02, msg=True)
    prob.solve(solver)

    # Extract assignment
    assignment = {s: [] for s in targets.keys()}
    for g in groups:
        for s in targets.keys():
            if pulp.value(x[g,s]) > 0.5:
                assignment[s].append(g)

    # Assign split column
    df['split'] = None
    for s, group_list in assignment.items():
        df.loc[df[group_col].isin(group_list), 'split'] = s

    return df


album_df = lp_group_split(split_df, group_col='album_name')
comp_df = lp_group_split(split_df, group_col='composition_canonic')

Welcome to the CBC MILP Solver 
Version: 2.10.3 
Build Date: Dec 15 2019 

command line - /home/huw-cheston/.cache/pypoetry/virtualenvs/deep-pianist-identification-2Pg5O36G-py3.12/lib/python3.12/site-packages/pulp/apis/../solverdir/cbc/linux/i64/cbc /tmp/2f2fa1d655514f97812b38913cc8578f-pulp.mps -sec 120 -ratio 0.02 -timeMode elapsed -branch -printingOptions all -solution /tmp/2f2fa1d655514f97812b38913cc8578f-pulp.sol (default strategy 1)
At line 2 NAME          MODEL
At line 3 ROWS
At line 442 COLUMNS
At line 9098 RHS
At line 9536 BOUNDS
At line 10614 ENDATA
Problem MODEL has 437 rows, 1086 columns and 6492 elements
Coin0008I MODEL read with 0 errors
seconds was changed from 1e+100 to 120
ratioGap was changed from 0 to 0.02
Option for timeMode changed from cpu to elapsed
Continuous objective value is 0 - 0.00 seconds
Cgl0004I processed model has 437 rows, 1086 columns (1077 integer (1077 of which binary)) and 6492 elements
Cbc0038I Initial state - 8 integers unsatisfied sum - 2.54121


### Sanity check

- All pianists should have 1 track in every split (train/val/test)
- Splits should be roughly 8/1/1
- Splits should be roughly proportional between trio/solo
- Composition split:
    - No leaks of composition between splits
- Album split:
    - No leaks of album between splits

In [36]:
# Each pianist should have one recording in every split
for pianist in DESIRED_PIANISTS:
    for (df_name, df_) in zip(["album", "comp"], [album_df, comp_df]):
        for split in ["train", "val", "test"]:
            got = df_[(df_["pianist"] == pianist) & (df_["split"] == split)]
            assert len(got) >= 1, f"Pianist {pianist} does not appear in split {split} for df {df_name}"

In [44]:
# Splits should be approximately 8/1/1
for (df_name, df_) in zip(["album", "comp"], [album_df, comp_df]):
    total_tracks = len(df_)
    for split, desired_proportion in zip(["train", "val", "test"], [0.8, 0.1, 0.1]):
        actual_proportion = len(df_[df_["split"] == split]) / total_tracks
        assert math.isclose(actual_proportion, desired_proportion, abs_tol=0.01)


In [50]:
# No leaks of compositions/albums between splits
for column_name, df_ in zip(["album_name", "composition_canonic"], [album_df, comp_df]):
    for split_a in ["train", "val", "test"]:
        split_a_df = df_[df_["split"] == split_a]
        split_a_comps = set(split_a_df[column_name].unique().tolist())
        other_splits = [i for i in ["train", "val", "test"] if i != split_a]
        for split_b in other_splits:
            split_b_df = df_[df_["split"] == split_b]
            split_b_comps = set(split_b_df[column_name].unique().tolist())
            shared = set.intersection(split_a_comps, split_b_comps)
            assert len(shared) == 0

In [52]:
for (df_name, df_) in zip(["album", "comp"], [album_df, comp_df]):
    total_tracks = len(df_)
    print(total_tracks)

1629
1629
