# 1. Set up 

In [1]:
# Cell 1

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
# Use AutoProcessor for Wav2Vec2-BERT - it bundles feature_extractor and tokenizer (if needed)
from transformers import AutoModelForAudioClassification, AutoProcessor

from torch.optim import AdamW
import pandas as pd
import numpy as np
import os
import sys
import ast # For parsing string representations of lists/arrays
import logging
import time
from sklearn.metrics import hamming_loss, jaccard_score, f1_score # Add more as needed
from tqdm.notebook import tqdm # Use notebook version of tqdm
import librosa # Needed for loading raw audio now



# --- Project Setup ---
# Detect if running in notebook or script to adjust path

cwd = os.getcwd()
PROJECT_ROOT = os.path.abspath(os.path.join(cwd, '../../')) # NOTE: remember to change if change the directory structure



print(f"PROJECT_ROOT detected as: {PROJECT_ROOT}")
if PROJECT_ROOT not in sys.path:
    print(f"Adding {PROJECT_ROOT} to sys.path")
    sys.path.append(PROJECT_ROOT)

# --- Config and Utils ---
try:
    import config # Import your configuration file
    # Optionally import utils if needed, e.g., for get_audio_path if not defined here
    # import src.utils as utils
except ModuleNotFoundError:
     print("ERROR: Cannot import config or utils. Make sure PROJECT_ROOT is correct and src is importable.")
     # Or add src to path: sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src'))
     # import config
     # import utils


# --- Setup Logging ---
for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # Clear previous
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[logging.StreamHandler(sys.stdout)])

print("Imports and basic setup complete.")

  from .autonotebook import tqdm as notebook_tqdm


PROJECT_ROOT detected as: /workspace/musicClaGen
Adding /workspace/musicClaGen to sys.path
/workspace/musicClaGen
Imports and basic setup complete.


# 2. Config 

In [2]:
# Cell 2
# --- Load Config ---
# Ensure config.py has the correct paths in the PATHS dict
manifest_path = config.PATHS.get('SMALL_MULTILABEL_PATH', os.path.join(config.PATHS['PROCESSED_DATA_DIR'], 'small_subset_multihot.csv')) # Use .get for safety
genre_list_path = config.PATHS.get('GENRE_LIST_PATH', os.path.join(config.PATHS['PROCESSED_DATA_DIR'], 'unified_genres.txt'))
model_save_dir = config.PATHS['MODELS_DIR']

# Ensure config.py has MODEL_PARAMS dict with model_checkpoint
model_checkpoint = config.MODEL_PARAMS['model_checkpoint'] # e.g., "facebook/w2v-bert-2.0" - VERIFY!
learning_rate = config.MODEL_PARAMS['learning_rate']
batch_size = config.MODEL_PARAMS['batch_size'] # Use the small BS for notebook test
num_epochs_debug = 1 # <<<--- RUN ONLY 1 EPOCH FOR DEBUGGING ---<<<
weight_decay = config.MODEL_PARAMS['weight_decay']
gradient_accumulation_steps = config.MODEL_PARAMS['gradient_accumulation_steps']

# --- Load unified genre list ---
try:
    with open(genre_list_path, 'r') as f:
        unified_genres = [line.strip() for line in f if line.strip()]
    num_labels = len(unified_genres) # should be the number of labels defined in the unified_genres.txt file, in this case it should be 22.
    logging.info(f"Loaded {num_labels} unified genres from {genre_list_path}")
    if num_labels == 0: raise ValueError("Genre list is empty!")
except Exception as e:
    logging.error(f"Failed to load or process unified genre list: {e}", exc_info=True)
    raise SystemExit("Cannot proceed without genre list.")

# --- Setup Device ---
device = torch.device(config.DEVICE if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
if not torch.cuda.is_available() and config.DEVICE=="cuda":
     logging.warning("CUDA selected but not available, falling back to CPU.")

# --- Create Save Directory ---
os.makedirs(model_save_dir, exist_ok=True)

2025-05-04 15:10:22,028 - INFO - Loaded 22 unified genres from /workspace/musicClaGen/data/processed/unified_genres.txt
2025-05-04 15:10:22,030 - INFO - Using device: cuda


In [3]:
print(manifest_path)

/workspace/musicClaGen/data/processed/small_subset_multihot.csv


In [4]:
# Cell: Comprehensive Training Logger Setup

import os
import json
import time
import matplotlib.pyplot as plt
from datetime import datetime
import logging

class TrainingLogger:
    """Comprehensive training logger and metrics tracker"""
    
    def __init__(self, output_dir, model_name, config=None):
        """
        Initialize logger with output directory and model name
        
        Args:
            output_dir: Base directory for saving logs and checkpoints
            model_name: Name of the model being trained
            config: Dictionary of configuration parameters
        """
        # Create timestamp for unique run identification
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.run_name = f"{model_name}_{timestamp}"
        
        # Create output directories
        self.base_dir = os.path.join(output_dir, self.run_name)
        self.checkpoint_dir = os.path.join(self.base_dir, "checkpoints")
        self.log_dir = os.path.join(self.base_dir, "logs")
        self.plot_dir = os.path.join(self.base_dir, "plots")
        
        os.makedirs(self.base_dir, exist_ok=True)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(self.plot_dir, exist_ok=True)
        
        # Initialize metrics storage
        self.metrics = {
            "train_loss": [],
            "val_loss": [],
            "hamming_loss": [],
            "jaccard_samples": [],
            "f1_micro": [],
            "f1_macro": [],
            "learning_rate": [],
            "epochs": [],
            "steps": [],
            "best_metrics": {},
            "training_time": 0
        }
        
        # Save configuration
        self.config = config
        if config:
            with open(os.path.join(self.base_dir, "config.json"), 'w') as f:
                json.dump(config, f, indent=2)
        
        # Setup file logger
        self.setup_file_logger()
        
        # Log initialization
        logging.info(f"Initialized training run: {self.run_name}")
        logging.info(f"Output directory: {self.base_dir}")
    
    def setup_file_logger(self):
        """Setup file logging"""
        log_file = os.path.join(self.log_dir, "training.log")
        file_handler = logging.FileHandler(log_file)
        file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
        
        # Add file handler to root logger
        root_logger = logging.getLogger()
        root_logger.addHandler(file_handler)
    
    def log_epoch(self, epoch, train_loss, val_metrics, learning_rate, step):
        """Log metrics for an epoch"""
        self.metrics["epochs"].append(epoch)
        self.metrics["train_loss"].append(train_loss)
        self.metrics["val_loss"].append(val_metrics.get("eval_loss", 0))
        self.metrics["hamming_loss"].append(val_metrics.get("hamming_loss", 0))
        self.metrics["jaccard_samples"].append(val_metrics.get("jaccard_samples", 0))
        self.metrics["f1_micro"].append(val_metrics.get("f1_micro", 0))
        self.metrics["f1_macro"].append(val_metrics.get("f1_macro", 0))
        self.metrics["learning_rate"].append(learning_rate)
        self.metrics["steps"].append(step)
        
        # Log to file
        logging.info(f"Epoch {epoch} metrics:")
        logging.info(f"  Train Loss: {train_loss:.4f}")
        for name, value in val_metrics.items():
            if isinstance(value, (int, float)):
                logging.info(f"  {name.replace('_', ' ').title()}: {value:.4f}")
        
        # Save metrics after each epoch
        self.save_metrics()
        
        # Generate plots
        self.generate_plots()
    
    def update_best_metrics(self, epoch, step, val_metrics, model_path):
        """Update best metrics if current results are better"""
        current_metric = val_metrics.get("hamming_loss", float('inf'))
        
        # For hamming loss, lower is better
        if not self.metrics["best_metrics"] or current_metric < self.metrics["best_metrics"].get("hamming_loss", float('inf')):
            self.metrics["best_metrics"] = {
                "epoch": epoch,
                "step": step,
                "model_checkpoint": model_path,
                **{k: v for k, v in val_metrics.items() if isinstance(v, (int, float))}
            }
            logging.info(f"New best model at epoch {epoch}, step {step} with hamming_loss: {current_metric:.4f}")
            return True
        return False
    
    def save_metrics(self):
        """Save metrics to JSON file"""
        metrics_file = os.path.join(self.log_dir, "metrics.json")
        with open(metrics_file, 'w') as f:
            json.dump(self.metrics, f, indent=2)
    
    def save_trainer_state(self, epoch, step, optimizer_state=None, scheduler_state=None):
        """Save trainer state"""
        trainer_state = {
            "epoch": epoch,
            "step": step,
            "best_metrics": self.metrics["best_metrics"],
            "timestamp": datetime.now().isoformat(),
            "training_time": self.metrics["training_time"]
        }
        
        # Add optimizer and scheduler states if provided
        if optimizer_state:
            optimizer_state_file = os.path.join(self.log_dir, "optimizer_state.pt")
            torch.save(optimizer_state, optimizer_state_file)
            trainer_state["optimizer_state_file"] = optimizer_state_file
            
        if scheduler_state:
            scheduler_state_file = os.path.join(self.log_dir, "scheduler_state.pt")
            torch.save(scheduler_state, scheduler_state_file)
            trainer_state["scheduler_state_file"] = scheduler_state_file
            
        state_file = os.path.join(self.base_dir, "trainer_state.json")
        with open(state_file, 'w') as f:
            json.dump(trainer_state, f, indent=2)
    
    def generate_plots(self):
        """Generate and save plots of training metrics"""
        # Loss plot
        plt.figure(figsize=(10, 6))
        plt.plot(self.metrics["epochs"], self.metrics["train_loss"], label="Train Loss")
        plt.plot(self.metrics["epochs"], self.metrics["val_loss"], label="Validation Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss")
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(self.plot_dir, "loss_plot.png"))
        plt.close()
        
        # Metrics plot
        plt.figure(figsize=(12, 8))
        plt.subplot(2, 2, 1)
        plt.plot(self.metrics["epochs"], self.metrics["hamming_loss"])
        plt.xlabel("Epoch")
        plt.ylabel("Hamming Loss")
        plt.title("Hamming Loss")
        plt.grid(True)
        
        plt.subplot(2, 2, 2)
        plt.plot(self.metrics["epochs"], self.metrics["jaccard_samples"])
        plt.xlabel("Epoch")
        plt.ylabel("Jaccard Score")
        plt.title("Jaccard Score (Samples)")
        plt.grid(True)
        
        plt.subplot(2, 2, 3)
        plt.plot(self.metrics["epochs"], self.metrics["f1_micro"])
        plt.xlabel("Epoch")
        plt.ylabel("F1 Score")
        plt.title("F1 Score (Micro)")
        plt.grid(True)
        
        plt.subplot(2, 2, 4)
        plt.plot(self.metrics["epochs"], self.metrics["f1_macro"])
        plt.xlabel("Epoch")
        plt.ylabel("F1 Score")
        plt.title("F1 Score (Macro)")
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.plot_dir, "metrics_plot.png"))
        plt.close()
        
        # Learning rate plot
        if len(self.metrics["steps"]) > 0:
            plt.figure(figsize=(10, 6))
            plt.plot(self.metrics["steps"], self.metrics["learning_rate"])
            plt.xlabel("Step")
            plt.ylabel("Learning Rate")
            plt.title("Learning Rate Schedule")
            plt.grid(True)
            plt.savefig(os.path.join(self.plot_dir, "lr_plot.png"))
            plt.close()
    
    def save_model_checkpoint(self, model, epoch, step, optimizer=None, scheduler=None, is_best=False):
        """Save model checkpoint"""
        checkpoint_name = f"checkpoint-{epoch}"
        checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
        os.makedirs(checkpoint_path, exist_ok=True)
        
        # Save model state dict
        model_path = os.path.join(checkpoint_path, "model.pth")
        torch.save(model.state_dict(), model_path)
        
        # Save optimizer and scheduler if provided
        if optimizer:
            optimizer_path = os.path.join(checkpoint_path, "optimizer.pth")
            torch.save(optimizer.state_dict(), optimizer_path)
        
        if scheduler:
            scheduler_path = os.path.join(checkpoint_path, "scheduler.pth")
            torch.save(scheduler.state_dict(), scheduler_path)
        
        # Save model config
        if hasattr(model, 'config'):
            config_path = os.path.join(checkpoint_path, "config.json")
            with open(config_path, 'w') as f:
                json.dump(model.config.to_dict(), f, indent=2)
        
        # Create a symbolic link or copy for the best model
        if is_best:
            best_path = os.path.join(self.base_dir, "best_model")
            # If we can use symlinks
            try:
                if os.path.exists(best_path):
                    if os.path.islink(best_path):
                        os.unlink(best_path)
                    else:
                        os.rmdir(best_path)
                os.symlink(checkpoint_path, best_path)
                logging.info(f"Created symbolic link to best model at {best_path}")
            except (OSError, NotImplementedError):
                # Fallback: copy the model file
                best_model_path = os.path.join(self.base_dir, "best_model.pth")
                torch.save(model.state_dict(), best_model_path)
                logging.info(f"Saved copy of best model to {best_model_path}")
            
        logging.info(f"Saved model checkpoint to {checkpoint_path}")
        return checkpoint_path
    
    def finish_training(self, total_time):
        """Log end of training and final metrics"""
        self.metrics["training_time"] = total_time
        logging.info(f"Training completed in {total_time:.2f} seconds")
        logging.info(f"Best model: {self.metrics['best_metrics'].get('model_checkpoint', 'None')}")
        
        # Save final metrics
        self.save_metrics()
        
        # Generate final plots
        self.generate_plots()
        
        # Save final trainer state
        self.save_trainer_state(
            epoch=self.metrics["epochs"][-1] if self.metrics["epochs"] else 0,
            step=self.metrics["steps"][-1] if self.metrics["steps"] else 0
        )

