# OpenPIR: An Open-Source Dataset for Predominant Instrument Classification

## Introduction

This notebook contains the full experiment code for our OpenPIR paper. OpenPIR is a new dataset for predominant instrument classification. Previously, IRMAS was the only sufficient sized dataset available for this task, but it had issues such as unbalance across classes or training data being single-label while testing multi-label. We mainly reimplemented the experiment methodologies of the han paper https://arxiv.org/pdf/1605.09507. Using the experiment framework, we train the same model on our new dataset and test it.




## Imports: This section contains necessary imports and label dictionaries to map string labels to integers for training

In [None]:
# All import statements
!pip install librosa

# Standard library imports
from collections import defaultdict
from itertools import filterfalse
import glob
import json
import logging
import math
import os
import random

# Third-party imports
import librosa
import matplotlib.pyplot as plt
import numpy as np
import pickle
import seaborn as sns
from sklearn.metrics import f1_score, hamming_loss, classification_report, multilabel_confusion_matrix
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler, Subset
import torchaudio
from torchvision import transforms
from tqdm import tqdm




# line wrap function for visual clarity
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

# mount drive
from google.colab import drive
drive.mount('/content/drive')

folder_path = '/content/drive/MyDrive/IRMAS-TrainingData'

os.listdir(folder_path)

# need this for computer to map integer values to labels we use
instrument_labels = {'cel': 0, 'cla': 1, 'flu': 2, 'gac': 3, 'gel': 4, 'org': 5, 'pia': 6, 'sax': 7, 'tru': 8, 'vio': 9, 'voi': 10}

openmic_instrument_labels = {'cel': 0, 'cla': 1, 'flu': 2, 'guitar': 3, 'org': 4, 'pia': 5, 'sax': 6, 'tru': 7, 'vio': 8, 'voi': 9}

# instrument_labels = {'cello': 0, 'clarinet': 1, 'flute': 2, 'acoustic guitar': 3, 'electric guitar': 4, 'organ': 5, 'piano': 6, 'saxophone': 7, 'trumpet': 8, 'violin': 9, 'voice': 10}
IRMAS_genre_labels = {'cla': 0, 'pop_roc': 1, 'cou_fol': 2, 'jaz_blu': 3, 'lat_sou': 4}

# {'pop_roc', 'jaz_blu', 'lat_sou', 'cou_fol', 'cla'}
lastfm_genre_labels = {
    'blues': 0,
    'classical': 1,
    'country': 2,
    'disco': 3,
    'jazz': 3,
    'folk': 4,
    'metal': 5,
    'pop': 6,
    'rock': 7,
    'latin': 8,
    'soul': 9
}

In [None]:
# @title Audio Feature Extraction and Preprocessing:

# This section extracts audio features (mel-spectrogram, MFCCs, etc.) from audio files, performs preprocessing (mono conversion, normalization, chunking), and saves the processed data.



# def chunk_audio(waveform, chunk_length=1.0, overlap=0.0, sr=22050):
#     """
#     1) If stereo, convert to mono by averaging channels.
#     2) Normalize amplitude to [-1,1].
#     3) Split waveform into overlapping chunks of length `chunk_length` seconds.
#     """
#     if waveform.ndim == 1:
#         waveform = waveform.unsqueeze(0)
#     # If stereo => average over channel dimension => [1, T]
#     if waveform.ndim == 2 and waveform.shape[0] > 1:
#         waveform = torch.mean(waveform, dim=0, keepdim=True)

#     # Normalize
#     max_val = waveform.abs().max()
#     if max_val > 0:
#         waveform = waveform / max_val

#     chunk_size = int(sr * chunk_length)
#     stride = int(chunk_size * (1 - overlap))
#     num_chunks = (waveform.shape[1] - chunk_size) // stride + 1
#     num_chunks = max(num_chunks, 1)  # at least 1 chunk

#     chunks = []
#     start = 0
#     for i in range(num_chunks):
#         end = start + chunk_size
#         if end >= waveform.shape[1]:
#             # If the last chunk is partial, slice from the back
#             chunk = waveform[:, -chunk_size:]
#             chunks.append(chunk)
#             break
#         else:
#             chunk = waveform[:, start:end]
#             chunks.append(chunk)
#         start += stride

#     return chunks


# def compute_librosa_feats(chunk, sr=22050):
#     """
#     Extract various librosa features from a single chunk.
#     `chunk` has shape [1, chunk_size].
#     """
#     # Remove channel dimension => shape [chunk_size]
#     waveform_np = chunk.squeeze(0).cpu().numpy()

#     zcr = librosa.feature.zero_crossing_rate(waveform_np, frame_length=1024, hop_length=512)[0]
#     centroid = librosa.feature.spectral_centroid(y=waveform_np, sr=sr, n_fft=1024, hop_length=512)[0]
#     rms = librosa.feature.rms(y=waveform_np, frame_length=1024, hop_length=512)[0]
#     rolloff = librosa.feature.spectral_rolloff(y=waveform_np, sr=sr, n_fft=1024, hop_length=512)[0]
#     bandwidth = librosa.feature.spectral_bandwidth(y=waveform_np, sr=sr, n_fft=1024, hop_length=512)[0]

#     mel_spec = librosa.feature.melspectrogram(
#         y=waveform_np, sr=sr, n_fft=1024, hop_length=512,
#         n_mels=128, power=2.0
#     )
#     mel_db = librosa.power_to_db(mel_spec)

#     mfcc = librosa.feature.mfcc(
#         y=waveform_np, sr=sr, n_mfcc=20,
#         n_fft=1024, hop_length=512
#     )

#     return {
#         "zcr": zcr,
#         "centroid": centroid,
#         "rms": rms,
#         "rolloff": rolloff,
#         "bandwidth": bandwidth,
#         "mel": mel_db,
#         "mfcc": mfcc
#     }


# def process_and_save(data, output_file, chunk_length=1.0,
#                      overlap=0.0, sr=22050):
#     """
#     - list_of_dicts: each entry has something like:
#         {
#           'filepath': '...',
#           'audio': tensor([...]),  # shape [2, T] or [1, T]
#           'instrument': 'electric guitar',
#           'genre': 'pop_roc',
#           ...
#         }
#     - We'll split each 'audio' into 1-sec chunks, compute features,
#       and produce multiple new dictionaries, one for each chunk.
#     - 'audio' in the new dict => the feature dictionary from compute_librosa_feats
#     - Add 'index' key => chunk number.
#     - Keep all other keys the same (filepath, instrument, etc.).
#     - Write the final list to `output_file` as a pickle.
#     """
#     list_of_dicts = torch.load(data, weights_only=False)
#     processed_list = []

