In [1]:
# Standard library
import datetime
import json
import os
from collections import Counter
import random

# Numerical computing
import numpy as np

# PyTorch
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import DataLoader
import torchaudio.transforms as T
import torchaudio.functional as F

# Datasets
from datasets import load_dataset

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.model_summary import ModelSummary

# Metrics
from sklearn.utils.class_weight import compute_class_weight
from torchmetrics.classification import (
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)

# Visualization
import matplotlib.pyplot as plt

# ==== BACKBONE ====
# Hugging Face Transformers
from transformers import (
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Model,
    ViTFeatureExtractor, 
    ViTModel
)

# MAE-AST Library
from s3prl.nn.upstream import S3PRLUpstream

# Fix the length of the input audio to the same length
def _match_length_force(self, xs, target_max_len):
    xs_max_len = xs.size(1)
    if xs_max_len > target_max_len:
        xs = xs[:, :target_max_len, :]
    elif xs_max_len < target_max_len:
        pad_len = target_max_len - xs_max_len
        xs = torch.cat(
            (xs, xs[:, -1:, :].repeat(1, pad_len, 1)),
            dim=1
        )
    return xs

S3PRLUpstream._match_length = _match_length_force

  from .autonotebook import tqdm as notebook_tqdm
  torchaudio.set_audio_backend("sox_io")
ESPnet is not installed, cannot use espnet_hubert upstream


In [2]:
# loading the dataset
data_dir = "data/watkins"
annotations_file_train = os.path.join(data_dir, "annotations.train.csv")
annotations_file_valid = os.path.join(data_dir, "annotations.valid.csv")
annotations_file_test = os.path.join(data_dir, "annotations.test.csv")

ds = load_dataset(
    "csv",
    data_files={"train": annotations_file_train,
                "validation": annotations_file_valid,
                "test": annotations_file_test},
)

for split_name in ["train", "validation", "test"]:
    split_dataset = ds[split_name]
    labels = split_dataset["label"]
    total = len(labels)
    counts = Counter(labels)

    print(f"{split_name.capitalize()} dataset: {total} examples, {len(counts)} classes")
    if "label" in split_dataset.features and hasattr(split_dataset.features["label"], "names"):
        class_names = split_dataset.features["label"].names
        for idx, name in enumerate(class_names):
            print(f"  {idx} ({name}): {counts.get(name, 0)}")
    else:
        for label, count in counts.items():
            print(f"  {label}: {count}")

Train dataset: 1017 examples, 31 classes
  Clymene_Dolphin: 38
  Bottlenose_Dolphin: 15
  Spinner_Dolphin: 69
  Beluga,_White_Whale: 30
  Bearded_Seal: 22
  Minke_Whale: 10
  Humpback_Whale: 38
  Southern_Right_Whale: 15
  White-sided_Dolphin: 33
  Narwhal: 30
  White-beaked_Dolphin: 34
  Northern_Right_Whale: 32
  Frasers_Dolphin: 52
  Grampus,_Rissos_Dolphin: 40
  Harp_Seal: 28
  Atlantic_Spotted_Dolphin: 35
  Fin,_Finback_Whale: 30
  Ross_Seal: 30
  Rough-Toothed_Dolphin: 30
  Killer_Whale: 21
  Pantropical_Spotted_Dolphin: 40
  Short-Finned_Pacific_Pilot_Whale: 40
  Bowhead_Whale: 36
  False_Killer_Whale: 35
  Melon_Headed_Whale: 38
  Long-Finned_Pilot_Whale: 42
  Striped_Dolphin: 49
  Leopard_Seal: 6
  Walrus: 23
  Sperm_Whale: 45
  Common_Dolphin: 31
Validation dataset: 339 examples, 31 classes
  Clymene_Dolphin: 12
  Bottlenose_Dolphin: 5
  Spinner_Dolphin: 23
  Beluga,_White_Whale: 10
  Bearded_Seal: 7
  Minke_Whale: 4
  Humpback_Whale: 13
  Southern_Right_Whale: 5
  White-side

In [3]:
# class weights calculation
train_labels = ds["train"]["label"]
unique_labels = sorted(set(train_labels))
label_to_int = {label: idx for idx, label in enumerate(unique_labels)}
y_train = [label_to_int[lbl] for lbl in train_labels]

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.arange(len(unique_labels)),
    y=y_train
)

num_classes = len(class_weights)