logging.info("TrainingLogger class defined.")

2025-05-04 15:10:22,547 - INFO - TrainingLogger class defined.


# 3. Dataset Class Definition + Data Collator

In [5]:
# # Cell 3: Dataset Class Definition (Raw Audio Version) This cell uses the regex parser to parse the multi_hot_label string back into a list of integers.



# # Define(recollect)the regex parser from preprocess.py if needed, 
# # otherwise use ast.literal_eval--- 
# # NOTE: After changing usage.ipynb 05/03/2025, should fall back to ast.literal_eval now. Clean code later

# import re

# def parse_numpy_array_string(array_str):
#     """
#     Parse strings like '[np.float32(1.0), np.float32(0.0), ...]' into a list of integers.
#     This is needed because ast.literal_eval cannot handle 'np.float32()' in the string.
#     """
#     if not isinstance(array_str, str):
#         return []
    
#     try:
#         # Extract all the float values using regular expressions
#         float_matches = re.findall(r'np\.float32\((\d+\.\d+)\)', array_str)
        
#         # Convert matches to integers (1.0 -> 1, 0.0 -> 0)
#         values = []
#         for match in float_matches:
#             value = float(match)
#             # Convert to integer if it's 0.0 or 1.0
#             if value == 1.0:
#                 values.append(1)
#             elif value == 0.0:
#                 values.append(0)
#             else:
#                 values.append(value)  # Keep as float if not 0 or 1
                
#         return values
#     except Exception as e:
#         logging.warning(f"Error parsing array string: {e}")
#         return []

# class FMARawAudioDataset(Dataset):
#     """
#     Loads raw audio waveforms and labels from manifest, uses Hugging Face
#     feature extractor (like ASTFeatureExtractor or Wav2Vec2Processor) on the fly.
#     """
#     def __init__(self, manifest_path, feature_extractor):
#         """
#         Args:
#             manifest_path (str): Path to the final manifest CSV file.
#             feature_extractor: Initialized Hugging Face AutoFeatureExtractor or AutoProcessor.
#         """
#         logging.info(f"Initializing FMARawAudioDataset from: {manifest_path}")
#         if feature_extractor is None:
#              raise ValueError("FMARawAudioDataset requires a feature_extractor/processor instance.")

#         self.feature_extractor = feature_extractor
#         # Get target sampling rate directly from the extractor/processor
#         try:
#              # Works for Wav2Vec2Processor, ASTFeatureExtractor, etc.
#              self.target_sr = self.feature_extractor.sampling_rate
#              logging.info(f"Target sampling rate set from feature extractor: {self.target_sr} Hz")
#         except AttributeError:
#              logging.warning("Could not get sampling_rate from feature_extractor, using config.")
#              # Fallback to config if needed, but ensuring match is crucial
#              self.target_sr = config.PREPROCESSING_PARAMS['sample_rate']


#         logging.info(f"Loading manifest from: {manifest_path}")
#         try:
#             self.manifest = pd.read_csv(manifest_path)
#             # Ensure index is set if needed elsewhere, or use default range index
#             if 'track_id' in self.manifest.columns:
#                  self.manifest = self.manifest.set_index('track_id', drop=False)

#             # --- Parse the 'multi_hot_label' string back into a list ---
#             # Here: if we decide to use raw audio, we use regex parser; 
#             #       if we decide to use mel spectrogram, we use ast.literal_eval

#             # Choose the correct parser based on how labels were saved in the CSV
#             # If saved as '[1.0, 0.0,...]' use ast.literal_eval
#             # label_parser = ast.literal_eval
#             # If saved as '[np.float32(1.0)...]' uncomment and use regex parser
#             label_parser = parse_numpy_array_string

#             self.manifest['multi_hot_label'] = self.manifest['multi_hot_label'].apply(label_parser)
#             logging.info(f"Loaded and parsed manifest with {len(self.manifest)} entries.")
#             # Check the first parsed label
#             logging.info(f"Example parsed label (first entry): {self.manifest['multi_hot_label'].iloc[0]}")

#         except Exception as e:
#             logging.error(f"Error loading or parsing manifest {manifest_path}: {e}", exc_info=True)
#             raise

#     def __len__(self):
#         """Returns the total number of samples in the dataset."""
#         return len(self.manifest)

#     def __getitem__(self, idx):
#         """
#         Loads raw audio for index idx, processes it with the feature extractor,
#         and returns the processed inputs and labels.
#         """
#         if torch.is_tensor(idx): idx = idx.tolist() # Handle tensor indices

#         # Get the row data from the manifest
#         row = self.manifest.iloc[idx]
#         track_id = row.get('track_id', self.manifest.index[idx]) # Get track_id safely
#         label_vector = row['multi_hot_label'] # Already parsed list/array

#         # Construct absolute audio path if necessary
#         audio_path = row['audio_path']

#         #NOTE: originally, the mel-spectrogram's path is relative  but the raw audio's path is absolute, so we need to make sure the audio_path is absolute
#         # So we are check if the audio_path is absolute or relative in case we load the wrong data, if it's relative, we need to join it with the PROJECT_ROOT
#         if not os.path.isabs(audio_path):
#              # Assumes path in manifest is relative to PROJECT_ROOT
#              audio_path = os.path.join(config.PROJECT_ROOT, audio_path)

#         try:
#             # --- 1. Load RAW Audio Waveform ---
#             # Load full 30s clip at the TARGET sample rate required by the processor
#             waveform, loaded_sr = librosa.load(
#                 audio_path,
#                 sr=self.target_sr, # Use processor's sampling rate
#                 duration=30.0     # Load the full 30 seconds
#             )
#             # Ensure minimum length if needed (though duration should handle it)
#             min_samples = int(0.1 * self.target_sr) # Example: require at least 0.1s
#             if len(waveform) < min_samples:
#                  raise ValueError(f"Audio signal for track {track_id} too short after loading.")

#             # --- 2. Apply Feature Extractor ---
#             # Pass the raw waveform numpy array
#             # The extractor handles normalization, padding/truncation, tensor conversion
            
#             max_length = 5000

#             inputs = self.feature_extractor(
#                 waveform,
#                 sampling_rate=self.target_sr,
#                 return_tensors="pt",
#                 return_attention_mask=True # Request attention mask
#             )

#             # --- 3. Prepare Outputs ---
#             # Squeeze unnecessary batch dimension added by the extractor
#             # Key name ('input_values', 'input_features') depends on the specific extractor
#             feature_tensor = inputs.get('input_values', inputs.get('input_features'))
#             if feature_tensor is None:
#                 raise KeyError("Expected 'input_values' or 'input_features' key from feature_extractor output.")
#             feature_tensor = feature_tensor.squeeze(0) # Remove batch dim -> [Channels?, Freq?, Time] or [SeqLen, Dim]

#             attention_mask = inputs.get('attention_mask', None)
#             if attention_mask is not None:
#                  attention_mask = attention_mask.squeeze(0)

#             # Convert label list/array to float tensor for BCE loss
#             label_tensor = torch.tensor(label_vector, dtype=torch.float32)

#             # Return dictionary matching model's expected input names
#             model_input_dict = {"labels": label_tensor}
#             # Use the key the feature extractor provided
#             if 'input_values' in inputs:
#                  model_input_dict['input_values'] = feature_tensor
#             elif 'input_features' in inputs:
#                  model_input_dict['input_features'] = feature_tensor

#             if attention_mask is not None:
#                  model_input_dict['attention_mask'] = attention_mask

#             return model_input_dict

#         except FileNotFoundError:
#              logging.error(f"Audio file not found for track {track_id} at {audio_path}")
#              raise # Or implement skipping logic with collate_fn
#         except Exception as e:
#             logging.error(f"Error loading/processing track {track_id} at {audio_path}: {e}", exc_info=True)
#             raise # Or implement skipping logic


# print("FMARawAudioDataset class defined.")

In [6]:
# Cell 3: Dataset Class Definition (Raw Audio Version - This cell uses the ast.literal_eval parser to parse the multi_hot_label string back into a list of integers.)

import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import os
import ast # For parsing label string '[1.0, 0.0,...]'
import re  # Keep import for the commented out function below
import logging
import librosa
# Ensure config is imported from a previous cell or uncomment:
# import config

# --- Optional: Keep custom parser commented out for reference ---
# # Define(recollect)the regex parser from preprocess.py if needed,
# # otherwise use ast.literal_eval---
# # NOTE: After changing usage.ipynb 05/03/2025, should fall back to ast.literal_eval now. Clean code later
# def parse_numpy_array_string(array_str):
#     """
#     Parse strings like '[np.float32(1.0), np.float32(0.0), ...]' into a list of integers.
#     This is needed because ast.literal_eval cannot handle 'np.float32()' in the string.
#     """
#     if not isinstance(array_str, str): return []
#     try:
#         # Match digits, optionally followed by a decimal and more digits
#         float_matches = re.findall(r'np\.float32\(([\d\.]+)\)', array_str)
#         values = []
#         for match_str in float_matches:
#             value = float(match_str) # Convert string match to float
#             values.append(1.0 if value == 1.0 else 0.0) # Store as float 0.0 or 1.0
#         return values
#     except Exception as e:
#         logging.warning(f"Error parsing array string: {e}")
#         return []
# --- End commented out parser ---