#     for entry in tqdm(list_of_dicts, desc="Processing entries"):
#         audio_data = entry['audio']
#         # If 'audio' is numpy, convert to torch
#         if isinstance(audio_data, np.ndarray):
#             audio_data = torch.from_numpy(audio_data)

#         # Chunk the audio
#         chunks = chunk_audio(audio_data, chunk_length=chunk_length,
#                              overlap=overlap, sr=sr)

#         # For each chunk, compute features & produce a new dictionary
#         for i, ch in enumerate(chunks):
#             feats = compute_librosa_feats(ch, sr=sr)

#             # Copy original dict
#             new_dict = dict(entry)
#             # Replace 'audio' with the feature dictionary
#             new_dict["audio"] = feats
#             # Add chunk index
#             new_dict["index"] = i

#             processed_list.append(new_dict)

#     torch.save(processed_list, output_file)

#     print(f"Saved {len(processed_list)} dictionaries to {output_file}")


# # Example usage:
# # Suppose you have a list_of_dicts loaded from somewhere, e.g.:
# # list_of_dicts = [
# #   {'filepath': '/some/path.wav', 'audio': torch.randn(2, 44100), 'instrument': 'guitar', ...},
# #   {'filepath': '/some/other.wav', 'audio': torch.randn(2, 88200), 'instrument': 'voice', ...},
# #   ...
# # ]
# #
# process_and_save("/content/drive/MyDrive/Data/IRMAS/filtered_dataset_newsamples.pt", "/content/drive/MyDrive/Data/IRMAS/IRMASG_TEST_FILTERED_NEWSAMPLES.pkl", chunk_length=1.0, overlap=0.5)

Processing entries: 100%|██████████| 3174/3174 [40:05<00:00,  1.32it/s]


Saved 98866 dictionaries to /content/drive/MyDrive/Data/IRMAS/IRMASG_TEST_FILTERED_NEWSAMPLES.pkl


## Helper functions:

We also experimented with genre information but did not end up using them in the end.

In [None]:

def one_hot_encode_genre(d, lastfm_genre_labels, length):
    # Create a one-hot vector of length equal to the number of lastfm genres.
    num_genres = length
    one_hot = np.zeros(num_genres, dtype=int)

    # Determine which genre information to use.
    # Use genre_list and/or genre_lastfm if available.
    genres = []
    if d.get("openmic") == True:
      if d.get("IRMASG_genres") and len(d["IRMASG_genres"]) > 0:
          genres.extend(d["IRMASG_genres"])
    else:
      if d.get("genre_list") and len(d["genre_list"]) > 0:
          genres.extend(d["genre_list"])
      if d.get("genre_lastfm") is not None:
          genres.append(d["genre_lastfm"])

    # If both are empty, fall back on d["genre"] which might be a combined string (e.g., "pop_roc").
    if not genres:
        # Split the combined string on the underscore.
        irmas_genre = d["genre"]
        if irmas_genre == 'pop_roc':
            genres = ['pop', 'rock']
        elif irmas_genre == 'lat_sou':
            genres = ['latin', 'soul']
        elif irmas_genre == 'cou_fol':
            genres = ['country', 'folk']
        elif irmas_genre == 'jaz_blu':
            genres = ['jazz', 'blues']
        elif irmas_genre == 'cla':
            genres = ['classical']
        else:
            genres = [irmas_genre]

    # One-hot encode each genre present.
    for g in genres:
        if g in lastfm_genre_labels:
            idx = lastfm_genre_labels[g]
            one_hot[idx] = 1
        else:
            if g == None:
              # print(d)
            # print(f"Warning: Genre '{g}' not found in mapping.")
              pass

    if(np.all(one_hot == 0)):
      # print("All zeros: ", d)
      pass

    return one_hot

# # Example usage:
# lastfm_genre_labels = {
#     'blues': 0,
#     'classical': 1,
#     'country': 2,
#     'disco': 3,
#     'jazz': 3,
#     'folk': 4,
#     'metal': 5,
#     'pop': 6,
#     'rock': 7,
#     'latin': 8,
#     'soul': 9
# }

# # Data dictionary example:
# data_dict = {
#     "genre": "pop_roc",
#     "genre_list": [],        # or e.g., ["rock", "metal"]
#     "genre_lastfm": None     # or e.g., "rock"
# }

# one_hot_vector = one_hot_encode_genre(data_dict, lastfm_genre_labels)
# print(one_hot_vector)
def multi_hot_encode_instruments(d, instrument_labels):
    num_instruments = len(instrument_labels)
    one_hot = np.zeros(num_instruments, dtype=int)

    # instruments = d.get("instrument", [])
    instruments = d

    if isinstance(instruments, str):
        instruments = [instruments]  # Single string converted to list

    if instruments is None or len(instruments) == 0:
        return one_hot  # return all zeros if no instruments are provided

    for inst in instruments:
        if inst in instrument_labels:
            idx = instrument_labels[inst]
            one_hot[idx] = 1
        else:
            # print(f"Warning: Instrument '{inst}' not found in mapping.")
            pass

    return one_hot

def compute_mean_std(dataset):
    """
    Compute the global per-channel mean and standard deviation for a dataset.
    Assumes that each __getitem__ returns a tensor of shape [1, C, frames].
    """
    sum_ = None
    sumsq_ = None
    count = 0

    for i in range(len(dataset)):
        features, _, _ = dataset[i]  # features shape: [1, C, frames]
        features = features.squeeze(0)  # now shape: [C, frames]
        if sum_ is None:
            sum_ = features.sum(dim=1)
            sumsq_ = (features ** 2).sum(dim=1)
            count = features.shape[1]
        else:
            sum_ += features.sum(dim=1)
            sumsq_ += (features ** 2).sum(dim=1)
            count += features.shape[1]

    mean = sum_ / count
    var = (sumsq_ / count) - (mean ** 2)
    std = torch.sqrt(var.clamp_min(1e-8))
    return mean, std


