In [1]:
# Cell 1

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
# Use AutoProcessor for Wav2Vec2-BERT - it bundles feature_extractor and tokenizer (if needed)
from transformers import AutoModelForAudioClassification, AutoProcessor

from torch.optim import AdamW
import pandas as pd
import numpy as np
import os
import sys
import ast # For parsing string representations of lists/arrays
import logging
import time
from sklearn.metrics import hamming_loss, jaccard_score, f1_score # Add more as needed
from tqdm.notebook import tqdm # Use notebook version of tqdm
import librosa # Needed for loading raw audio now



# --- Project Setup ---
# Detect if running in notebook or script to adjust path

cwd = os.getcwd()
PROJECT_ROOT = os.path.abspath(os.path.join(cwd, '../../')) # NOTE: remember to change if change the directory structure



print(f"PROJECT_ROOT detected as: {PROJECT_ROOT}")
if PROJECT_ROOT not in sys.path:
    print(f"Adding {PROJECT_ROOT} to sys.path")
    sys.path.append(PROJECT_ROOT)

# --- Config and Utils ---
try:
    import config # Import your configuration file
    # Optionally import utils if needed, e.g., for get_audio_path if not defined here
    # import src.utils as utils
except ModuleNotFoundError:
     print("ERROR: Cannot import config or utils. Make sure PROJECT_ROOT is correct and src is importable.")
     # Or add src to path: sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src'))
     # import config
     # import utils


# --- Setup Logging ---
for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # Clear previous
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[logging.StreamHandler(sys.stdout)])

print("Imports and basic setup complete.")

  from .autonotebook import tqdm as notebook_tqdm


PROJECT_ROOT detected as: /home/zhuoyuan/CSprojects/musicClaGen
Adding /home/zhuoyuan/CSprojects/musicClaGen to sys.path
/home/zhuoyuan/CSprojects/musicClaGen
Imports and basic setup complete.


In [2]:
# Cell 2 

In [3]:
# Cell 2
# --- Load Config ---
# Ensure config.py has the correct paths in the PATHS dict
manifest_path = config.PATHS.get('SMALL_MULTILABEL_PATH', os.path.join(config.PATHS['PROCESSED_DATA_DIR'], 'small_subset_multihot.csv')) # Use .get for safety
genre_list_path = config.PATHS.get('GENRE_LIST_PATH', os.path.join(config.PATHS['PROCESSED_DATA_DIR'], 'unified_genres.txt'))
model_save_dir = config.PATHS['MODELS_DIR']

# Ensure config.py has MODEL_PARAMS dict with model_checkpoint
model_checkpoint = config.MODEL_PARAMS['model_checkpoint'] # e.g., "facebook/w2v-bert-2.0" - VERIFY!
learning_rate = config.MODEL_PARAMS['learning_rate']
batch_size = config.MODEL_PARAMS['batch_size'] # Use the small BS for notebook test
num_epochs_debug = 1 # <<<--- RUN ONLY 1 EPOCH FOR DEBUGGING ---<<<
weight_decay = config.MODEL_PARAMS['weight_decay']
gradient_accumulation_steps = config.MODEL_PARAMS['gradient_accumulation_steps']

# --- Load unified genre list ---
try:
    with open(genre_list_path, 'r') as f:
        unified_genres = [line.strip() for line in f if line.strip()]
    num_labels = len(unified_genres) # should be the number of labels defined in the unified_genres.txt file, in this case it should be 22.
    logging.info(f"Loaded {num_labels} unified genres from {genre_list_path}")
    if num_labels == 0: raise ValueError("Genre list is empty!")
except Exception as e:
    logging.error(f"Failed to load or process unified genre list: {e}", exc_info=True)
    raise SystemExit("Cannot proceed without genre list.")

