In [2]:
# Standard library
import datetime
import json
import os
from collections import Counter
import random

os.chdir("../..")
print(os.getcwd())

# Numerical computing
import numpy as np

# PyTorch
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import DataLoader
import torchaudio.transforms as T
import torchaudio.functional as F

# Datasets
from datasets import load_dataset

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.model_summary import ModelSummary

# Metrics
from sklearn.utils.class_weight import compute_class_weight
from torchmetrics.classification import (
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)

# Visualization
import matplotlib.pyplot as plt

# ==== BACKBONE ====
# Hugging Face Transformers
from transformers import (
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Model,
    ViTFeatureExtractor, 
    ViTModel
)

# MAE-AST Library
from s3prl.nn.upstream import S3PRLUpstream

# Fix the length of the input audio to the same length
def _match_length_force(self, xs, target_max_len):
    xs_max_len = xs.size(1)
    if xs_max_len > target_max_len:
        xs = xs[:, :target_max_len, :]
    elif xs_max_len < target_max_len:
        pad_len = target_max_len - xs_max_len
        xs = torch.cat(
            (xs, xs[:, -1:, :].repeat(1, pad_len, 1)),
            dim=1
        )
    return xs

S3PRLUpstream._match_length = _match_length_force

/home/incantator/Documents/mbari-mae


In [3]:
# loading the dataset
data_dir = "data/watkins"
annotations_file_train = os.path.join(data_dir, "annotations.train.csv")
annotations_file_valid = os.path.join(data_dir, "annotations.valid.csv")
annotations_file_test = os.path.join(data_dir, "annotations.test.csv")

ds = load_dataset(
    "csv",
    data_files={"train": annotations_file_train,
                "validation": annotations_file_valid,
                "test": annotations_file_test},
)

for split_name in ["train", "validation", "test"]:
    split_dataset = ds[split_name]
    labels = split_dataset["label"]
    total = len(labels)
    counts = Counter(labels)

    print(f"{split_name.capitalize()} dataset: {total} examples, {len(counts)} classes")
    if "label" in split_dataset.features and hasattr(split_dataset.features["label"], "names"):
        class_names = split_dataset.features["label"].names
        for idx, name in enumerate(class_names):
            print(f"  {idx} ({name}): {counts.get(name, 0)}")
    else:
        for label, count in counts.items():
            print(f"  {label}: {count}")

Train dataset: 1017 examples, 31 classes
  Clymene_Dolphin: 38
  Bottlenose_Dolphin: 15
  Spinner_Dolphin: 69
  Beluga,_White_Whale: 30
  Bearded_Seal: 22
  Minke_Whale: 10
  Humpback_Whale: 38
  Southern_Right_Whale: 15
  White-sided_Dolphin: 33
  Narwhal: 30
  White-beaked_Dolphin: 34
  Northern_Right_Whale: 32
  Frasers_Dolphin: 52
  Grampus,_Rissos_Dolphin: 40
  Harp_Seal: 28
  Atlantic_Spotted_Dolphin: 35
  Fin,_Finback_Whale: 30
  Ross_Seal: 30
  Rough-Toothed_Dolphin: 30
  Killer_Whale: 21
  Pantropical_Spotted_Dolphin: 40
  Short-Finned_Pacific_Pilot_Whale: 40
  Bowhead_Whale: 36
  False_Killer_Whale: 35
  Melon_Headed_Whale: 38
  Long-Finned_Pilot_Whale: 42
  Striped_Dolphin: 49
  Leopard_Seal: 6
  Walrus: 23
  Sperm_Whale: 45
  Common_Dolphin: 31
Validation dataset: 339 examples, 31 classes
  Clymene_Dolphin: 12
  Bottlenose_Dolphin: 5
  Spinner_Dolphin: 23
  Beluga,_White_Whale: 10
  Bearded_Seal: 7
  Minke_Whale: 4
  Humpback_Whale: 13
  Southern_Right_Whale: 5
  White-side

In [4]:
# class weights calculation
train_labels = ds["train"]["label"]
unique_labels = sorted(set(train_labels))
label_to_int = {label: idx for idx, label in enumerate(unique_labels)}
y_train = [label_to_int[lbl] for lbl in train_labels]

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.arange(len(unique_labels)),
    y=y_train
)

num_classes = len(class_weights)