In [4]:
# model definition
class WMMDClassifier(pl.LightningModule):
    def __init__(
        self,
        num_classes: int,
        lr: float = 1e-3,
        backbone: str = "facebook/wav2vec2-base",
        ckpt_path: str = "",
        finetune: bool = False,
        class_weights=None,
    ):
        super().__init__()
        self.save_hyperparameters()

        # ========== WAV2VEC2 ==========
        if backbone == "facebook/wav2vec2-base":
            self.backbone     = Wav2Vec2Model.from_pretrained(backbone)
            self.embedding_dim = self.backbone.config.hidden_size
        
        # ========== TINY WAV2VEC2 ==========
        elif backbone == "patrickvonplaten/tiny-wav2vec2-no-tokenizer":
            self.backbone     = Wav2Vec2Model.from_pretrained(backbone)
            self.embedding_dim = self.backbone.config.hidden_size
        
        # ========== MAE-AST ==========
        elif backbone == "mae-ast":
            up_kwargs = {"name": "mae_ast_patch"}
            s3 = S3PRLUpstream(**up_kwargs)

            enc = s3.upstream.model.encoder
            enc.layers = nn.ModuleList(list(enc.layers)[:4])
            s3.upstream.model.dec_sine_pos_embed = None
            s3.upstream.model.decoder = None
            s3.upstream.model.final_proj_reconstruction = None
            s3.upstream.model.final_proj_classification  = None

            new_n = len(enc.layers)
            s3._num_layers       = new_n
            s3._hidden_sizes     = s3._hidden_sizes[:new_n]
            s3._downsample_rates = s3._downsample_rates[:new_n]

            self.backbone      = s3
            self.embedding_dim = s3.hidden_sizes[-1]

            # Load the checkpoint for mae ast
            if ckpt_path:
                self.load_mae_ckpt(ckpt_path)
            
            mel_transform = T.MelSpectrogram(
                sample_rate=2000, 
                n_fft=1024, 
                win_length=512,
                hop_length=20, 
                n_mels=128,
            )

            class DBWithDeltas(nn.Module):
                def __init__(self):
                    super().__init__()

                def forward(self, spec):
                    spec_db = F.amplitude_to_DB(
                        spec,
                        multiplier=10.0,
                        amin=1e-10,
                        db_multiplier=0
                    )
                    t = spec_db.transpose(0, 1)
                    d1 = F.compute_deltas(t.transpose(0,1))
                    d2 = F.compute_deltas(d1)
                    return torch.cat([
                        t,
                        d1.transpose(0,1),
                        d2.transpose(0,1)
                    ], dim=1)
            
            s3.upstream.model.fbank   = mel_transform
            s3.upstream.model.db_norm = DBWithDeltas()
        
        # ========== CNN-SPECTROGRAM ==========
        elif backbone == "cnn-spect":
            self.mel_spec = torchaudio.transforms.MelSpectrogram(
                sample_rate=2000,
                n_fft=1024,
                hop_length=20,
                win_length=512,
                n_mels=128
            )
            self.to_db = torchaudio.transforms.AmplitudeToDB(top_db=80)

            self.backbone = nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
                nn.Conv2d(32,32,3,padding=1), nn.BatchNorm2d(32), nn.ReLU(),
                nn.MaxPool2d(2),

                nn.Conv2d(32,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(),
                nn.Conv2d(64,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(),
                nn.MaxPool2d(2),

                nn.Conv2d(64,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(),
                nn.Conv2d(128,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(),
                nn.MaxPool2d(2),

                nn.Conv2d(128,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(),
                nn.Conv2d(256,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(),
                nn.MaxPool2d(2),

                nn.Conv2d(256,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(),
                nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(),
                nn.AdaptiveAvgPool2d((1,1)),
                nn.Flatten()
            )
            self.embedding_dim = 512
        
        # ========== VIT-IMAGENET ==========
        elif backbone == "vit-imagenet":
            self.mel_spec = torchaudio.transforms.MelSpectrogram(
                sample_rate=2000,
                n_fft=1024,
                hop_length=20,
                win_length=512,
                n_mels=128
            )
            self.to_db = torchaudio.transforms.AmplitudeToDB(top_db=80)

            self.feature_extractor = ViTFeatureExtractor.from_pretrained(
                "google/vit-base-patch16-224-in21k"
            )

            self.backbone = ViTModel.from_pretrained(
                "google/vit-base-patch16-224-in21k",
                add_pooling_layer=False
            )
            self.embedding_dim = self.backbone.config.hidden_size

        else:
            raise ValueError(f"Unsupported backbone '{backbone}'")

        try:
            self.backbone.gradient_checkpointing_enable()
        except Exception:
            pass

        for param in self.backbone.parameters():
            param.requires_grad = finetune

        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.embedding_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, num_classes),
        )

        if class_weights is not None:
            cw = torch.tensor(class_weights, dtype=torch.float)
            self.criterion = nn.CrossEntropyLoss(weight=cw)
        else:
            self.criterion = nn.CrossEntropyLoss()

        metrics_kwargs = dict(num_classes=num_classes, average='macro')
        self.train_precision = MulticlassPrecision(**metrics_kwargs)
        self.train_recall = MulticlassRecall(**metrics_kwargs)
        self.train_f1 = MulticlassF1Score(**metrics_kwargs)
        self.val_precision = MulticlassPrecision(**metrics_kwargs)
        self.val_recall = MulticlassRecall(**metrics_kwargs)
        self.val_f1 = MulticlassF1Score(**metrics_kwargs)
        self.test_precision = MulticlassPrecision(**metrics_kwargs)
        self.test_recall = MulticlassRecall(**metrics_kwargs)
        self.test_f1 = MulticlassF1Score(**metrics_kwargs)

    def forward(self, x):
        bname = self.hparams.backbone.lower()
        if bname in {"facebook/wav2vec2-base",
             "patrickvonplaten/tiny-wav2vec2-no-tokenizer"}:
            hidden = self.backbone(x).last_hidden_state

        elif bname == "mae-ast":
            if x.dim() == 3:
                x = x.squeeze(-1)
            wav_lens = torch.full((x.size(0),), x.size(1),
                                dtype=torch.long, device=x.device)
            hidden = self.backbone(x, wav_lens)[0][-1]

        elif bname == "cnn-spect":
            spec = self._safe_logmelspec(x, self.mel_spec, self.to_db)  # (B,1,F,L) in [0,1]
            hidden = self.backbone(spec)

        elif bname == "vit-imagenet":
            spec = self._safe_logmelspec(x, self.mel_spec, self.to_db)  # (B,1,F,L) in [0,1]

            img = torch.nn.functional.interpolate(
                spec, size=(224, 224), mode="bilinear", align_corners=False
            )

            pixel_values = self.feature_extractor(images=img, return_tensors="pt").pixel_values.to(img.device)
            outputs = self.backbone(pixel_values=pixel_values)
            hidden = outputs.last_hidden_state.mean(dim=1)

        else:
            raise ValueError(f"Unsupported backbone '{self.hparams.backbone}'")

        emb = hidden if hidden.dim() == 2 else hidden.mean(dim=1)
        return self.classifier(emb)


    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='train')
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='val')

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='test')

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

    def log_batch_metrics(self, loss, preds, targets, prefix):
        self.log(f'{prefix}_loss', loss, prog_bar=True, on_epoch=True)
        acc = (preds == targets).float().mean()
        self.log(f'{prefix}_acc', acc, prog_bar=True, on_epoch=True)
        precision = getattr(self, f'{prefix}_precision')(preds, targets)
        recall = getattr(self, f'{prefix}_recall')(preds, targets)
        f1 = getattr(self, f'{prefix}_f1')(preds, targets)
        self.log(f'{prefix}_precision', precision, on_epoch=True)
        self.log(f'{prefix}_recall', recall, on_epoch=True)
        self.log(f'{prefix}_f1', f1, on_epoch=True)

    def on_train_end(self):
        save_dir = getattr(self, 'save_dir', None)
        if save_dir:
            self.save_model(save_dir)

    def save_model(self):
        base_dir = 'model'
        bn = self.hparams.backbone.replace('/', '_')
        cw = getattr(self.hparams, 'class_weights', None)
        balance_flag = 'imbalance' if cw is not None else 'balance'
        timestamp = getattr(self, 'finish_time', datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
        folder = os.path.join(base_dir, f"{bn}_{balance_flag}", timestamp)
        os.makedirs(folder, exist_ok=True)

        ckpt_path = os.path.join(folder, f"{timestamp}.pt")
        payload = {
            'state_dict': self.state_dict(),
            'hparams': dict(self.hparams)
        }
        for attr in ('test_results', 'finish_time', 'epochs_trained'):
            if hasattr(self, attr):
                payload[attr] = getattr(self, attr)
        torch.save(payload, ckpt_path)

        stats_path = os.path.join(folder, f"{timestamp}.txt")
        raw_hparams = dict(self.hparams)
        serializable_hparams = {}
        for k, v in raw_hparams.items():
            if isinstance(v, np.ndarray):
                serializable_hparams[k] = v.tolist()
            elif isinstance(v, torch.Tensor):
                serializable_hparams[k] = v.cpu().item() if v.ndim == 0 else v.cpu().tolist()
            else:
                serializable_hparams[k] = v

        serializable_results = {}
        if hasattr(self, 'test_results'):
            for k, v in self.test_results.items():
                serializable_results[k] = v.cpu().item() if isinstance(v, torch.Tensor) else v

        with open(stats_path, 'w') as f:
            f.write(f"Model architecture:\n{self}\n\n")
            f.write("Hyperparameters:\n")
            f.write(json.dumps(serializable_hparams, indent=4))
            f.write("\n\n")
            if serializable_results:
                f.write("Test results:\n")
                f.write(json.dumps(serializable_results, indent=4))
                f.write("\n\n")
            if hasattr(self, 'epochs_trained'):
                f.write(f"Epochs trained: {self.epochs_trained}\n")

        self._last_save_dir = folder
        self._last_timestamp = timestamp
        print(f"Artifacts saved to {folder}/")
    
    def load_mae_ckpt(self, ckpt_source: str):
        """
        Load MAE-AST checkpoint into the truncated upstream model
        and verify exact weight equality for each loaded parameter.
        """
        loaded = torch.load(ckpt_source)
        state_dict = loaded.get('model', loaded)

        up = self.backbone.upstream.model
        up_state = up.state_dict()

        to_load = {k: v for k, v in state_dict.items()
                   if k in up_state and v.shape == up_state[k].shape}

        missing, unexpected = up.load_state_dict(to_load, strict=False)

        for k, v in to_load.items():
            if not torch.equal(up_state[k], v):
                raise RuntimeError(f"Weight mismatch at '{k}' after loading checkpoint")

        print(f"Successfully loaded {len(to_load)} parameters; missing: {len(missing)}, unexpected: {len(unexpected)}")
    
    @staticmethod
    def _safe_logmelspec(waveform, mel_spec, to_db, db_range=80.0):
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(1)
        spec = mel_spec(waveform) + 1e-10
        spec_db = to_db(spec)
        spec_db = torch.nan_to_num(spec_db, neginf=-db_range)
        t = spec_db
        d1 = F.compute_deltas(t.squeeze(1)).unsqueeze(1)
        d2 = F.compute_deltas(d1.squeeze(1)).unsqueeze(1)
        spec3 = torch.cat([t, d1, d2], dim=1)
        spec3 = ((spec3 + db_range)/db_range).clamp(0.0, 1.0)
        return spec3

    @classmethod
    def load_model(cls, load_dir: str, map_location=None):
        """
        Load a model checkpoint and hyperparameters from a directory.

        Returns:
            model (MammalClassifier): Loaded model
        """
        hparams_path = os.path.join(load_dir, 'hparams.json')
        with open(hparams_path, 'r') as f:
            hparams = json.load(f)

        model = cls(**hparams)
        ckpt_path = os.path.join(load_dir, f'{cls.__name__}.ckpt')
        state = torch.load(ckpt_path, map_location=map_location)
        model.load_state_dict(state['state_dict'])
        return model

In [5]:
# sanity check
model = WMMDClassifier(
        num_classes=31, lr=1e-3,
        backbone="cnn-spect", finetune=True,
        class_weights=class_weights,
        ckpt_path=""
    )

print(ModelSummary(model, max_depth=1))

   | Name            | Type                | Params | Mode 
-----------------------------------------------------------------
0  | mel_spec        | MelSpectrogram      | 0      | train
1  | to_db           | AmplitudeToDB       | 0      | train
2  | backbone        | Sequential          | 4.7 M  | train
3  | classifier      | Sequential          | 557 K  | train
4  | criterion       | CrossEntropyLoss    | 0      | train
5  | train_precision | MulticlassPrecision | 0      | train
6  | train_recall    | MulticlassRecall    | 0      | train
7  | train_f1        | MulticlassF1Score   | 0      | train
8  | val_precision   | MulticlassPrecision | 0      | train
9  | val_recall      | MulticlassRecall    | 0      | train
10 | val_f1          | MulticlassF1Score   | 0      | train
11 | test_precision  | MulticlassPrecision | 0      | train
12 | test_recall     | MulticlassRecall    | 0      | train
13 | test_f1         | MulticlassF1Score   | 0      | train
----------------------------------

In [6]:
def WMMD_Collate(batch):
    waveforms, labels = zip(*batch)

    lengths    = [w.shape[0] for w in waveforms]
    raw_max    = max(lengths)

    min_len    = 5_000
    max_len    = 25_000
    target_len = min(max(raw_max, min_len), max_len)

    padded_waveforms = []
    for w in waveforms:
        L = w.shape[0]

        if L > target_len:
            start = random.randint(0, L - target_len)
            w2    = w[start : start + target_len]

        elif L < target_len:
            pad_amt = target_len - L
            w2      = torch.nn.functional.pad(w, (0, pad_amt))

        else:
            w2 = w

        padded_waveforms.append(w2)

    batch_waveforms = torch.stack(padded_waveforms, dim=0)
    batch_labels    = torch.tensor(labels, dtype=torch.long)
    return batch_waveforms, batch_labels

class WMMDSoundDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, backbone: str, target_sr: int = 2000):
        """
        dataset: list of dicts with keys 'path' & 'label'
        backbone: 'facebook/wav2vec2-base' or 'mae-ast'
        target_sr: sampling rate (e.g. 2000)
        """
        self.dataset = dataset
        self.backbone = backbone
        self.target_sr = target_sr
        self.resampler_cache = {}

        if self.backbone in {"facebook/wav2vec2-base", "patrickvonplaten/tiny-wav2vec2-no-tokenizer"}:
            self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
                self.backbone, return_attention_mask=False, sampling_rate=self.target_sr
            )
        elif self.backbone in {"mae-ast", "cnn-spect", "vit-imagenet"}:
            self.processor = None
        else:
            raise ValueError(f"Unsupported backbone '{self.backbone}'")

        labels = sorted({item['label'] for item in dataset})
        self.label_to_int = {lbl: i for i, lbl in enumerate(labels)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        audio_path = item["path"]
        waveform, orig_sr = torchaudio.load(audio_path)

        if orig_sr != self.target_sr:
            if orig_sr not in self.resampler_cache:
                self.resampler_cache[orig_sr] = torchaudio.transforms.Resample(orig_sr, self.target_sr)
            waveform = self.resampler_cache[orig_sr](waveform)

        waveform = waveform / (waveform.abs().max() + 1e-6)
        wav_1d = waveform.squeeze(0)  
        
        if self.backbone in {"facebook/wav2vec2-base", "patrickvonplaten/tiny-wav2vec2-no-tokenizer"}:
            arr = wav_1d.numpy()
            feats = self.processor(arr, sampling_rate=self.target_sr, return_tensors="pt")
            inp = feats.input_values.squeeze(0)
        elif self.backbone in {"cnn-spect", "vit-imagenet", "mae-ast"}:
            inp = wav_1d            
        else:
            raise ValueError(f"Unsupported backbone '{self.backbone}'")

        lbl = self.label_to_int[item['label']]
        return inp, lbl

class WMMDDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_dict,
        backbone: str,
        batch_size: int = 2,
        num_workers: int = 1
    ):
        super().__init__()
        self.dataset_dict = dataset_dict
        self.backbone = backbone
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_ds = WMMDSoundDataset(self.dataset_dict["train"], backbone=self.backbone)
        self.val_ds   = WMMDSoundDataset(self.dataset_dict["validation"], backbone=self.backbone)
        self.test_ds  = WMMDSoundDataset(self.dataset_dict["test"], backbone=self.backbone)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