# --- Setup Device ---
device = torch.device(config.DEVICE if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
if not torch.cuda.is_available() and config.DEVICE=="cuda":
     logging.warning("CUDA selected but not available, falling back to CPU.")

# --- Create Save Directory ---
os.makedirs(model_save_dir, exist_ok=True)

2025-05-02 22:09:34,766 - INFO - Loaded 22 unified genres from /home/zhuoyuan/CSprojects/musicClaGen/data/processed/unified_genres.txt
2025-05-02 22:09:34,767 - INFO - Using device: cuda


In [4]:
print(manifest_path)

/home/zhuoyuan/CSprojects/musicClaGen/data/processed/small_subset_multihot.csv


# Cell 3

In [5]:
# Cell 3: Dataset Class Definition (Raw Audio Version)



# Define(recollect)the regex parser from preprocess.py if needed, 
# otherwise use ast.literal_eval---

import re

def parse_numpy_array_string(array_str):
    """
    Parse strings like '[np.float32(1.0), np.float32(0.0), ...]' into a list of integers.
    This is needed because ast.literal_eval cannot handle 'np.float32()' in the string.
    """
    if not isinstance(array_str, str):
        return []
    
    try:
        # Extract all the float values using regular expressions
        float_matches = re.findall(r'np\.float32\((\d+\.\d+)\)', array_str)
        
        # Convert matches to integers (1.0 -> 1, 0.0 -> 0)
        values = []
        for match in float_matches:
            value = float(match)
            # Convert to integer if it's 0.0 or 1.0
            if value == 1.0:
                values.append(1)
            elif value == 0.0:
                values.append(0)
            else:
                values.append(value)  # Keep as float if not 0 or 1
                
        return values
    except Exception as e:
        logging.warning(f"Error parsing array string: {e}")
        return []

class FMARawAudioDataset(Dataset):
    """
    Loads raw audio waveforms and labels from manifest, uses Hugging Face
    feature extractor (like ASTFeatureExtractor or Wav2Vec2Processor) on the fly.
    """
    def __init__(self, manifest_path, feature_extractor):
        """
        Args:
            manifest_path (str): Path to the final manifest CSV file.
            feature_extractor: Initialized Hugging Face AutoFeatureExtractor or AutoProcessor.
        """
        logging.info(f"Initializing FMARawAudioDataset from: {manifest_path}")
        if feature_extractor is None:
             raise ValueError("FMARawAudioDataset requires a feature_extractor/processor instance.")

        self.feature_extractor = feature_extractor
        # Get target sampling rate directly from the extractor/processor
        try:
             # Works for Wav2Vec2Processor, ASTFeatureExtractor, etc.
             self.target_sr = self.feature_extractor.sampling_rate
             logging.info(f"Target sampling rate set from feature extractor: {self.target_sr} Hz")
        except AttributeError:
             logging.warning("Could not get sampling_rate from feature_extractor, using config.")
             # Fallback to config if needed, but ensuring match is crucial
             self.target_sr = config.PREPROCESSING_PARAMS['sample_rate']


        logging.info(f"Loading manifest from: {manifest_path}")
        try:
            self.manifest = pd.read_csv(manifest_path)
            # Ensure index is set if needed elsewhere, or use default range index
            if 'track_id' in self.manifest.columns:
                 self.manifest = self.manifest.set_index('track_id', drop=False)

            # --- Parse the 'multi_hot_label' string back into a list ---
            # Here: if we decide to use raw audio, we use regex parser; 
            #       if we decide to use mel spectrogram, we use ast.literal_eval

            # Choose the correct parser based on how labels were saved in the CSV
            # If saved as '[1.0, 0.0,...]' use ast.literal_eval
            # label_parser = ast.literal_eval
            # If saved as '[np.float32(1.0)...]' uncomment and use regex parser
            label_parser = parse_numpy_array_string

            self.manifest['multi_hot_label'] = self.manifest['multi_hot_label'].apply(label_parser)
            logging.info(f"Loaded and parsed manifest with {len(self.manifest)} entries.")
            # Check the first parsed label
            logging.info(f"Example parsed label (first entry): {self.manifest['multi_hot_label'].iloc[0]}")

        except Exception as e:
            logging.error(f"Error loading or parsing manifest {manifest_path}: {e}", exc_info=True)
            raise

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.manifest)

    def __getitem__(self, idx):
        """
        Loads raw audio for index idx, processes it with the feature extractor,
        and returns the processed inputs and labels.
        """
        if torch.is_tensor(idx): idx = idx.tolist() # Handle tensor indices

        # Get the row data from the manifest
        row = self.manifest.iloc[idx]
        track_id = row.get('track_id', self.manifest.index[idx]) # Get track_id safely
        label_vector = row['multi_hot_label'] # Already parsed list/array

        # Construct absolute audio path if necessary
        audio_path = row['audio_path']

        #NOTE: originally, the mel-spectrogram's path is relative  but the raw audio's path is absolute, so we need to make sure the audio_path is absolute
        # So we are check if the audio_path is absolute or relative in case we load the wrong data, if it's relative, we need to join it with the PROJECT_ROOT
        if not os.path.isabs(audio_path):
             # Assumes path in manifest is relative to PROJECT_ROOT
             audio_path = os.path.join(config.PROJECT_ROOT, audio_path)

        try:
            # --- 1. Load RAW Audio Waveform ---
            # Load full 30s clip at the TARGET sample rate required by the processor
            waveform, loaded_sr = librosa.load(
                audio_path,
                sr=self.target_sr, # Use processor's sampling rate
                duration=30.0     # Load the full 30 seconds
            )
            # Ensure minimum length if needed (though duration should handle it)
            min_samples = int(0.1 * self.target_sr) # Example: require at least 0.1s
            if len(waveform) < min_samples:
                 raise ValueError(f"Audio signal for track {track_id} too short after loading.")

            # --- 2. Apply Feature Extractor ---
            # Pass the raw waveform numpy array
            # The extractor handles normalization, padding/truncation, tensor conversion
            
            max_length = 5000

            inputs = self.feature_extractor(
                waveform,
                sampling_rate=self.target_sr,
                return_tensors="pt",
                return_attention_mask=True # Request attention mask
            )

            # --- 3. Prepare Outputs ---
            # Squeeze unnecessary batch dimension added by the extractor
            # Key name ('input_values', 'input_features') depends on the specific extractor
            feature_tensor = inputs.get('input_values', inputs.get('input_features'))
            if feature_tensor is None:
                raise KeyError("Expected 'input_values' or 'input_features' key from feature_extractor output.")
            feature_tensor = feature_tensor.squeeze(0) # Remove batch dim -> [Channels?, Freq?, Time] or [SeqLen, Dim]

            attention_mask = inputs.get('attention_mask', None)
            if attention_mask is not None:
                 attention_mask = attention_mask.squeeze(0)

            # Convert label list/array to float tensor for BCE loss
            label_tensor = torch.tensor(label_vector, dtype=torch.float32)

            # Return dictionary matching model's expected input names
            model_input_dict = {"labels": label_tensor}
            # Use the key the feature extractor provided
            if 'input_values' in inputs:
                 model_input_dict['input_values'] = feature_tensor
            elif 'input_features' in inputs:
                 model_input_dict['input_features'] = feature_tensor

            if attention_mask is not None:
                 model_input_dict['attention_mask'] = attention_mask

            return model_input_dict

        except FileNotFoundError:
             logging.error(f"Audio file not found for track {track_id} at {audio_path}")
             raise # Or implement skipping logic with collate_fn
        except Exception as e:
            logging.error(f"Error loading/processing track {track_id} at {audio_path}: {e}", exc_info=True)
            raise # Or implement skipping logic