In [5]:
# model definition
class WMMDClassifier(pl.LightningModule):
    def __init__(
        self,
        num_classes: int,
        lr: float = 1e-3,
        backbone: str = "facebook/wav2vec2-base",
        ckpt_path: str = "",
        finetune: bool = False,
        class_weights=None,
    ):
        super().__init__()
        self.save_hyperparameters()

        # ========== WAV2VEC2 ==========
        if backbone == "facebook/wav2vec2-base":
            self.backbone     = Wav2Vec2Model.from_pretrained(backbone)
            self.embedding_dim = self.backbone.config.hidden_size
        
        # ========== TINY WAV2VEC2 ==========
        elif backbone == "patrickvonplaten/tiny-wav2vec2-no-tokenizer":
            self.backbone     = Wav2Vec2Model.from_pretrained(backbone)
            self.embedding_dim = self.backbone.config.hidden_size
        
        # ========== MAE-AST ==========
        elif backbone == "mae-ast":
            up_kwargs = {"name": "mae_ast_patch"}
            s3 = S3PRLUpstream(**up_kwargs)

            enc = s3.upstream.model.encoder
            enc.layers = nn.ModuleList(list(enc.layers)[:4])
            s3.upstream.model.dec_sine_pos_embed = None
            s3.upstream.model.decoder = None
            s3.upstream.model.final_proj_reconstruction = None
            s3.upstream.model.final_proj_classification  = None

            new_n = len(enc.layers)
            s3._num_layers       = new_n
            s3._hidden_sizes     = s3._hidden_sizes[:new_n]
            s3._downsample_rates = s3._downsample_rates[:new_n]

            self.backbone      = s3
            self.embedding_dim = s3.hidden_sizes[-1]

            # Load the checkpoint for mae ast
            if ckpt_path:
                self.load_mae_ckpt(ckpt_path)
            
            mel_transform = T.MelSpectrogram(
                sample_rate=2000, 
                n_fft=1024, 
                win_length=512,
                hop_length=20, 
                n_mels=128,
            )

            class DBWithDeltas(nn.Module):
                def __init__(self):
                    super().__init__()

                def forward(self, spec):
                    spec_db = F.amplitude_to_DB(
                        spec,
                        multiplier=10.0,
                        amin=1e-10,
                        db_multiplier=0
                    )
                    t = spec_db.transpose(0, 1)
                    d1 = F.compute_deltas(t.transpose(0,1))
                    d2 = F.compute_deltas(d1)
                    return torch.cat([
                        t,
                        d1.transpose(0,1),
                        d2.transpose(0,1)
                    ], dim=1)
            
            s3.upstream.model.fbank   = mel_transform
            s3.upstream.model.db_norm = DBWithDeltas()
        
        # ========== CNN-SPECTROGRAM ==========
        elif backbone == "cnn-spect":
            self.mel_spec = torchaudio.transforms.MelSpectrogram(
                sample_rate=2000,
                n_fft=1024,
                hop_length=20,
                win_length=512,
                n_mels=128
            )
            self.to_db = torchaudio.transforms.AmplitudeToDB(top_db=80)

            self.backbone = nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
                nn.Conv2d(32,32,3,padding=1), nn.BatchNorm2d(32), nn.ReLU(),
                nn.MaxPool2d(2),

                nn.Conv2d(32,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(),
                nn.Conv2d(64,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(),
                nn.MaxPool2d(2),

                nn.Conv2d(64,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(),
                nn.Conv2d(128,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(),
                nn.MaxPool2d(2),

                nn.Conv2d(128,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(),
                nn.Conv2d(256,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(),
                nn.MaxPool2d(2),

                nn.Conv2d(256,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(),
                nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(),
                nn.AdaptiveAvgPool2d((1,1)),
                nn.Flatten()
            )
            self.embedding_dim = 512
        
        # ========== VIT-IMAGENET ==========
        elif backbone == "vit-imagenet":
            self.mel_spec = torchaudio.transforms.MelSpectrogram(
                sample_rate=2000,
                n_fft=1024,
                hop_length=20,
                win_length=512,
                n_mels=128
            )
            self.to_db = torchaudio.transforms.AmplitudeToDB(top_db=80)

            self.feature_extractor = ViTFeatureExtractor.from_pretrained(
                "google/vit-base-patch16-224-in21k"
            )

            self.backbone = ViTModel.from_pretrained(
                "google/vit-base-patch16-224-in21k",
                add_pooling_layer=False
            )
            self.embedding_dim = self.backbone.config.hidden_size

        else:
            raise ValueError(f"Unsupported backbone '{backbone}'")

        try:
            self.backbone.gradient_checkpointing_enable()
        except Exception:
            pass

        for param in self.backbone.parameters():
            param.requires_grad = finetune

        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.embedding_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, num_classes),
        )

        if class_weights is not None:
            cw = torch.tensor(class_weights, dtype=torch.float)
            self.criterion = nn.CrossEntropyLoss(weight=cw)
        else:
            self.criterion = nn.CrossEntropyLoss()

        metrics_kwargs = dict(num_classes=num_classes, average='macro')
        self.train_precision = MulticlassPrecision(**metrics_kwargs)
        self.train_recall = MulticlassRecall(**metrics_kwargs)
        self.train_f1 = MulticlassF1Score(**metrics_kwargs)
        self.val_precision = MulticlassPrecision(**metrics_kwargs)
        self.val_recall = MulticlassRecall(**metrics_kwargs)
        self.val_f1 = MulticlassF1Score(**metrics_kwargs)
        self.test_precision = MulticlassPrecision(**metrics_kwargs)
        self.test_recall = MulticlassRecall(**metrics_kwargs)
        self.test_f1 = MulticlassF1Score(**metrics_kwargs)

    def forward(self, x):
        bname = self.hparams.backbone.lower()
        if bname in {"facebook/wav2vec2-base",
             "patrickvonplaten/tiny-wav2vec2-no-tokenizer"}:
            hidden = self.backbone(x).last_hidden_state

        elif bname == "mae-ast":
            if x.dim() == 3:
                x = x.squeeze(-1)
            wav_lens = torch.full((x.size(0),), x.size(1),
                                dtype=torch.long, device=x.device)
            hidden = self.backbone(x, wav_lens)[0][-1]

        elif bname == "cnn-spect":
            spec = self._safe_logmelspec(x, self.mel_spec, self.to_db)  # (B,1,F,L) in [0,1]
            hidden = self.backbone(spec)

        elif bname == "vit-imagenet":
            spec = self._safe_logmelspec(x, self.mel_spec, self.to_db)  # (B,1,F,L) in [0,1]

            img = torch.nn.functional.interpolate(
                spec, size=(224, 224), mode="bilinear", align_corners=False
            )

            pixel_values = self.feature_extractor(images=img, return_tensors="pt").pixel_values.to(img.device)
            outputs = self.backbone(pixel_values=pixel_values)
            hidden = outputs.last_hidden_state.mean(dim=1)

        else:
            raise ValueError(f"Unsupported backbone '{self.hparams.backbone}'")

        emb = hidden if hidden.dim() == 2 else hidden.mean(dim=1)
        return self.classifier(emb)


    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='train')
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='val')

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='test')

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

    def log_batch_metrics(self, loss, preds, targets, prefix):
        self.log(f'{prefix}_loss', loss, prog_bar=True, on_epoch=True)
        acc = (preds == targets).float().mean()
        self.log(f'{prefix}_acc', acc, prog_bar=True, on_epoch=True)
        precision = getattr(self, f'{prefix}_precision')(preds, targets)
        recall = getattr(self, f'{prefix}_recall')(preds, targets)
        f1 = getattr(self, f'{prefix}_f1')(preds, targets)
        self.log(f'{prefix}_precision', precision, on_epoch=True)
        self.log(f'{prefix}_recall', recall, on_epoch=True)
        self.log(f'{prefix}_f1', f1, on_epoch=True)

    def on_train_end(self):
        save_dir = getattr(self, 'save_dir', None)
        if save_dir:
            self.save_model(save_dir)

    def save_model(self):
        base_dir = 'model'
        bn = self.hparams.backbone.replace('/', '_')
        cw = getattr(self.hparams, 'class_weights', None)
        balance_flag = 'imbalance' if cw is not None else 'balance'
        timestamp = getattr(self, 'finish_time', datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
        folder = os.path.join(base_dir, f"{bn}_{balance_flag}", timestamp)
        os.makedirs(folder, exist_ok=True)

        ckpt_path = os.path.join(folder, f"{timestamp}.pt")
        payload = {
            'state_dict': self.state_dict(),
            'hparams': dict(self.hparams)
        }
        for attr in ('test_results', 'finish_time', 'epochs_trained'):
            if hasattr(self, attr):
                payload[attr] = getattr(self, attr)
        torch.save(payload, ckpt_path)

        stats_path = os.path.join(folder, f"{timestamp}.txt")
        raw_hparams = dict(self.hparams)
        serializable_hparams = {}
        for k, v in raw_hparams.items():
            if isinstance(v, np.ndarray):
                serializable_hparams[k] = v.tolist()
            elif isinstance(v, torch.Tensor):
                serializable_hparams[k] = v.cpu().item() if v.ndim == 0 else v.cpu().tolist()
            else:
                serializable_hparams[k] = v

        serializable_results = {}
        if hasattr(self, 'test_results'):
            for k, v in self.test_results.items():
                serializable_results[k] = v.cpu().item() if isinstance(v, torch.Tensor) else v

        with open(stats_path, 'w') as f:
            f.write(f"Model architecture:\n{self}\n\n")
            f.write("Hyperparameters:\n")
            f.write(json.dumps(serializable_hparams, indent=4))
            f.write("\n\n")
            if serializable_results:
                f.write("Test results:\n")
                f.write(json.dumps(serializable_results, indent=4))
                f.write("\n\n")
            if hasattr(self, 'epochs_trained'):
                f.write(f"Epochs trained: {self.epochs_trained}\n")

        self._last_save_dir = folder
        self._last_timestamp = timestamp
        print(f"Artifacts saved to {folder}/")
    
    def load_mae_ckpt(self, ckpt_source: str):
        """
        Load MAE-AST checkpoint into the truncated upstream model
        and verify exact weight equality for each loaded parameter.
        """
        loaded = torch.load(ckpt_source)
        state_dict = loaded.get('model', loaded)

        up = self.backbone.upstream.model
        up_state = up.state_dict()

        to_load = {k: v for k, v in state_dict.items()
                   if k in up_state and v.shape == up_state[k].shape}

        missing, unexpected = up.load_state_dict(to_load, strict=False)

        for k, v in to_load.items():
            if not torch.equal(up_state[k], v):
                raise RuntimeError(f"Weight mismatch at '{k}' after loading checkpoint")

        print(f"Successfully loaded {len(to_load)} parameters; missing: {len(missing)}, unexpected: {len(unexpected)}")
    
    @staticmethod
    def _safe_logmelspec(waveform, mel_spec, to_db, db_range=80.0):
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(1)
        spec = mel_spec(waveform) + 1e-10
        spec_db = to_db(spec)
        spec_db = torch.nan_to_num(spec_db, neginf=-db_range)
        t = spec_db
        d1 = F.compute_deltas(t.squeeze(1)).unsqueeze(1)
        d2 = F.compute_deltas(d1.squeeze(1)).unsqueeze(1)
        spec3 = torch.cat([t, d1, d2], dim=1)
        spec3 = ((spec3 + db_range)/db_range).clamp(0.0, 1.0)
        return spec3

    @classmethod
    def load_model(cls, load_dir: str, map_location=None):
        """
        Load a model checkpoint and hyperparameters from a directory.

        Returns:
            model (MammalClassifier): Loaded model
        """
        hparams_path = os.path.join(load_dir, 'hparams.json')
        with open(hparams_path, 'r') as f:
            hparams = json.load(f)

        model = cls(**hparams)
        ckpt_path = os.path.join(load_dir, f'{cls.__name__}.ckpt')
        state = torch.load(ckpt_path, map_location=map_location)
        model.load_state_dict(state['state_dict'])
        return model

In [6]:
# sanity check
model = WMMDClassifier(
        num_classes=31, lr=1e-3,
        backbone="cnn-spect", finetune=True,
        class_weights=class_weights,
        ckpt_path=""
    )

print(ModelSummary(model, max_depth=1))

   | Name            | Type                | Params | Mode 
-----------------------------------------------------------------
0  | mel_spec        | MelSpectrogram      | 0      | train
1  | to_db           | AmplitudeToDB       | 0      | train
2  | backbone        | Sequential          | 4.7 M  | train
3  | classifier      | Sequential          | 557 K  | train
4  | criterion       | CrossEntropyLoss    | 0      | train
5  | train_precision | MulticlassPrecision | 0      | train
6  | train_recall    | MulticlassRecall    | 0      | train
7  | train_f1        | MulticlassF1Score   | 0      | train
8  | val_precision   | MulticlassPrecision | 0      | train
9  | val_recall      | MulticlassRecall    | 0      | train
10 | val_f1          | MulticlassF1Score   | 0      | train
11 | test_precision  | MulticlassPrecision | 0      | train
12 | test_recall     | MulticlassRecall    | 0      | train
13 | test_f1         | MulticlassF1Score   | 0      | train
----------------------------------

In [7]:
def WMMD_Collate(batch):
    waveforms, labels = zip(*batch)

    lengths    = [w.shape[0] for w in waveforms]
    raw_max    = max(lengths)

    min_len    = 5_000
    max_len    = 25_000
    target_len = min(max(raw_max, min_len), max_len)

    padded_waveforms = []
    for w in waveforms:
        L = w.shape[0]

        if L > target_len:
            start = random.randint(0, L - target_len)
            w2    = w[start : start + target_len]

        elif L < target_len:
            pad_amt = target_len - L
            w2      = torch.nn.functional.pad(w, (0, pad_amt))

        else:
            w2 = w

        padded_waveforms.append(w2)

    batch_waveforms = torch.stack(padded_waveforms, dim=0)
    batch_labels    = torch.tensor(labels, dtype=torch.long)
    return batch_waveforms, batch_labels

class WMMDSoundDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, backbone: str, target_sr: int = 2000):
        """
        dataset: list of dicts with keys 'path' & 'label'
        backbone: 'facebook/wav2vec2-base' or 'mae-ast'
        target_sr: sampling rate (e.g. 2000)
        """
        self.dataset = dataset
        self.backbone = backbone
        self.target_sr = target_sr
        self.resampler_cache = {}

        if self.backbone in {"facebook/wav2vec2-base", "patrickvonplaten/tiny-wav2vec2-no-tokenizer"}:
            self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
                self.backbone, return_attention_mask=False, sampling_rate=self.target_sr
            )
        elif self.backbone in {"mae-ast", "cnn-spect", "vit-imagenet"}:
            self.processor = None
        else:
            raise ValueError(f"Unsupported backbone '{self.backbone}'")

        labels = sorted({item['label'] for item in dataset})
        self.label_to_int = {lbl: i for i, lbl in enumerate(labels)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        audio_path = item["path"]
        waveform, orig_sr = torchaudio.load(audio_path)

        if orig_sr != self.target_sr:
            if orig_sr not in self.resampler_cache:
                self.resampler_cache[orig_sr] = torchaudio.transforms.Resample(orig_sr, self.target_sr)
            waveform = self.resampler_cache[orig_sr](waveform)

        waveform = waveform / (waveform.abs().max() + 1e-6)
        wav_1d = waveform.squeeze(0)  
        
        if self.backbone in {"facebook/wav2vec2-base", "patrickvonplaten/tiny-wav2vec2-no-tokenizer"}:
            arr = wav_1d.numpy()
            feats = self.processor(arr, sampling_rate=self.target_sr, return_tensors="pt")
            inp = feats.input_values.squeeze(0)
        elif self.backbone in {"cnn-spect", "vit-imagenet", "mae-ast"}:
            inp = wav_1d            
        else:
            raise ValueError(f"Unsupported backbone '{self.backbone}'")

        lbl = self.label_to_int[item['label']]
        return inp, lbl

class WMMDDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_dict,
        backbone: str,
        batch_size: int = 2,
        num_workers: int = 1
    ):
        super().__init__()
        self.dataset_dict = dataset_dict
        self.backbone = backbone
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_ds = WMMDSoundDataset(self.dataset_dict["train"], backbone=self.backbone)
        self.val_ds   = WMMDSoundDataset(self.dataset_dict["validation"], backbone=self.backbone)
        self.test_ds  = WMMDSoundDataset(self.dataset_dict["test"], backbone=self.backbone)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