In [7]:
# callback for logging metrics
class MetricsLogger(pl.Callback):
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []
        self.train_precisions = []
        self.val_precisions = []
        self.train_recalls = []
        self.val_recalls = []
        self.train_f1s = []
        self.val_f1s = []

    def on_train_epoch_end(self, trainer, pl_module):
        m = trainer.callback_metrics
        self.train_losses.append(m['train_loss'].item())
        self.train_accs.append(m['train_acc'].item())
        self.train_precisions.append(m['train_precision'].item())
        self.train_recalls.append(m['train_recall'].item())
        self.train_f1s.append(m['train_f1'].item())

    def on_validation_epoch_end(self, trainer, pl_module):
        m = trainer.callback_metrics
        self.val_losses.append(m['val_loss'].item())
        self.val_accs.append(m['val_acc'].item())
        self.val_precisions.append(m['val_precision'].item())
        self.val_recalls.append(m['val_recall'].item())
        self.val_f1s.append(m['val_f1'].item())

In [15]:
model_configs = [
    # Wav2Vec 2.0 (fine-tune)
    {
        "num_classes": num_classes,
        "lr": 1e-6,
        "backbone": "facebook/wav2vec2-base",
        "finetune": True,
        "class_weights": class_weights,
        "max_epochs": 500,
        "ckpt_path": "",
        "batch_size": 2,
    },
    # # Tiny‐Wav2Vec 2.0 (fine-tune)
    # {
    #     "num_classes": num_classes,
    #     "lr": 1e-6,
    #     "backbone": "patrickvonplaten/tiny-wav2vec2-no-tokenizer",
    #     "finetune": True,
    #     "class_weights": class_weights,
    #     "max_epochs": 500,
    #     "ckpt_path": "",
    #     "batch_size": 2,
    # },
    # # CNN-Spectrogram (train from scratch)
    # {
    #     "num_classes": num_classes,
    #     "lr": 1e-3,
    #     "backbone": "cnn-spect",
    #     "finetune": True,
    #     "class_weights": class_weights,
    #     "max_epochs": 500,
    #     "ckpt_path": "",
    #     "batch_size": 2,
    # },
    # ViT-ImageNet (fine-tune)
    {
        "num_classes": num_classes,
        "lr": 1e-5,
        "backbone": "vit-imagenet",
        "finetune": True,
        "class_weights": class_weights,
        "max_epochs": 500,
        "ckpt_path": "",
        "batch_size": 2,
    },
    # MAE-AST (fine-tune)
    # {
    #     "num_classes": num_classes,
    #     "lr": 1e-6,
    #     "backbone": "mae-ast",
    #     "finetune": True,
    #     "class_weights": class_weights,
    #     "max_epochs": 500,
    #     "ckpt_path": "4Enc_1Dec-61epoch-0.103loss.pt",
    #     "batch_size": 2,
    # },
]