print("FMARawAudioDataset class defined.")

FMARawAudioDataset class defined.


In [6]:
print(model_checkpoint)

facebook/w2v-bert-2.0


In [7]:
# Cell 3.5: Define Data Collator for Padding (Corrected Padding Logic)

import torch
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
# from transformers.feature_extraction_utils import BatchFeature # Not strictly needed here

@dataclass
class DataCollatorAudio:
    """
    Data collator that dynamically pads the inputs received Feature Extractor.
    Correctly handles padding for [SequenceLength, FeatureDim] tensors.
    """
    padding_value: float = 0.0 # Standard padding for features/audio

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # features is a list of dicts like [{'input_values': tensor1, 'labels': label1, 'attention_mask': mask1}, ...]

        # --- Pad 'input_values' (or 'input_features') ---
        input_key = 'input_values' if 'input_values' in features[0] else 'input_features'
        input_features = [d[input_key] for d in features]

        # Determine max sequence length *in this batch* (assuming shape [SeqLen, FeatureDim])
        # Add check for empty list
        if not input_features:
             return {}
        max_len = max(feat.shape[0] for feat in input_features) # <<<--- Get length of FIRST dimension

        # Pad each feature tensor to max_len along the sequence dimension (first dim)
        padded_features = []
        for feat in input_features:
            # feat shape is [SeqLen, FeatureDim]
            num_frames = feat.shape[0]
            num_features = feat.shape[1] # Should be consistent (e.g., 160)
            pad_width = max_len - num_frames

            # Pad argument format for 2D tensor: (pad_left_dim1, pad_right_dim1, pad_left_dim0, pad_right_dim0)
            # We only want to pad the end of the sequence dimension (dim 0)
            # (0, 0) means no padding on left/right of feature dim (dim 1)
            # (0, pad_width) means 0 padding before seq dim (dim 0), pad_width padding after
            padded_feat = torch.nn.functional.pad(feat, (0, 0, 0, pad_width), mode='constant', value=self.padding_value)
            # Verify shape after padding
            # print(f"Original shape: {feat.shape}, Padded shape: {padded_feat.shape}, Target max_len: {max_len}")
            padded_features.append(padded_feat)

        # Stack the padded features into a batch tensor
        # Now all tensors in padded_features should have shape [max_len, FeatureDim]
        try:
             batch_input_features = torch.stack(padded_features) # Shape: [BatchSize, max_len, FeatureDim]
        except RuntimeError as e:
             logging.error(f"RuntimeError during torch.stack. Shapes in batch might still differ or be incompatible.")
             # Print shapes for debugging
             for i, p_feat in enumerate(padded_features): logging.error(f" Padded shape {i}: {p_feat.shape}")
             raise e


        # --- Prepare Batch Dictionary ---
        batch = {"input_values": batch_input_features}

        # --- Pad 'attention_mask' if present ---
        # Attention mask usually has shape [SeqLen]
        if "attention_mask" in features[0] and features[0]["attention_mask"] is not None:
            attention_masks = [d["attention_mask"] for d in features]
            padded_masks = []
            for mask in attention_masks:
                 pad_width = max_len - mask.shape[-1] # Pad last dimension (the sequence length)
                 # Pad argument format for 1D tensor: (pad_left, pad_right)
                 padded_mask = torch.nn.functional.pad(mask, (0, pad_width), mode='constant', value=0) # Pad attention mask with 0
                 padded_masks.append(padded_mask)
            batch["attention_mask"] = torch.stack(padded_masks) # Shape: [BatchSize, max_len]

        # --- Stack Labels ---
        labels = [d["labels"] for d in features]
        batch["labels"] = torch.stack(labels) # Shape: [BatchSize, num_labels]

        return batch