class FMARawAudioDataset(Dataset):
    """
    Loads raw audio waveforms and labels from manifest, uses Hugging Face
    feature extractor (like ASTFeatureExtractor or Wav2Vec2Processor/AutoFeatureExtractor) on the fly.
    Assumes padding/truncation will be handled by a collate function.
    """
    def __init__(self, manifest_path, feature_extractor):
        """
        Args:
            manifest_path (str): Path to the final manifest CSV file (e.g., small_subset_multihot.csv).
            feature_extractor: Initialized Hugging Face AutoFeatureExtractor instance.
        """
        # Ensure num_labels is available globally or passed if needed for verification
        global num_labels
        if 'num_labels' not in globals():
             logging.error("Global variable 'num_labels' not found. Load it first (e.g., from Cell 2).")
             # Alternative: pass num_labels as an argument to __init__

        logging.info(f"Initializing FMARawAudioDataset from: {manifest_path}")
        if feature_extractor is None:
             raise ValueError("FMARawAudioDataset requires a feature_extractor instance.")

        self.feature_extractor = feature_extractor
        try:
             self.target_sr = self.feature_extractor.sampling_rate
             logging.info(f"Target sampling rate set from feature extractor: {self.target_sr} Hz")
        except AttributeError:
             logging.error("Could not get sampling_rate from feature_extractor.", exc_info=True)
             raise

        logging.info(f"Loading manifest from: {manifest_path}")
        try:
            self.manifest = pd.read_csv(manifest_path)
            # Set index to track_id AFTER loading, keep column too if needed elsewhere
            if 'track_id' in self.manifest.columns:
                 self.manifest = self.manifest.set_index('track_id', drop=False) # Keep column if row.get('track_id'...) is used
            else:
                 logging.warning("Manifest CSV does not contain 'track_id' column. Using DataFrame index.")
                 # Make sure index IS the track_id
                 if not pd.api.types.is_integer_dtype(self.manifest.index):
                      logging.warning("Manifest index is not integer type. Ensure it matches track IDs.")


            # --- Parse the 'multi_hot_label' string back into a list ---
            # NOTE: After changing usage.ipynb 05/03/2025, should fall back to ast.literal_eval now. Clean code later
            # Use ast.literal_eval assuming labels were saved as standard list strings '[1.0, 0.0,...]'
            logging.info("Attempting to parse 'multi_hot_label' column using ast.literal_eval...")
            label_parser = ast.literal_eval # <<<--- Using ast.literal_eval
            # label_parser = parse_numpy_array_string # Keep commented out as requested

            label_col_name = 'multi_hot_label'
            if label_col_name not in self.manifest.columns:
                 raise KeyError(f"Column '{label_col_name}' not found in manifest CSV at {manifest_path}")

            self.manifest[label_col_name] = self.manifest[label_col_name].apply(label_parser)

            # --- Verification step ---
            first_label = self.manifest[label_col_name].iloc[0] # Use iloc[0] here to get FIRST row for checking
            if not isinstance(first_label, list):
                 raise TypeError(f"Parsed label is not a list, check parser/CSV format. Got type: {type(first_label)}")
            # Check length against num_labels loaded in Cell 2
            if len(first_label) != num_labels:
                 logging.error(f"FATAL: Parsed label length ({len(first_label)}) does not match expected num_labels ({num_labels}). Check parsing or unified_genres.txt.")
                 raise ValueError("Parsed label length mismatch.")
            logging.info(f"Example parsed label verified (type {type(first_label)}, length {len(first_label)}): {str(first_label)[:100]}...")
            # --- End Verification ---

            logging.info(f"Loaded and parsed manifest with {len(self.manifest)} entries.")

        except FileNotFoundError:
             logging.error(f"Manifest file not found: {manifest_path}", exc_info=True)
             raise
        except Exception as e:
            logging.error(f"Error loading or parsing manifest {manifest_path}: {e}", exc_info=True)
            raise

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.manifest)

    def __getitem__(self, idx):
        """
        Loads raw audio for index 'idx' (which is the track_id/index label),
        processes it with the feature extractor,
        and returns the processed inputs and labels.
        """
        if torch.is_tensor(idx): idx = idx.tolist() # Handle tensor indices

        # --- Use idx directly as track_id BEFORE main try block ---
        track_id = idx
        # ---------------------------------------------------------

        try:
            # --- Get the row data using .loc with the track_id ---
            row = self.manifest.loc[track_id] # Use .loc with the index label (track_id)
            # ------------------------------------------------------

            # --- Get required data from the row ---
            multi_hot_label = row['multi_hot_label']
            audio_path = row['audio_path']
            # ---------------------------------------

            # Construct absolute audio path if necessary (keep your NOTE)
            # NOTE: originally, the mel-spectrogram's path is relative  but the raw audio's path is absolute, so we need to make sure the audio_path is absolute
            # So we are check if the audio_path is absolute or relative in case we load the wrong data, if it's relative, we need to join it with the PROJECT_ROOT
            if not os.path.isabs(audio_path):
                audio_path = os.path.join(config.PROJECT_ROOT, audio_path)

            # --- 1. Load RAW Audio Waveform ---
            waveform, loaded_sr = librosa.load(
                audio_path,
                sr=self.target_sr, # Use extractor's sampling rate
                duration=30.0      # Load the full 30 seconds
            )
            min_samples = int(0.1 * self.target_sr)
            if len(waveform) < min_samples:
                 logging.warning(f"Audio signal for track {track_id} too short, returning None.")
                 return None # Requires collate_fn to handle None

            # --- 2. Apply Feature Extractor ---
            # Let the Data Collator handle padding/truncation later
            inputs = self.feature_extractor(
                waveform,
                sampling_rate=self.target_sr,
                return_tensors="pt",
                # REMOVED padding/truncation/max_length args
                return_attention_mask=True # Keep requesting mask
            )

            # --- 3. Prepare Outputs ---
            feature_tensor = inputs.get('input_values', inputs.get('input_features'))
            if feature_tensor is None:
                raise KeyError(f"Expected 'input_values' or 'input_features' key from feature_extractor output. Got keys: {inputs.keys()}")
            feature_tensor = feature_tensor.squeeze(0)

            attention_mask = inputs.get('attention_mask', None)
            if attention_mask is not None:
                attention_mask = attention_mask.squeeze(0)

            # Convert label list to float tensor
            label_tensor = torch.tensor(multi_hot_label, dtype=torch.float32)

            # Return dictionary
            model_input_dict = {"labels": label_tensor}
            input_key = 'input_values' if 'input_values' in inputs else 'input_features'
            model_input_dict[input_key] = feature_tensor
            if attention_mask is not None:
                model_input_dict['attention_mask'] = attention_mask

            return model_input_dict

        except KeyError:
             # This might catch if track_id wasn't found by .loc (handled above),
             # or if column names like 'multi_hot_label', 'audio_path' are wrong in CSV
             logging.error(f"KeyError accessing data for track {track_id}. Check manifest columns.", exc_info=True)
             return None
        except FileNotFoundError:
             logging.error(f"Audio file not found for track {track_id} at {audio_path}")
             return None
        except Exception as e:
            # Use the track_id obtained safely before the try block
            logging.error(f"Error loading/processing track {track_id}: {e}", exc_info=True)
            return None # Return None on generic error

print("FMARawAudioDataset class defined (using raw audio, feature extractor, ast.literal_eval for labels, .loc access).")

FMARawAudioDataset class defined (using raw audio, feature extractor, ast.literal_eval for labels, .loc access).


In [7]:
print(model_checkpoint)

facebook/w2v-bert-2.0


In [8]:
# Cell 3.5: Define Data Collator for Padding (Handles None values)

import torch
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import logging # Add logging

@dataclass
class DataCollatorAudio:
    """
    Data collator that dynamically pads the inputs received Feature Extractor.
    Handles None values returned by the Dataset on error.
    """
    padding_value: float = 0.0 # Standard padding for features/audio

    def __call__(self, features: List[Optional[Dict[str, Union[List[int], torch.Tensor]]]]) -> Dict[str, torch.Tensor]:
        # features is a list of dicts OR None values from __getitem__

        # --- Filter out None entries ---
        valid_features = [f for f in features if f is not None]
        if not valid_features:
             # If all samples in the batch failed, return an empty dictionary
             # The training loop should ideally handle this (e.g., skip batch)
             logging.warning("Collate function received empty batch after filtering Nones.")
             return {}
        # -----------------------------

        # --- Determine keys and pad based on valid features ---
        input_key = 'input_values' if 'input_values' in valid_features[0] else 'input_features'
        input_features = [d[input_key] for d in valid_features]

        # Determine sequence length dimension based on the FIRST valid tensor
        seq_len_dim = -1
        if len(input_features[0].shape) == 2:
            seq_len_dim = 0 if input_features[0].shape[0] > input_features[0].shape[1] else -1
        elif len(input_features[0].shape) == 1:
             seq_len_dim = 0
        else:
             logging.warning(f"Unexpected tensor shape {input_features[0].shape}, assuming seq len is last dim.")

        max_len = max(feat.shape[seq_len_dim] for feat in input_features)

        # Pad each feature tensor to max_len
        padded_features = []
        for feat in input_features:
            pad_width = max_len - feat.shape[seq_len_dim]
            if seq_len_dim == 0 and len(feat.shape)==2: padding = (0, 0, 0, pad_width) # Pad SeqLen dim (dim 0)
            else: padding = (0, pad_width) # Pad last dim (SeqLen)

            padded_feat = torch.nn.functional.pad(feat, padding, mode='constant', value=self.padding_value)
            padded_features.append(padded_feat)

        # Stack the padded features
        batch_input_features = torch.stack(padded_features)
        batch = {input_key: batch_input_features} # Use the correct key

        # Pad 'attention_mask' if present
        if "attention_mask" in valid_features[0] and valid_features[0]["attention_mask"] is not None:
            attention_masks = [d["attention_mask"] for d in valid_features]
            # Assuming mask is 1D [SeqLen] or 2D [1, SeqLen] etc. - pad last dim
            max_mask_len = max(m.shape[-1] for m in attention_masks)
            padded_masks = []
            for mask in attention_masks:
                 pad_width = max_mask_len - mask.shape[-1]
                 padded_mask = torch.nn.functional.pad(mask, (0, pad_width), mode='constant', value=0)
                 padded_masks.append(padded_mask)
            batch["attention_mask"] = torch.stack(padded_masks)

        # Stack Labels
        labels = [d["labels"] for d in valid_features]
        batch["labels"] = torch.stack(labels)

        return batch

# Create an instance of the collator (do this in Cell 4)
# data_collator = DataCollatorAudio()
# print("DataCollatorAudio defined.")

# 4: Load Feature Extractor, Create DataLoaders with Custom Collator

In [9]:
# Cell 4: Load Feature Extractor, Create DataLoaders with Custom Collator

from transformers import AutoFeatureExtractor # Use the correct class

# Ensure FMARawAudioDataset and DataCollatorAudio are defined in previous cells