In [8]:
# callback for logging metrics
class MetricsLogger(pl.Callback):
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []
        self.train_precisions = []
        self.val_precisions = []
        self.train_recalls = []
        self.val_recalls = []
        self.train_f1s = []
        self.val_f1s = []

    def on_train_epoch_end(self, trainer, pl_module):
        m = trainer.callback_metrics
        self.train_losses.append(m['train_loss'].item())
        self.train_accs.append(m['train_acc'].item())
        self.train_precisions.append(m['train_precision'].item())
        self.train_recalls.append(m['train_recall'].item())
        self.train_f1s.append(m['train_f1'].item())

    def on_validation_epoch_end(self, trainer, pl_module):
        m = trainer.callback_metrics
        self.val_losses.append(m['val_loss'].item())
        self.val_accs.append(m['val_acc'].item())
        self.val_precisions.append(m['val_precision'].item())
        self.val_recalls.append(m['val_recall'].item())
        self.val_f1s.append(m['val_f1'].item())

In [9]:
model_configs = [
    # Wav2Vec 2.0 (fine-tune)
    {
        "num_classes": num_classes,
        "lr": 1e-6,
        "backbone": "facebook/wav2vec2-base",
        "finetune": True,
        "class_weights": class_weights,
        "max_epochs": 500,
        "ckpt_path": "",
        "batch_size": 2,
    },
    # # Tiny‐Wav2Vec 2.0 (fine-tune)
    # {
    #     "num_classes": num_classes,
    #     "lr": 1e-6,
    #     "backbone": "patrickvonplaten/tiny-wav2vec2-no-tokenizer",
    #     "finetune": True,
    #     "class_weights": class_weights,
    #     "max_epochs": 500,
    #     "ckpt_path": "",
    #     "batch_size": 2,
    # },
    # # CNN-Spectrogram (train from scratch)
    # {
    #     "num_classes": num_classes,
    #     "lr": 1e-3,
    #     "backbone": "cnn-spect",
    #     "finetune": True,
    #     "class_weights": class_weights,
    #     "max_epochs": 500,
    #     "ckpt_path": "",
    #     "batch_size": 2,
    # },
    # ViT-ImageNet (fine-tune)
    {
        "num_classes": num_classes,
        "lr": 1e-5,
        "backbone": "vit-imagenet",
        "finetune": True,
        "class_weights": class_weights,
        "max_epochs": 500,
        "ckpt_path": "",
        "batch_size": 2,
    },
    # MAE-AST (fine-tune)
    # {
    #     "num_classes": num_classes,
    #     "lr": 1e-6,
    #     "backbone": "mae-ast",
    #     "finetune": True,
    #     "class_weights": class_weights,
    #     "max_epochs": 500,
    #     "ckpt_path": "4Enc_1Dec-61epoch-0.103loss.pt",
    #     "batch_size": 2,
    # },
]