# Create an instance of the collator (do this in Cell 4)
# data_collator = DataCollatorAudio()
# print("DataCollatorAudio defined.")

# Cell 4

In [8]:
# Cell 4: Load Feature Extractor, Create DataLoaders with Custom Collator

from transformers import AutoFeatureExtractor # Use the correct class
from torch.utils.data import DataLoader, Subset # Ensure Subset is imported
# Ensure FMARawAudioDataset and DataCollatorAudio are defined in previous cells

# --- Load Feature Extractor ---
# (Using model_checkpoint defined in Cell 2)
logging.info(f"Loading feature extractor for: {model_checkpoint}")
try:
    # Load the feature extractor associated with Wav2Vec2-BERT
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
    logging.info("Feature extractor loaded successfully.")
    # Log the expected sample rate
    processor_sr = feature_extractor.sampling_rate
    print(f"Feature extractor expects sample rate: {processor_sr}")
    # Ensure config matches extractor's expected rate
    if config.PREPROCESSING_PARAMS['sample_rate'] != processor_sr:
         logging.warning(f"Config sample rate ({config.PREPROCESSING_PARAMS['sample_rate']}) differs from feature extractor ({processor_sr}). Ensure audio loading uses {processor_sr} Hz.")
         # Update config value if necessary, or ensure Dataset uses processor_sr
         # config.PREPROCESSING_PARAMS['sample_rate'] = processor_sr # Be careful modifying config dynamically

except Exception as e:
    logging.error(f"Could not load feature extractor for {model_checkpoint}. Cannot proceed. Error: {e}", exc_info=True)
    raise SystemExit # Stop execution if extractor fails

# --- Create Full Dataset ---
# Ensure FMARawAudioDataset __init__ accepts feature_extractor
try:
    # Pass the loaded feature_extractor instance
    full_dataset = FMARawAudioDataset(manifest_path, feature_extractor=feature_extractor)
    manifest_df = full_dataset.manifest
except Exception as e:
     logging.error("Failed to instantiate FMARawAudioDataset.", exc_info=True)
     raise SystemExit

# --- Create SMALLER DEBUG Datasets ---
logging.info("Creating DEBUG DataLoaders with small subsets and custom collator...")
try:
    # Get indices for the splits from the manifest
    train_indices = manifest_df[manifest_df['split'] == 'training'].index[:16].tolist() # Small subset for debug
    val_indices = manifest_df[manifest_df['split'] == 'validation'].index[:8].tolist()  # Small subset for debug

    # Create Subset instances
    debug_train_dataset = Subset(full_dataset, train_indices)
    debug_val_dataset = Subset(full_dataset, val_indices)

    # --- Create Data Collator Instance ---
    # (Assumes DataCollatorAudio class is defined in Cell 3.5)
    data_collator = DataCollatorAudio()
    print("DataCollatorAudio instance created.")

    # --- Create DataLoaders using the custom collate_fn ---
    debug_train_dataloader = DataLoader(
        debug_train_dataset,
        batch_size=batch_size, # Use small batch_size from config
        shuffle=True,
        collate_fn=data_collator # Apply custom padding at batch level
        # num_workers=4, # Optional: Add workers later for performance
        # pin_memory=True # Optional: Add if using GPU
    )
    debug_val_dataloader = DataLoader(
        debug_val_dataset,
        batch_size=batch_size, # Use small batch_size from config
        shuffle=False, # No need to shuffle validation data
        collate_fn=data_collator # Apply custom padding at batch level
        # num_workers=4,
        # pin_memory=True
    )
    logging.info(f"DEBUG Dataset sizes: Train={len(debug_train_dataset)}, Val={len(debug_val_dataset)}")
    logging.info("DEBUG DataLoaders with custom collator created.")