# --- Load Feature Extractor ---
# (Using model_checkpoint defined in Cell 2)
logging.info(f"Loading feature extractor for: {model_checkpoint}")
try:
    # Load the feature extractor associated with Wav2Vec2-BERT
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
    logging.info("Feature extractor loaded successfully.")
    # Log the expected sample rate
    processor_sr = feature_extractor.sampling_rate
    print(f"Feature extractor expects sample rate: {processor_sr}")
    # Ensure config matches extractor's expected rate
    if config.PREPROCESSING_PARAMS['sample_rate'] != processor_sr:
         logging.warning(f"Config sample rate ({config.PREPROCESSING_PARAMS['sample_rate']}) differs from feature extractor ({processor_sr}). Ensure audio loading uses {processor_sr} Hz.")
         # Update config value if necessary, or ensure Dataset uses processor_sr
         # config.PREPROCESSING_PARAMS['sample_rate'] = processor_sr # Be careful modifying config dynamically

except Exception as e:
    logging.error(f"Could not load feature extractor for {model_checkpoint}. Cannot proceed. Error: {e}", exc_info=True)
    raise SystemExit # Stop execution if extractor fails

# --- Create Full Dataset ---
# Ensure FMARawAudioDataset __init__ accepts feature_extractor
try:
    # Pass the loaded feature_extractor instance
    full_dataset = FMARawAudioDataset(manifest_path, feature_extractor=feature_extractor)
    manifest_df = full_dataset.manifest
except Exception as e:
     logging.error("Failed to instantiate FMARawAudioDataset.", exc_info=True)
     raise SystemExit

# --- Create SMALLER DEBUG Datasets ---
logging.info("Creating DEBUG DataLoaders with small subsets and custom collator...")
try:
    # Get indices for the splits from the manifest
    train_indices = manifest_df[manifest_df['split'] == 'training'].index[:16].tolist() # Small subset for debug
    val_indices = manifest_df[manifest_df['split'] == 'validation'].index[:8].tolist()  # Small subset for debug

    # Create Subset instances
    debug_train_dataset = Subset(full_dataset, train_indices)
    debug_val_dataset = Subset(full_dataset, val_indices)

    # --- Create Data Collator Instance ---
    # (Assumes DataCollatorAudio class is defined in Cell 3.5)
    data_collator = DataCollatorAudio()
    print("DataCollatorAudio instance created.")

    # --- Create DataLoaders using the custom collate_fn ---
    debug_train_dataloader = DataLoader(
        debug_train_dataset,
        batch_size=batch_size, # Use small batch_size from config
        shuffle=True,
        collate_fn=data_collator # Apply custom padding at batch level
        # num_workers=4, # Optional: Add workers later for performance
        # pin_memory=True # Optional: Add if using GPU
    )
    debug_val_dataloader = DataLoader(
        debug_val_dataset,
        batch_size=batch_size, # Use small batch_size from config
        shuffle=False, # No need to shuffle validation data
        collate_fn=data_collator # Apply custom padding at batch level
        # num_workers=4,
        # pin_memory=True
    )
    logging.info(f"DEBUG Dataset sizes: Train={len(debug_train_dataset)}, Val={len(debug_val_dataset)}")
    logging.info("DEBUG DataLoaders with custom collator created.")
except Exception as e:
    logging.error(f"Failed to create DEBUG datasets/dataloaders: {e}", exc_info=True)
    raise SystemExit

2025-05-04 15:10:22,648 - INFO - Loading feature extractor for: facebook/w2v-bert-2.0
2025-05-04 15:10:22,779 - INFO - Feature extractor loaded successfully.
Feature extractor expects sample rate: 16000
2025-05-04 15:10:22,781 - INFO - Initializing FMARawAudioDataset from: /workspace/musicClaGen/data/processed/small_subset_multihot.csv
2025-05-04 15:10:22,783 - INFO - Target sampling rate set from feature extractor: 16000 Hz
2025-05-04 15:10:22,784 - INFO - Loading manifest from: /workspace/musicClaGen/data/processed/small_subset_multihot.csv
2025-05-04 15:10:22,821 - INFO - Attempting to parse 'multi_hot_label' column using ast.literal_eval...
2025-05-04 15:10:23,141 - INFO - Example parsed label verified (type <class 'list'>, length 22): [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...
2025-05-04 15:10:23,142 - INFO - Loaded and parsed manifest with 8000 entries.
2025-05-04 15:10:23,143 - INFO - Creating DEBUG DataLoaders with sm

# 5: Load Wav2Vec2-BERT Model and Modify Head

In [10]:
# Cell 5: Load Wav2Vec2-BERT Model and Modify Head

import torch.nn as nn # Ensure nn is imported
from transformers import AutoModelForAudioClassification

logging.info(f"Loading pre-trained Wav2Vec2-BERT model: {model_checkpoint}")
try:
    # Load the model configured for audio classification
    model = AutoModelForAudioClassification.from_pretrained(
        model_checkpoint,
        num_labels=num_labels,
        ignore_mismatched_sizes=True # Essential for replacing the head
    )
    logging.info("Model loaded initially.")

    # --- Explicit Head Replacement (Recommended) ---
    # Though I have defined num_labels = num_labels on previous step, I want to explicitly replace it again to ensure the head is correct.
    # If the above code is correct, the explicitly approach below might seem redundant but.
    
    # I MUST verify the correct attribute name for the classifier head for Wav2Vec2-BERT. 
    # Common names include 'classifier', 'projector','classification_head'. Use print(model) after loading to check.
    classifier_attr = 'classifier' # <<<--- VERIFY THIS ATTRIBUTE NAME ---<<<

    if hasattr(model, classifier_attr):
        original_classifier = getattr(model, classifier_attr)
        logging.info(f"Found classifier attribute '{classifier_attr}' of type {type(original_classifier)}")

        # Check if it's a simple Linear layer or potentially a sequence/projection
        if isinstance(original_classifier, nn.Linear):
            in_features = original_classifier.in_features
            logging.info(f"Replacing classifier head '{classifier_attr}'. Original out: {original_classifier.out_features}, New out: {num_labels}")
            setattr(model, classifier_attr, nn.Linear(in_features, num_labels))
            print(f"Successfully replaced classifier head '{classifier_attr}'.")
        # Add checks here if Wav2Vec2-BERT uses a different common head structure
        # elif isinstance(original_classifier, nn.Sequential): ... etc.
        else:
             logging.warning(f"Classifier head '{classifier_attr}' is not nn.Linear ({type(original_classifier)}). Attempting replacement might fail or need adjustment.")
             # If you know the structure (e.g., model.projector + model.classifier), adjust accordingly.
             # For now, we assume a direct replacement might work or the implicit loading handled it.

    else:
         logging.warning(f"Could not automatically find classifier attribute '{classifier_attr}'. Ensure head size ({num_labels}) was correctly set via 'num_labels' argument during loading or modify manually.")

    model.to(device)
    logging.info("Wav2Vec2-BERT Model loaded and moved to device.")
    # print(model) # Uncomment this line and run to inspect the model structure and find the classifier name

except Exception as e:
    logging.error(f"Failed to load model '{model_checkpoint}': {e}", exc_info=True)
    raise SystemExit # Stop if model loading fails

2025-05-04 15:10:23,168 - INFO - Loading pre-trained Wav2Vec2-BERT model: facebook/w2v-bert-2.0


Some weights of Wav2Vec2BertForSequenceClassification were not initialized from the model checkpoint at facebook/w2v-bert-2.0 and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


2025-05-04 15:10:25,708 - INFO - Model loaded initially.
2025-05-04 15:10:25,709 - INFO - Found classifier attribute 'classifier' of type <class 'torch.nn.modules.linear.Linear'>
2025-05-04 15:10:25,711 - INFO - Replacing classifier head 'classifier'. Original out: 22, New out: 22
Successfully replaced classifier head 'classifier'.
2025-05-04 15:10:26,725 - INFO - Wav2Vec2-BERT Model loaded and moved to device.


In [11]:
# verify the correct attribute name for the classifier head for Wav2Vec2-BERT.
# print(model)


In [12]:
# Cell: Save Model Architecture Details

# Create a function to capture model architecture details
def save_model_architecture(model, base_dir):
    """Save model architecture details"""
    architecture_info = {
        "model_type": model.__class__.__name__,
        "parameter_count": sum(p.numel() for p in model.parameters()),
        "trainable_parameter_count": sum(p.numel() for p in model.parameters() if p.requires_grad)
    }
    
    # Add config parameters if available
    if hasattr(model, 'config'):
        if hasattr(model.config, 'to_dict'):
            architecture_info["config"] = model.config.to_dict()
        else:
            # Try to convert config to dict
            try:
                architecture_info["config"] = vars(model.config)
            except:
                pass
    
    # Save to file
    os.makedirs(base_dir, exist_ok=True)
    with open(os.path.join(base_dir, "model_architecture.json"), 'w') as f:
        json.dump(architecture_info, f, indent=2)
    
    logging.info(f"Saved model architecture details to {os.path.join(base_dir, 'model_architecture.json')}")
    
    return architecture_info

# Save architecture info - safely handle whether logger exists or not
model_info_dir = os.path.join(model_save_dir, "model_info")
model_arch = save_model_architecture(model, model_info_dir)
logging.info(f"Model architecture saved to {model_info_dir}")

2025-05-04 15:10:26,763 - INFO - Saved model architecture details to /workspace/musicClaGen/models/model_info/model_architecture.json
2025-05-04 15:10:26,765 - INFO - Model architecture saved to /workspace/musicClaGen/models/model_info


# 6: Define Optimizer, Loss Function, and Metrics Calculation

In [13]:
# Cell 6: Define Optimizer, Loss Function, and Metrics Calculation

import torch.optim as optim
from sklearn.metrics import hamming_loss, jaccard_score, f1_score # Make sure these are imported

# --- Optimizer ---
optimizer = optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)
logging.info(f"Optimizer AdamW defined with LR={learning_rate}, Weight Decay={weight_decay}")

# --- Loss Function ---
# Use BCEWithLogitsLoss for multi-label classification (includes Sigmoid)
criterion = nn.BCEWithLogitsLoss().to(device)
logging.info("Loss function BCEWithLogitsLoss defined.")

# --- Metrics Function ---
def compute_metrics(eval_preds):
    """Calculates multi-label metrics from logits and labels."""
    logits, labels = eval_preds
    # Ensure inputs are numpy arrays on CPU
    logits_np = logits.detach().cpu().numpy() if isinstance(logits, torch.Tensor) else logits
    labels_np = labels.detach().cpu().numpy() if isinstance(labels, torch.Tensor) else labels

    # Apply sigmoid and threshold
    probs = 1 / (1 + np.exp(-logits_np)) # Manual sigmoid
    preds = (probs > 0.5).astype(int)
    labels_np = labels_np.astype(int) # Ensure labels are integers

    if labels_np.shape != preds.shape:
         logging.error(f"Shape mismatch in compute_metrics! Labels: {labels_np.shape}, Preds: {preds.shape}")
         # Return default metrics indicating failure
         return {'hamming_loss': 1.0, 'jaccard_samples': 0.0, 'f1_micro': 0.0, 'f1_macro': 0.0}

    metrics = {}
    try:
        metrics['hamming_loss'] = hamming_loss(labels_np, preds)
        # Use average='samples' for Jaccard in multi-label scenario
        metrics['jaccard_samples'] = jaccard_score(labels_np, preds, average='samples', zero_division=0)
        metrics['f1_micro'] = f1_score(labels_np, preds, average='micro', zero_division=0)
        metrics['f1_macro'] = f1_score(labels_np, preds, average='macro', zero_division=0)
        # Optional: Add Accuracy (subset accuracy)
        # metrics['accuracy'] = accuracy_score(labels_np, preds) # This is exact match accuracy
    except Exception as e:
         logging.error(f"Error calculating metrics: {e}")
         metrics = {'hamming_loss': 1.0, 'jaccard_samples': 0.0, 'f1_micro': 0.0, 'f1_macro': 0.0}

    # Log inside the main evaluate function now for better context
    # logging.info(f"Metrics: Hamming={metrics['hamming_loss']:.4f}, Jaccard(samples)={metrics['jaccard_samples']:.4f}, F1 Micro={metrics['f1_micro']:.4f}, F1 Macro={metrics['f1_macro']:.4f}")
    return metrics

