"""
End-to-end TensorFlow pipeline for multilabel tag classification from spectrogram JPGs.
Assumptions:
- You have a CSV with columns: musicbrainz_recording_id, artist, track, album, musicbrainz_artist, final_tags
- Each spectrogram image is named <musicbrainz_recording_id>.jpg in an images/ folder (or point to your directory)
- There are 8 possible tags (the script will compute unique tags and assert that length == 8)


What this file provides:
1. CSV + filesystem intersection: keep only common ids
2. Parser for final_tags column (robust to quotes/encoding)
3. Build tf.data.Dataset reading JPEGs, preprocessing, augmentation
4. CNN model definition (Keras) for multilabel classification -> sigmoid outputs for each tag
5. Training, validation split, callbacks, saving
6. Inference functions that return probabilities per tag


Run: adjust paths and hyperparameters near the top of the file.
"""

In [51]:
import os
import ast
import json
import random
from pathlib import Path
from typing import List, Tuple, Dict


import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

In [52]:
# ------------------------------- USER CONFIG -------------------------------
CSV_PATH = "data/tracks_metadata_202510071738.csv" # path to your CSV
IMAGE_DIR = "data/spectrogram/" # directory containing <mbid>.jpg
IMAGE_EXT = ".jpg" # image file extension
IMAGE_SIZE = (250, 100) # model input size
BATCH_SIZE = 32
AUTOTUNE = tf.data.AUTOTUNE
EPOCHS = 40
SEED = 42
MODEL_OUTPUT = "models/spectrogram_multilabel.h5"
# ---------------------------------------------------------------------------

In [53]:
# ------------------------------ UTIL FUNCTIONS -----------------------------


def read_csv_and_filter(csv_path: str, image_dir: str, id_col: str = "musicbrainz_recording_id") -> pd.DataFrame:
    """Read CSV, parse tags column, and retain only rows with a matching image file present.
    Returns a cleaned DataFrame with an added column 'image_path'.
    """
    df = pd.read_csv(csv_path, usecols=["musicbrainz_recording_id", "final_tags"])


    if id_col not in df.columns:
        raise ValueError(f"CSV must contain column '{id_col}'. Found: {df.columns.tolist()}")


    # robust parser for final_tags column
    if 'final_tags' not in df.columns:
        raise ValueError("CSV must contain 'final_tags' column")


    def parse_final_tags_cell(cell):
    # cell examples: '["pop", "blues-r&b-soul"]' or "['pop','rock']"
        if pd.isna(cell):
            return []
        if isinstance(cell, (list, tuple)):
            return list(cell)
        s = str(cell)
        # Try: ast.literal_eval (safe) first
        try:
            parsed = ast.literal_eval(s)
            if isinstance(parsed, (list, tuple)):
                return [str(x).strip() for x in parsed]
        except Exception:
            pass
        # Fallback: try to strip brackets and split by comma
        s2 = s.strip().lstrip('[').rstrip(']')
        parts = [p.strip().strip('\"').strip("'") for p in s2.split(',') if p.strip()]
        return parts


    df['parsed_tags'] = df['final_tags'].apply(parse_final_tags_cell)


    # build image path
    def image_path_for_id(mbid):
        return os.path.join(image_dir, f"{mbid}{IMAGE_EXT}")


    df['image_path'] = df[id_col].apply(image_path_for_id)
    # check existence
    df['image_exists'] = df['image_path'].apply(os.path.exists)


    # keep only existing
    filtered = df[df['image_exists']].copy()
    filtered.reset_index(drop=True, inplace=True)


    total_csv = len(df)
    total_images = sum(1 for _ in Path(image_dir).glob(f'*{IMAGE_EXT}'))
    kept = len(filtered)
    print(f"CSV rows: {total_csv}, images in folder: {total_images}, kept after intersection: {kept}")


    return filtered

In [54]:
IMAGE_DIR

'data/spectrogram/'

In [55]:
df = read_csv_and_filter(CSV_PATH, IMAGE_DIR)
df.sample(5)

CSV rows: 1096, images in folder: 1083, kept after intersection: 1083


Unnamed: 0,musicbrainz_recording_id,final_tags,parsed_tags,image_path,image_exists
195,35181f72-868e-4298-b516-ac6c4c75652f,"[""pop""]",[pop],data/spectrogram/35181f72-868e-4298-b516-ac6c4...,True
52,15e2fda3-b76a-4d7d-94a9-a429d336352f,"[""rock-metal-psychedelic""]",[rock-metal-psychedelic],data/spectrogram/15e2fda3-b76a-4d7d-94a9-a429d...,True
530,838db018-12ee-4d53-828f-304769f1933d,"[""hip_hop-rap""]",[hip_hop-rap],data/spectrogram/838db018-12ee-4d53-828f-30476...,True
872,8079d38b-efcf-401e-817c-cb4f293c2e89,"[""rock-metal-psychedelic"", ""pop"", ""blues-r&b-s...","[rock-metal-psychedelic, pop, blues-r&b-soul, ...",data/spectrogram/8079d38b-efcf-401e-817c-cb4f2...,True
523,815ac72b-5477-421c-a685-8008886af46f,"[""pop"", ""rock-metal-psychedelic""]","[pop, rock-metal-psychedelic]",data/spectrogram/815ac72b-5477-421c-a685-80088...,True