except Exception as e:
    logging.error(f"Failed to create DEBUG datasets/dataloaders: {e}", exc_info=True)
    raise SystemExit

2025-05-02 22:09:34,826 - INFO - Loading feature extractor for: facebook/w2v-bert-2.0
2025-05-02 22:09:34,968 - INFO - Feature extractor loaded successfully.
Feature extractor expects sample rate: 16000
2025-05-02 22:09:34,969 - INFO - Initializing FMARawAudioDataset from: /home/zhuoyuan/CSprojects/musicClaGen/data/processed/small_subset_multihot.csv
2025-05-02 22:09:34,969 - INFO - Target sampling rate set from feature extractor: 16000 Hz
2025-05-02 22:09:34,969 - INFO - Loading manifest from: /home/zhuoyuan/CSprojects/musicClaGen/data/processed/small_subset_multihot.csv
2025-05-02 22:09:35,036 - INFO - Loaded and parsed manifest with 8000 entries.
2025-05-02 22:09:35,037 - INFO - Example parsed label (first entry): [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2025-05-02 22:09:35,037 - INFO - Creating DEBUG DataLoaders with small subsets and custom collator...
DataCollatorAudio instance created.
2025-05-02 22:09:35,041 - INFO - DEBUG Dataset sizes: Train=16, Val=

NameError: name 'inspect' is not defined

In [9]:
# Cell 5: Load Wav2Vec2-BERT Model and Modify Head

import torch.nn as nn # Ensure nn is imported
from transformers import AutoModelForAudioClassification

logging.info(f"Loading pre-trained Wav2Vec2-BERT model: {model_checkpoint}")
try:
    # Load the model configured for audio classification
    model = AutoModelForAudioClassification.from_pretrained(
        model_checkpoint,
        num_labels=num_labels,
        ignore_mismatched_sizes=True # Essential for replacing the head
    )
    logging.info("Model loaded initially.")

    # --- Explicit Head Replacement (Recommended) ---
    # Though I have defined num_labels = num_labels on previous step, I want to explicitly replace it again to ensure the head is correct.
    # If the above code is correct, the explicitly approach below might seem redundant but.
    
    # I MUST verify the correct attribute name for the classifier head for Wav2Vec2-BERT. 
    # Common names include 'classifier', 'projector','classification_head'. Use print(model) after loading to check.
    classifier_attr = 'classifier' # <<<--- VERIFY THIS ATTRIBUTE NAME ---<<<

    if hasattr(model, classifier_attr):
        original_classifier = getattr(model, classifier_attr)
        logging.info(f"Found classifier attribute '{classifier_attr}' of type {type(original_classifier)}")

        # Check if it's a simple Linear layer or potentially a sequence/projection
        if isinstance(original_classifier, nn.Linear):
            in_features = original_classifier.in_features
            logging.info(f"Replacing classifier head '{classifier_attr}'. Original out: {original_classifier.out_features}, New out: {num_labels}")
            setattr(model, classifier_attr, nn.Linear(in_features, num_labels))
            print(f"Successfully replaced classifier head '{classifier_attr}'.")
        # Add checks here if Wav2Vec2-BERT uses a different common head structure
        # elif isinstance(original_classifier, nn.Sequential): ... etc.
        else:
             logging.warning(f"Classifier head '{classifier_attr}' is not nn.Linear ({type(original_classifier)}). Attempting replacement might fail or need adjustment.")
             # If you know the structure (e.g., model.projector + model.classifier), adjust accordingly.
             # For now, we assume a direct replacement might work or the implicit loading handled it.

    else:
         logging.warning(f"Could not automatically find classifier attribute '{classifier_attr}'. Ensure head size ({num_labels}) was correctly set via 'num_labels' argument during loading or modify manually.")

    model.to(device)
    logging.info("Wav2Vec2-BERT Model loaded and moved to device.")
    # print(model) # Uncomment this line and run to inspect the model structure and find the classifier name

except Exception as e:
    logging.error(f"Failed to load model '{model_checkpoint}': {e}", exc_info=True)
    raise SystemExit # Stop if model loading fails

2025-05-02 22:09:35,051 - INFO - Loading pre-trained Wav2Vec2-BERT model: facebook/w2v-bert-2.0


Some weights of Wav2Vec2BertForSequenceClassification were not initialized from the model checkpoint at facebook/w2v-bert-2.0 and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


2025-05-02 22:09:35,830 - INFO - Model loaded initially.
2025-05-02 22:09:35,831 - INFO - Found classifier attribute 'classifier' of type <class 'torch.nn.modules.linear.Linear'>
2025-05-02 22:09:35,832 - INFO - Replacing classifier head 'classifier'. Original out: 22, New out: 22
Successfully replaced classifier head 'classifier'.
2025-05-02 22:09:39,498 - INFO - Wav2Vec2-BERT Model loaded and moved to device.


In [10]:
# verify the correct attribute name for the classifier head for Wav2Vec2-BERT.
# print(model)


In [11]:
# Optimizer, Loss, Metrics Functoin

In [12]:
# Cell 6: Define Optimizer, Loss Function, and Metrics Calculation

import torch.optim as optim
from sklearn.metrics import hamming_loss, jaccard_score, f1_score # Make sure these are imported

# --- Optimizer ---
optimizer = optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)
logging.info(f"Optimizer AdamW defined with LR={learning_rate}, Weight Decay={weight_decay}")