print("Optimizer, Loss, and compute_metrics function defined.")

2025-05-04 15:10:26,785 - INFO - Optimizer AdamW defined with LR=5e-05, Weight Decay=0.01
2025-05-04 15:10:26,789 - INFO - Loss function BCEWithLogitsLoss defined.
Optimizer, Loss, and compute_metrics function defined.


# 7: Define Training Function for One Epoch

In [14]:
# Cell 7: Define Training Function for One Epoch (with AMP and Scheduler)

from torch.cuda.amp import autocast, GradScaler # Or from torch.amp import ...

# Ensure compute_metrics, torch, logging, tqdm etc. are imported

def train_epoch(model, dataloader, criterion, optimizer, device, gradient_accumulation_steps, scaler, scheduler=None): # <<< Added scheduler=None
    model.train()
    total_loss = 0
    num_samples = 0
    successful_steps = 0 # Counter for successful steps
    optimizer.zero_grad()

    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    num_batches = len(dataloader) # Get total batches for scheduler check

    for step, batch in enumerate(progress_bar):
        if batch is None or not batch: continue

        try:
            expected_model_input_key = "input_features" # VERIFY THIS KEY NAME
            input_data_key = 'input_values' if 'input_values' in batch else 'input_features'
            model_inputs = {expected_model_input_key: batch[input_data_key].to(device)}
            if 'attention_mask' in batch and batch['attention_mask'] is not None: model_inputs['attention_mask'] = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            with autocast(device_type=device.type, enabled=(device.type=='cuda')): # Correct autocast usage
                outputs = model(**model_inputs)
                logits = outputs.logits
                loss = criterion(logits, labels)

            if torch.isnan(loss):
                logging.warning(f"NaN loss detected at step {step}. Skipping batch.")
                if (step + 1) % gradient_accumulation_steps != 0: optimizer.zero_grad()
                continue

            scaled_loss = loss / gradient_accumulation_steps
            scaler.scale(scaled_loss).backward()

            batch_size_actual = labels.size(0)
            total_loss += loss.item() * batch_size_actual
            num_samples += batch_size_actual

            if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == num_batches:
                scaler.step(optimizer)
                scaler.update()
                # --- Step the scheduler AFTER the optimizer step ---
                if scheduler:
                    scheduler.step() # <<<--- ADDED SCHEDULER STEP HERE
                # -------------------------------------------------
                optimizer.zero_grad()
                successful_steps +=1 # Count successful optimizer steps


            progress_bar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'}) # Optionally show LR

        except Exception as e:
             logging.error(f"Error during training step {step}: {e}", exc_info=True)
             optimizer.zero_grad() # Zero grad on error too
             continue

    # Final optimizer step might not be needed if scheduler steps correctly, depends on exact logic.
    # Let's remove the extra step outside the loop for now.

    avg_loss = total_loss / num_samples if num_samples > 0 else 0
    print(f"\nCompleted training epoch. Successful optimizer steps: {successful_steps}")
    print(f"Average Training Loss for Epoch: {avg_loss:.4f}")
    return avg_loss

print("train_epoch function updated to accept scheduler.")

train_epoch function updated to accept scheduler.


# 8. Define Evaluation Function

In [15]:
# # Cell 8: Define Evaluation Function (Corrected Model Input)

# def evaluate(model, dataloader, criterion, device):
#     model.eval()
#     total_loss = 0
#     all_logits = []
#     all_labels = []
#     num_samples = 0

#     with torch.no_grad():
#         for step, batch in enumerate(tqdm(dataloader, desc="Evaluating", leave=False)):
#             try:
#                 # --- CORRECTED INPUT PREPARATION ---
#                 expected_model_input_key = "input_features" # <<<--- VERIFY THIS KEY NAME

#                 if 'input_values' not in batch:
#                      raise KeyError("Batch dictionary missing 'input_values' from Dataset/Extractor.")

#                 model_inputs = {
#                     expected_model_input_key: batch['input_values'].to(device)
#                 }
#                 if 'attention_mask' in batch and batch['attention_mask'] is not None:
#                      model_inputs['attention_mask'] = batch['attention_mask'].to(device)
#                 # --- END CORRECTION ---

#                 labels = batch['labels'].to(device)

#                 # Forward pass
#                 outputs = model(**model_inputs) # Pass the correctly named arguments
#                 logits = outputs.logits

#                 # Calculate loss
#                 loss = criterion(logits, labels)
#                 total_loss += loss.item() * labels.size(0)
#                 num_samples += labels.size(0)

#                 all_logits.append(logits.cpu())
#                 all_labels.append(labels.cpu())
#             except Exception as e:
#                  logging.error(f"Error during evaluation step {step}, batch keys: {batch.keys()}. Error: {e}", exc_info=True)
#                  continue # Skip batch

#     if not all_logits or not all_labels or num_samples == 0:
#         logging.warning("Evaluation yielded no results (all batches failed or empty dataloader?).")
#         return {}

#     avg_loss = total_loss / num_samples

#     all_logits_cat = torch.cat(all_logits, dim=0)
#     all_labels_cat = torch.cat(all_labels, dim=0)

#     eval_preds = (all_logits_cat, all_labels_cat)
#     metrics = compute_metrics(eval_preds)
#     metrics['eval_loss'] = avg_loss

#     print(f"\nValidation Loss: {avg_loss:.4f}")
#     for name, value in metrics.items():
#          if name != 'eval_loss': print(f"  Validation {name.replace('_', ' ').title()}: {value:.4f}")

#     return metrics

# print("evaluate function updated.")

In [16]:
# # Cell 8: Define Evaluation Function (with AMP)

# # Ensure compute_metrics function is defined in a previous cell
# # Ensure torch, logging, tqdm, np are imported

# def evaluate(model, dataloader, criterion, device):
#     model.eval() # Set model to evaluation mode
#     total_loss = 0
#     all_logits = []
#     all_labels = []
#     num_samples = 0

#     with torch.no_grad(): # Disable gradient calculations
#         for step, batch in enumerate(tqdm(dataloader, desc="Evaluating", leave=False)):
#             if batch is None or not batch: continue
#             try:
#                 # Prepare inputs
#                 expected_model_input_key = "input_features" # VERIFY THIS KEY NAME
#                 input_data_key = 'input_values' if 'input_values' in batch else 'input_features'

#                 model_inputs = {}
#                 if input_data_key in batch:
#                     model_inputs[expected_model_input_key] = batch[input_data_key].to(device)
#                 else:
#                     raise KeyError(f"Required input key not found in batch during evaluation.")

#                 if 'attention_mask' in batch and batch['attention_mask'] is not None:
#                      model_inputs['attention_mask'] = batch['attention_mask'].to(device)

#                 labels = batch['labels'].to(device)

#                 # --- Use autocast for forward pass during evaluation ---
#                 # Although not strictly needed for memory unless inputs are huge,
#                 # it ensures consistency with training pass calculations.
#                 with autocast(device_type=device.type):
#                     outputs = model(**model_inputs)
#                     logits = outputs.logits
#                     loss = criterion(logits, labels)
#                 # ----------------------------------------------------

#                 total_loss += loss.item() * labels.size(0)
#                 num_samples += labels.size(0)

#                 all_logits.append(logits.cpu()) # Store logits on CPU
#                 all_labels.append(labels.cpu()) # Store labels on CPU
#             except Exception as e:
#                  logging.error(f"Error during evaluation step {step}: {e}", exc_info=True)
#                  continue # Skip batch on error

#     if not all_logits or not all_labels or num_samples == 0:
#         logging.warning("Evaluation yielded no results.")
#         return {}

#     # Calculate average loss over processed samples
#     avg_loss = total_loss / num_samples

#     # Concatenate results from all batches
#     all_logits_cat = torch.cat(all_logits, dim=0)
#     all_labels_cat = torch.cat(all_labels, dim=0)

#     # Calculate metrics using the helper function
#     eval_preds = (all_logits_cat, all_labels_cat) # Pass tensors directly
#     metrics = compute_metrics(eval_preds)
#     metrics['eval_loss'] = avg_loss

#     # Log metrics
#     print(f"\nValidation Loss: {avg_loss:.4f}")
#     for name, value in metrics.items():
#          if name != 'eval_loss': print(f"  Validation {name.replace('_', ' ').title()}: {value:.4f}")

#     return metrics # Return dictionary of all metrics

# print("evaluate function defined with AMP (autocast only).")

In [20]:
# Cell 8: Enhanced Evaluation Function with Threshold Testing
from torch.cuda.amp import autocast  # Add this import for mixed precision

