In [7]:
# Standard library
import datetime
import json
import os
import random
from collections import Counter

os.chdir("../..")
print(os.getcwd())

# Numerical computing
import numpy as np

# PyTorch
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import DataLoader

# Hugging Face Transformers
from transformers import (
    HubertModel,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Model,
)

# Datasets
from datasets import load_dataset

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

# Metrics
from sklearn.utils.class_weight import compute_class_weight
from torchmetrics.classification import (
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)

# Visualization
import matplotlib.pyplot as plt

/home/incantator/Documents/mbari-mae


In [8]:
# Loading the dataset
data_dir = "data/watkins"
annotations_file_train = os.path.join(data_dir, "annotations.train.csv")
annotations_file_valid = os.path.join(data_dir, "annotations.valid.csv")
annotations_file_test = os.path.join(data_dir, "annotations.test.csv")

ds = load_dataset(
    "csv",
    data_files={"train": annotations_file_train,
                "validation": annotations_file_valid,
                "test": annotations_file_test},
)

for split_name in ["train", "validation", "test"]:
    split_dataset = ds[split_name]
    labels = split_dataset["label"]
    total = len(labels)
    counts = Counter(labels)

    print(f"{split_name.capitalize()} dataset: {total} examples, {len(counts)} classes")
    if "label" in split_dataset.features and hasattr(split_dataset.features["label"], "names"):
        class_names = split_dataset.features["label"].names
        for idx, name in enumerate(class_names):
            print(f"  {idx} ({name}): {counts.get(name, 0)}")
    else:
        for label, count in counts.items():
            print(f"  {label}: {count}")

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Train dataset: 1017 examples, 31 classes
  Clymene_Dolphin: 38
  Bottlenose_Dolphin: 15
  Spinner_Dolphin: 69
  Beluga,_White_Whale: 30
  Bearded_Seal: 22
  Minke_Whale: 10
  Humpback_Whale: 38
  Southern_Right_Whale: 15
  White-sided_Dolphin: 33
  Narwhal: 30
  White-beaked_Dolphin: 34
  Northern_Right_Whale: 32
  Frasers_Dolphin: 52
  Grampus,_Rissos_Dolphin: 40
  Harp_Seal: 28
  Atlantic_Spotted_Dolphin: 35
  Fin,_Finback_Whale: 30
  Ross_Seal: 30
  Rough-Toothed_Dolphin: 30
  Killer_Whale: 21
  Pantropical_Spotted_Dolphin: 40
  Short-Finned_Pacific_Pilot_Whale: 40
  Bowhead_Whale: 36
  False_Killer_Whale: 35
  Melon_Headed_Whale: 38
  Long-Finned_Pilot_Whale: 42
  Striped_Dolphin: 49
  Leopard_Seal: 6
  Walrus: 23
  Sperm_Whale: 45
  Common_Dolphin: 31
Validation dataset: 339 examples, 31 classes
  Clymene_Dolphin: 12
  Bottlenose_Dolphin: 5
  Spinner_Dolphin: 23
  Beluga,_White_Whale: 10
  Bearded_Seal: 7
  Minke_Whale: 4
  Humpback_Whale: 13
  Southern_Right_Whale: 5
  White-side

In [13]:
# Class weights calculation
train_labels = ds["train"]["label"]
unique_labels = sorted(set(train_labels))
label_to_int = {label: idx for idx, label in enumerate(unique_labels)}
y_train = [label_to_int[lbl] for lbl in train_labels]

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.arange(len(unique_labels)),
    y=y_train
)

num_classes = len(class_weights)