In [56]:
type(df.iloc[805].parsed_tags)

list

In [57]:
def build_tag_binarizer(df: pd.DataFrame, tag_col: str = 'parsed_tags', expected_n_tags: int = None) -> Tuple[MultiLabelBinarizer, List[str]]:
    """Create a MultiLabelBinarizer mapping.
    If expected_n_tags provided, assert number of unique tags equals that value.
    """
    mlb = MultiLabelBinarizer(sparse_output=False)
    mlb.fit(df[tag_col])
    classes = list(mlb.classes_)
    print(f"Found {len(classes)} unique tags: {classes}")
    if expected_n_tags is not None:
        assert len(classes) == expected_n_tags, f"Expected {expected_n_tags} tags but found {len(classes)}"
    return mlb, classes

In [58]:
mlb, classes = build_tag_binarizer(df, tag_col='parsed_tags', expected_n_tags=None) # set to 8 if you want hard assert
n_labels = len(classes)

Found 8 unique tags: ['blues-r&b-soul', 'electronic-funk-disco-dance', 'folk-classical-country-jazz', 'hip_hop-rap', 'opera-musical-theater-soundtrack-vocal-a_cappella', 'others', 'pop', 'rock-metal-psychedelic']


In [59]:
classes, n_labels

(['blues-r&b-soul',
  'electronic-funk-disco-dance',
  'folk-classical-country-jazz',
  'hip_hop-rap',
  'opera-musical-theater-soundtrack-vocal-a_cappella',
  'others',
  'pop',
  'rock-metal-psychedelic'],
 8)

In [60]:
# Binarize labels
labels = mlb.transform(df['parsed_tags'])
labels

array([[1, 0, 0, ..., 0, 1, 0],
       [1, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 1, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 1],
       [0, 1, 0, ..., 0, 1, 0],
       [1, 0, 0, ..., 0, 1, 0]])

#### TRAIN TEST SPLIT

In [61]:
# train/val/test split
train_paths, test_paths, train_labels, test_labels = train_test_split(
    df['image_path'].tolist(), labels, test_size=0.15, random_state=SEED, stratify=None
    )
train_paths, val_paths, train_labels, val_labels = train_test_split(
    train_paths, train_labels, test_size=0.15, random_state=SEED, stratify=None
    )

In [62]:
train_paths[:5], train_labels[:5]

(['data/spectrogram/ca1f8b74-6976-42b0-a415-1f2934a752cd.jpg',
  'data/spectrogram/eb623fb9-ed07-4f13-a8aa-446f014c5fee.jpg',
  'data/spectrogram/036892d0-355c-4436-bb06-47e9f235e4b2.jpg',
  'data/spectrogram/39973f6c-4d2f-4683-947c-10f74f909cfe.jpg',
  'data/spectrogram/e838057c-04e2-4875-94e7-75a74123cbe9.jpg'],
 array([[1, 0, 1, 0, 0, 0, 1, 1],
        [0, 0, 1, 0, 0, 0, 1, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 1, 0],
        [0, 0, 0, 0, 1, 0, 1, 0]]))

#### Tensorflow compatible dataset

In [63]:
# -------------------------- TENSORFLOW DATA PIPELINE -----------------------


def load_and_preprocess_image(path: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    """Given a filename and label, read image and process to [0,1] float32 tensor."""
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=1) # spectrograms are RGB or grayscale saved as RGB
    image = tf.image.resize(image, IMAGE_SIZE)
    image = tf.image.grayscale_to_rgb(image)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label




def augment_image(image: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    # Simple augmentations suitable for spectrograms: small vertical/horizontal shifts, random brightness
    # Be conservative: don't flip horizontally (would flip time axis)
    image = tf.image.random_brightness(image, max_delta=0.08)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    # small time/frequency shifts -> translate horizontally/vertically
    if tf.random.uniform(()) > 0.7:
    # width, height translation
        image = tf.roll(image, shift=tf.random.uniform((), -10, 10, dtype=tf.int32), axis=1)
        if tf.random.uniform(()) > 0.7:
            image = tf.roll(image, shift=tf.random.uniform((), -5, 5, dtype=tf.int32), axis=0)
    return image, label




def make_tf_dataset(paths: List[str], labels: np.ndarray, training: bool = True) -> tf.data.Dataset:
    ds = tf.data.Dataset.from_tensor_slices((paths, labels))
    if training:
        ds = ds.shuffle(buffer_size=len(paths), seed=SEED)
    ds = ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
    if training:
        ds = ds.map(augment_image, num_parallel_calls=AUTOTUNE)
    ds = ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return ds

In [64]:
train_ds = make_tf_dataset(train_paths, np.array(train_labels), training=True)
train_ds

<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 250, 100, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 8), dtype=tf.int64, name=None))>