In [16]:
# training Loops
for cfg in model_configs:
    dm = WMMDDataModule(
        dataset_dict=ds,
        backbone=cfg["backbone"],
        batch_size = cfg.get("batch_size", 2),
        num_workers=0
    )

    model = WMMDClassifier(
        num_classes=cfg['num_classes'], lr=cfg['lr'],
        backbone=cfg['backbone'], finetune=cfg['finetune'],
        class_weights=cfg['class_weights'], 
        ckpt_path=cfg['ckpt_path']
    )
    metrics_cb = MetricsLogger()
    # callback for early stopping
    early_stopping = EarlyStopping(
        monitor='val_loss',
        mode='min',
        patience=20,
        min_delta=0.01,
        verbose=True
    )
    callbacks = [metrics_cb, early_stopping]

    trainer = pl.Trainer(
        max_epochs=cfg['max_epochs'],
        accelerator='gpu', devices=1,
        precision='16-mixed', 
        accumulate_grad_batches=2,
        check_val_every_n_epoch=1,
        num_sanity_val_steps=0,
        enable_progress_bar=True,
        log_every_n_steps=1,
        callbacks=callbacks
    )

    trainer.fit(model, dm)
    test_res = trainer.test(model, dm)[0]

    model.test_results = test_res
    model.finish_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    model.epochs_trained = trainer.current_epoch + 1
    model.save_model()

    metrics_map = {
        'accuracy':   ('train_accs',      'val_accs'),
        'precision':  ('train_precisions','val_precisions'),
        'recall':     ('train_recalls',   'val_recalls'),
        'f1_score':   ('train_f1s',       'val_f1s'),
    }
    
    for metric_name, (train_attr, val_attr) in metrics_map.items():
        train_vals = getattr(metrics_cb, train_attr)
        val_vals = getattr(metrics_cb, val_attr)
        epochs = list(range(1, len(train_vals) + 1))

        plt.figure()
        plt.plot(epochs, train_vals, label=f'train_{metric_name}')
        plt.plot(epochs, val_vals,   label=f'val_{metric_name}')
        plt.xlabel('Epoch')
        plt.ylabel(metric_name.replace('_', ' ').title())
        plt.title(f"{metric_name.replace('_', ' ').title()} over Epochs {model._last_timestamp}")
        plt.grid(True)
        plt.legend(loc='best')

        plot_file = os.path.join(model._last_save_dir, f"{model._last_timestamp}_{metric_name}.png")
        plt.savefig(plot_file)
        plt.close()

    print(f"Completed {cfg['backbone']} ({'FT' if cfg['finetune'] else 'Frozen'}), artifacts in {model._last_save_dir}")


Using 16bit Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name            | Type                | Params | Mode 
-----------------------------------------------------------------
0  | backbone        | Wav2Vec2Model       | 94.4 M | eval 
1  | classifier      | Sequential          | 819 K  | train
2  | criterion       | CrossEntropyLoss    | 0      | train
3  | train_precision | MulticlassPrecision | 0      | train
4  | train_recall    | MulticlassRecall    | 0      | train
5  | train_f1        | MulticlassF1Score   | 0      | train
6  | val_precision   | MulticlassPrecision | 0      | train
7  | val_recall      | MulticlassRecall    | 0

Epoch 0: 100%|██████████| 509/509 [01:11<00:00,  7.15it/s, v_num=62, train_loss_step=3.460, train_acc_step=0.000, val_loss=3.430, val_acc=0.0472, train_loss_epoch=3.440, train_acc_epoch=0.0472]

Metric val_loss improved. New best score: 3.429


Epoch 1: 100%|██████████| 509/509 [01:14<00:00,  6.84it/s, v_num=62, train_loss_step=3.310, train_acc_step=0.000, val_loss=3.410, val_acc=0.0944, train_loss_epoch=3.420, train_acc_epoch=0.060] 

Metric val_loss improved by 0.014 >= min_delta = 0.01. New best score: 3.414


Epoch 2: 100%|██████████| 509/509 [01:15<00:00,  6.75it/s, v_num=62, train_loss_step=3.450, train_acc_step=0.000, val_loss=3.400, val_acc=0.124, train_loss_epoch=3.410, train_acc_epoch=0.0678]

Metric val_loss improved by 0.010 >= min_delta = 0.01. New best score: 3.404


Epoch 3: 100%|██████████| 509/509 [01:07<00:00,  7.50it/s, v_num=62, train_loss_step=3.300, train_acc_step=0.000, val_loss=3.380, val_acc=0.159, train_loss_epoch=3.390, train_acc_epoch=0.107] 

Metric val_loss improved by 0.026 >= min_delta = 0.01. New best score: 3.378