def evaluate_with_thresholds(model, dataloader, criterion, device, thresholds=[0.05, 0.1, 0.2, 0.3, 0.4, 0.5]):
    """Evaluate model with multiple thresholds to find the optimal one"""
    model.eval()
    total_loss = 0
    all_logits = []
    all_labels = []
    num_samples = 0

    with torch.no_grad():  # Disable gradient calculations
        for step, batch in enumerate(tqdm(dataloader, desc="Evaluating", leave=False)):
            if batch is None or not batch: continue
            try:
                # Prepare inputs - handle the parameter name mapping
                input_data = batch.get('input_values', batch.get('input_features')).to(device)
                attention_mask = batch.get('attention_mask', None)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device)
                labels = batch.get('labels').to(device)
                
                # Create model input dict with the correct parameter name
                model_inputs = {
                    # Use input_features instead of input_values for Wav2Vec2-BERT
                    'input_features': input_data,
                    'attention_mask': attention_mask if attention_mask is not None else None
                }
                
                # Forward pass with autocast
                with autocast(device_type=device.type):
                    outputs = model(**model_inputs)
                    logits = outputs.logits
                    loss = criterion(logits, labels)
                
                total_loss += loss.item() * labels.size(0)
                num_samples += labels.size(0)
                
                # Store logits and labels for metric calculation
                all_logits.append(logits.cpu())
                all_labels.append(labels.cpu())
                
            except Exception as e:
                logging.error(f"Error during evaluation step {step}: {e}", exc_info=True)
                continue
    
    if not all_logits or not all_labels or num_samples == 0:
        logging.warning("Evaluation yielded no results.")
        return {}
    
    # Calculate average loss
    avg_loss = total_loss / num_samples
    
    # Concatenate results from all batches
    all_logits_cat = torch.cat(all_logits, dim=0)
    all_labels_cat = torch.cat(all_labels, dim=0)
    
    # Convert to numpy
    logits_np = all_logits_cat.numpy()
    labels_np = all_labels_cat.numpy()
    
    # Apply sigmoid to get probabilities
    probs = 1 / (1 + np.exp(-logits_np))  # Manual sigmoid, same as in your compute_metrics
    
    # Add diagnostic information about predictions
    print(f"Prediction stats - Min: {probs.min():.4f}, Max: {probs.max():.4f}, Mean: {probs.mean():.4f}")
    print(f"Prediction histogram: {np.histogram(probs.flatten(), bins=10, range=(0,1))[0]}")
    
    # Display label distribution
    label_counts = np.sum(labels_np, axis=0)
    print(f"Label distribution: min={label_counts.min()}, max={label_counts.max()}, mean={label_counts.mean():.1f}")
    
    # Test multiple thresholds to find the optimal one
    metrics = {'eval_loss': avg_loss}
    threshold_metrics = {}
    best_f1 = 0
    best_threshold = 0.5  # Default
    
    for threshold in thresholds:
        # Apply threshold
        preds = (probs > threshold).astype(int)
        
        try:
            # Calculate metrics for this threshold
            ham = hamming_loss(labels_np, preds)
            jac = jaccard_score(labels_np, preds, average='samples', zero_division=0)
            f1_mic = f1_score(labels_np, preds, average='micro', zero_division=0)
            f1_mac = f1_score(labels_np, preds, average='macro', zero_division=0)
            
            threshold_metrics[threshold] = {
                'hamming_loss': ham,
                'jaccard_samples': jac,
                'f1_micro': f1_mic,
                'f1_macro': f1_mac
            }
            
            print(f"Threshold {threshold}: Hamming={ham:.4f}, F1-micro={f1_mic:.4f}, F1-macro={f1_mac:.4f}")
            
            # Track best threshold based on micro F1
            if f1_mic > best_f1:
                best_f1 = f1_mic
                best_threshold = threshold
                
        except Exception as e:
            logging.error(f"Error calculating metrics with threshold {threshold}: {e}")
    
    print(f"\nBest threshold: {best_threshold} (F1-micro: {best_f1:.4f})")
    
    # Use best threshold for final metrics
    best_preds = (probs > best_threshold).astype(int)
    metrics['hamming_loss'] = hamming_loss(labels_np, best_preds)
    metrics['jaccard_samples'] = jaccard_score(labels_np, best_preds, average='samples', zero_division=0)
    metrics['f1_micro'] = f1_score(labels_np, best_preds, average='micro', zero_division=0)
    metrics['f1_macro'] = f1_score(labels_np, best_preds, average='macro', zero_division=0)
    metrics['best_threshold'] = best_threshold
    
    # Don't include the full threshold_metrics dictionary in the returned metrics
    # This prevents formatting errors when printing the metrics
    # If you need this data, access it separately
    # metrics['threshold_metrics'] = threshold_metrics
    
    # Log metrics
    print(f"\nValidation Loss: {avg_loss:.4f}")
    for name, value in metrics.items():
        if name != 'eval_loss':  # Already printed eval_loss above
            try:
                print(f"  Validation {name.replace('_', ' ').title()}: {value:.4f}")
            except (TypeError, ValueError):
                print(f"  Validation {name.replace('_', ' ').title()}: {value}")
    
    return metrics

# 9. One Trial Epoch for Debugging Test Run

In [None]:
# Cell 9: Debug Training Run with Improved Evaluation and AMP

# Import torch amp components for mixed precision training
from torch.cuda.amp import autocast, GradScaler

def train_epoch(model, dataloader, criterion, optimizer, device, grad_accum_steps=1, scaler=None, scheduler=None):
    """Train model for one epoch using gradient accumulation for larger effective batches"""
    model.train()
    epoch_loss = 0.0
    num_batches_processed = 0
    optimizer_steps = 0
    
    # Zero the gradients at the beginning of epoch
    optimizer.zero_grad()
    
    # Process each batch
    for step, batch in enumerate(tqdm(dataloader, desc="Training", leave=False)):
        # Skip empty or malformed batches
        if batch is None or not batch:
            continue
            
        # Process batch with gradient accumulation for efficiency
        # Only backward + optimize every grad_accum_steps or at the last batch
        do_optimizer_step = ((step + 1) % grad_accum_steps == 0) or (step == len(dataloader) - 1)
        
        try:
            # Prepare batch inputs and move to device
            input_values = batch.get('input_values', batch.get('input_features')).to(device)
            attention_mask = batch.get('attention_mask', None)
            if attention_mask is not None:
                attention_mask = attention_mask.to(device)
            labels = batch.get('labels').to(device)
            
            # Get model parameter name requirements (for different HF models)
            model_inputs = {
                'input_features': input_values,  # Use 'input_features' for Wav2Vec2-BERT
                'attention_mask': attention_mask if attention_mask is not None else None
            }
            
            # --- Forward pass with autocast for mixed precision ---
            with autocast(device_type=device.type, enabled=(scaler is not None)):
                outputs = model(**model_inputs)
                loss = criterion(outputs.logits, labels)
                
                # Scale loss by gradient accumulation steps
                loss = loss / grad_accum_steps
            
            # --- Backward pass with scaler ---
            if scaler is not None:
                scaler.scale(loss).backward()
            else:
                loss.backward()
                
            # --- Optimizer step if needed ---
            if do_optimizer_step:
                if scaler is not None:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                
                # Apply scheduler step
                if scheduler is not None:
                    scheduler.step()
                
                # Zero the gradients
                optimizer.zero_grad()  
                optimizer_steps += 1
                
            # Accumulate loss statistics (use the pre-scaled loss for reporting)
            epoch_loss += loss.item() * grad_accum_steps
            num_batches_processed += 1
            
        except Exception as e:
            logging.error(f"Error in training batch {step}: {e}")
            continue  # Skip problematic batch
    
    # Calculate average loss for the epoch
    avg_loss = epoch_loss / num_batches_processed if num_batches_processed > 0 else float('inf')
    
    # Log results
    print(f"\nCompleted training epoch. Successful optimizer steps: {optimizer_steps}")
    print(f"Average Training Loss for Epoch: {avg_loss:.4f}")
    
    return avg_loss