# --- Loss Function ---
# Use BCEWithLogitsLoss for multi-label classification (includes Sigmoid)
criterion = nn.BCEWithLogitsLoss().to(device)
logging.info("Loss function BCEWithLogitsLoss defined.")

# --- Metrics Function ---
def compute_metrics(eval_preds):
    """Calculates multi-label metrics from logits and labels."""
    logits, labels = eval_preds
    # Ensure inputs are numpy arrays on CPU
    logits_np = logits.detach().cpu().numpy() if isinstance(logits, torch.Tensor) else logits
    labels_np = labels.detach().cpu().numpy() if isinstance(labels, torch.Tensor) else labels

    # Apply sigmoid and threshold
    probs = 1 / (1 + np.exp(-logits_np)) # Manual sigmoid
    preds = (probs > 0.5).astype(int)
    labels_np = labels_np.astype(int) # Ensure labels are integers

    if labels_np.shape != preds.shape:
         logging.error(f"Shape mismatch in compute_metrics! Labels: {labels_np.shape}, Preds: {preds.shape}")
         # Return default metrics indicating failure
         return {'hamming_loss': 1.0, 'jaccard_samples': 0.0, 'f1_micro': 0.0, 'f1_macro': 0.0}

    metrics = {}
    try:
        metrics['hamming_loss'] = hamming_loss(labels_np, preds)
        # Use average='samples' for Jaccard in multi-label scenario
        metrics['jaccard_samples'] = jaccard_score(labels_np, preds, average='samples', zero_division=0)
        metrics['f1_micro'] = f1_score(labels_np, preds, average='micro', zero_division=0)
        metrics['f1_macro'] = f1_score(labels_np, preds, average='macro', zero_division=0)
        # Optional: Add Accuracy (subset accuracy)
        # metrics['accuracy'] = accuracy_score(labels_np, preds) # This is exact match accuracy
    except Exception as e:
         logging.error(f"Error calculating metrics: {e}")
         metrics = {'hamming_loss': 1.0, 'jaccard_samples': 0.0, 'f1_micro': 0.0, 'f1_macro': 0.0}

    # Log inside the main evaluate function now for better context
    # logging.info(f"Metrics: Hamming={metrics['hamming_loss']:.4f}, Jaccard(samples)={metrics['jaccard_samples']:.4f}, F1 Micro={metrics['f1_micro']:.4f}, F1 Macro={metrics['f1_macro']:.4f}")
    return metrics

print("Optimizer, Loss, and compute_metrics function defined.")

2025-05-02 22:09:39,525 - INFO - Optimizer AdamW defined with LR=5e-05, Weight Decay=0.01
2025-05-02 22:09:39,527 - INFO - Loss function BCEWithLogitsLoss defined.
Optimizer, Loss, and compute_metrics function defined.


In [13]:
# Cell 7: Define Training Function for One Epoch (Corrected Model Input)

def train_epoch(model, dataloader, criterion, optimizer, device, gradient_accumulation_steps):
    model.train()
    total_loss = 0
    num_samples = 0
    optimizer.zero_grad()

    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for step, batch in enumerate(progress_bar):
        try:
            # --- CORRECTED INPUT PREPARATION ---
            # Explicitly use the keys the model expects.
            # Assume the feature extractor output key is 'input_values'
            # Assume the model's forward method expects 'input_features'
            # You MUST verify 'input_features' is the correct key for Wav2Vec2BertForSequenceClassification
            expected_model_input_key = "input_features" # <<<--- VERIFY THIS KEY NAME

            if 'input_values' not in batch: # Check if extractor output key is different
                 raise KeyError("Batch dictionary missing 'input_values' from Dataset/Extractor.")

            # Build the dictionary for the model's forward pass
            model_inputs = {
                expected_model_input_key: batch['input_values'].to(device) # Map dataset output key to model input key
            }
            if 'attention_mask' in batch and batch['attention_mask'] is not None:
                 model_inputs['attention_mask'] = batch['attention_mask'].to(device)
            # --- END CORRECTION ---

            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(**model_inputs) # Pass the correctly named arguments
            logits = outputs.logits

            # Calculate loss
            loss = criterion(logits, labels)

            # ... (rest of loss scaling, backward, optimizer step remains the same) ...
            if torch.isnan(loss):
                logging.warning(f"NaN loss detected at step {step}. Skipping batch.")
                if (step + 1) % gradient_accumulation_steps != 0: model.zero_grad()
                continue
            scaled_loss = loss / gradient_accumulation_steps
            scaled_loss.backward()
            batch_size_actual = labels.size(0)
            total_loss += loss.item() * batch_size_actual
            num_samples += batch_size_actual
            if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(dataloader):
                optimizer.step()
                optimizer.zero_grad()
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

        except Exception as e:
             logging.error(f"Error during training step {step}, batch keys: {batch.keys()}. Error: {e}", exc_info=True)
             continue

    if (step + 1) % gradient_accumulation_steps != 0 and num_samples > 0: # Ensure step was defined
         optimizer.step(); optimizer.zero_grad()

    avg_loss = total_loss / num_samples if num_samples > 0 else 0
    print(f"\nAverage Training Loss for Epoch: {avg_loss:.4f}")
    return avg_loss