Epoch 4: 100%|██████████| 509/509 [00:46<00:00, 10.93it/s, v_num=62, train_loss_step=3.520, train_acc_step=0.000, val_loss=3.360, val_acc=0.142, train_loss_epoch=3.380, train_acc_epoch=0.116]

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 3.362


Epoch 5: 100%|██████████| 509/509 [01:17<00:00,  6.55it/s, v_num=62, train_loss_step=3.210, train_acc_step=0.000, val_loss=3.350, val_acc=0.162, train_loss_epoch=3.350, train_acc_epoch=0.128]

Metric val_loss improved by 0.013 >= min_delta = 0.01. New best score: 3.349


Epoch 6: 100%|██████████| 509/509 [01:27<00:00,  5.79it/s, v_num=62, train_loss_step=3.230, train_acc_step=0.000, val_loss=3.330, val_acc=0.171, train_loss_epoch=3.340, train_acc_epoch=0.124]

Metric val_loss improved by 0.023 >= min_delta = 0.01. New best score: 3.326


Epoch 7: 100%|██████████| 509/509 [01:17<00:00,  6.61it/s, v_num=62, train_loss_step=3.420, train_acc_step=0.000, val_loss=3.300, val_acc=0.180, train_loss_epoch=3.310, train_acc_epoch=0.173]

Metric val_loss improved by 0.024 >= min_delta = 0.01. New best score: 3.303


Epoch 8: 100%|██████████| 509/509 [00:48<00:00, 10.58it/s, v_num=62, train_loss_step=3.450, train_acc_step=0.000, val_loss=3.280, val_acc=0.201, train_loss_epoch=3.280, train_acc_epoch=0.172]

Metric val_loss improved by 0.025 >= min_delta = 0.01. New best score: 3.277


Epoch 9: 100%|██████████| 509/509 [00:54<00:00,  9.36it/s, v_num=62, train_loss_step=3.380, train_acc_step=0.000, val_loss=3.260, val_acc=0.218, train_loss_epoch=3.260, train_acc_epoch=0.179]

Metric val_loss improved by 0.020 >= min_delta = 0.01. New best score: 3.257


Epoch 11: 100%|██████████| 509/509 [00:51<00:00,  9.83it/s, v_num=62, train_loss_step=2.800, train_acc_step=1.000, val_loss=3.210, val_acc=0.218, train_loss_epoch=3.210, train_acc_epoch=0.231]

Metric val_loss improved by 0.049 >= min_delta = 0.01. New best score: 3.208


Epoch 12: 100%|██████████| 509/509 [00:49<00:00, 10.25it/s, v_num=62, train_loss_step=3.160, train_acc_step=0.000, val_loss=3.190, val_acc=0.251, train_loss_epoch=3.180, train_acc_epoch=0.249]

Metric val_loss improved by 0.020 >= min_delta = 0.01. New best score: 3.188


Epoch 13: 100%|██████████| 509/509 [00:48<00:00, 10.50it/s, v_num=62, train_loss_step=2.850, train_acc_step=1.000, val_loss=3.160, val_acc=0.268, train_loss_epoch=3.170, train_acc_epoch=0.237]

Metric val_loss improved by 0.026 >= min_delta = 0.01. New best score: 3.162


Epoch 14: 100%|██████████| 509/509 [00:47<00:00, 10.62it/s, v_num=62, train_loss_step=3.220, train_acc_step=0.000, val_loss=3.150, val_acc=0.260, train_loss_epoch=3.140, train_acc_epoch=0.266]

Metric val_loss improved by 0.013 >= min_delta = 0.01. New best score: 3.149


Epoch 15: 100%|██████████| 509/509 [00:50<00:00, 10.07it/s, v_num=62, train_loss_step=2.930, train_acc_step=1.000, val_loss=3.130, val_acc=0.271, train_loss_epoch=3.110, train_acc_epoch=0.296]

Metric val_loss improved by 0.021 >= min_delta = 0.01. New best score: 3.128


Epoch 16: 100%|██████████| 509/509 [00:48<00:00, 10.41it/s, v_num=62, train_loss_step=3.210, train_acc_step=0.000, val_loss=3.100, val_acc=0.295, train_loss_epoch=3.080, train_acc_epoch=0.293]

Metric val_loss improved by 0.029 >= min_delta = 0.01. New best score: 3.098


Epoch 17: 100%|██████████| 509/509 [00:51<00:00,  9.87it/s, v_num=62, train_loss_step=3.200, train_acc_step=0.000, val_loss=3.080, val_acc=0.313, train_loss_epoch=3.050, train_acc_epoch=0.316]

Metric val_loss improved by 0.021 >= min_delta = 0.01. New best score: 3.078


Epoch 18: 100%|██████████| 509/509 [00:50<00:00, 10.12it/s, v_num=62, train_loss_step=3.380, train_acc_step=0.000, val_loss=3.060, val_acc=0.263, train_loss_epoch=3.020, train_acc_epoch=0.320]

Metric val_loss improved by 0.020 >= min_delta = 0.01. New best score: 3.058


Epoch 19: 100%|██████████| 509/509 [00:50<00:00, 10.00it/s, v_num=62, train_loss_step=2.690, train_acc_step=1.000, val_loss=3.030, val_acc=0.324, train_loss_epoch=3.000, train_acc_epoch=0.331]

Metric val_loss improved by 0.025 >= min_delta = 0.01. New best score: 3.033


Epoch 20: 100%|██████████| 509/509 [01:03<00:00,  8.00it/s, v_num=62, train_loss_step=3.210, train_acc_step=0.000, val_loss=3.010, val_acc=0.292, train_loss_epoch=2.960, train_acc_epoch=0.343]

Metric val_loss improved by 0.025 >= min_delta = 0.01. New best score: 3.008


Epoch 21: 100%|██████████| 509/509 [00:55<00:00,  9.11it/s, v_num=62, train_loss_step=2.690, train_acc_step=0.000, val_loss=2.990, val_acc=0.351, train_loss_epoch=2.940, train_acc_epoch=0.344]

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 2.993


Epoch 22: 100%|██████████| 509/509 [00:54<00:00,  9.32it/s, v_num=62, train_loss_step=2.220, train_acc_step=1.000, val_loss=2.960, val_acc=0.339, train_loss_epoch=2.910, train_acc_epoch=0.357]

Metric val_loss improved by 0.036 >= min_delta = 0.01. New best score: 2.956


Epoch 23: 100%|██████████| 509/509 [00:51<00:00,  9.97it/s, v_num=62, train_loss_step=2.450, train_acc_step=1.000, val_loss=2.940, val_acc=0.322, train_loss_epoch=2.890, train_acc_epoch=0.373]

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 2.941


Epoch 24: 100%|██████████| 509/509 [00:52<00:00,  9.78it/s, v_num=62, train_loss_step=2.520, train_acc_step=1.000, val_loss=2.920, val_acc=0.348, train_loss_epoch=2.850, train_acc_epoch=0.383]

Metric val_loss improved by 0.025 >= min_delta = 0.01. New best score: 2.916