# function to initialize weights of model using glorot as mentioned in paper
def weights_init(m):
    """Apply Glorot (Xavier) initialization to all Conv2d and Linear layers."""

    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        init.xavier_uniform_(m.weight)  # Apply Xavier Uniform
        if m.bias is not None:
            init.zeros_(m.bias)  # Bias initialized to zero

    elif isinstance(m, Conv_2d):  # Handle your custom Conv_2d class
        init.xavier_uniform_(m.conv1.weight)
        if m.conv1.bias is not None:
            init.zeros_(m.conv1.bias)

        init.xavier_uniform_(m.conv2.weight)
        if m.conv2.bias is not None:
            init.zeros_(m.conv2.bias)

    elif isinstance(m, nn.BatchNorm2d):  # Batch Norm initialization
        init.ones_(m.weight)
        init.zeros_(m.bias)


def set_seed(seed=42):
    """Set seed for reproducibility across NumPy, PyTorch, and random operations."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # Ensures consistency

def seed_worker(worker_id):
    """Ensures DataLoader workers have consistent seeds."""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)



## OpenPIRDataset:
This class defines a custom PyTorch Dataset for loading and preprocessing audio data from the OpenPIR dataset. It handles tasks like loading audio features, encoding labels, splitting into train/validation sets, and applying normalization. We precomputed different features to try implementing methodologies from other papers but ended up only using melspec.

**Key functionalities:**

* **Initialization (`__init__`)**:
  - Loads audio features and labels from a specified data file (`.pkl` or `.npy`).
  - Encodes instrument and genre labels using `multi_hot_encode_instruments` and `one_hot_encode_genre`.
  - Splits the data into train/validation sets based on the `split` argument.
  - Stores normalization parameters (mean and standard deviation) if provided.
* **Data Access (`__getitem__`)**:
  - Retrieves the audio features, instrument labels, and genre labels for a given index.
  - Converts the features to PyTorch tensors.
  - Applies normalization if mean and standard deviation are available.
  - Returns the processed data.
* **Dataset Length (`__len__`)**:
  - Returns the total number of samples in the dataset.

In [None]:
class OpenPIRDataset(Dataset):
    def __init__(self, data_file, split='all', test_sample_half=False, seed=42, mean=None, std=None, openmic=True):
        """
        Initializes the OpenPIRDataset.

        Args:
            data_file (str): Path to the data file (.pkl or .npy).
            split (str, optional): Data split ('train', 'valid', 'all'). Defaults to 'all'.
            test_sample_half (bool, optional): Whether to sample half of the test data. Defaults to False.
            seed (int, optional): Random seed for splitting. Defaults to 42.
            mean (torch.Tensor, optional): Mean for normalization. Defaults to None.
            std (torch.Tensor, optional): Standard deviation for normalization. Defaults to None.
            openmic (bool, optional): Whether the dataset is OpenMIC. Defaults to True.
        """
        # Load data based on file type
        if data_file.endswith('.pkl'):
            with open(data_file, 'rb') as f:
                data_list = torch.load(f, weights_only=False)
        elif data_file.endswith('.npy'):
            data_list = np.load(data_file, allow_pickle=True)
        else:
            raise ValueError("Unsupported file format: must be .pkl or .npy")

        # Ensure consistent instrument key
        for d in data_list:
            if "instrument" not in d and "instruments" in d:
                d["instrument"] = d["instruments"]

        # Filter out irrelevant entries
        data_list = [
            d for d in data_list
            if "instrument" in d and d["instrument"] is not None and
            not (isinstance(d["instrument"], list) and len(d["instrument"]) == 1 and d["instrument"][0] == "other")
        ]

        # Extract features, instrument, genre, and audio indices
        all_features = [d["audio"] for d in data_list]
        all_instrument = [multi_hot_encode_instruments(d, instrument_labels) for d in data_list]  # Use appropriate labels based on openmic
        all_genre = [one_hot_encode_genre(d, lastfm_genre_labels, 10) for d in data_list]
        filepaths = [d["filepath"] for d in data_list]
        unique_paths = list(set(filepaths))
        path2idx = {p: i for i, p in enumerate(unique_paths)}
        audio_idx = [path2idx[p] for p in filepaths]

        # Split data if necessary
        if split in ['train', 'valid']:
            indices = np.arange(len(data_list))
            train_idx, val_idx = train_test_split(
                indices,
                test_size=0.15,
                random_state=seed,
                stratify=all_instrument  # Uncomment for stratification if desired.
            )
            selected_indices = train_idx if split == 'train' else val_idx
        else:
            selected_indices = np.arange(len(data_list))

        # Store selected data
        self.features = [all_features[i] for i in selected_indices]
        self.instrument = [all_instrument[i] for i in selected_indices]
        self.genre = [all_genre[i] for i in selected_indices]
        self.audio_idx = [audio_idx[i] for i in selected_indices]
        self.split = split
        self.mean = mean
        self.std = std

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        """
        Retrieves data for a given index.

        Args:
            idx (int): Index of the data sample.

        Returns:
            tuple: Audio features, instrument label, and genre/audio index.
        """
        feats_dict = self.features[idx]
        main_label = self.instrument[idx]
        aux_label = self.genre[idx]
        audio_idx_val = self.audio_idx[idx]

        # Convert features to tensors
        zcr_t = torch.from_numpy(feats_dict['zcr']).float()
        centroid_t = torch.from_numpy(feats_dict['centroid']).float()
        rms_t = torch.from_numpy(feats_dict['rms']).float()
        rolloff_t = torch.from_numpy(feats_dict['rolloff']).float()
        bandwidth_t = torch.from_numpy(feats_dict['bandwidth']).float()
        mel_t = torch.from_numpy(feats_dict['mel']).float()
        mfcc_t = torch.from_numpy(feats_dict['mfcc']).float()

        # Combine features
        scalar_stack = torch.stack([zcr_t, centroid_t, rms_t, rolloff_t, bandwidth_t], dim=0)
        combined_features = mel_t.unsqueeze(0)  # Use only mel-spectrogram

        # Apply normalization
        if (self.mean is not None) and (self.std is not None):
            mean_ = self.mean.view(1, -1, 1)
            std_ = self.std.view(1, -1, 1)
            combined_features = (combined_features - mean_) / (std_ + 1e-8)

        # Return data based on split
        if self.split in ['train', 'valid']:
            return combined_features, main_label, aux_label
        else:
            return combined_features, main_label, audio_idx_val

## IRMASGTestDataset:

This class defines a custom PyTorch Dataset for loading and preprocessing audio data for testing, similar to `OpenPIRDataset`. It handles loading precomputed features, labels, and audio indices, and applies normalization if necessary. Note that this is the same as the original Test Dataset from IRMAS.

**Key functionalities:**

* **Initialization (`__init__`)**:
    - Loads precomputed features, labels, and audio indices from a specified data file (`.pkl` or `.npy`).
    - Stores normalization parameters (mean and standard deviation) if provided.
* **Data Access (`__getitem__`)**:
    - Retrieves the audio features, instrument labels, and audio index for a given index.
    - Converts the features to PyTorch tensors.
    - Applies normalization if mean and standard deviation are available.
    - Returns the processed data.
* **Dataset Length (`__len__`)**:

In [None]:
class IRMASGTestDataset(Dataset):
    def __init__(self, data_file, split='train', val_size=0.15, test_sample_half=False, seed=42, aux_mapping={
        '0': 1, '1': 1, '2': 1, '3': 0, '4': 0,
        '5': 0, '6': 0, '7': 1, '8': 1, '9': 1, '10': 2
    }, mean=None, std=None):
        """
        data_file: path to the .pkl or .npy with keys ["features", "labels", "audio_idx"]
        split: 'train', 'valid', or 'test'.
        val_size: fraction of data to use for validation.
        test_sample_half: whether to randomly sample half of the unique test audio indices.
        seed: random seed for reproducibility.
        aux_mapping: optional dictionary for auxiliary labels.
        """

        # 1) Load the precomputed dictionary
        if data_file.endswith('.pkl'):
            with open(data_file, 'rb') as f:
                data_dict = pickle.load(f)
        elif data_file.endswith('.npy'):
            data_dict = np.load(data_file, allow_pickle=True).item()
        else:
            raise ValueError("Unsupported file format: must be .pkl or .npy")

        features = data_dict['features']
        labels = data_dict['labels']
        audio_idx = data_dict['audio_idx']

        # 2) Train-validation split (no test split here)
        train_indices, val_indices = train_test_split(
            np.arange(len(features)), test_size=val_size, random_state=seed, stratify=labels
        )

        if split == 'train':
            selected_indices = train_indices
        elif split == 'valid':
            selected_indices = val_indices
        elif split == 'test':
            # Select test set (use all indices initially)
            selected_indices = np.arange(len(features))

            if test_sample_half:
                # Find unique audio indices
                unique_audio_indices = np.unique(audio_idx)

                # Randomly select half of them
                np.random.seed(seed)
                sampled_audio_indices = np.random.choice(
                    unique_audio_indices, size=len(unique_audio_indices) // 2, replace=False)

                # Filter dataset to include only selected audio indices
                selected_indices = [
                    i for i in selected_indices if audio_idx[i] in sampled_audio_indices]

        else:
            raise ValueError("Invalid split. Use 'train', 'valid', or 'test'.")

        # 3) Apply the selected indices to filter data
        self.features = [features[i] for i in selected_indices]
        self.labels = [labels[i] for i in selected_indices]
        self.audio_idx = [audio_idx[i] for i in selected_indices]
        self.split = split
        self.aux_mapping = aux_mapping

        self.mean = mean
        self.std = std

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        """
        Returns:
          combined_features: shape [153, frames]
          label: (optional) torch.LongTensor
          audio_idx: (only if you want it, e.g. in test)
        """
        feats_dict = self.features[idx]
        label = self.labels[idx]
        audio_idx = self.audio_idx[idx]

        # Convert each numpy array to torch.Tensors
        zcr_t = torch.from_numpy(feats_dict['zcr']).type(torch.float32)
        centroid_t = torch.from_numpy(
            feats_dict['centroid']).type(torch.float32)
        rms_t = torch.from_numpy(feats_dict['rms']).type(torch.float32)
        rolloff_t = torch.from_numpy(feats_dict['rolloff']).type(torch.float32)
        bandwidth_t = torch.from_numpy(
            feats_dict['bandwidth']).type(torch.float32)
        mel_t = torch.from_numpy(feats_dict['mel']).type(torch.float32)
        mfcc_t = torch.from_numpy(feats_dict['mfcc']).type(torch.float32)

        # Stack the 5 scalar features into [5, frames]
        scalar_stack = torch.stack(
            [zcr_t, centroid_t, rms_t, rolloff_t, bandwidth_t], dim=0)

        # Concatenate to get [153, frames]: 5 scalars + 128 mel + 20 mfcc
        # combined_features = torch.cat([scalar_stack, mel_t, mfcc_t], dim=0).unsqueeze(0)
        combined_features = mel_t.unsqueeze(0)

        # Apply normalization if available
        if (self.mean is not None) and (self.std is not None):
            mean_ = self.mean.view(1, -1, 1)  # [1, C, 1]
            std_ = self.std.view(1, -1, 1)
            combined_features = (combined_features - mean_) / (std_ + 1e-8)

        label_t = torch.tensor(label, dtype=torch.long)
        if self.split in ['train', 'valid']:
            aux_label = self.aux_mapping[str(label)]
            aux_label_t = torch.tensor(aux_label, dtype=torch.long)
            return combined_features, label_t, aux_label_t
        else:
            return combined_features, label_t, audio_idx

## TestBatchSampler: Maintaining Audio Context During Testing

The `TestBatchSampler` is a custom PyTorch sampler crucial for testing
audio-based models. It ensures that all slices or chunks belonging to
the same audio file are processed together in a batch or in controlled
sub-batches. This is essential for preserving the temporal context of
the audio, which is often vital for accurate predictions in tasks like
instrument classification or music genre recognition. Essentially, it allows the model to process one full audio sample as one batch, so that we can know that all the 1-second snippets in this batch belong to the same audio, and hence aggregate evaluations.

**How it Works:**

1. **Initialization:** The sampler takes a dictionary `audio_to_indices`
   that maps each unique audio ID to a list of corresponding dataset indices
   representing its slices. It also takes the desired `batch_size`.

2. **Batch Generation:** During iteration, the sampler retrieves all indices
   associated with a single audio ID. It then yields either:
   - All indices at once if the `batch_size` can accommodate them.
   - Sub-batches of the specified size, ensuring that slices from the same
     audio are grouped.

3. **Batch Count:** The sampler calculates the total number of batches,
   considering potential sub-batches created due to `batch_size` constraints.

**Benefits:**

- **Preserves Audio Context:** By processing all slices of an audio together,
  the model can make more informed predictions based on the complete temporal
  information.
- **Controlled Batching:** The `batch_size` parameter allows flexibility in
  managing batch sizes while ensuring that audio context is maintained.
- **Improved Accuracy:** By preserving context, the sampler contributes to
  achieving more accurate results in audio-based tasks.


In [None]:


class TestBatchSampler(Sampler):
    def __init__(self, audio_to_indices, batch_size):
        """
        audio_to_indices: dict[int, list[int]]
            A dict mapping audio_idx -> list of dataset indices
            that correspond to that audio.
        batch_size: int
            Max number of samples in a batch. If you want to keep
            each audio in exactly one batch, you can just set this
            to a very large number or the max number of slices.
        """
        self.audio_to_indices = audio_to_indices
        self.audio_idxs = list(audio_to_indices.keys())
        self.batch_size = batch_size

    def __iter__(self):
        """
        Yield a list of dataset indices in each batch.
        In this case, we yield all items for one audio_idx
        (possibly split if you want smaller batches).
        """
        for audio_idx in self.audio_idxs:
            indices = self.audio_to_indices[audio_idx]
            # If you want all slices from an audio_idx
            # to be in a single batch, just do:
            # yield indices
            #
            # Otherwise, if you want them in sub-batches of size batch_size:
            for i in range(0, len(indices), self.batch_size):
                yield indices[i : i + self.batch_size]

    def __len__(self):
        """
        The total number of batches.
        This is the sum of however many sub-batches each audio has
        when chunked by batch_size.
        """
        total_batches = 0
        for idxs in self.audio_to_indices.values():
            total_batches += math.ceil(len(idxs) / self.batch_size)
        return total_batches



## Convolutional Neural Network Components: Conv_2d and CNN

This section describes two key components of a Convolutional Neural Network (CNN)
used for audio classification: `Conv_2d` (a convolutional block) and `CNN`
(the overall network architecture). This network architecture is the same architecture used in the han paper.

### Conv_2d: A Convolutional Block

The `Conv_2d` class defines a fundamental building block for CNNs. It consists of:

1. **Two Convolutional Layers:** These layers extract features from the input audio data.
2. **LeakyReLU Activation:** Introduces non-linearity to the model, allowing it to learn
   more complex patterns.
3. **Max Pooling:** Reduces the spatial dimensions of the feature maps, decreasing the
   computational complexity and making the model more robust to small variations in the input.
4. **Dropout:** A regularization technique that helps prevent overfitting by randomly
   ignoring some neurons during training.

**Purpose:** `Conv_2d` blocks extract features, introduce non-linearity, and reduce
spatial dimensions while preventing overfitting. By stacking multiple `Conv_2d` blocks,
we can build deeper and more powerful CNNs.

### CNN: A Convolutional Neural Network for Audio Classification

The `CNN` class defines the overall architecture of the convolutional neural network.
It utilizes multiple `Conv_2d` blocks and other layers to process audio data for
classification tasks like instrument recognition or music genre classification.

**Key Components:**

1. **Multiple Conv_2d Blocks:** Extract hierarchical features from the audio data.
2. **Additional Convolutional Layers:** Further process the extracted features.
3. **Global Max Pooling:** Reduces the spatial dimensions to a single value per feature map.
4. **Fully Connected Layers:** Map the learned features to the desired output classes.
5. **Dropout:** Applied to the fully connected layers to prevent overfitting.

**Purpose:** The `CNN` model learns hierarchical representations of audio data,
starting with low-level features and progressing to more abstract features.
The fully connected layers then use these learned features to classify the audio
into different categories.

In [None]:

class Conv_2d(nn.Module):
    # shape: size of kernel / pooling: pooling factor (3 means shrinking dimension to 1/3)
    def __init__(self, input_channels, output_channels, shape=3, padding=1, pooling=3, dropout=0.25):
        # call constructor from nn.Module class
        super(Conv_2d, self).__init__()

        # 1st convolution layer - use input channels, output channels, add padding of 1x1
        self.conv1 = nn.Conv2d(
            input_channels, output_channels, shape, padding=padding)

        # 2nd convolution layer - use output channels for input and output because we want number of filters to be the same
        self.conv2 = nn.Conv2d(
            output_channels, output_channels, shape, padding=padding)

        # relu for activation
        self.l_relu = nn.LeakyReLU(negative_slope=0.33)
        self.maxpool = nn.MaxPool2d(pooling)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.conv1(x)
        x = self.l_relu(x)
        x = self.conv2(x)
        x = self.l_relu(x)
        x = self.maxpool(x)
        x = self.dropout(x)
        return x

class CNN(nn.Module):
  def __init__(self, input_channels=1, num_classes=11):
    super(CNN, self).__init__()

    self.layer1 = Conv_2d(input_channels, 32)

    self.layer2 = Conv_2d(32, 64)

    self.layer3 = Conv_2d(64, 128)

    self.layer4 = nn.Conv2d(128, 256, 3, padding=1)

    self.l_relu = nn.LeakyReLU(negative_slope=0.33)

    self.layer5 = nn.Conv2d(256, 256, 3, padding=1)

    # global max pooling takes in (x, y) parameters and shrinks remaining row x col dimension into specified numbers.

    self.global_max_pool = nn.AdaptiveMaxPool2d((1, 1))

    # nn.Linear -> fully connected layer. It expects flattened array as input, maps all input neurons to output neurons.

    self.fc = nn.Linear(256, 1024)

    self.dropout = nn.Dropout(0.5)

    self.fc2 = nn.Linear(1024, num_classes)

  # Used in previous experiment but not used in final paper experiment
  def set_num_classes(self, num_classes):
    """Reinitialize the final layer to accommodate a different number of classes. """
    self.fc2 = nn.Linear(1024, num_classes)


  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.l_relu(x)
    x = self.layer5(x)
    x = self.l_relu(x)
    x = self.global_max_pool(x)


    x = x.view(x.size(0), -1)
    x = self.fc(x)
    x = self.dropout(x)
    x = self.fc2(x)

    # nn.BCEWithLogitsLoss() already uses sigmoid internally, so we don't define it in our model
    # x = F.sigmoid(x)

    return x


## Training / Testing
This script implements a systematic training pipeline for the CNN model. It was originally used to fine-tune a model after pretraining, but one can easily change the layer configuration to just train from scratch. We ended up training models from scratch using this code. This code allows testing on different layer freezing configurations across multiple seeds. After training, the model is evaluated on the IRMAS test dataset, and the results are saved.


In [None]:
#@title Training / Testing


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# ----------------------------
# 1) Define layer freezing configurations
# ----------------------------
layer_configs = [
    # {
    #     "name": "freeze_all_conv",
    #     "layers": ["layer1", "layer2", "layer3", "layer4", "layer5"],
    #     "description": "Freeze all convolutional layers"
    # },
    # {
    #     "name": "freeze_early_mid",
    #     "layers": ["layer1", "layer2", "layer3"],
    #     "description": "Freeze early and mid convolutional layers"
    # },
    # {
    #     "name": "freeze_early",
    #     "layers": ["layer1", "layer2"],
    #     "description": "Freeze only early convolutional layers"
    # },
    # {
    #     "name": "freeze_first",
    #     "layers": ["layer1"],
    #     "description": "Freeze only first convolutional layer"
    # },
    {
        "name": "unfreeze_all",
        "layers": [],
        "description": "Fine-tune all layers"
    }
]

# List of seeds to test
seeds = [1, 10, 9000]  # Modify as needed

# ----------------------------
# 2) Path configuration
# ----------------------------
# Base directory for fine-tuning results
finetune_base_dir = '/content/drive/MyDrive/Aiden/genre_auxiliary_experiment/IRMASv9_all_features'
os.makedirs(finetune_base_dir, exist_ok=True)

# Path to pretrained model - use the exact folder name
pretrain_base_dir = '/content/drive/MyDrive/Aiden/genre_auxiliary_experiment/pretrain_openmic'
pretrain_folder_name = "seed_4_openmic_only_100epochs"  # Change this to your pretrained model folder

# ----------------------------
# 3) Ensure worker seeds are consistent for reproducibility
# ----------------------------
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


# ----------------------------
# 4) Initialize result tracking
# ----------------------------
all_results = {
    "configs": [config["name"] for config in layer_configs],
    "seeds": seeds,
    "f1_macro": {},
    "f1_micro": {},
    "val_accuracy": {}
}

# ----------------------------
# 5) Main experiment loop
# ----------------------------
for SEED in seeds:
    print(f"\n{'='*70}")
    print(f"RUNNING EXPERIMENTS WITH SEED {SEED}")
    print(f"{'='*70}")

    # Set seed for reproducibility
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Initialize tracking dictionary for this seed
    seed_results = {
        "config_names": [],
        "val_accuracies": [],
        "f1_macro": [],
        "f1_micro": [],
        "early_stop_epoch": []
    }

    # Get pretrained model path
    pretrained_model_path = os.path.join(pretrain_base_dir, pretrain_folder_name, "best_model.pth")

    # Create datasets with this seed
    train_dataset = OpenPIRDataset(
        data_file="/content/drive/MyDrive/Data/IRMAS/IRMASG_Precomputed_Features_v9.pkl",
        split='train', seed=SEED, openmic=False  
    )

    valid_dataset = OpenPIRDataset(
        data_file="/content/drive/MyDrive/Data/IRMAS/IRMASG_Precomputed_Features_v9.pkl",
        split='valid', seed=SEED, openmic=False
    )

    # # Compute normalization parameters from training data
    # train_mean, train_std = compute_mean_std(train_dataset)

    # # Apply normalization to datasets
    # train_dataset.mean = train_mean
    # train_dataset.std = train_std
    # valid_dataset.mean = train_mean
    # valid_dataset.std = train_std

    # Create test dataset with same normalization
    test_dataset = IRMASGTestDataset(
        data_file="/content/drive/MyDrive/irmas_test_set_augmented_features.pkl",
        split='test', test_sample_half=False, seed=SEED,
        # mean=train_mean, std=train_std
    )

    # Build audio index mapping for test data
    audio_to_indices = defaultdict(list)
    for i in range(len(test_dataset)):
        _, _, audio_idx = test_dataset[i]
        audio_to_indices[audio_idx].append(i)

    max_slices = max(len(idxs) for idxs in audio_to_indices.values())
    batch_size = max_slices

    # Create data loaders
    g = torch.Generator()
    g.manual_seed(SEED)

    train_loader = DataLoader(
        train_dataset, batch_size=64,
        shuffle=True, num_workers=8,
        worker_init_fn=seed_worker, generator=g
    )

    valid_loader = DataLoader(
        valid_dataset, batch_size=64,
        shuffle=False, num_workers=8,
        worker_init_fn=seed_worker, generator=g
    )

    test_batch_sampler = TestBatchSampler(audio_to_indices, batch_size=batch_size)
    test_loader = DataLoader(
        test_dataset, batch_sampler=test_batch_sampler,
        num_workers=8, pin_memory=True
    )

    # Loop through each layer freezing configuration
    for config in layer_configs:
        config_name = config["name"]
        freeze_layers = config["layers"]

        print(f"\n{'-'*60}")
        print(f"Testing configuration: {config_name}")
        print(f"Description: {config['description']}")
        print(f"Freezing layers: {', '.join(freeze_layers) if freeze_layers else 'None'}")
        print(f"{'-'*60}")

        # Create experiment directory for this seed+config combination
        experiment_dir = os.path.join(finetune_base_dir, f'seed_{SEED}_{config_name}')
        os.makedirs(experiment_dir, exist_ok=True)

        # Log configuration details
        with open(os.path.join(experiment_dir, "config.txt"), "w") as f:
            f.write(f"Seed: {SEED}\n")
            f.write(f"Configuration: {config_name}\n")
            f.write(f"Description: {config['description']}\n")
            f.write(f"Frozen layers: {', '.join(freeze_layers) if freeze_layers else 'None'}\n")
            f.write(f"Pretrained model: {pretrained_model_path}\n")
            f.write(f"Pretrained from folder: {pretrain_folder_name}\n")

        # Load pretrained model (trained on OpenMIC with 10 classes)
        cnn = CNN(input_channels=1, num_classes=11).to(device)
        # cnn.load_state_dict(torch.load(pretrained_model_path, map_location=device))
        # cnn.set_num_classes(11)
        # cnn = cnn.to(device)



        # Apply freezing based on configuration
        for name, param in cnn.named_parameters():
            # Extract layer name (before the first dot)
            layer_name = name.split('.')[0]

            # Set requires_grad based on whether this layer should be frozen
            param.requires_grad = layer_name not in freeze_layers

            # Log the freezing status
            status = "frozen (fixed)" if layer_name in freeze_layers else "trainable"
            print(f"  {name}: {status}")

        # Configure optimizer (only for trainable parameters)
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, cnn.parameters()),
            lr=0.001,
            weight_decay=1e-4
        )

        # Learning rate scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6, verbose=True
        )

        # Set up loss function
        criterion_instr = torch.nn.BCEWithLogitsLoss()

        # Early stopping setup
        early_stop_patience = 15
        early_stop_counter = 0

        # Training metrics
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []
        best_val_loss = float('inf')
        stopped_epoch = 100  # Default to max epochs
        num_epochs = 100     # Fine-tuning typically needs fewer epochs

        # Main training loop for this configuration
        for epoch in tqdm(range(num_epochs), desc=f"Training {config_name}"):
            # Training phase
            cnn.train()
            running_loss = 0.0
            train_correct = 0
            train_total = 0

            for data, labels, aux_labels in train_loader:
                data = data.to(device)
                labels = labels.to(device)

                # Convert to one-hot for BCE loss
                BCE_labels = labels.float()

                optimizer.zero_grad()
                outputs = cnn(data)
                loss = criterion_instr(outputs, BCE_labels)
                loss.backward()
                optimizer.step()

                # Track metrics
                # running_loss += loss.item()
                # predicted_class = torch.argmax(outputs, dim=1)
                # train_correct += (predicted_class == labels).sum().item()
                # train_total += labels.size(0)

                # After loss.backward() and optimizer.step()
                running_loss += loss.item()

                # Apply sigmoid to convert logits to probabilities
                probabilities = torch.sigmoid(outputs)

                # Threshold probabilities at 0.5 to get predicted labels
                predicted_labels = (probabilities >= 0.5).int()

                # Compute accuracy for multi-label classification
                correct_predictions = (predicted_labels == labels.int()).float()
                accuracy_per_sample = correct_predictions.mean(dim=1)  # accuracy per sample
                batch_accuracy = accuracy_per_sample.mean().item()

                train_correct += batch_accuracy * labels.size(0)
                train_total += labels.size(0)

            # Calculate epoch metrics
            avg_train_loss = running_loss / len(train_loader)
            train_accuracy = train_correct / train_total
            train_losses.append(avg_train_loss)
            train_accuracies.append(train_accuracy)

            # Validation phase
            cnn.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for data, labels, aux_labels in valid_loader:
                    data = data.to(device)
                    labels = labels.to(device)
                    BCE_labels = labels.float()

                    outputs = cnn(data)
                    loss = criterion_instr(outputs, BCE_labels)
                    val_loss += loss.item()

                    # predicted_class = torch.argmax(outputs, dim=1)
                    # val_correct += (predicted_class == labels).sum().item()
                    # val_total += labels.size(0)

                    probabilities = torch.sigmoid(outputs)
                    predicted_labels = (probabilities >= 0.5).int()

                    correct_predictions = (predicted_labels == labels.int()).float()
                    accuracy_per_sample = correct_predictions.mean(dim=1)
                    batch_accuracy = accuracy_per_sample.mean().item()

                    val_correct += batch_accuracy * labels.size(0)
                    val_total += labels.size(0)



            # Calculate validation metrics
            avg_val_loss = val_loss / len(valid_loader)
            val_accuracy = val_correct / val_total
            val_losses.append(avg_val_loss)
            val_accuracies.append(val_accuracy)

            # Update learning rate scheduler
            scheduler.step(avg_val_loss)
            current_lr = optimizer.param_groups[0]["lr"]

            # Log epoch results
            with open(os.path.join(experiment_dir, "training_log.txt"), "a") as f:
                f.write(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, "
                        f"Train Acc: {train_accuracy:.4f}, Val Loss: {avg_val_loss:.4f}, "
                        f"Val Acc: {val_accuracy:.4f}, LR: {current_lr:.6f}\n")

            # Save best model and check for early stopping
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(cnn.state_dict(), os.path.join(experiment_dir, "best_model.pth"))
                early_stop_counter = 0
            else:
                early_stop_counter += 1

            # Check for early stopping
            if early_stop_counter >= early_stop_patience:
                print(f"Early stopping at epoch {epoch+1}")
                with open(os.path.join(experiment_dir, "training_log.txt"), "a") as f:
                    f.write(f"Early stopping triggered at epoch {epoch+1}\n")
                stopped_epoch = epoch + 1
                break

        # Save training curves
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.plot(train_losses, label='Training Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title(f'Loss Curves - Seed {SEED}, {config_name}')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(train_accuracies, label='Training Accuracy')
        plt.plot(val_accuracies, label='Validation Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.title(f'Accuracy Curves - Seed {SEED}, {config_name}')
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(experiment_dir, 'training_curves.png'))
        plt.close()

        # Test best model on the test set
        best_model_path = os.path.join(experiment_dir, "best_model.pth")
        cnn.load_state_dict(torch.load(best_model_path))

        cnn.eval()
        final_predictions = []
        final_labels = []

        with open(os.path.join(experiment_dir, "predictions.txt"), "w") as pred_file:
            pred_file.write("Batch-wise Model Predictions vs. Actual Labels\n")
            pred_file.write("=" * 80 + "\n")

            with torch.no_grad():
                for data, labels, audio_idxs in test_loader:
                    data = data.to(device)
                    true_label = labels[0].int().numpy()

                    outputs = cnn(data)
                    probabilities = torch.sigmoid(outputs)

                    # Sum across all slices for this audio
                    sum_prob = torch.sum(probabilities, dim=0)

                    # Normalize the sum
                    max_sum = torch.max(sum_prob)
                    if max_sum > 0:  # Avoid division by zero
                        normalized_sum_prob = sum_prob / max_sum
                    else:
                        normalized_sum_prob = sum_prob

                    # Apply threshold
                    predicted = (normalized_sum_prob.cpu() >= 0.5).int().numpy()

                    final_predictions.append(predicted)
                    final_labels.append(true_label)

                    # Write predictions to file
                    pred_file.write(f"Audio ID: {audio_idxs[0].item()}\n")
                    pred_file.write(f"Predicted: {predicted.tolist()}\n")
                    pred_file.write(f"Actual: {true_label.tolist()}\n")
                    pred_file.write("-" * 80 + "\n")

        # Calculate metrics
        final_predictions = np.array(final_predictions)
        final_labels = np.array(final_labels)

        f1_macro = f1_score(final_labels, final_predictions, average='macro')
        f1_micro = f1_score(final_labels, final_predictions, average='micro')
        hl = hamming_loss(final_labels, final_predictions)

        print(f"\nResults for seed {SEED}, config {config_name}:")
        print(f"F1 Macro: {f1_macro:.4f}")
        print(f"F1 Micro: {f1_micro:.4f}")
        print(f"Hamming Loss: {hl:.4f}")

        # Save results
        with open(os.path.join(experiment_dir, "results.txt"), "w") as f:
            f.write(f"Seed: {SEED}\n")
            f.write(f"Configuration: {config_name}\n")
            f.write(f"F1 Macro: {f1_macro:.4f}\n")
            f.write(f"F1 Micro: {f1_micro:.4f}\n")
            f.write(f"Hamming Loss: {hl:.4f}\n")
            f.write(f"Best Validation Loss: {best_val_loss:.4f}\n")
            f.write(f"Best Validation Accuracy: {max(val_accuracies):.4f}\n")
            f.write(f"Stopped at epoch: {stopped_epoch}/{num_epochs}\n")

        # Generate classification report
        report_str = classification_report(final_labels, final_predictions)
        with open(os.path.join(experiment_dir, "classification_report.txt"), "w") as f:
            f.write(f"F1 Macro: {f1_macro:.4f}\n")
            f.write(f"F1 Micro: {f1_micro:.4f}\n")
            f.write(f"Hamming Loss: {hl:.4f}\n\n")
            f.write(report_str)

        # Store results for this seed+config
        seed_results["config_names"].append(config_name)
        seed_results["val_accuracies"].append(max(val_accuracies))
        seed_results["f1_macro"].append(f1_macro)
        seed_results["f1_micro"].append(f1_micro)
        seed_results["early_stop_epoch"].append(stopped_epoch)

    # Create comparative visualization for this seed
    plt.figure(figsize=(14, 8))

    # Plot F1 scores
    plt.subplot(2, 1, 1)
    plt.bar(np.arange(len(seed_results["config_names"])) - 0.15,
            seed_results["f1_macro"], width=0.3, label="F1 Macro")
    plt.bar(np.arange(len(seed_results["config_names"])) + 0.15,
            seed_results["f1_micro"], width=0.3, label="F1 Micro")
    plt.xticks(range(len(seed_results["config_names"])), seed_results["config_names"], rotation=45)
    plt.ylabel("F1 Score")
    plt.title(f"F1 Scores Across Layer Freezing Configurations (Seed {SEED})")
    plt.legend()

    # Plot validation accuracies
    plt.subplot(2, 1, 2)
    plt.bar(range(len(seed_results["config_names"])), seed_results["val_accuracies"])
    plt.xticks(range(len(seed_results["config_names"])), seed_results["config_names"], rotation=45)
    plt.ylabel("Best Validation Accuracy")
    plt.title(f"Validation Accuracy Across Layer Freezing Configurations (Seed {SEED})")

    plt.tight_layout()
    plt.savefig(os.path.join(finetune_base_dir, f"seed_{SEED}_comparison.png"))
    plt.close()

    # Save seed results to json
    with open(os.path.join(finetune_base_dir, f"seed_{SEED}_results.json"), "w") as f:
        json.dump(seed_results, f, indent=4)

    # Store in all_results
    all_results["f1_macro"][SEED] = seed_results["f1_macro"]
    all_results["f1_micro"][SEED] = seed_results["f1_micro"]
    all_results["val_accuracy"][SEED] = seed_results["val_accuracies"]

# After all seeds, compute average performance and standard deviation for each configuration
avg_f1_macro = []
std_f1_macro = []
avg_f1_micro = []
std_f1_micro = []
avg_val_acc = []
std_val_acc = []

for i, config in enumerate([config["name"] for config in layer_configs]):
    # Collect metrics across seeds for this configuration
    f1_macro_values = [all_results["f1_macro"][seed][i] for seed in seeds]
    f1_micro_values = [all_results["f1_micro"][seed][i] for seed in seeds]
    val_acc_values = [all_results["val_accuracy"][seed][i] for seed in seeds]

    # Compute statistics
    avg_f1_macro.append(np.mean(f1_macro_values))
    std_f1_macro.append(np.std(f1_macro_values))
    avg_f1_micro.append(np.mean(f1_micro_values))
    std_f1_micro.append(np.std(f1_micro_values))
    avg_val_acc.append(np.mean(val_acc_values))
    std_val_acc.append(np.std(val_acc_values))

# Add averages and standard deviations to results
all_results["avg_f1_macro"] = avg_f1_macro
all_results["std_f1_macro"] = std_f1_macro
all_results["avg_f1_micro"] = avg_f1_micro
all_results["std_f1_micro"] = std_f1_micro
all_results["avg_val_accuracy"] = avg_val_acc
all_results["std_val_accuracy"] = std_val_acc

# Save overall results
with open(os.path.join(finetune_base_dir, "overall_results.json"), "w") as f:
    json.dump(all_results, f, indent=4)

# Create final visualization with error bars
plt.figure(figsize=(14, 8))
x = np.arange(len(layer_configs))
width = 0.35

plt.bar(x - width/2, avg_f1_macro, width,
        yerr=std_f1_macro, capsize=5, label="F1 Macro", color='blue', alpha=0.7)
plt.bar(x + width/2, avg_f1_micro, width,
        yerr=std_f1_micro, capsize=5, label="F1 Micro", color='green', alpha=0.7)

plt.xlabel("Layer Freezing Configuration")
plt.ylabel("F1 Score (Average over Seeds)")
plt.title("Performance Comparison Across Layer Freezing Strategies")
plt.xticks(x, [config["name"] for config in layer_configs], rotation=45, ha="right")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(finetune_base_dir, "overall_comparison.png"))
plt.close()

print(f"\nAll experiments complete! Results saved to {finetune_base_dir}")



RUNNING EXPERIMENTS WITH SEED 1