print("train_epoch function updated.")

train_epoch function defined with added logging.


In [14]:
# Cell 8: Define Evaluation Function (Corrected Model Input)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_logits = []
    all_labels = []
    num_samples = 0

    with torch.no_grad():
        for step, batch in enumerate(tqdm(dataloader, desc="Evaluating", leave=False)):
            try:
                # --- CORRECTED INPUT PREPARATION ---
                expected_model_input_key = "input_features" # <<<--- VERIFY THIS KEY NAME

                if 'input_values' not in batch:
                     raise KeyError("Batch dictionary missing 'input_values' from Dataset/Extractor.")

                model_inputs = {
                    expected_model_input_key: batch['input_values'].to(device)
                }
                if 'attention_mask' in batch and batch['attention_mask'] is not None:
                     model_inputs['attention_mask'] = batch['attention_mask'].to(device)
                # --- END CORRECTION ---

                labels = batch['labels'].to(device)

                # Forward pass
                outputs = model(**model_inputs) # Pass the correctly named arguments
                logits = outputs.logits

                # Calculate loss
                loss = criterion(logits, labels)
                total_loss += loss.item() * labels.size(0)
                num_samples += labels.size(0)

                all_logits.append(logits.cpu())
                all_labels.append(labels.cpu())
            except Exception as e:
                 logging.error(f"Error during evaluation step {step}, batch keys: {batch.keys()}. Error: {e}", exc_info=True)
                 continue # Skip batch

    if not all_logits or not all_labels or num_samples == 0:
        logging.warning("Evaluation yielded no results (all batches failed or empty dataloader?).")
        return {}

    avg_loss = total_loss / num_samples

    all_logits_cat = torch.cat(all_logits, dim=0)
    all_labels_cat = torch.cat(all_labels, dim=0)

    eval_preds = (all_logits_cat, all_labels_cat)
    metrics = compute_metrics(eval_preds)
    metrics['eval_loss'] = avg_loss

    print(f"\nValidation Loss: {avg_loss:.4f}")
    for name, value in metrics.items():
         if name != 'eval_loss': print(f"  Validation {name.replace('_', ' ').title()}: {value:.4f}")

    return metrics

print("evaluate function updated.")

evaluate function defined with added logging.


In [15]:
# Cell 9: Run ONE Epoch for Debugging

from tqdm import tqdm # Ensure tqdm is imported

# Ensure model, criterion, optimizer, dataloaders etc. are defined from previous cells
print(f"\n--- Starting Debug Training Run for {num_epochs_debug} epoch ---")
start_time = time.time()

# Make sure model and criterion are on the correct device
model.to(device)
criterion.to(device)

for epoch in range(num_epochs_debug): # num_epochs_debug was set to 1 in Cell 2
    print(f"\n--- Debug Epoch {epoch+1}/{num_epochs_debug} ---")

    # Run training step for one epoch on the debug training data
    train_loss = train_epoch(
        model,
        debug_train_dataloader, # Use the SMALL debug dataloader
        criterion,
        optimizer,
        device,
        gradient_accumulation_steps # Pass grad accum steps
    )

    # Run evaluation step on the debug validation data
    eval_metrics = evaluate(
        model,
        debug_val_dataloader, # Use the SMALL debug dataloader
        criterion,
        device
    )

    print(f"\nDebug Epoch {epoch+1} finished.")
    print(f"  Avg Train Loss: {train_loss:.4f}")
    if eval_metrics:
        # Print all collected metrics
        for name, value in eval_metrics.items():
            print(f"  Validation {name.replace('_', ' ').title()}: {value:.4f}")
    else:
        print("  Validation failed to produce metrics.")

    # Optional: Save model after this 1 epoch for inspection
    save_path = os.path.join(model_save_dir, f"wav2vec2bert_debug_epoch_{epoch+1}.pth") # <<<--- Corrected filename
    try:
         torch.save(model.state_dict(), save_path)
         logging.info(f"Saved debug model checkpoint to {save_path}")
    except Exception as e:
         logging.error(f"Failed to save debug model checkpoint: {e}", exc_info=True)