In [14]:
# Model definition
class WMMDClassifier(pl.LightningModule):
    def __init__(
        self,
        num_classes: int,
        lr: float = 1e-4,
        backbone: str = "facebook/wav2vec2-base",
        ckpt_path: str = "",
        finetune: bool = False,
        class_weights=None,
    ):
        super().__init__()
        self.save_hyperparameters()

        if backbone == "facebook/wav2vec2-base":
            self.backbone     = Wav2Vec2Model.from_pretrained(backbone)
            self.embedding_dim = self.backbone.config.hidden_size

        else:
            raise ValueError(f"Unsupported backbone '{backbone}'")

        try:
            self.backbone.gradient_checkpointing_enable()
        except Exception:
            pass

        for param in self.backbone.parameters():
            param.requires_grad = finetune

        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.embedding_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, num_classes),
        )

        if class_weights is not None:
            cw = torch.tensor(class_weights, dtype=torch.float)
            self.criterion = nn.CrossEntropyLoss(weight=cw)
        else:
            self.criterion = nn.CrossEntropyLoss()

        metrics_kwargs = dict(num_classes=num_classes, average='macro')
        self.train_precision = MulticlassPrecision(**metrics_kwargs)
        self.train_recall = MulticlassRecall(**metrics_kwargs)
        self.train_f1 = MulticlassF1Score(**metrics_kwargs)
        self.val_precision = MulticlassPrecision(**metrics_kwargs)
        self.val_recall = MulticlassRecall(**metrics_kwargs)
        self.val_f1 = MulticlassF1Score(**metrics_kwargs)
        self.test_precision = MulticlassPrecision(**metrics_kwargs)
        self.test_recall = MulticlassRecall(**metrics_kwargs)
        self.test_f1 = MulticlassF1Score(**metrics_kwargs)

    def forward(self, x):
        bname = self.hparams.backbone.lower()
        if bname == "facebook/wav2vec2-base":
            out = self.backbone(x)
            hidden = out.last_hidden_state
        else:
            raise ValueError(f"Unsupported backbone in forward(): '{self.hparams.backbone}'")

        emb = hidden.mean(dim=1)
        return self.classifier(emb)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='train')
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='val')

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='test')

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

    def log_batch_metrics(self, loss, preds, targets, prefix):
        self.log(f'{prefix}_loss', loss, prog_bar=True, on_epoch=True)
        acc = (preds == targets).float().mean()
        self.log(f'{prefix}_acc', acc, prog_bar=True, on_epoch=True)
        precision = getattr(self, f'{prefix}_precision')(preds, targets)
        recall = getattr(self, f'{prefix}_recall')(preds, targets)
        f1 = getattr(self, f'{prefix}_f1')(preds, targets)
        self.log(f'{prefix}_precision', precision, on_epoch=True)
        self.log(f'{prefix}_recall', recall, on_epoch=True)
        self.log(f'{prefix}_f1', f1, on_epoch=True)

    # 2. Save the model state at the end of training
    def on_train_end(self):
        save_dir = getattr(self, 'save_dir', None)
        if save_dir:
            self.save_model(save_dir)

    def save_model(self):
        # New folder structure: /model/<model_name>_<balance_flag>/<timestamp>/
        base_dir = 'model'
        bn = self.hparams.backbone.replace('/', '_')
        # Determine balance flag based on class_weights
        cw = getattr(self.hparams, 'class_weights', None)
        balance_flag = 'imbalance' if cw is not None else 'balance'
        timestamp = getattr(self, 'finish_time', datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
        folder = os.path.join(base_dir, f"{bn}_{balance_flag}", timestamp)
        os.makedirs(folder, exist_ok=True)

        # 1) checkpoint
        ckpt_path = os.path.join(folder, f"{timestamp}.pt")
        payload = {
            'state_dict': self.state_dict(),
            'hparams': dict(self.hparams)
        }
        for attr in ('test_results', 'finish_time', 'epochs_trained'):
            if hasattr(self, attr):
                payload[attr] = getattr(self, attr)
        torch.save(payload, ckpt_path)

        # 2) human-readable stats
        stats_path = os.path.join(folder, f"{timestamp}.txt")
        raw_hparams = dict(self.hparams)
        serializable_hparams = {}
        for k, v in raw_hparams.items():
            if isinstance(v, np.ndarray):
                serializable_hparams[k] = v.tolist()
            elif isinstance(v, torch.Tensor):
                serializable_hparams[k] = v.cpu().item() if v.ndim == 0 else v.cpu().tolist()
            else:
                serializable_hparams[k] = v

        serializable_results = {}
        if hasattr(self, 'test_results'):
            for k, v in self.test_results.items():
                serializable_results[k] = v.cpu().item() if isinstance(v, torch.Tensor) else v

        with open(stats_path, 'w') as f:
            f.write(f"Model architecture:\n{self}\n\n")
            f.write("Hyperparameters:\n")
            f.write(json.dumps(serializable_hparams, indent=4))
            f.write("\n\n")
            if serializable_results:
                f.write("Test results:\n")
                f.write(json.dumps(serializable_results, indent=4))
                f.write("\n\n")
            if hasattr(self, 'epochs_trained'):
                f.write(f"Epochs trained: {self.epochs_trained}\n")

        # store for downstream use
        self._last_save_dir = folder
        self._last_timestamp = timestamp
        print(f"Artifacts saved to {folder}/")

    # 3. Load back the model for inference or continuation
    @classmethod
    def load_model(cls, load_dir: str, map_location=None):
        """
        Load a model checkpoint and hyperparameters from a directory.

        Returns:
            model (MammalClassifier): Loaded model
        """
        # Load hyperparameters
        hparams_path = os.path.join(load_dir, 'hparams.json')
        with open(hparams_path, 'r') as f:
            hparams = json.load(f)

        # Instantiate model
        model = cls(**hparams)
        # Load checkpoint
        ckpt_path = os.path.join(load_dir, f'{cls.__name__}.ckpt')
        state = torch.load(ckpt_path, map_location=map_location)
        model.load_state_dict(state['state_dict'])
        return model


In [15]:
def WMMD_Collate(batch):
    waveforms, labels = zip(*batch)
    max_len = max(w.shape[0] for w in waveforms)
    min_len = 400
    padded_len = max(max_len, min_len)

    padded_waveforms = []
    for waveform in waveforms:
        padding_needed = padded_len - waveform.shape[0]
        padded_waveform = torch.nn.functional.pad(waveform, (0, padding_needed))
        padded_waveforms.append(padded_waveform)

    padded = torch.stack(padded_waveforms, dim=0)
    labels = torch.tensor(labels, dtype=torch.long)
    return padded, labels

class WMMDSoundDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, backbone: str, target_sr: int = 2000):
        """
        dataset: list of dicts with keys 'path' & 'label'
        backbone: 'facebook/wav2vec2-base' or 'mae-ast'
        target_sr: sampling rate (e.g. 2000)
        """
        self.dataset = dataset
        self.backbone = backbone.lower()
        self.target_sr = target_sr
        self.resampler_cache = {}

        if self.backbone == "facebook/wav2vec2-base":
            self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
                "facebook/wav2vec2-base", return_attention_mask=False, sampling_rate=target_sr
            )
        elif self.backbone == "mae-ast":
            self.processor = None
        else:
            raise ValueError(f"Unsupported backbone '{backbone}'")

        labels = sorted({item['label'] for item in dataset})
        self.label_to_int = {lbl: i for i, lbl in enumerate(labels)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        audio_path = item["path"]
        waveform, orig_sr = torchaudio.load(audio_path)

        if orig_sr != self.target_sr:
            if orig_sr not in self.resampler_cache:
                self.resampler_cache[orig_sr] = torchaudio.transforms.Resample(orig_sr, self.target_sr)
            waveform = self.resampler_cache[orig_sr](waveform)

        waveform = waveform / (waveform.abs().max() + 1e-6)
        wav_1d = waveform.squeeze(0)  
        
        if self.backbone == "facebook/wav2vec2-base":
            arr = wav_1d.numpy()
            feats = self.processor(arr, sampling_rate=self.target_sr, return_tensors="pt")
            inp = feats.input_values.squeeze(0)
        else:
            raise ValueError(f"Unsupported backbone '{self.backbone}'")

        lbl = self.label_to_int[item['label']]
        return inp, lbl

class WMMDDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_dict,
        backbone: str,
        batch_size: int = 2,
        num_workers: int = 1
    ):
        super().__init__()
        self.dataset_dict = dataset_dict
        self.backbone = backbone   # ← store it here
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        # pass backbone into each split
        self.train_ds = WMMDSoundDataset(self.dataset_dict["train"], backbone=self.backbone)
        self.val_ds   = WMMDSoundDataset(self.dataset_dict["validation"], backbone=self.backbone)
        self.test_ds  = WMMDSoundDataset(self.dataset_dict["test"], backbone=self.backbone)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

    def test_dataloader(self):
        # Use the test split for testing
        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