Epoch 25: 100%|██████████| 509/509 [00:53<00:00,  9.59it/s, v_num=62, train_loss_step=2.890, train_acc_step=0.000, val_loss=2.900, val_acc=0.342, train_loss_epoch=2.820, train_acc_epoch=0.416]

Metric val_loss improved by 0.020 >= min_delta = 0.01. New best score: 2.896


Epoch 26: 100%|██████████| 509/509 [00:50<00:00, 10.04it/s, v_num=62, train_loss_step=2.950, train_acc_step=0.000, val_loss=2.870, val_acc=0.360, train_loss_epoch=2.790, train_acc_epoch=0.410]

Metric val_loss improved by 0.022 >= min_delta = 0.01. New best score: 2.874


Epoch 27: 100%|██████████| 509/509 [00:53<00:00,  9.50it/s, v_num=62, train_loss_step=1.830, train_acc_step=1.000, val_loss=2.860, val_acc=0.366, train_loss_epoch=2.750, train_acc_epoch=0.413]

Metric val_loss improved by 0.017 >= min_delta = 0.01. New best score: 2.857


Epoch 28: 100%|██████████| 509/509 [01:08<00:00,  7.47it/s, v_num=62, train_loss_step=1.710, train_acc_step=1.000, val_loss=2.830, val_acc=0.366, train_loss_epoch=2.730, train_acc_epoch=0.433]

Metric val_loss improved by 0.027 >= min_delta = 0.01. New best score: 2.830


Epoch 29: 100%|██████████| 509/509 [01:15<00:00,  6.77it/s, v_num=62, train_loss_step=3.170, train_acc_step=0.000, val_loss=2.810, val_acc=0.363, train_loss_epoch=2.680, train_acc_epoch=0.446]

Metric val_loss improved by 0.023 >= min_delta = 0.01. New best score: 2.808


Epoch 30: 100%|██████████| 509/509 [01:21<00:00,  6.28it/s, v_num=62, train_loss_step=2.160, train_acc_step=1.000, val_loss=2.780, val_acc=0.392, train_loss_epoch=2.650, train_acc_epoch=0.469]

Metric val_loss improved by 0.025 >= min_delta = 0.01. New best score: 2.783


Epoch 31: 100%|██████████| 509/509 [01:32<00:00,  5.52it/s, v_num=62, train_loss_step=2.920, train_acc_step=0.000, val_loss=2.760, val_acc=0.381, train_loss_epoch=2.630, train_acc_epoch=0.485]

Metric val_loss improved by 0.022 >= min_delta = 0.01. New best score: 2.761


Epoch 32: 100%|██████████| 509/509 [01:15<00:00,  6.75it/s, v_num=62, train_loss_step=2.600, train_acc_step=0.000, val_loss=2.740, val_acc=0.386, train_loss_epoch=2.620, train_acc_epoch=0.488]

Metric val_loss improved by 0.018 >= min_delta = 0.01. New best score: 2.743


Epoch 33: 100%|██████████| 509/509 [01:33<00:00,  5.44it/s, v_num=62, train_loss_step=1.890, train_acc_step=1.000, val_loss=2.710, val_acc=0.395, train_loss_epoch=2.580, train_acc_epoch=0.478]

Metric val_loss improved by 0.031 >= min_delta = 0.01. New best score: 2.711


Epoch 34: 100%|██████████| 509/509 [01:24<00:00,  6.06it/s, v_num=62, train_loss_step=2.550, train_acc_step=1.000, val_loss=2.690, val_acc=0.389, train_loss_epoch=2.540, train_acc_epoch=0.504]

Metric val_loss improved by 0.023 >= min_delta = 0.01. New best score: 2.689


Epoch 35: 100%|██████████| 509/509 [01:17<00:00,  6.60it/s, v_num=62, train_loss_step=2.780, train_acc_step=0.000, val_loss=2.670, val_acc=0.398, train_loss_epoch=2.510, train_acc_epoch=0.507]

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 2.674


Epoch 36: 100%|██████████| 509/509 [01:15<00:00,  6.74it/s, v_num=62, train_loss_step=2.870, train_acc_step=0.000, val_loss=2.660, val_acc=0.386, train_loss_epoch=2.480, train_acc_epoch=0.537]

Metric val_loss improved by 0.011 >= min_delta = 0.01. New best score: 2.662


Epoch 37: 100%|██████████| 509/509 [01:14<00:00,  6.86it/s, v_num=62, train_loss_step=2.480, train_acc_step=1.000, val_loss=2.650, val_acc=0.381, train_loss_epoch=2.430, train_acc_epoch=0.531]

Metric val_loss improved by 0.017 >= min_delta = 0.01. New best score: 2.645


Epoch 38: 100%|██████████| 509/509 [01:16<00:00,  6.66it/s, v_num=62, train_loss_step=1.700, train_acc_step=1.000, val_loss=2.620, val_acc=0.407, train_loss_epoch=2.420, train_acc_epoch=0.553]

Metric val_loss improved by 0.027 >= min_delta = 0.01. New best score: 2.618


Epoch 39: 100%|██████████| 509/509 [01:15<00:00,  6.73it/s, v_num=62, train_loss_step=1.110, train_acc_step=1.000, val_loss=2.610, val_acc=0.410, train_loss_epoch=2.380, train_acc_epoch=0.572]

Metric val_loss improved by 0.012 >= min_delta = 0.01. New best score: 2.605


Epoch 40: 100%|██████████| 509/509 [01:09<00:00,  7.34it/s, v_num=62, train_loss_step=2.190, train_acc_step=1.000, val_loss=2.590, val_acc=0.404, train_loss_epoch=2.350, train_acc_epoch=0.564]

Metric val_loss improved by 0.019 >= min_delta = 0.01. New best score: 2.587


Epoch 41: 100%|██████████| 509/509 [01:18<00:00,  6.52it/s, v_num=62, train_loss_step=1.360, train_acc_step=1.000, val_loss=2.570, val_acc=0.410, train_loss_epoch=2.310, train_acc_epoch=0.568]

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 2.571


Epoch 42: 100%|██████████| 509/509 [01:19<00:00,  6.43it/s, v_num=62, train_loss_step=1.280, train_acc_step=1.000, val_loss=2.540, val_acc=0.419, train_loss_epoch=2.290, train_acc_epoch=0.577]

Metric val_loss improved by 0.033 >= min_delta = 0.01. New best score: 2.538


Epoch 43: 100%|██████████| 509/509 [01:16<00:00,  6.68it/s, v_num=62, train_loss_step=1.840, train_acc_step=1.000, val_loss=2.520, val_acc=0.413, train_loss_epoch=2.260, train_acc_epoch=0.599]

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 2.523


Epoch 44: 100%|██████████| 509/509 [01:22<00:00,  6.19it/s, v_num=62, train_loss_step=2.350, train_acc_step=0.000, val_loss=2.490, val_acc=0.416, train_loss_epoch=2.210, train_acc_epoch=0.618]

Metric val_loss improved by 0.028 >= min_delta = 0.01. New best score: 2.495