In [65]:

val_ds = make_tf_dataset(val_paths, np.array(val_labels), training=False)
test_ds = make_tf_dataset(test_paths, np.array(test_labels), training=False)

#### BUILD MODEL

In [66]:
def build_model(input_shape=(224,224,3), n_labels=8, dropout_rate=0.5) -> tf.keras.Model:
    inputs = tf.keras.Input(shape=input_shape)


    # Simple bespoke CNN - replace or expand with EfficientNet, MobileNetV2, etc. for better performance
    x = tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.MaxPool2D(2)(x)


    x = tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.MaxPool2D(2)(x)


    x = tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.MaxPool2D(2)(x)


    x = tf.keras.layers.Conv2D(256, 3, padding='same', activation='relu')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)


    x = tf.keras.layers.Dropout(dropout_rate)(x)
    x = tf.keras.layers.Dense(256, activation='relu')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(0.3)(x)


    outputs = tf.keras.layers.Dense(n_labels, activation='sigmoid')(x) # multilabel -> sigmoid


    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model


In [67]:
model = build_model(input_shape=(*IMAGE_SIZE, 3), n_labels=n_labels)
model.summary()

In [68]:
def compile_and_train(model: tf.keras.Model,
    train_ds: tf.data.Dataset,
    val_ds: tf.data.Dataset,
    epochs: int = EPOCHS,
    model_output: str = MODEL_OUTPUT):


    # Ensure the checkpoint directory exists before training
    model_dir = os.path.dirname(model_output)
    if model_dir:
        os.makedirs(model_dir, exist_ok=True)


    callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=7, restore_best_weights=True),
        tf.keras.callbacks.ModelCheckpoint(model_output, monitor='val_loss', save_best_only=True)
    ]


    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss='binary_crossentropy',
        metrics=[tf.keras.metrics.AUC(curve='ROC', multi_label=True), 'binary_accuracy']
    )


    history = model.fit(train_ds, validation_data=val_ds, epochs=epochs, callbacks=callbacks)
    return history


# ------------------------------ INFERENCE ---------------------------------


def predict_probabilities(model: tf.keras.Model, image_paths: List[str], mlb: MultiLabelBinarizer) -> List[Dict[str, float]]:
    paths = tf.constant(image_paths)
    ds = tf.data.Dataset.from_tensor_slices(paths)
    def _load(path):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, IMAGE_SIZE)
        img = tf.cast(img, tf.float32) / 255.0
        return img
    ds = ds.map(lambda p: _load(p)).batch(BATCH_SIZE)
    preds = model.predict(ds)
    results = []
    for row in preds:
        mapping = {label: float(prob) for label, prob in zip(mlb.classes_, row)}
        results.append(mapping)
    return results

In [None]:
history = compile_and_train(model, train_ds, val_ds, epochs=EPOCHS, model_output=MODEL_OUTPUT)


# Evaluate on test set
print('\nEvaluating on test set...')
res = model.evaluate(test_ds)
print(res)


# Save class mapping
out_dir = os.path.dirname(MODEL_OUTPUT)
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, 'classes.json'), 'w') as f:
    json.dump({'classes': classes}, f)


print(f"Model saved to {MODEL_OUTPUT} and classes saved to {os.path.join(out_dir, 'classes.json')}")


# Example inference usage (first 5 test images)
sample_paths = test_paths[:5]
probs = predict_probabilities(model, sample_paths, mlb)
for p, path in zip(probs, sample_paths):
    print(path, p)

Epoch 1/40
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 746ms/step - auc_2: 0.5107 - binary_accuracy: 0.5053 - loss: 0.9102



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 800ms/step - auc_2: 0.5096 - binary_accuracy: 0.5115 - loss: 0.9066 - val_auc_2: 0.4751 - val_binary_accuracy: 0.5571 - val_loss: 0.6838 - learning_rate: 1.0000e-04
Epoch 2/40
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 732ms/step - auc_2: 0.5539 - binary_accuracy: 0.5466 - loss: 0.8557

In [None]:


# --------------------------- NOTES & TIPS ---------------------------------
# - If you have imbalanced labels, consider computing per-class positive weights and using a custom weighted loss.
# - For better accuracy, replace the backbone with a pretrained model (EfficientNetB0, MobileNetV2) and fine-tune.
# Example: use tf.keras.applications.EfficientNetB0(include_top=False, input_shape=..., weights='imagenet') then add GAP + Dense(sigmoid).
# - If your spectrograms are mono grayscale, you can decode as channels=1 and optionally repeat channels to 3 for ImageNet backbones.
# - Tune augmentations carefully; some augmentations (horizontal flip) may break time-frequency relationships and are not recommended.
# - To use class weights: compute pos_weight = (N - pos) / pos per class and use them in a custom loss.
# -------