In [16]:
#callback for logging metrics
class MetricsLogger(pl.Callback):
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []
        self.train_precisions = []
        self.val_precisions = []
        self.train_recalls = []
        self.val_recalls = []
        self.train_f1s = []
        self.val_f1s = []

    def on_train_epoch_end(self, trainer, pl_module):
        m = trainer.callback_metrics
        self.train_losses.append(m['train_loss'].item())
        self.train_accs.append(m['train_acc'].item())
        self.train_precisions.append(m['train_precision'].item())
        self.train_recalls.append(m['train_recall'].item())
        self.train_f1s.append(m['train_f1'].item())

    def on_validation_epoch_end(self, trainer, pl_module):
        m = trainer.callback_metrics
        self.val_losses.append(m['val_loss'].item())
        self.val_accs.append(m['val_acc'].item())
        self.val_precisions.append(m['val_precision'].item())
        self.val_recalls.append(m['val_recall'].item())
        self.val_f1s.append(m['val_f1'].item())

# callback for early stopping
early_stopping = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=10,
    min_delta=0.01,
    verbose=True
)

In [17]:
# Model configurations
model_configs = [
    {"num_classes": num_classes, "lr": 1e-3, "backbone": "facebook/wav2vec2-base",  "finetune": False, "class_weights": class_weights, "max_epochs": 3},
    # {"num_classes": num_classes, "lr": 1e-3, "backbone": "mae-ast",  "finetune": False, "class_weights": class_weights, "max_epochs": 3, "ckpt_path": "4Enc_1Dec-61epoch-0.103loss.pt"},
]