Epoch 3:   5%|▍         | 24/509 [58:21<19:39:09,  0.01it/s, v_num=60, train_loss_step=3.420, train_acc_step=0.000, val_loss=3.440, val_acc=0.0295, train_loss_epoch=3.440, train_acc_epoch=0.0265]
Epoch 45: 100%|██████████| 509/509 [01:20<00:00,  6.31it/s, v_num=62, train_loss_step=2.390, train_acc_step=1.000, val_loss=2.480, val_acc=0.440, train_loss_epoch=2.190, train_acc_epoch=0.614]

Metric val_loss improved by 0.013 >= min_delta = 0.01. New best score: 2.482


Epoch 46: 100%|██████████| 509/509 [01:20<00:00,  6.33it/s, v_num=62, train_loss_step=1.980, train_acc_step=1.000, val_loss=2.470, val_acc=0.413, train_loss_epoch=2.140, train_acc_epoch=0.639]

Metric val_loss improved by 0.014 >= min_delta = 0.01. New best score: 2.468


Epoch 48: 100%|██████████| 509/509 [01:18<00:00,  6.52it/s, v_num=62, train_loss_step=1.120, train_acc_step=1.000, val_loss=2.440, val_acc=0.428, train_loss_epoch=2.100, train_acc_epoch=0.645]

Metric val_loss improved by 0.031 >= min_delta = 0.01. New best score: 2.437


Epoch 49: 100%|██████████| 509/509 [00:49<00:00, 10.32it/s, v_num=62, train_loss_step=1.370, train_acc_step=1.000, val_loss=2.410, val_acc=0.431, train_loss_epoch=2.080, train_acc_epoch=0.662]

Metric val_loss improved by 0.032 >= min_delta = 0.01. New best score: 2.405


Epoch 50: 100%|██████████| 509/509 [01:09<00:00,  7.32it/s, v_num=62, train_loss_step=1.010, train_acc_step=1.000, val_loss=2.380, val_acc=0.442, train_loss_epoch=2.030, train_acc_epoch=0.651]

Metric val_loss improved by 0.021 >= min_delta = 0.01. New best score: 2.384


Epoch 51: 100%|██████████| 509/509 [01:22<00:00,  6.17it/s, v_num=62, train_loss_step=2.010, train_acc_step=1.000, val_loss=2.370, val_acc=0.425, train_loss_epoch=2.000, train_acc_epoch=0.665]

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 2.369


Epoch 53: 100%|██████████| 509/509 [01:19<00:00,  6.38it/s, v_num=62, train_loss_step=1.670, train_acc_step=1.000, val_loss=2.350, val_acc=0.425, train_loss_epoch=1.930, train_acc_epoch=0.699]

Metric val_loss improved by 0.023 >= min_delta = 0.01. New best score: 2.346


Epoch 54: 100%|██████████| 509/509 [01:25<00:00,  5.96it/s, v_num=62, train_loss_step=0.700, train_acc_step=1.000, val_loss=2.310, val_acc=0.448, train_loss_epoch=1.900, train_acc_epoch=0.693]

Metric val_loss improved by 0.031 >= min_delta = 0.01. New best score: 2.314


Epoch 55: 100%|██████████| 509/509 [01:25<00:00,  5.94it/s, v_num=62, train_loss_step=1.800, train_acc_step=1.000, val_loss=2.300, val_acc=0.425, train_loss_epoch=1.860, train_acc_epoch=0.723]

Metric val_loss improved by 0.017 >= min_delta = 0.01. New best score: 2.297


Epoch 57: 100%|██████████| 509/509 [01:14<00:00,  6.83it/s, v_num=62, train_loss_step=1.290, train_acc_step=1.000, val_loss=2.270, val_acc=0.463, train_loss_epoch=1.810, train_acc_epoch=0.740]

Metric val_loss improved by 0.023 >= min_delta = 0.01. New best score: 2.274


Epoch 58: 100%|██████████| 509/509 [01:15<00:00,  6.70it/s, v_num=62, train_loss_step=0.668, train_acc_step=1.000, val_loss=2.240, val_acc=0.478, train_loss_epoch=1.760, train_acc_epoch=0.747]

Metric val_loss improved by 0.034 >= min_delta = 0.01. New best score: 2.240


Epoch 61: 100%|██████████| 509/509 [01:19<00:00,  6.42it/s, v_num=62, train_loss_step=1.250, train_acc_step=1.000, val_loss=2.220, val_acc=0.472, train_loss_epoch=1.680, train_acc_epoch=0.772]

Metric val_loss improved by 0.019 >= min_delta = 0.01. New best score: 2.221


Epoch 62: 100%|██████████| 509/509 [01:18<00:00,  6.52it/s, v_num=62, train_loss_step=1.020, train_acc_step=1.000, val_loss=2.210, val_acc=0.475, train_loss_epoch=1.680, train_acc_epoch=0.751]

Metric val_loss improved by 0.014 >= min_delta = 0.01. New best score: 2.208


Epoch 63: 100%|██████████| 509/509 [01:08<00:00,  7.39it/s, v_num=62, train_loss_step=0.580, train_acc_step=1.000, val_loss=2.190, val_acc=0.457, train_loss_epoch=1.610, train_acc_epoch=0.794]

Metric val_loss improved by 0.014 >= min_delta = 0.01. New best score: 2.194


Epoch 64: 100%|██████████| 509/509 [01:15<00:00,  6.73it/s, v_num=62, train_loss_step=0.919, train_acc_step=1.000, val_loss=2.160, val_acc=0.484, train_loss_epoch=1.550, train_acc_epoch=0.800]

Metric val_loss improved by 0.036 >= min_delta = 0.01. New best score: 2.157


Epoch 65: 100%|██████████| 509/509 [01:30<00:00,  5.61it/s, v_num=62, train_loss_step=0.704, train_acc_step=1.000, val_loss=2.140, val_acc=0.469, train_loss_epoch=1.540, train_acc_epoch=0.815]

Metric val_loss improved by 0.020 >= min_delta = 0.01. New best score: 2.138


Epoch 66: 100%|██████████| 509/509 [01:24<00:00,  6.04it/s, v_num=62, train_loss_step=0.752, train_acc_step=1.000, val_loss=2.110, val_acc=0.490, train_loss_epoch=1.510, train_acc_epoch=0.821]

Metric val_loss improved by 0.031 >= min_delta = 0.01. New best score: 2.107


Epoch 69: 100%|██████████| 509/509 [00:53<00:00,  9.47it/s, v_num=62, train_loss_step=0.968, train_acc_step=1.000, val_loss=2.070, val_acc=0.507, train_loss_epoch=1.410, train_acc_epoch=0.842]

Metric val_loss improved by 0.038 >= min_delta = 0.01. New best score: 2.069


Epoch 71: 100%|██████████| 509/509 [00:52<00:00,  9.69it/s, v_num=62, train_loss_step=1.480, train_acc_step=0.000, val_loss=2.020, val_acc=0.513, train_loss_epoch=1.360, train_acc_epoch=0.858]

Metric val_loss improved by 0.047 >= min_delta = 0.01. New best score: 2.022


Epoch 74: 100%|██████████| 509/509 [00:50<00:00, 10.03it/s, v_num=62, train_loss_step=0.812, train_acc_step=1.000, val_loss=2.000, val_acc=0.510, train_loss_epoch=1.250, train_acc_epoch=0.887]