end_time = time.time()
print(f"\n--- Debug Run Finished in {end_time - start_time:.2f} seconds ---")


--- Starting Debug Training Run for 1 epoch ---

--- Debug Epoch 1/1 ---


Training:   0%|          | 0/8 [00:00<?, ?it/s]

2025-05-02 22:09:40,966 - INFO - Train Batch Input 'input_values' shape: torch.Size([2, 1498, 160])
2025-05-02 22:09:40,974 - INFO - Train Batch Input 'attention_mask' shape: torch.Size([2, 1498])
2025-05-02 22:09:40,975 - INFO - Train Batch Labels shape: torch.Size([2, 22])
2025-05-02 22:09:40,975 - ERROR - Error during training step 0, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1632541273.py", line 21, in train_epoch
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forwar

Training:  12%|█▎        | 1/8 [00:01<00:09,  1.40s/it]

2025-05-02 22:09:41,208 - ERROR - Error during training step 1, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1632541273.py", line 21, in train_epoch
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'


Training:  25%|██▌       | 2/8 [00:01<00:04,  1.40it/s]

2025-05-02 22:09:41,465 - ERROR - Error during training step 2, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1632541273.py", line 21, in train_epoch
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'


Training:  38%|███▊      | 3/8 [00:01<00:02,  1.98it/s]

2025-05-02 22:09:41,672 - ERROR - Error during training step 3, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1632541273.py", line 21, in train_epoch
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'


Training:  50%|█████     | 4/8 [00:02<00:01,  2.58it/s]

2025-05-02 22:09:41,883 - ERROR - Error during training step 4, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1632541273.py", line 21, in train_epoch
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'


Training:  62%|██████▎   | 5/8 [00:02<00:00,  3.09it/s]

2025-05-02 22:09:42,102 - ERROR - Error during training step 5, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1632541273.py", line 21, in train_epoch
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'


Training:  75%|███████▌  | 6/8 [00:02<00:00,  3.47it/s]

2025-05-02 22:09:42,314 - ERROR - Error during training step 6, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1632541273.py", line 21, in train_epoch
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'


Training:  88%|████████▊ | 7/8 [00:02<00:00,  3.80it/s]

2025-05-02 22:09:42,519 - ERROR - Error during training step 7, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1632541273.py", line 21, in train_epoch
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'


                                                       


Average Training Loss for Epoch: 0.0000 (Total Loss: 0, Samples: 0)


Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]

2025-05-02 22:09:42,737 - INFO - Eval Batch Input 'input_values' shape: torch.Size([2, 1499, 160])
2025-05-02 22:09:42,738 - INFO - Eval Batch Input 'attention_mask' shape: torch.Size([2, 1499])
2025-05-02 22:09:42,738 - INFO - Eval Batch Labels shape: torch.Size([2, 22])
2025-05-02 22:09:42,739 - ERROR - Error during evaluation step 0, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1040071444.py", line 22, in evaluate
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_ca

Evaluating:  25%|██▌       | 1/4 [00:00<00:00,  4.75it/s]

2025-05-02 22:09:42,949 - ERROR - Error during evaluation step 1, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1040071444.py", line 22, in evaluate
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'


Evaluating:  50%|█████     | 2/4 [00:00<00:00,  4.75it/s]

2025-05-02 22:09:43,218 - ERROR - Error during evaluation step 2, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1040071444.py", line 22, in evaluate
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'


Evaluating:  75%|███████▌  | 3/4 [00:00<00:00,  4.22it/s]

2025-05-02 22:09:43,448 - ERROR - Error during evaluation step 3, batch keys: dict_keys(['input_values', 'attention_mask', 'labels']). Error: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'
Traceback (most recent call last):
  File "/tmp/ipykernel_578021/1040071444.py", line 22, in evaluate
    outputs = model(**model_inputs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zhuoyuan/miniconda3/envs/musicClaGen_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Wav2Vec2BertForSequenceClassification.forward() got an unexpected keyword argument 'input_values'


                                                         


Debug Epoch 1 finished.
  Avg Train Loss: 0.0000
  Validation failed to produce metrics.




2025-05-02 22:09:47,856 - INFO - Saved debug model checkpoint to /home/zhuoyuan/CSprojects/musicClaGen/models/wav2vec2bert_debug_epoch_1.pth

--- Debug Run Finished in 8.29 seconds ---