In [18]:
# Training Loops
for cfg in model_configs:
    dm = WMMDDataModule(
        dataset_dict=ds,
        backbone=cfg["backbone"],
        batch_size=2,
        num_workers=0
    )

    model = WMMDClassifier(
        num_classes=cfg['num_classes'], lr=cfg['lr'],
        backbone=cfg['backbone'], finetune=cfg['finetune'],
        class_weights=cfg['class_weights']
    )
    metrics_cb = MetricsLogger()
    callbacks = [metrics_cb, early_stopping]

    trainer = pl.Trainer(
        max_epochs=cfg['max_epochs'],
        accelerator='gpu', devices=1,
        precision='16-mixed', accumulate_grad_batches=2,
        check_val_every_n_epoch=1,
        num_sanity_val_steps=0,
        enable_progress_bar=True, log_every_n_steps=1,
        callbacks=callbacks
    )

    # Train and test
    trainer.fit(model, dm)
    test_res = trainer.test(model, dm)[0]

    # Attach metadata and save
    model.test_results = test_res
    model.finish_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    model.epochs_trained = trainer.current_epoch + 1
    model.save_model()

    # Plot train/val curves
    metrics_map = {
        'accuracy':   ('train_accs',      'val_accs'),
        'precision':  ('train_precisions','val_precisions'),
        'recall':     ('train_recalls',   'val_recalls'),
        'f1_score':   ('train_f1s',       'val_f1s'),
    }
    
    for metric_name, (train_attr, val_attr) in metrics_map.items():
        train_vals = getattr(metrics_cb, train_attr)
        val_vals = getattr(metrics_cb, val_attr)
        epochs = list(range(1, len(train_vals) + 1))

        plt.figure()
        plt.plot(epochs, train_vals, label=f'train_{metric_name}')
        plt.plot(epochs, val_vals,   label=f'val_{metric_name}')
        plt.xlabel('Epoch')
        plt.ylabel(metric_name.replace('_', ' ').title())
        plt.title(f"{metric_name.replace('_', ' ').title()} over Epochs {model._last_timestamp}")
        plt.grid(True)
        plt.legend(loc='best')

        plot_file = os.path.join(model._last_save_dir, f"{model._last_timestamp}_{metric_name}.png")
        plt.savefig(plot_file)
        plt.close()

    print(f"Completed {cfg['backbone']} ({'FT' if cfg['finetune'] else 'Frozen'}), artifacts in {model._last_save_dir}")


config.json: 0.00B [00:00, ?B/s]



pytorch_model.bin:   0%|          | 0.00/380M [00:00<?, ?B/s]

ValueError: Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply when loading files with safetensors.
See the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434

model.safetensors:   0%|          | 0.00/380M [00:00<?, ?B/s]

In [19]:
torch.load("4Enc_1Dec-61epoch-0.103loss.pt")

  torch.load("4Enc_1Dec-61epoch-0.103loss.pt")


FileNotFoundError: [Errno 2] No such file or directory: '4Enc_1Dec-61epoch-0.103loss.pt'