Metric val_loss improved by 0.022 >= min_delta = 0.01. New best score: 1.999


Epoch 78: 100%|██████████| 509/509 [00:54<00:00,  9.41it/s, v_num=62, train_loss_step=1.960, train_acc_step=1.000, val_loss=1.940, val_acc=0.537, train_loss_epoch=1.160, train_acc_epoch=0.892]

Metric val_loss improved by 0.056 >= min_delta = 0.01. New best score: 1.943


Epoch 80: 100%|██████████| 509/509 [01:07<00:00,  7.58it/s, v_num=62, train_loss_step=0.357, train_acc_step=1.000, val_loss=1.930, val_acc=0.516, train_loss_epoch=1.090, train_acc_epoch=0.912]

Metric val_loss improved by 0.010 >= min_delta = 0.01. New best score: 1.933


Epoch 83: 100%|██████████| 509/509 [00:50<00:00, 10.10it/s, v_num=62, train_loss_step=0.421, train_acc_step=1.000, val_loss=1.910, val_acc=0.537, train_loss_epoch=1.020, train_acc_epoch=0.915]

Metric val_loss improved by 0.023 >= min_delta = 0.01. New best score: 1.911


Epoch 86: 100%|██████████| 509/509 [00:48<00:00, 10.59it/s, v_num=62, train_loss_step=0.506, train_acc_step=1.000, val_loss=1.880, val_acc=0.540, train_loss_epoch=0.940, train_acc_epoch=0.920]

Metric val_loss improved by 0.033 >= min_delta = 0.01. New best score: 1.878


Epoch 87: 100%|██████████| 509/509 [00:51<00:00,  9.94it/s, v_num=62, train_loss_step=0.525, train_acc_step=1.000, val_loss=1.850, val_acc=0.534, train_loss_epoch=0.923, train_acc_epoch=0.932]

Metric val_loss improved by 0.027 >= min_delta = 0.01. New best score: 1.850


Epoch 96: 100%|██████████| 509/509 [00:53<00:00,  9.47it/s, v_num=62, train_loss_step=0.333, train_acc_step=1.000, val_loss=1.840, val_acc=0.522, train_loss_epoch=0.688, train_acc_epoch=0.955]

Metric val_loss improved by 0.012 >= min_delta = 0.01. New best score: 1.839


Epoch 97: 100%|██████████| 509/509 [00:49<00:00, 10.25it/s, v_num=62, train_loss_step=0.193, train_acc_step=1.000, val_loss=1.810, val_acc=0.537, train_loss_epoch=0.689, train_acc_epoch=0.954]

Metric val_loss improved by 0.029 >= min_delta = 0.01. New best score: 1.810


Epoch 99: 100%|██████████| 509/509 [00:49<00:00, 10.27it/s, v_num=62, train_loss_step=0.716, train_acc_step=1.000, val_loss=1.790, val_acc=0.543, train_loss_epoch=0.647, train_acc_epoch=0.950]

Metric val_loss improved by 0.018 >= min_delta = 0.01. New best score: 1.792


Epoch 100: 100%|██████████| 509/509 [00:50<00:00, 10.10it/s, v_num=62, train_loss_step=0.724, train_acc_step=1.000, val_loss=1.760, val_acc=0.560, train_loss_epoch=0.628, train_acc_epoch=0.960]

Metric val_loss improved by 0.030 >= min_delta = 0.01. New best score: 1.762


Epoch 107: 100%|██████████| 509/509 [00:50<00:00, 10.12it/s, v_num=62, train_loss_step=0.286, train_acc_step=1.000, val_loss=1.730, val_acc=0.558, train_loss_epoch=0.518, train_acc_epoch=0.968] 

Metric val_loss improved by 0.035 >= min_delta = 0.01. New best score: 1.727


Epoch 127: 100%|██████████| 509/509 [01:01<00:00,  8.25it/s, v_num=62, train_loss_step=0.152, train_acc_step=1.000, val_loss=1.750, val_acc=0.555, train_loss_epoch=0.304, train_acc_epoch=0.974] 

Monitored metric val_loss did not improve in the last 20 records. Best score: 1.727. Signaling Trainer to stop.


Epoch 127: 100%|██████████| 509/509 [01:06<00:00,  7.68it/s, v_num=62, train_loss_step=0.152, train_acc_step=1.000, val_loss=1.750, val_acc=0.555, train_loss_epoch=0.304, train_acc_epoch=0.974]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\yohanes.setiawan\AppData\Local\miniconda3\envs\wmmd_env\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=21` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 170/170 [00:10<00:00, 15.68it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.5427728891372681
         test_f1            0.4709932208061218
        test_loss           1.8143362998962402
     test_precision         0.5349065661430359
       test_recall          0.43903636932373047
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Artifacts saved to model\facebook_wav2vec2-base_imbalance\20250703_165539/
Completed facebook/wav2vec2-base (FT), artifacts in model\facebook_wav2vec2-base_imbalance\20250703_165539


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using 16bit Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VI

Epoch 0:   0%|          | 0/509 [00:00<?, ?it/s] 

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


Epoch 0: 100%|██████████| 509/509 [01:00<00:00,  8.40it/s, v_num=63, train_loss_step=3.280, train_acc_step=1.000, val_loss=3.420, val_acc=0.0678, train_loss_epoch=3.430, train_acc_epoch=0.0472]

Metric val_loss improved. New best score: 3.420


Epoch 1: 100%|██████████| 509/509 [00:52<00:00,  9.66it/s, v_num=63, train_loss_step=3.490, train_acc_step=0.000, val_loss=3.410, val_acc=0.0678, train_loss_epoch=3.420, train_acc_epoch=0.0393]

Metric val_loss improved by 0.012 >= min_delta = 0.01. New best score: 3.408


Epoch 3: 100%|██████████| 509/509 [00:55<00:00,  9.25it/s, v_num=63, train_loss_step=3.310, train_acc_step=0.000, val_loss=3.390, val_acc=0.0678, train_loss_epoch=3.420, train_acc_epoch=0.0521]

Metric val_loss improved by 0.016 >= min_delta = 0.01. New best score: 3.392


Epoch 8: 100%|██████████| 509/509 [00:54<00:00,  9.34it/s, v_num=63, train_loss_step=3.560, train_acc_step=0.000, val_loss=3.380, val_acc=0.0678, train_loss_epoch=3.410, train_acc_epoch=0.0639]

Metric val_loss improved by 0.010 >= min_delta = 0.01. New best score: 3.382


Epoch 17:  15%|█▍        | 75/509 [00:07<00:41, 10.51it/s, v_num=63, train_loss_step=3.180, train_acc_step=0.500, val_loss=3.380, val_acc=0.0678, train_loss_epoch=3.410, train_acc_epoch=0.0659] 


Detected KeyboardInterrupt, attempting graceful shutdown ...


Epoch 3:  44%|████▎     | 222/509 [2:42:50<3:30:30,  0.02it/s, v_num=61, train_loss_step=3.290, train_acc_step=0.000, val_loss=3.210, val_acc=0.159, train_loss_epoch=3.310, train_acc_epoch=0.106]


NameError: name 'exit' is not defined