# Check memory before initializing dataloaders
print(f"GPU memory before dataloader setup: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

# --- Debug Training Run ---
print("\n--- Starting Debug Training Run for 1 epoch (with AMP) ---")
num_epochs = num_epochs_debug

# Initialize GradScaler for AMP
scaler = GradScaler(enabled=(device.type == 'cuda'))

# Save model checkpoint after each debug epoch
debug_checkpoint_path = os.path.join(model_save_dir, f"{model_checkpoint.replace('/', '_')}_debug_AMP")
start_time = time.time()

# Initialize any metrics tracking
epoch_metrics = []

# Run training loop for specified number of debug epochs
for epoch in range(num_epochs):
    logging.info(f"DEBUG Epoch {epoch+1}/{num_epochs}")
    
    # Run one training epoch (with gradient accumulation and AMP)
    train_loss = train_epoch(
        model,
        debug_train_dataloader,
        criterion,
        optimizer,
        device,
        gradient_accumulation_steps,
        scaler  # Pass the scaler object (created with new API)
    )
    
    # Run evaluation step with the new evaluation function
    eval_metrics = evaluate_with_thresholds(
        model,
        debug_val_dataloader,
        criterion,
        device
    )
    
    print(f"\nDebug Epoch {epoch+1} finished.")
    print(f"  Avg Train Loss: {train_loss:.4f}")
    
    # Modified printing to handle potential dictionaries in metrics
    if eval_metrics:
        for name, value in eval_metrics.items():
            if name not in ['threshold_metrics']:  # Skip dictionary values 
                try:
                    print(f"  Validation {name.replace('_', ' ').title()}: {value:.4f}")
                except (TypeError, ValueError):
                    print(f"  Validation {name.replace('_', ' ').title()}: {value}")
    else:
        print("  Validation failed to produce metrics.")
    
    # Save checkpoint for this epoch
    epoch_save_path = f"{debug_checkpoint_path}_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), epoch_save_path)
    logging.info(f"Saved debug model checkpoint to {epoch_save_path}")
    
    # Add metrics to tracking
    epoch_metrics.append({
        'epoch': epoch+1,
        'train_loss': train_loss,
        'eval_metrics': eval_metrics,
    })

# Report training metrics over all epochs
debug_duration = time.time() - start_time
print(f"\n--- Debug Run Finished in {debug_duration:.2f} seconds ---")

GPU memory before dataloader setup: 7.00 GB

--- Starting Debug Training Run for 1 epoch (with AMP) ---

--- Debug Epoch 1/1 ---


Training:   0%|          | 0/8 [00:00<?, ?it/s]

                                                                                 


Completed training epoch. Successful optimizer steps: 2
Average Training Loss for Epoch: 0.5593


                                                         

Prediction stats - Min: 0.4116, Max: 0.5376, Mean: 0.4470
Prediction histogram: [  0   0   0   0 168   8   0   0   0   0]
Label distribution: min=0.0, max=4.0, mean=0.5
Threshold 0.05: Hamming=0.9432, F1-micro=0.1075, F1-macro=0.0869
Threshold 0.1: Hamming=0.9432, F1-micro=0.1075, F1-macro=0.0869
Threshold 0.2: Hamming=0.9432, F1-micro=0.1075, F1-macro=0.0869
Threshold 0.3: Hamming=0.9432, F1-micro=0.1075, F1-macro=0.0869
Threshold 0.4: Hamming=0.9432, F1-micro=0.1075, F1-macro=0.0869
Threshold 0.5: Hamming=0.0795, F1-micro=0.2222, F1-macro=0.0182

Best threshold: 0.5 (F1-micro: 0.2222)

Validation Loss: 0.6024
  Validation Hamming Loss: 0.0795
  Validation Jaccard Samples: 0.2500
  Validation F1 Micro: 0.2222
  Validation F1 Macro: 0.0182

Debug Epoch 1 finished.
  Avg Train Loss: 0.5593
  Validation Eval Loss: 0.6024
  Validation Hamming Loss: 0.0795
  Validation Jaccard Samples: 0.2500
  Validation F1 Micro: 0.2222
  Validation F1 Macro: 0.0182
  Validation Best Threshold: 0.5000




TypeError: unsupported format string passed to dict.__format__

# Trial debug training run worked! Now let's try the full training run.

# 10. Set Up DataLoaders for FULL Splits & LR Scheduler

In [None]:
# Cell 10: Setup DataLoaders for FULL Splits & LR Scheduler

from transformers import get_linear_schedule_with_warmup # Import scheduler

# --- Ensure Feature Extractor is Loaded ---
# (Code from previous Cell 4 - necessary if kernel restarted)
logging.info(f"Loading feature extractor for: {model_checkpoint}")
try:
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
    logging.info("Feature extractor loaded successfully.")
    target_sr = feature_extractor.sampling_rate
    print(f"Feature extractor expects sample rate: {target_sr}")
except Exception as e:
    logging.error(f"Could not load feature extractor. Error: {e}", exc_info=True)
    raise SystemExit

# --- Create Full Dataset instance ---
try:
    full_dataset = FMARawAudioDataset(manifest_path, feature_extractor=feature_extractor)
    manifest_df = full_dataset.manifest
except Exception as e:
     logging.error("Failed to instantiate FMARawAudioDataset.", exc_info=True)
     raise SystemExit

# --- Create FULL Datasets for Train/Val/Test ---
logging.info("Creating DataLoaders with FULL splits and custom collator...")
try:
    # Get indices for the splits from the manifest
    train_indices = manifest_df[manifest_df['split'] == 'training'].index.tolist()
    val_indices = manifest_df[manifest_df['split'] == 'validation'].index.tolist()
    test_indices = manifest_df[manifest_df['split'] == 'test'].index.tolist() # Get test indices too

    # Create Subset instances using the FULL index lists
    train_dataset = Subset(full_dataset, train_indices)
    val_dataset = Subset(full_dataset, val_indices)
    test_dataset = Subset(full_dataset, test_indices) # Create test dataset

    # --- Create Data Collator Instance ---
    data_collator = DataCollatorAudio()
    print("DataCollatorAudio instance created.")

    # --- Create DataLoaders ---
    # Use actual batch_size from config
    effective_batch_size = config.MODEL_PARAMS["batch_size"] * config.MODEL_PARAMS["gradient_accumulation_steps"]
    logging.info(f"Batch size: {config.MODEL_PARAMS['batch_size']}, Grad Accum Steps: {config.MODEL_PARAMS['gradient_accumulation_steps']}, Effective BS: {effective_batch_size}")

    # Use num_workers for faster loading (adjust based on instance cores)
    num_workers = 4 if os.name == 'posix' else 0
    pin_memory = True if device.type == 'cuda' else False

    train_dataloader = DataLoader(
        train_dataset, batch_size=config.MODEL_PARAMS["batch_size"], shuffle=True,
        collate_fn=data_collator, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=(num_workers>0)
    )
    val_dataloader = DataLoader(
        val_dataset, batch_size=config.MODEL_PARAMS["batch_size"], shuffle=False,
        collate_fn=data_collator, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=(num_workers>0)
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=config.MODEL_PARAMS["batch_size"], shuffle=False,
        collate_fn=data_collator, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=(num_workers>0)
    )
    logging.info(f"FULL Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")
    logging.info("FULL DataLoaders with custom collator created.")

    # --- Setup LR Scheduler ---
    num_epochs = config.MODEL_PARAMS["epochs"]
    num_training_steps = (len(train_dataloader) // config.MODEL_PARAMS["gradient_accumulation_steps"]) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
         optimizer, # Optimizer defined in Cell 6
         num_warmup_steps=0, # You can add warmup steps if desired (e.g., 10% of total steps)
         num_training_steps=num_training_steps
    )
    logging.info(f"LR Scheduler created. Total optimization steps: {num_training_steps}")

except Exception as e:
    logging.error(f"Failed to create datasets/dataloaders: {e}", exc_info=True)
    raise SystemExit

print("\nSetup for full training run complete.")

2025-05-04 04:27:37,443 - INFO - Loading feature extractor for: facebook/w2v-bert-2.0


2025-05-04 04:27:37,534 - INFO - Feature extractor loaded successfully.
Feature extractor expects sample rate: 16000
2025-05-04 04:27:37,536 - INFO - Initializing FMARawAudioDataset from: /workspace/musicClaGen/data/processed/small_subset_multihot.csv
2025-05-04 04:27:37,538 - INFO - Target sampling rate set from feature extractor: 16000 Hz
2025-05-04 04:27:37,539 - INFO - Loading manifest from: /workspace/musicClaGen/data/processed/small_subset_multihot.csv
2025-05-04 04:27:37,572 - INFO - Attempting to parse 'multi_hot_label' column using ast.literal_eval...
2025-05-04 04:27:37,893 - INFO - Example parsed label verified (type <class 'list'>, length 22): [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...
2025-05-04 04:27:37,895 - INFO - Loaded and parsed manifest with 8000 entries.
2025-05-04 04:27:37,897 - INFO - Creating DataLoaders with FULL splits and custom collator...
DataCollatorAudio instance created.
2025-05-04 04:27:37,904

# 11.Full training run 

In [None]:
# # Cell 11: Run Full Training Loop

# # Clear CUDA cache and force garbage collection
# import gc
# import torch
# torch.cuda.empty_cache()
# gc.collect()

# # Check memory usage before training
# print(f"GPU memory allocated before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
# print(f"GPU memory reserved before training: {torch.cuda.memory_reserved() / 1e9:.2f} GB")


# # Make sure model, criterion, optimizer, scheduler, dataloaders defined from previous cells
# num_epochs = config.MODEL_PARAMS["epochs"] # Get actual epochs from config
# gradient_accumulation_steps = config.MODEL_PARAMS["gradient_accumulation_steps"]
# metric_to_monitor = 'hamming_loss' # Metric to decide best model (lower is better)
# best_val_metric = float('inf')

# # --- Initialize GradScaler for AMP ---
# scaler = GradScaler(enabled=(device.type == 'cuda'))
# # ------------------------------------

# logging.info(f"--- Starting FULL Training for {num_epochs} epochs ---")
# start_time = time.time()

# # Make sure model and criterion are on the correct device
# model.to(device)
# criterion.to(device)

# for epoch in range(num_epochs):
#     epoch_start_time = time.time()
#     logging.info(f"\n--- Epoch {epoch+1}/{num_epochs} ---")

#     # Run training for one epoch
#     train_loss = train_epoch(
#         model, train_dataloader, criterion, optimizer, device,
#         gradient_accumulation_steps, scaler, scheduler # Pass scaler and scheduler
#     )

#     # Run evaluation on validation set
#     eval_metrics = evaluate(model, val_dataloader, criterion, device)

#     print(f"\nEpoch {epoch+1} finished.")
#     print(f"  Avg Train Loss: {train_loss:.4f}")

#     if not eval_metrics:
#         logging.warning(f"Epoch {epoch+1}: Evaluation failed, skipping checkpoint.")
#         continue

#     # Log all validation metrics
#     for name, value in eval_metrics.items():
#         print(f"  Validation {name.replace('_', ' ').title()}: {value:.4f}")

#     # Save model checkpoint if validation metric improved
#     current_val_metric = eval_metrics.get(metric_to_monitor, float('inf'))
#     if current_val_metric < best_val_metric:
#         best_val_metric = current_val_metric
#         # Use a consistent name for the best model checkpoint
#         save_path = os.path.join(model_save_dir, f"{model_checkpoint.replace('/', '_')}_finetuned_best.pth")
#         try:
#             torch.save(model.state_dict(), save_path)
#             logging.info(f"Validation metric improved ({metric_to_monitor}={current_val_metric:.4f}). Saved best model to {save_path}")
#         except Exception as e:
#             logging.error(f"Failed to save model checkpoint: {e}", exc_info=True)
#     else:
#          logging.info(f"Validation metric did not improve ({metric_to_monitor}={current_val_metric:.4f}). Best: {best_val_metric:.4f}")

#     epoch_duration = time.time() - epoch_start_time
#     logging.info(f"Epoch {epoch+1} finished in {epoch_duration / 60:.2f} minutes.")

# total_training_time = time.time() - start_time
# logging.info(f"--- Training Finished in {total_training_time / 60:.2f} minutes ---")


• Time: Took ~47 minutes for 1 epoch on the full `fma_small` training set (~6400 samples). This is a realistic time given the model size, 30s inputs, data loading, and AMP.

• Errors During Training: The log shows several errors during the training loop:

  • `ERROR - Error loading/processing track ...`
  
  `audioread.exceptions.NoBackendError`: This error occurred multiple times (tracks 133297, 99134, 98569, 98567, 98565, 108925). It indicates `librosa.load` failed. It first tries `soundfile` (which fails often with MP3s, sometimes due to file existence/permissions or internal errors), then falls back to `audioread`, which then fails because no suitable backend (like `ffmpeg`) was found or successfully used by `audioread`. This is despite installing `ffmpeg` earlier. It suggests `librosa`'s fallback mechanism isn't working reliably in this environment.

  • `[src/libmpg123/...]: warning: Cannot read next header...`, `error: dequantization failed!`, `error: part2_3_length ... too large...`, `error: Giving up resync...`: These are lower-level MP3 decoding errors from the `mpg123` library, likely called by `audioread` or another backend. They indicate corrupted or non-standard MP3 files.


When checked the documentation on fma github(https://github.com/mdeff/fma/wiki), these track IDs are flawed indeed, so everything is fine!

In [None]:
# Cell: Enhanced Training Loop with Logger and Selected Checkpoints

# Clear CUDA cache and force garbage collection
import gc
import torch
import shutil
torch.cuda.empty_cache()
gc.collect()

# Check memory usage before training
print(f"GPU memory allocated before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"GPU memory reserved before training: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Create training configuration dictionary
training_config = {
    "model_checkpoint": model_checkpoint,
    "num_labels": num_labels,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "num_epochs": config.MODEL_PARAMS["epochs"],
    "weight_decay": weight_decay,
    "gradient_accumulation_steps": gradient_accumulation_steps,
    "device": str(device),
    "dataset_info": {
        "train_size": len(train_dataset),
        "val_size": len(val_dataset),
        "test_size": len(test_dataset),
        "num_genres": num_labels,
        "genre_list": unified_genres
    },
    "optimizer": "AdamW",
    "scheduler": "linear_warmup_decay",
    "model_params": {
        "hidden_size": model.config.hidden_size if hasattr(model.config, "hidden_size") else "unknown",
        "num_hidden_layers": model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else "unknown",
        "num_attention_heads": model.config.num_attention_heads if hasattr(model.config, "num_attention_heads") else "unknown"
    }
}

# Modified TrainingLogger class with checkpoint management
class TrainingLoggerWithCleanup(TrainingLogger):
    def __init__(self, output_dir, model_name, config=None, max_checkpoints=5, save_frequency=2):
        super().__init__(output_dir, model_name, config)
        self.max_checkpoints = max_checkpoints
        self.save_frequency = save_frequency  # Save one checkpoint every N epochs
        self.saved_checkpoints = []
    
    def save_model_checkpoint(self, model, epoch, step, optimizer=None, scheduler=None, is_best=False):
        """Save model checkpoint with cleanup"""
        # Only save checkpoint if it's a multiple of save_frequency or it's the best model
        if epoch % self.save_frequency != 0 and not is_best:
            logging.info(f"Skipping checkpoint at epoch {epoch} (saving every {self.save_frequency} epochs)")
            return None
            
        # Save checkpoint as normal
        checkpoint_name = f"checkpoint-{epoch}"
        checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
        os.makedirs(checkpoint_path, exist_ok=True)
        
        # Save model state dict
        model_path = os.path.join(checkpoint_path, "model.pth")
        torch.save(model.state_dict(), model_path)
        
        # Save optimizer and scheduler if provided
        if optimizer:
            optimizer_path = os.path.join(checkpoint_path, "optimizer.pth")
            torch.save(optimizer.state_dict(), optimizer_path)
        
        if scheduler:
            scheduler_path = os.path.join(checkpoint_path, "scheduler.pth")
            torch.save(scheduler.state_dict(), scheduler_path)
        
        # Save model config
        if hasattr(model, 'config'):
            config_path = os.path.join(checkpoint_path, "config.json")
            with open(config_path, 'w') as f:
                json.dump(model.config.to_dict(), f, indent=2)
        
        # Create a symbolic link or copy for the best model
        if is_best:
            best_path = os.path.join(self.base_dir, "best_model")
            # If we can use symlinks
            try:
                if os.path.exists(best_path):
                    if os.path.islink(best_path):
                        os.unlink(best_path)
                    else:
                        os.rmdir(best_path)
                os.symlink(checkpoint_path, best_path)
                logging.info(f"Created symbolic link to best model at {best_path}")
            except (OSError, NotImplementedError):
                # Fallback: copy the model file
                best_model_path = os.path.join(self.base_dir, "best_model.pth")
                torch.save(model.state_dict(), best_model_path)
                logging.info(f"Saved copy of best model to {best_model_path}")
        
        # Add to list of saved checkpoints
        self.saved_checkpoints.append(checkpoint_path)
        logging.info(f"Saved model checkpoint to {checkpoint_path}")
        
        # Clean up old checkpoints if we have too many
        self._cleanup_old_checkpoints()
        
        return checkpoint_path
    
    def _cleanup_old_checkpoints(self):
        """Remove oldest checkpoints to maintain only the last max_checkpoints"""
        # Always keep best checkpoint separate
        best_checkpoint = os.path.join(self.base_dir, "best_model")
        best_checkpoint_target = None
        if os.path.islink(best_checkpoint):
            best_checkpoint_target = os.path.realpath(best_checkpoint)
        
        # Skip cleanup if we don't have enough checkpoints yet
        if len(self.saved_checkpoints) <= self.max_checkpoints:
            return
        
        # Remove oldest checkpoints
        while len(self.saved_checkpoints) > self.max_checkpoints:
            oldest_checkpoint = self.saved_checkpoints.pop(0)
            
            # Don't delete if it's the best checkpoint
            if best_checkpoint_target and oldest_checkpoint == best_checkpoint_target:
                continue
                
            try:
                if os.path.exists(oldest_checkpoint):
                    shutil.rmtree(oldest_checkpoint)
                    logging.info(f"Removed old checkpoint: {oldest_checkpoint}")
            except Exception as e:
                logging.error(f"Error removing checkpoint {oldest_checkpoint}: {e}")

# Initialize the logger with cleanup
logger = TrainingLoggerWithCleanup(
    output_dir=model_save_dir,
    model_name=model_checkpoint.replace('/', '_'),
    config=training_config,
    max_checkpoints=5,     # Keep only the 5 most recent checkpoints
    save_frequency=2       # Save checkpoint every 2 epochs
)

# Copy model architecture to logger's directory if we already saved it elsewhere
if os.path.exists(os.path.join(model_save_dir, "model_info", "model_architecture.json")):
    import shutil
    os.makedirs(logger.base_dir, exist_ok=True)
    shutil.copy(
        os.path.join(model_save_dir, "model_info", "model_architecture.json"),
        os.path.join(logger.base_dir, "model_architecture.json")
    )
    logging.info(f"Copied model architecture to training run directory: {logger.base_dir}")

# Save processor for later use
processor_save_path = os.path.join(logger.base_dir, "processor")
os.makedirs(processor_save_path, exist_ok=True)
processor.save_pretrained(processor_save_path)
logging.info(f"Saved processor to {processor_save_path}")

# Make sure model, criterion are on the correct device
model.to(device)
criterion.to(device)

# Initialize variables for training loop
num_epochs = config.MODEL_PARAMS["epochs"]
gradient_accumulation_steps = config.MODEL_PARAMS["gradient_accumulation_steps"]
metric_to_monitor = 'hamming_loss'  # Metric to decide best model (lower is better)
best_val_metric = float('inf')
global_step = 0

# Initialize GradScaler for AMP
scaler = GradScaler(enabled=(device.type == 'cuda'))

logging.info(f"--- Starting FULL Training for {num_epochs} epochs ---")
start_time = time.time()

for epoch in range(num_epochs):
    epoch_start_time = time.time()
    logging.info(f"\n--- Epoch {epoch+1}/{num_epochs} ---")

    # Run training for one epoch
    train_loss = train_epoch(
        model, train_dataloader, criterion, optimizer, device,
        gradient_accumulation_steps, scaler, scheduler  # Pass scaler and scheduler
    )
    
    # Update global step (approximate)
    steps_per_epoch = len(train_dataloader) // gradient_accumulation_steps
    global_step += steps_per_epoch

    # Run evaluation on validation set (using enhanced function)
    eval_metrics = evaluate_with_thresholds(model, val_dataloader, criterion, device)

    print(f"\nEpoch {epoch+1} finished.")
    print(f"  Avg Train Loss: {train_loss:.4f}")

    if not eval_metrics:
        logging.warning(f"Epoch {epoch+1}: Evaluation failed, skipping checkpoint.")
        continue

    # Log epoch metrics to the logger
    logger.log_epoch(
        epoch=epoch+1,
        train_loss=train_loss,
        val_metrics=eval_metrics,
        learning_rate=scheduler.get_last_lr()[0] if scheduler else learning_rate,
        step=global_step
    )

    # Save checkpoint for this epoch (logger will handle save frequency)
    checkpoint_path = logger.save_model_checkpoint(
        model=model,
        epoch=epoch+1,
        step=global_step,
        optimizer=optimizer,
        scheduler=scheduler
    )

    # Check if this is the best model
    current_val_metric = eval_metrics.get(metric_to_monitor, float('inf'))
    is_best = False
    
    if current_val_metric < best_val_metric:
        is_best = True
        best_val_metric = current_val_metric
        logging.info(f"Validation metric improved ({metric_to_monitor}={current_val_metric:.4f})")
        
        # Update best model in logger
        logger.update_best_metrics(
            epoch=epoch+1,
            step=global_step,
            val_metrics=eval_metrics,
            model_path=checkpoint_path
        )
        
        # Save as best model (always save the best model regardless of save_frequency)
        logger.save_model_checkpoint(
            model=model,
            epoch=epoch+1,
            step=global_step,
            optimizer=optimizer,
            scheduler=scheduler,
            is_best=True
        )
    else:
        logging.info(f"Validation metric did not improve ({metric_to_monitor}={current_val_metric:.4f}). Best: {best_val_metric:.4f}")

    # Update trainer state after each epoch
    logger.save_trainer_state(
        epoch=epoch+1,
        step=global_step,
        optimizer_state=optimizer.state_dict(),
        scheduler_state=scheduler.state_dict() if scheduler else None
    )

    epoch_duration = time.time() - epoch_start_time
    logging.info(f"Epoch {epoch+1} finished in {epoch_duration / 60:.2f} minutes.")

total_training_time = time.time() - start_time
logger.finish_training(total_training_time)
logging.info(f"--- Training Finished in {total_training_time / 60:.2f} minutes ---")

# Print path to training logs and results
print(f"\nTraining logs and results saved to: {logger.base_dir}")
print(f"Best model checkpoint: {logger.metrics['best_metrics'].get('model_checkpoint', 'None')}")

# Show summary of training results
print("\nTraining Summary:")
print(f"  Total epochs: {num_epochs}")
print(f"  Best {metric_to_monitor}: {best_val_metric:.4f}")
print(f"  Final train loss: {train_loss:.4f}")
print(f"  Training time: {total_training_time / 60:.2f} minutes")

# 12. Evaluate Best Model

In [28]:
# Cell 12: Evaluate Best Model on Test Set (Robust Version)

import torch.nn as nn
import os
from transformers import AutoModelForAudioClassification
import logging
from torch.utils.data import DataLoader
import time

logging.info("\n--- Evaluating on Test Set using Best Model ---")

# Construct path to the best saved model
best_model_path = os.path.join(model_save_dir, f"{model_checkpoint.replace('/', '_')}_finetuned_best.pth")

if os.path.exists(best_model_path):
    try:
        logging.info(f"Loading best model from {best_model_path}")
        
        # Re-initialize model with correct structure
        model_reloaded = AutoModelForAudioClassification.from_pretrained(
            model_checkpoint,
            num_labels=num_labels,
            ignore_mismatched_sizes=True
        )
        
        # Load the saved state dict
        model_reloaded.load_state_dict(torch.load(best_model_path, map_location=device))
        model_reloaded.to(device)
        model_reloaded.eval()
        logging.info("Model successfully loaded and moved to device")
        
        # Create a safer test dataloader with no workers (avoid multiprocessing issues)
        safe_test_dataloader = DataLoader(
            test_dataset, 
            batch_size=config.MODEL_PARAMS["batch_size"],
            shuffle=False,
            collate_fn=data_collator,
            num_workers=0,  # Use main process only - no worker processes
            pin_memory=False  # Disable pinned memory to reduce memory usage
        )
        logging.info("Created safer test dataloader without worker processes")
        
        # Run evaluation with extra error handling
        logging.info("Starting evaluation on test set...")
        try:
            start_time = time.time()
            test_metrics = evaluate_with_thresholds(model_reloaded, safe_test_dataloader, criterion, device)
            eval_time = time.time() - start_time
            
            # Log test results
            logging.info(f"\n--- Final Test Set Results (completed in {eval_time:.2f}s) ---")
            if test_metrics:
                for metric_name, metric_value in test_metrics.items():
                    if metric_name != 'threshold_metrics':  # Skip nested dictionary
                        try:
                            logging.info(f"Test {metric_name.replace('_', ' ').title()}: {metric_value:.4f}")
                        except (TypeError, ValueError):
                            logging.info(f"Test {metric_name.replace('_', ' ').title()}: {metric_value}")
            else:
                logging.info("Test evaluation failed to produce metrics.")


        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                logging.error("CUDA out of memory during evaluation. Try reducing batch size.")
            elif "DataLoader worker" in str(e):
                logging.error(f"DataLoader worker error (should not happen with num_workers=0): {e}")
            else:
                logging.error(f"Runtime error during evaluation: {e}")
        except Exception as e:
            logging.error(f"Error during evaluation: {e}", exc_info=True)
            
    except Exception as e:
        logging.error(f"Failed to load model: {e}", exc_info=True)
else:
    logging.warning(f"Best model checkpoint not found at {best_model_path}. Skipping final test evaluation.")

2025-05-04 10:46:34,932 - INFO - 
--- Evaluating on Test Set using Best Model ---
2025-05-04 10:46:34,936 - INFO - Loading best model from /workspace/musicClaGen/models/facebook_w2v-bert-2.0_finetuned_best.pth


Some weights of Wav2Vec2BertForSequenceClassification were not initialized from the model checkpoint at facebook/w2v-bert-2.0 and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


2025-05-04 10:46:41,758 - INFO - Model successfully loaded and moved to device
2025-05-04 10:46:41,761 - INFO - Created safer test dataloader without worker processes
2025-05-04 10:46:41,762 - INFO - Starting evaluation on test set...


                                                             


Validation Loss: 0.1972
  Validation Hamming Loss: 0.0574
  Validation Jaccard Samples: 0.0000
  Validation F1 Micro: 0.0000
  Validation F1 Macro: 0.0000
2025-05-04 10:55:42,077 - INFO - 
--- Final Test Set Results (completed in 540.31s) ---
2025-05-04 10:55:42,078 - INFO - Test Hamming Loss: 0.0574
2025-05-04 10:55:42,080 - INFO - Test Jaccard Samples: 0.0000
2025-05-04 10:55:42,081 - INFO - Test F1 Micro: 0.0000
2025-05-04 10:55:42,082 - INFO - Test F1 Macro: 0.0000
2025-05-04 10:55:42,083 - INFO - Test Eval Loss: 0.1972