In [10]:
# training Loops
for cfg in model_configs:
    dm = WMMDDataModule(
        dataset_dict=ds,
        backbone=cfg["backbone"],
        batch_size = cfg.get("batch_size", 2),
        num_workers=0
    )

    model = WMMDClassifier(
        num_classes=cfg['num_classes'], lr=cfg['lr'],
        backbone=cfg['backbone'], finetune=cfg['finetune'],
        class_weights=cfg['class_weights'], 
        ckpt_path=cfg['ckpt_path']
    )
    metrics_cb = MetricsLogger()
    # callback for early stopping
    early_stopping = EarlyStopping(
        monitor='val_loss',
        mode='min',
        patience=20,
        min_delta=0.01,
        verbose=True
    )
    callbacks = [metrics_cb, early_stopping]

    trainer = pl.Trainer(
        max_epochs=cfg['max_epochs'],
        accelerator='gpu', devices=1,
        precision='16-mixed', 
        accumulate_grad_batches=2,
        check_val_every_n_epoch=1,
        num_sanity_val_steps=0,
        enable_progress_bar=True,
        log_every_n_steps=1,
        callbacks=callbacks
    )

    trainer.fit(model, dm)
    test_res = trainer.test(model, dm)[0]

    model.test_results = test_res
    model.finish_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    model.epochs_trained = trainer.current_epoch + 1
    model.save_model()

    metrics_map = {
        'accuracy':   ('train_accs',      'val_accs'),
        'precision':  ('train_precisions','val_precisions'),
        'recall':     ('train_recalls',   'val_recalls'),
        'f1_score':   ('train_f1s',       'val_f1s'),
    }
    
    for metric_name, (train_attr, val_attr) in metrics_map.items():
        train_vals = getattr(metrics_cb, train_attr)
        val_vals = getattr(metrics_cb, val_attr)
        epochs = list(range(1, len(train_vals) + 1))

        plt.figure()
        plt.plot(epochs, train_vals, label=f'train_{metric_name}')
        plt.plot(epochs, val_vals,   label=f'val_{metric_name}')
        plt.xlabel('Epoch')
        plt.ylabel(metric_name.replace('_', ' ').title())
        plt.title(f"{metric_name.replace('_', ' ').title()} over Epochs {model._last_timestamp}")
        plt.grid(True)
        plt.legend(loc='best')

        plot_file = os.path.join(model._last_save_dir, f"{model._last_timestamp}_{metric_name}.png")
        plt.savefig(plot_file)
        plt.close()

    print(f"Completed {cfg['backbone']} ({'FT' if cfg['finetune'] else 'Frozen'}), artifacts in {model._last_save_dir}")




ValueError: Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply when loading files with safetensors.
See the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434