"""
End-to-end TensorFlow pipeline for multilabel tag classification from spectrogram JPGs.
Assumptions:
- You have a CSV with columns: musicbrainz_recording_id, artist, track, album, musicbrainz_artist, final_tags
- Each spectrogram image is named <musicbrainz_recording_id>.jpg in an images/ folder (or point to your directory)
- There are 8 possible tags (the script will compute unique tags and assert that length == 8)


What this file provides:
1. CSV + filesystem intersection: keep only common ids
2. Parser for final_tags column (robust to quotes/encoding)
3. Build tf.data.Dataset reading JPEGs, preprocessing, augmentation
4. CNN model definition (Keras) for multilabel classification -> sigmoid outputs for each tag
5. Training, validation split, callbacks, saving
6. Inference functions that return probabilities per tag


Run: adjust paths and hyperparameters near the top of the file.
"""

In [3]:
import os
import ast
import json
import random
from pathlib import Path
from typing import List, Tuple, Dict


import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

In [11]:
# ------------------------------- USER CONFIG -------------------------------
CSV_PATH = "data/tracks_metadata_202510071738.csv" # path to your CSV
IMAGE_DIR = "data/spectrogram/" # directory containing <mbid>.jpg
IMAGE_EXT = ".jpg" # image file extension
IMAGE_SIZE = (1000, 400) # model input size
BATCH_SIZE = 32
AUTOTUNE = tf.data.AUTOTUNE
EPOCHS = 40
SEED = 42
MODEL_OUTPUT = "models/spectrogram_multilabel.h5"
# ---------------------------------------------------------------------------

In [25]:
# ------------------------------ UTIL FUNCTIONS -----------------------------


def read_csv_and_filter(csv_path: str, image_dir: str, id_col: str = "musicbrainz_recording_id") -> pd.DataFrame:
    """Read CSV, parse tags column, and retain only rows with a matching image file present.
    Returns a cleaned DataFrame with an added column 'image_path'.
    """
    df = pd.read_csv(csv_path, usecols=["musicbrainz_recording_id", "final_tags"])


    if id_col not in df.columns:
        raise ValueError(f"CSV must contain column '{id_col}'. Found: {df.columns.tolist()}")


    # robust parser for final_tags column
    if 'final_tags' not in df.columns:
        raise ValueError("CSV must contain 'final_tags' column")


    def parse_final_tags_cell(cell):
    # cell examples: '["pop", "blues-r&b-soul"]' or "['pop','rock']"
        if pd.isna(cell):
            return []
        if isinstance(cell, (list, tuple)):
            return list(cell)
        s = str(cell)
        # Try: ast.literal_eval (safe) first
        try:
            parsed = ast.literal_eval(s)
            if isinstance(parsed, (list, tuple)):
                return [str(x).strip() for x in parsed]
        except Exception:
            pass
        # Fallback: try to strip brackets and split by comma
        s2 = s.strip().lstrip('[').rstrip(']')
        parts = [p.strip().strip('\"').strip("'") for p in s2.split(',') if p.strip()]
        return parts


    df['parsed_tags'] = df['final_tags'].apply(parse_final_tags_cell)


    # build image path
    def image_path_for_id(mbid):
        return os.path.join(image_dir, f"{mbid}{IMAGE_EXT}")


    df['image_path'] = df[id_col].apply(image_path_for_id)
    # check existence
    df['image_exists'] = df['image_path'].apply(os.path.exists)


    # keep only existing
    filtered = df[df['image_exists']].copy()
    filtered.reset_index(drop=True, inplace=True)


    total_csv = len(df)
    total_images = sum(1 for _ in Path(image_dir).glob(f'*{IMAGE_EXT}'))
    kept = len(filtered)
    print(f"CSV rows: {total_csv}, images in folder: {total_images}, kept after intersection: {kept}")


    return filtered

In [13]:
IMAGE_DIR

'data/spectrogram/'

In [28]:
df = read_csv_and_filter(CSV_PATH, IMAGE_DIR)
df.sample(5)

CSV rows: 1096, images in folder: 1083, kept after intersection: 1083


Unnamed: 0,musicbrainz_recording_id,final_tags,parsed_tags,image_path,image_exists
241,3ec43616-c0f4-4100-82e0-973a67705f6a,"[""pop""]",[pop],data/spectrogram/3ec43616-c0f4-4100-82e0-973a6...,True
456,6f8beb66-ba91-45b3-9019-2d36cff68e7b,"[""pop""]",[pop],data/spectrogram/6f8beb66-ba91-45b3-9019-2d36c...,True
805,c58cbe16-1cf3-4abe-8ecf-72efc3df47af,"[""blues-r&b-soul"", ""electronic-funk-disco-danc...","[blues-r&b-soul, electronic-funk-disco-dance, ...",data/spectrogram/c58cbe16-1cf3-4abe-8ecf-72efc...,True
60,18121d60-bc40-49a9-92c7-87babed045aa,"[""hip_hop-rap""]",[hip_hop-rap],data/spectrogram/18121d60-bc40-49a9-92c7-87bab...,True
532,060134b9-d7aa-432f-9d0a-2e30eccc02de,"[""folk-classical-country-jazz""]",[folk-classical-country-jazz],data/spectrogram/060134b9-d7aa-432f-9d0a-2e30e...,True


In [31]:
type(df.iloc[805].parsed_tags)

list

In [32]:
def build_tag_binarizer(df: pd.DataFrame, tag_col: str = 'parsed_tags', expected_n_tags: int = None) -> Tuple[MultiLabelBinarizer, List[str]]:
    """Create a MultiLabelBinarizer mapping.
    If expected_n_tags provided, assert number of unique tags equals that value.
    """
    mlb = MultiLabelBinarizer(sparse_output=False)
    mlb.fit(df[tag_col])
    classes = list(mlb.classes_)
    print(f"Found {len(classes)} unique tags: {classes}")
    if expected_n_tags is not None:
        assert len(classes) == expected_n_tags, f"Expected {expected_n_tags} tags but found {len(classes)}"
    return mlb, classes

In [33]:
mlb, classes = build_tag_binarizer(df, tag_col='parsed_tags', expected_n_tags=None) # set to 8 if you want hard assert
n_labels = len(classes)

Found 8 unique tags: ['blues-r&b-soul', 'electronic-funk-disco-dance', 'folk-classical-country-jazz', 'hip_hop-rap', 'opera-musical-theater-soundtrack-vocal-a_cappella', 'others', 'pop', 'rock-metal-psychedelic']


In [37]:
classes, n_labels

(['blues-r&b-soul',
  'electronic-funk-disco-dance',
  'folk-classical-country-jazz',
  'hip_hop-rap',
  'opera-musical-theater-soundtrack-vocal-a_cappella',
  'others',
  'pop',
  'rock-metal-psychedelic'],
 8)