In [3]:
# Standard library
import datetime
import json
import os
from collections import Counter
import random

os.chdir("../..")
print(os.getcwd())

# Numerical computing
import numpy as np

# PyTorch
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import DataLoader
import torchaudio.transforms as T
import torchaudio.functional as F

# Hugging Face Transformers
from transformers import (
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Model,
)

# Datasets
from datasets import load_dataset

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.model_summary import ModelSummary

# Metrics
from sklearn.utils.class_weight import compute_class_weight
from torchmetrics.classification import (
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)

# Visualization
import matplotlib.pyplot as plt

# MAE-AST Library
from s3prl.nn.upstream import S3PRLUpstream

# Fix the length of the input audio to the same length
def _match_length_force(self, xs, target_max_len):
    xs_max_len = xs.size(1)
    if xs_max_len > target_max_len:
        xs = xs[:, :target_max_len, :]
    elif xs_max_len < target_max_len:
        pad_len = target_max_len - xs_max_len
        xs = torch.cat(
            (xs, xs[:, -1:, :].repeat(1, pad_len, 1)),
            dim=1
        )
    return xs

S3PRLUpstream._match_length = _match_length_force

/home/incantator/Documents/mbari-mae


In [4]:
# loading the dataset
data_dir = "data/watkins"
annotations_file_train = os.path.join(data_dir, "annotations.train.csv")
annotations_file_valid = os.path.join(data_dir, "annotations.valid.csv")
annotations_file_test = os.path.join(data_dir, "annotations.test.csv")

ds = load_dataset(
    "csv",
    data_files={"train": annotations_file_train,
                "validation": annotations_file_valid,
                "test": annotations_file_test},
)

for split_name in ["train", "validation", "test"]:
    split_dataset = ds[split_name]
    labels = split_dataset["label"]
    total = len(labels)
    counts = Counter(labels)

    print(f"{split_name.capitalize()} dataset: {total} examples, {len(counts)} classes")
    if "label" in split_dataset.features and hasattr(split_dataset.features["label"], "names"):
        class_names = split_dataset.features["label"].names
        for idx, name in enumerate(class_names):
            print(f"  {idx} ({name}): {counts.get(name, 0)}")
    else:
        for label, count in counts.items():
            print(f"  {label}: {count}")

Train dataset: 1017 examples, 31 classes
  Clymene_Dolphin: 38
  Bottlenose_Dolphin: 15
  Spinner_Dolphin: 69
  Beluga,_White_Whale: 30
  Bearded_Seal: 22
  Minke_Whale: 10
  Humpback_Whale: 38
  Southern_Right_Whale: 15
  White-sided_Dolphin: 33
  Narwhal: 30
  White-beaked_Dolphin: 34
  Northern_Right_Whale: 32
  Frasers_Dolphin: 52
  Grampus,_Rissos_Dolphin: 40
  Harp_Seal: 28
  Atlantic_Spotted_Dolphin: 35
  Fin,_Finback_Whale: 30
  Ross_Seal: 30
  Rough-Toothed_Dolphin: 30
  Killer_Whale: 21
  Pantropical_Spotted_Dolphin: 40
  Short-Finned_Pacific_Pilot_Whale: 40
  Bowhead_Whale: 36
  False_Killer_Whale: 35
  Melon_Headed_Whale: 38
  Long-Finned_Pilot_Whale: 42
  Striped_Dolphin: 49
  Leopard_Seal: 6
  Walrus: 23
  Sperm_Whale: 45
  Common_Dolphin: 31
Validation dataset: 339 examples, 31 classes
  Clymene_Dolphin: 12
  Bottlenose_Dolphin: 5
  Spinner_Dolphin: 23
  Beluga,_White_Whale: 10
  Bearded_Seal: 7
  Minke_Whale: 4
  Humpback_Whale: 13
  Southern_Right_Whale: 5
  White-side

In [5]:
# class weights calculation
train_labels = ds["train"]["label"]
unique_labels = sorted(set(train_labels))
label_to_int = {label: idx for idx, label in enumerate(unique_labels)}
y_train = [label_to_int[lbl] for lbl in train_labels]

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.arange(len(unique_labels)),
    y=y_train
)

num_classes = len(class_weights)

In [6]:
# model definition
class WMMDClassifier(pl.LightningModule):
    def __init__(
        self,
        num_classes: int,
        lr: float = 1e-3,
        backbone: str = "facebook/wav2vec2-base",
        ckpt_path: str = "",
        finetune: bool = False,
        class_weights=None,
    ):
        super().__init__()
        self.save_hyperparameters()

        if backbone == "facebook/wav2vec2-base":
            self.backbone     = Wav2Vec2Model.from_pretrained(backbone)
            self.embedding_dim = self.backbone.config.hidden_size
        
        elif backbone == "patrickvonplaten/tiny-wav2vec2-no-tokenizer":
            self.backbone     = Wav2Vec2Model.from_pretrained(backbone)
            self.embedding_dim = self.backbone.config.hidden_size
        
        elif backbone.lower() == "mae-ast":
            up_kwargs = {"name": "mae_ast_patch"}
            s3 = S3PRLUpstream(**up_kwargs)

            enc = s3.upstream.model.encoder
            enc.layers = nn.ModuleList(list(enc.layers)[:4])
            s3.upstream.model.dec_sine_pos_embed = None
            s3.upstream.model.decoder = None
            s3.upstream.model.final_proj_reconstruction = None
            s3.upstream.model.final_proj_classification  = None

            new_n = len(enc.layers)
            s3._num_layers       = new_n
            s3._hidden_sizes     = s3._hidden_sizes[:new_n]
            s3._downsample_rates = s3._downsample_rates[:new_n]

            self.backbone      = s3
            self.embedding_dim = s3.hidden_sizes[-1]

            # Load the checkpoint for mae ast
            if ckpt_path:
                self.load_mae_ckpt(ckpt_path)
            
            mel_transform = T.MelSpectrogram(
                sample_rate=2000, n_fft=1024, win_length=512,
                hop_length=20, n_mels=128,
            )

            class DBWithDeltas(nn.Module):
                def __init__(self):
                    super().__init__()

                def forward(self, spec):
                    spec_db = F.amplitude_to_DB(
                        spec,
                        multiplier=10.0,
                        amin=1e-10,
                        db_multiplier=0
                    )
                    t = spec_db.transpose(0, 1)
                    d1 = F.compute_deltas(t.transpose(0,1))
                    d2 = F.compute_deltas(d1)
                    return torch.cat([
                        t,
                        d1.transpose(0,1),
                        d2.transpose(0,1)
                    ], dim=1)
            
            s3.upstream.model.fbank   = mel_transform
            s3.upstream.model.db_norm = DBWithDeltas()

        else:
            raise ValueError(f"Unsupported backbone '{backbone}'")

        try:
            self.backbone.gradient_checkpointing_enable()
        except Exception:
            pass

        for param in self.backbone.parameters():
            param.requires_grad = finetune

        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.embedding_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, num_classes),
        )

        if class_weights is not None:
            cw = torch.tensor(class_weights, dtype=torch.float)
            self.criterion = nn.CrossEntropyLoss(weight=cw)
        else:
            self.criterion = nn.CrossEntropyLoss()

        metrics_kwargs = dict(num_classes=num_classes, average='macro')
        self.train_precision = MulticlassPrecision(**metrics_kwargs)
        self.train_recall = MulticlassRecall(**metrics_kwargs)
        self.train_f1 = MulticlassF1Score(**metrics_kwargs)
        self.val_precision = MulticlassPrecision(**metrics_kwargs)
        self.val_recall = MulticlassRecall(**metrics_kwargs)
        self.val_f1 = MulticlassF1Score(**metrics_kwargs)
        self.test_precision = MulticlassPrecision(**metrics_kwargs)
        self.test_recall = MulticlassRecall(**metrics_kwargs)
        self.test_f1 = MulticlassF1Score(**metrics_kwargs)

    def forward(self, x):
        bname = self.hparams.backbone.lower()
        if bname == "facebook/wav2vec2-base":
            out = self.backbone(x)
            hidden = out.last_hidden_state
        elif bname == "patrickvonplaten/tiny-wav2vec2-no-tokenizer":
            out = self.backbone(x)
            hidden = out.last_hidden_state
        elif bname == "mae-ast":
            if x.dim() == 3:
                x = x.squeeze(-1)
            wav_lens = torch.full((x.size(0),), x.size(1),
                                dtype=torch.long, device=x.device)
            all_hs, _ = self.backbone(x, wav_lens)
            hidden   = all_hs[-1]
        else:
            raise ValueError(f"Unsupported backbone in forward(): '{self.hparams.backbone}'")

        emb = hidden.mean(dim=1)
        return self.classifier(emb)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='train')
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='val')

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        self.log_batch_metrics(loss, preds, y, prefix='test')

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

    def log_batch_metrics(self, loss, preds, targets, prefix):
        self.log(f'{prefix}_loss', loss, prog_bar=True, on_epoch=True)
        acc = (preds == targets).float().mean()
        self.log(f'{prefix}_acc', acc, prog_bar=True, on_epoch=True)
        precision = getattr(self, f'{prefix}_precision')(preds, targets)
        recall = getattr(self, f'{prefix}_recall')(preds, targets)
        f1 = getattr(self, f'{prefix}_f1')(preds, targets)
        self.log(f'{prefix}_precision', precision, on_epoch=True)
        self.log(f'{prefix}_recall', recall, on_epoch=True)
        self.log(f'{prefix}_f1', f1, on_epoch=True)

    def on_train_end(self):
        save_dir = getattr(self, 'save_dir', None)
        if save_dir:
            self.save_model(save_dir)

    def save_model(self):
        base_dir = 'model'
        bn = self.hparams.backbone.replace('/', '_')
        cw = getattr(self.hparams, 'class_weights', None)
        balance_flag = 'imbalance' if cw is not None else 'balance'
        timestamp = getattr(self, 'finish_time', datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
        folder = os.path.join(base_dir, f"{bn}_{balance_flag}", timestamp)
        os.makedirs(folder, exist_ok=True)

        ckpt_path = os.path.join(folder, f"{timestamp}.pt")
        payload = {
            'state_dict': self.state_dict(),
            'hparams': dict(self.hparams)
        }
        for attr in ('test_results', 'finish_time', 'epochs_trained'):
            if hasattr(self, attr):
                payload[attr] = getattr(self, attr)
        torch.save(payload, ckpt_path)

        stats_path = os.path.join(folder, f"{timestamp}.txt")
        raw_hparams = dict(self.hparams)
        serializable_hparams = {}
        for k, v in raw_hparams.items():
            if isinstance(v, np.ndarray):
                serializable_hparams[k] = v.tolist()
            elif isinstance(v, torch.Tensor):
                serializable_hparams[k] = v.cpu().item() if v.ndim == 0 else v.cpu().tolist()
            else:
                serializable_hparams[k] = v

        serializable_results = {}
        if hasattr(self, 'test_results'):
            for k, v in self.test_results.items():
                serializable_results[k] = v.cpu().item() if isinstance(v, torch.Tensor) else v

        with open(stats_path, 'w') as f:
            f.write(f"Model architecture:\n{self}\n\n")
            f.write("Hyperparameters:\n")
            f.write(json.dumps(serializable_hparams, indent=4))
            f.write("\n\n")
            if serializable_results:
                f.write("Test results:\n")
                f.write(json.dumps(serializable_results, indent=4))
                f.write("\n\n")
            if hasattr(self, 'epochs_trained'):
                f.write(f"Epochs trained: {self.epochs_trained}\n")

        self._last_save_dir = folder
        self._last_timestamp = timestamp
        print(f"Artifacts saved to {folder}/")
    
    def load_mae_ckpt(self, ckpt_source: str):
        """
        Load MAE-AST checkpoint into the truncated upstream model
        and verify exact weight equality for each loaded parameter.
        """
        loaded = torch.load(ckpt_source)
        state_dict = loaded.get('model', loaded)

        up = self.backbone.upstream.model
        up_state = up.state_dict()

        to_load = {k: v for k, v in state_dict.items()
                   if k in up_state and v.shape == up_state[k].shape}

        missing, unexpected = up.load_state_dict(to_load, strict=False)

        for k, v in to_load.items():
            if not torch.equal(up_state[k], v):
                raise RuntimeError(f"Weight mismatch at '{k}' after loading checkpoint")

        print(f"Successfully loaded {len(to_load)} parameters; missing: {len(missing)}, unexpected: {len(unexpected)}")

    @classmethod
    def load_model(cls, load_dir: str, map_location=None):
        """
        Load a model checkpoint and hyperparameters from a directory.

        Returns:
            model (MammalClassifier): Loaded model
        """
        hparams_path = os.path.join(load_dir, 'hparams.json')
        with open(hparams_path, 'r') as f:
            hparams = json.load(f)

        model = cls(**hparams)
        ckpt_path = os.path.join(load_dir, f'{cls.__name__}.ckpt')
        state = torch.load(ckpt_path, map_location=map_location)
        model.load_state_dict(state['state_dict'])
        return model

In [7]:
# sanity check
model = WMMDClassifier(
        num_classes=31, lr=1e-3,
        backbone="patrickvonplaten/tiny-wav2vec2-no-tokenizer", finetune=True,
        class_weights=class_weights,
        ckpt_path=""
    )

print(ModelSummary(model, max_depth=1))

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/829k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/812k [00:00<?, ?B/s]

ValueError: Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply when loading files with safetensors.
See the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434

 97%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████     | 640M/662M [21:53<00:44, 518kB/s]

In [8]:
# sanity check
model = WMMDClassifier(
        num_classes=31, lr=1e-3,
        backbone="mae-ast", finetune=True,
        class_weights=class_weights,
        ckpt_path="4Enc_1Dec-61epoch-0.103loss.pt"
    )

print(ModelSummary(model, max_depth=1))

Downloading: https://www.cs.utexas.edu/~harwath/model_checkpoints/mae_ast/chunk_patch_75_12LayerEncoder.pt
Destination: /home/incantator/.cache/s3prl/download/5e5fe701120580b1447c34e87c2f1bf9ac58e3ae2df707ea5e2fd85eeab0e0a9.chunk_patch_75_12LayerEncoder.pt
  4%|█████▎                                                                                                                                              | 23.8M/662M [00:45<20:31, 543kB/s]
urllib.Request method failed. Trying using another method...
Downloading: https://www.cs.utexas.edu/~harwath/model_checkpoints/mae_ast/chunk_patch_75_12LayerEncoder.pt
Destination: /home/incantator/.cache/s3prl/download/5e5fe701120580b1447c34e87c2f1bf9ac58e3ae2df707ea5e2fd85eeab0e0a9.chunk_patch_75_12LayerEncoder.pt
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 662M/662M [22:28<00:00, 515kB/s]
  checkpoint = torch.load(ckpt, map_location="c

Successfully loaded 76 parameters; missing: 0, unexpected: 0
   | Name            | Type                | Params | Mode 
-----------------------------------------------------------------
0  | backbone        | S3PRLUpstream       | 28.6 M | train
1  | classifier      | Sequential          | 6.3 M  | train
2  | criterion       | CrossEntropyLoss    | 0      | train
3  | train_precision | MulticlassPrecision | 0      | train
4  | train_recall    | MulticlassRecall    | 0      | train
5  | train_f1        | MulticlassF1Score   | 0      | train
6  | val_precision   | MulticlassPrecision | 0      | train
7  | val_recall      | MulticlassRecall    | 0      | train
8  | val_f1          | MulticlassF1Score   | 0      | train
9  | test_precision  | MulticlassPrecision | 0      | train
10 | test_recall     | MulticlassRecall    | 0      | train
11 | test_f1         | MulticlassF1Score   | 0      | train
-----------------------------------------------------------------
34.9 M    Trainable params


In [7]:
def WMMD_Collate(batch):
    waveforms, labels = zip(*batch)

    lengths    = [w.shape[0] for w in waveforms]
    raw_max    = max(lengths)

    min_len    = 5_000
    max_len    = 25_000
    target_len = min(max(raw_max, min_len), max_len)

    padded_waveforms = []
    for w in waveforms:
        L = w.shape[0]

        if L > target_len:
            start = random.randint(0, L - target_len)
            w2    = w[start : start + target_len]

        elif L < target_len:
            pad_amt = target_len - L
            w2      = torch.nn.functional.pad(w, (0, pad_amt))

        else:
            w2 = w

        padded_waveforms.append(w2)

    batch_waveforms = torch.stack(padded_waveforms, dim=0)
    batch_labels    = torch.tensor(labels, dtype=torch.long)
    return batch_waveforms, batch_labels

class WMMDSoundDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, backbone: str, target_sr: int = 2000):
        """
        dataset: list of dicts with keys 'path' & 'label'
        backbone: 'facebook/wav2vec2-base' or 'mae-ast'
        target_sr: sampling rate (e.g. 2000)
        """
        self.dataset = dataset
        self.backbone = backbone.lower()
        self.target_sr = target_sr
        self.resampler_cache = {}

        if self.backbone == "facebook/wav2vec2-base":
            self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
                "facebook/wav2vec2-base", return_attention_mask=False, sampling_rate=target_sr
            )
        elif self.backbone == "patrickvonplaten/tiny-wav2vec2-no-tokenizer":
            self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
                "patrickvonplaten/tiny-wav2vec2-no-tokenizer", return_attention_mask=False, sampling_rate=target_sr
            )
        elif self.backbone == "mae-ast":
            self.processor = None
        else:
            raise ValueError(f"Unsupported backbone '{backbone}'")

        labels = sorted({item['label'] for item in dataset})
        self.label_to_int = {lbl: i for i, lbl in enumerate(labels)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        audio_path = item["path"]
        waveform, orig_sr = torchaudio.load(audio_path)

        if orig_sr != self.target_sr:
            if orig_sr not in self.resampler_cache:
                self.resampler_cache[orig_sr] = torchaudio.transforms.Resample(orig_sr, self.target_sr)
            waveform = self.resampler_cache[orig_sr](waveform)

        waveform = waveform / (waveform.abs().max() + 1e-6)
        wav_1d = waveform.squeeze(0)  
        
        if self.backbone == "facebook/wav2vec2-base":
            arr = wav_1d.numpy()
            feats = self.processor(arr, sampling_rate=self.target_sr, return_tensors="pt")
            inp = feats.input_values.squeeze(0)
        elif self.backbone == "patrickvonplaten/tiny-wav2vec2-no-tokenizer":
            arr = wav_1d.numpy()
            feats = self.processor(arr, sampling_rate=self.target_sr, return_tensors="pt")
            inp = feats.input_values.squeeze(0)
        elif self.backbone == "mae-ast":
            inp = wav_1d
        else:
            raise ValueError(f"Unsupported backbone '{self.backbone}'")

        lbl = self.label_to_int[item['label']]
        return inp, lbl

class WMMDDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_dict,
        backbone: str,
        batch_size: int = 2,
        num_workers: int = 1
    ):
        super().__init__()
        self.dataset_dict = dataset_dict
        self.backbone = backbone
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_ds = WMMDSoundDataset(self.dataset_dict["train"], backbone=self.backbone)
        self.val_ds   = WMMDSoundDataset(self.dataset_dict["validation"], backbone=self.backbone)
        self.test_ds  = WMMDSoundDataset(self.dataset_dict["test"], backbone=self.backbone)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=WMMD_Collate
        )

In [8]:
# callback for logging metrics
class MetricsLogger(pl.Callback):
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []
        self.train_precisions = []
        self.val_precisions = []
        self.train_recalls = []
        self.val_recalls = []
        self.train_f1s = []
        self.val_f1s = []

    def on_train_epoch_end(self, trainer, pl_module):
        m = trainer.callback_metrics
        self.train_losses.append(m['train_loss'].item())
        self.train_accs.append(m['train_acc'].item())
        self.train_precisions.append(m['train_precision'].item())
        self.train_recalls.append(m['train_recall'].item())
        self.train_f1s.append(m['train_f1'].item())

    def on_validation_epoch_end(self, trainer, pl_module):
        m = trainer.callback_metrics
        self.val_losses.append(m['val_loss'].item())
        self.val_accs.append(m['val_acc'].item())
        self.val_precisions.append(m['val_precision'].item())
        self.val_recalls.append(m['val_recall'].item())
        self.val_f1s.append(m['val_f1'].item())

# callback for early stopping
early_stopping = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=10,
    min_delta=0.01,
    verbose=True
)

In [9]:
# model configurations
model_configs = [
    # {"num_classes": num_classes, "lr": 1e-3, "backbone": "facebook/wav2vec2-base",  "finetune": False, "class_weights": class_weights, "max_epochs": 500, "ckpt_path": ""},
    {"num_classes": num_classes, "lr": 1e-6, "backbone": "mae-ast",  "finetune": True, "class_weights": class_weights, "max_epochs": 500, "ckpt_path": "4Enc_1Dec-61epoch-0.103loss.pt"},
    # {"num_classes": num_classes, "lr": 1e-3, "backbone": "patrickvonplaten/tiny-wav2vec2-no-tokenizer",  "finetune": True, "class_weights": class_weights, "max_epochs": 500, "ckpt_path": ""},
]

In [10]:
# training Loops
for cfg in model_configs:
    dm = WMMDDataModule(
        dataset_dict=ds,
        backbone=cfg["backbone"],
        batch_size=2,
        num_workers=0
    )

    model = WMMDClassifier(
        num_classes=cfg['num_classes'], lr=cfg['lr'],
        backbone=cfg['backbone'], finetune=cfg['finetune'],
        class_weights=cfg['class_weights'], 
        ckpt_path=cfg['ckpt_path']
    )
    metrics_cb = MetricsLogger()
    callbacks = [metrics_cb, early_stopping]

    trainer = pl.Trainer(
        max_epochs=cfg['max_epochs'],
        accelerator='gpu', devices=1,
        precision='16-mixed', accumulate_grad_batches=2,
        check_val_every_n_epoch=1,
        num_sanity_val_steps=0,
        enable_progress_bar=True, log_every_n_steps=1,
        callbacks=callbacks
    )

    trainer.fit(model, dm)
    test_res = trainer.test(model, dm)[0]

    model.test_results = test_res
    model.finish_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    model.epochs_trained = trainer.current_epoch + 1
    model.save_model()

    metrics_map = {
        'accuracy':   ('train_accs',      'val_accs'),
        'precision':  ('train_precisions','val_precisions'),
        'recall':     ('train_recalls',   'val_recalls'),
        'f1_score':   ('train_f1s',       'val_f1s'),
    }
    
    for metric_name, (train_attr, val_attr) in metrics_map.items():
        train_vals = getattr(metrics_cb, train_attr)
        val_vals = getattr(metrics_cb, val_attr)
        epochs = list(range(1, len(train_vals) + 1))

        plt.figure()
        plt.plot(epochs, train_vals, label=f'train_{metric_name}')
        plt.plot(epochs, val_vals,   label=f'val_{metric_name}')
        plt.xlabel('Epoch')
        plt.ylabel(metric_name.replace('_', ' ').title())
        plt.title(f"{metric_name.replace('_', ' ').title()} over Epochs {model._last_timestamp}")
        plt.grid(True)
        plt.legend(loc='best')

        plot_file = os.path.join(model._last_save_dir, f"{model._last_timestamp}_{metric_name}.png")
        plt.savefig(plot_file)
        plt.close()

    print(f"Completed {cfg['backbone']} ({'FT' if cfg['finetune'] else 'Frozen'}), artifacts in {model._last_save_dir}")


Using 16bit Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX 2000 Ada Generation Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Successfully loaded 76 parameters; missing: 0, unexpected: 0



   | Name            | Type                | Params | Mode 
-----------------------------------------------------------------
0  | backbone        | S3PRLUpstream       | 28.6 M | train
1  | classifier      | Sequential          | 6.3 M  | train
2  | criterion       | CrossEntropyLoss    | 0      | train
3  | train_precision | MulticlassPrecision | 0      | train
4  | train_recall    | MulticlassRecall    | 0      | train
5  | train_f1        | MulticlassF1Score   | 0      | train
6  | val_precision   | MulticlassPrecision | 0      | train
7  | val_recall      | MulticlassRecall    | 0      | train
8  | val_f1          | MulticlassF1Score   | 0      | train
9  | test_precision  | MulticlassPrecision | 0      | train
10 | test_recall     | MulticlassRecall    | 0      | train
11 | test_f1         | MulticlassF1Score   | 0      | train
-----------------------------------------------------------------
34.9 M    Trainable params
0         Non-trainable params
34.9 M    Total params
139.50

Epoch 0: 100%|██████████| 509/509 [00:30<00:00, 16.43it/s, v_num=33, train_loss_step=3.290, train_acc_step=0.000, val_loss=3.340, val_acc=0.0826, train_loss_epoch=3.450, train_acc_epoch=0.0433]

Metric val_loss improved. New best score: 3.339


Epoch 1: 100%|██████████| 509/509 [00:29<00:00, 17.05it/s, v_num=33, train_loss_step=2.670, train_acc_step=1.000, val_loss=3.270, val_acc=0.142, train_loss_epoch=3.390, train_acc_epoch=0.0659] 

Metric val_loss improved by 0.070 >= min_delta = 0.01. New best score: 3.269


Epoch 2: 100%|██████████| 509/509 [00:36<00:00, 14.04it/s, v_num=33, train_loss_step=2.690, train_acc_step=0.000, val_loss=3.210, val_acc=0.159, train_loss_epoch=3.310, train_acc_epoch=0.106] 

Metric val_loss improved by 0.064 >= min_delta = 0.01. New best score: 3.205


Epoch 3: 100%|██████████| 509/509 [00:47<00:00, 10.71it/s, v_num=33, train_loss_step=3.540, train_acc_step=0.000, val_loss=3.160, val_acc=0.183, train_loss_epoch=3.280, train_acc_epoch=0.134]

Metric val_loss improved by 0.049 >= min_delta = 0.01. New best score: 3.156


Epoch 4: 100%|██████████| 509/509 [00:30<00:00, 16.78it/s, v_num=33, train_loss_step=3.070, train_acc_step=0.000, val_loss=3.110, val_acc=0.189, train_loss_epoch=3.230, train_acc_epoch=0.135]

Metric val_loss improved by 0.048 >= min_delta = 0.01. New best score: 3.109


Epoch 5: 100%|██████████| 509/509 [00:41<00:00, 12.34it/s, v_num=33, train_loss_step=1.560, train_acc_step=1.000, val_loss=3.070, val_acc=0.195, train_loss_epoch=3.190, train_acc_epoch=0.163]

Metric val_loss improved by 0.043 >= min_delta = 0.01. New best score: 3.065


Epoch 6: 100%|██████████| 509/509 [00:50<00:00, 10.16it/s, v_num=33, train_loss_step=2.770, train_acc_step=0.000, val_loss=3.020, val_acc=0.195, train_loss_epoch=3.140, train_acc_epoch=0.166]

Metric val_loss improved by 0.042 >= min_delta = 0.01. New best score: 3.023


Epoch 7: 100%|██████████| 509/509 [00:51<00:00,  9.81it/s, v_num=33, train_loss_step=3.510, train_acc_step=0.000, val_loss=2.980, val_acc=0.201, train_loss_epoch=3.100, train_acc_epoch=0.166]

Metric val_loss improved by 0.043 >= min_delta = 0.01. New best score: 2.981


Epoch 8: 100%|██████████| 509/509 [00:51<00:00,  9.89it/s, v_num=33, train_loss_step=2.960, train_acc_step=0.000, val_loss=2.940, val_acc=0.206, train_loss_epoch=3.080, train_acc_epoch=0.185]

Metric val_loss improved by 0.041 >= min_delta = 0.01. New best score: 2.940


Epoch 9: 100%|██████████| 509/509 [00:50<00:00, 10.15it/s, v_num=33, train_loss_step=3.230, train_acc_step=0.000, val_loss=2.910, val_acc=0.245, train_loss_epoch=3.040, train_acc_epoch=0.200]

Metric val_loss improved by 0.035 >= min_delta = 0.01. New best score: 2.905


Epoch 10: 100%|██████████| 509/509 [00:36<00:00, 13.97it/s, v_num=33, train_loss_step=2.270, train_acc_step=1.000, val_loss=2.860, val_acc=0.257, train_loss_epoch=3.010, train_acc_epoch=0.180]

Metric val_loss improved by 0.044 >= min_delta = 0.01. New best score: 2.861


Epoch 11: 100%|██████████| 509/509 [00:47<00:00, 10.63it/s, v_num=33, train_loss_step=3.370, train_acc_step=0.000, val_loss=2.830, val_acc=0.242, train_loss_epoch=2.970, train_acc_epoch=0.206]

Metric val_loss improved by 0.028 >= min_delta = 0.01. New best score: 2.833


Epoch 12: 100%|██████████| 509/509 [00:47<00:00, 10.76it/s, v_num=33, train_loss_step=3.410, train_acc_step=0.000, val_loss=2.800, val_acc=0.257, train_loss_epoch=2.950, train_acc_epoch=0.215]

Metric val_loss improved by 0.037 >= min_delta = 0.01. New best score: 2.796


Epoch 13: 100%|██████████| 509/509 [00:46<00:00, 11.02it/s, v_num=33, train_loss_step=2.600, train_acc_step=0.000, val_loss=2.770, val_acc=0.257, train_loss_epoch=2.920, train_acc_epoch=0.210]

Metric val_loss improved by 0.022 >= min_delta = 0.01. New best score: 2.774


Epoch 14: 100%|██████████| 509/509 [00:36<00:00, 13.93it/s, v_num=33, train_loss_step=2.330, train_acc_step=1.000, val_loss=2.740, val_acc=0.257, train_loss_epoch=2.880, train_acc_epoch=0.229]

Metric val_loss improved by 0.032 >= min_delta = 0.01. New best score: 2.742


Epoch 15: 100%|██████████| 509/509 [00:29<00:00, 17.31it/s, v_num=33, train_loss_step=3.950, train_acc_step=0.000, val_loss=2.710, val_acc=0.263, train_loss_epoch=2.850, train_acc_epoch=0.232]

Metric val_loss improved by 0.030 >= min_delta = 0.01. New best score: 2.712


Epoch 16: 100%|██████████| 509/509 [00:28<00:00, 17.85it/s, v_num=33, train_loss_step=2.350, train_acc_step=1.000, val_loss=2.690, val_acc=0.263, train_loss_epoch=2.820, train_acc_epoch=0.221]

Metric val_loss improved by 0.023 >= min_delta = 0.01. New best score: 2.690


Epoch 17: 100%|██████████| 509/509 [00:28<00:00, 17.58it/s, v_num=33, train_loss_step=2.210, train_acc_step=1.000, val_loss=2.670, val_acc=0.265, train_loss_epoch=2.790, train_acc_epoch=0.248]

Metric val_loss improved by 0.022 >= min_delta = 0.01. New best score: 2.668


Epoch 18: 100%|██████████| 509/509 [00:29<00:00, 17.28it/s, v_num=33, train_loss_step=0.301, train_acc_step=1.000, val_loss=2.650, val_acc=0.257, train_loss_epoch=2.770, train_acc_epoch=0.257]

Metric val_loss improved by 0.019 >= min_delta = 0.01. New best score: 2.649


Epoch 19: 100%|██████████| 509/509 [00:28<00:00, 17.65it/s, v_num=33, train_loss_step=3.380, train_acc_step=0.000, val_loss=2.620, val_acc=0.289, train_loss_epoch=2.740, train_acc_epoch=0.252]

Metric val_loss improved by 0.032 >= min_delta = 0.01. New best score: 2.617


Epoch 20: 100%|██████████| 509/509 [00:39<00:00, 12.91it/s, v_num=33, train_loss_step=2.110, train_acc_step=0.000, val_loss=2.590, val_acc=0.292, train_loss_epoch=2.710, train_acc_epoch=0.279]

Metric val_loss improved by 0.024 >= min_delta = 0.01. New best score: 2.593


Epoch 21: 100%|██████████| 509/509 [00:53<00:00,  9.52it/s, v_num=33, train_loss_step=3.240, train_acc_step=0.000, val_loss=2.580, val_acc=0.301, train_loss_epoch=2.710, train_acc_epoch=0.264] 

Metric val_loss improved by 0.016 >= min_delta = 0.01. New best score: 2.578


Epoch 22: 100%|██████████| 509/509 [00:29<00:00, 17.43it/s, v_num=33, train_loss_step=3.100, train_acc_step=0.000, val_loss=2.550, val_acc=0.316, train_loss_epoch=2.680, train_acc_epoch=0.275]

Metric val_loss improved by 0.028 >= min_delta = 0.01. New best score: 2.549


Epoch 23: 100%|██████████| 509/509 [00:28<00:00, 17.57it/s, v_num=33, train_loss_step=3.500, train_acc_step=0.000, val_loss=2.530, val_acc=0.316, train_loss_epoch=2.650, train_acc_epoch=0.286]

Metric val_loss improved by 0.020 >= min_delta = 0.01. New best score: 2.529


Epoch 24: 100%|██████████| 509/509 [00:56<00:00,  9.07it/s, v_num=33, train_loss_step=3.210, train_acc_step=0.000, val_loss=2.510, val_acc=0.304, train_loss_epoch=2.630, train_acc_epoch=0.286]

Metric val_loss improved by 0.016 >= min_delta = 0.01. New best score: 2.513


Epoch 25: 100%|██████████| 509/509 [01:08<00:00,  7.48it/s, v_num=33, train_loss_step=2.490, train_acc_step=1.000, val_loss=2.490, val_acc=0.319, train_loss_epoch=2.620, train_acc_epoch=0.274]

Metric val_loss improved by 0.025 >= min_delta = 0.01. New best score: 2.488


Epoch 26: 100%|██████████| 509/509 [00:59<00:00,  8.51it/s, v_num=33, train_loss_step=2.320, train_acc_step=0.000, val_loss=2.450, val_acc=0.330, train_loss_epoch=2.580, train_acc_epoch=0.300]

Metric val_loss improved by 0.034 >= min_delta = 0.01. New best score: 2.455


Epoch 27: 100%|██████████| 509/509 [00:59<00:00,  8.53it/s, v_num=33, train_loss_step=2.750, train_acc_step=0.000, val_loss=2.440, val_acc=0.324, train_loss_epoch=2.560, train_acc_epoch=0.305]

Metric val_loss improved by 0.014 >= min_delta = 0.01. New best score: 2.441


Epoch 28: 100%|██████████| 509/509 [00:54<00:00,  9.30it/s, v_num=33, train_loss_step=2.590, train_acc_step=0.000, val_loss=2.410, val_acc=0.357, train_loss_epoch=2.530, train_acc_epoch=0.331] 

Metric val_loss improved by 0.026 >= min_delta = 0.01. New best score: 2.415


Epoch 30: 100%|██████████| 509/509 [00:48<00:00, 10.39it/s, v_num=33, train_loss_step=1.200, train_acc_step=1.000, val_loss=2.370, val_acc=0.345, train_loss_epoch=2.500, train_acc_epoch=0.320]

Metric val_loss improved by 0.048 >= min_delta = 0.01. New best score: 2.367


Epoch 31: 100%|██████████| 509/509 [00:47<00:00, 10.62it/s, v_num=33, train_loss_step=2.850, train_acc_step=0.000, val_loss=2.350, val_acc=0.357, train_loss_epoch=2.480, train_acc_epoch=0.330]

Metric val_loss improved by 0.022 >= min_delta = 0.01. New best score: 2.345


Epoch 32: 100%|██████████| 509/509 [00:49<00:00, 10.33it/s, v_num=33, train_loss_step=1.210, train_acc_step=1.000, val_loss=2.320, val_acc=0.369, train_loss_epoch=2.420, train_acc_epoch=0.343]

Metric val_loss improved by 0.021 >= min_delta = 0.01. New best score: 2.324


Epoch 33: 100%|██████████| 509/509 [00:49<00:00, 10.32it/s, v_num=33, train_loss_step=3.420, train_acc_step=0.000, val_loss=2.300, val_acc=0.366, train_loss_epoch=2.420, train_acc_epoch=0.330]

Metric val_loss improved by 0.020 >= min_delta = 0.01. New best score: 2.304


Epoch 34: 100%|██████████| 509/509 [00:47<00:00, 10.63it/s, v_num=33, train_loss_step=0.137, train_acc_step=1.000, val_loss=2.280, val_acc=0.381, train_loss_epoch=2.380, train_acc_epoch=0.366]

Metric val_loss improved by 0.027 >= min_delta = 0.01. New best score: 2.277


Epoch 35: 100%|██████████| 509/509 [00:35<00:00, 14.53it/s, v_num=33, train_loss_step=3.760, train_acc_step=0.000, val_loss=2.260, val_acc=0.375, train_loss_epoch=2.340, train_acc_epoch=0.369]

Metric val_loss improved by 0.019 >= min_delta = 0.01. New best score: 2.258


Epoch 36: 100%|██████████| 509/509 [00:30<00:00, 16.66it/s, v_num=33, train_loss_step=2.520, train_acc_step=0.000, val_loss=2.230, val_acc=0.383, train_loss_epoch=2.350, train_acc_epoch=0.367] 

Metric val_loss improved by 0.024 >= min_delta = 0.01. New best score: 2.234


Epoch 37: 100%|██████████| 509/509 [00:29<00:00, 17.10it/s, v_num=33, train_loss_step=2.780, train_acc_step=0.000, val_loss=2.210, val_acc=0.375, train_loss_epoch=2.300, train_acc_epoch=0.378]

Metric val_loss improved by 0.020 >= min_delta = 0.01. New best score: 2.214


Epoch 38: 100%|██████████| 509/509 [00:31<00:00, 16.34it/s, v_num=33, train_loss_step=2.730, train_acc_step=0.000, val_loss=2.200, val_acc=0.378, train_loss_epoch=2.310, train_acc_epoch=0.373]

Metric val_loss improved by 0.017 >= min_delta = 0.01. New best score: 2.198


Epoch 39: 100%|██████████| 509/509 [00:31<00:00, 16.40it/s, v_num=33, train_loss_step=0.0928, train_acc_step=1.000, val_loss=2.170, val_acc=0.378, train_loss_epoch=2.250, train_acc_epoch=0.380]

Metric val_loss improved by 0.027 >= min_delta = 0.01. New best score: 2.170


Epoch 40: 100%|██████████| 509/509 [00:32<00:00, 15.54it/s, v_num=33, train_loss_step=3.060, train_acc_step=0.000, val_loss=2.140, val_acc=0.392, train_loss_epoch=2.210, train_acc_epoch=0.393] 

Metric val_loss improved by 0.027 >= min_delta = 0.01. New best score: 2.143


Epoch 41: 100%|██████████| 509/509 [00:29<00:00, 17.35it/s, v_num=33, train_loss_step=1.520, train_acc_step=1.000, val_loss=2.120, val_acc=0.410, train_loss_epoch=2.210, train_acc_epoch=0.398] 

Metric val_loss improved by 0.019 >= min_delta = 0.01. New best score: 2.124


Epoch 42: 100%|██████████| 509/509 [00:29<00:00, 17.40it/s, v_num=33, train_loss_step=2.020, train_acc_step=1.000, val_loss=2.100, val_acc=0.419, train_loss_epoch=2.180, train_acc_epoch=0.408]

Metric val_loss improved by 0.019 >= min_delta = 0.01. New best score: 2.105


Epoch 43: 100%|██████████| 509/509 [00:30<00:00, 16.92it/s, v_num=33, train_loss_step=1.870, train_acc_step=0.000, val_loss=2.090, val_acc=0.419, train_loss_epoch=2.150, train_acc_epoch=0.418]

Metric val_loss improved by 0.020 >= min_delta = 0.01. New best score: 2.085


Epoch 44: 100%|██████████| 509/509 [00:52<00:00,  9.63it/s, v_num=33, train_loss_step=2.560, train_acc_step=0.000, val_loss=2.060, val_acc=0.425, train_loss_epoch=2.150, train_acc_epoch=0.425]

Metric val_loss improved by 0.021 >= min_delta = 0.01. New best score: 2.064


Epoch 45: 100%|██████████| 509/509 [00:29<00:00, 17.07it/s, v_num=33, train_loss_step=0.438, train_acc_step=1.000, val_loss=2.040, val_acc=0.434, train_loss_epoch=2.100, train_acc_epoch=0.432] 

Metric val_loss improved by 0.021 >= min_delta = 0.01. New best score: 2.043


Epoch 46: 100%|██████████| 509/509 [00:31<00:00, 16.23it/s, v_num=33, train_loss_step=3.930, train_acc_step=0.000, val_loss=2.030, val_acc=0.445, train_loss_epoch=2.090, train_acc_epoch=0.431] 

Metric val_loss improved by 0.013 >= min_delta = 0.01. New best score: 2.030


Epoch 47: 100%|██████████| 509/509 [00:29<00:00, 17.18it/s, v_num=33, train_loss_step=2.510, train_acc_step=0.000, val_loss=2.000, val_acc=0.457, train_loss_epoch=2.040, train_acc_epoch=0.437]

Metric val_loss improved by 0.027 >= min_delta = 0.01. New best score: 2.002


Epoch 48: 100%|██████████| 509/509 [00:30<00:00, 16.76it/s, v_num=33, train_loss_step=0.618, train_acc_step=1.000, val_loss=1.980, val_acc=0.463, train_loss_epoch=2.040, train_acc_epoch=0.461] 

Metric val_loss improved by 0.026 >= min_delta = 0.01. New best score: 1.976


Epoch 50: 100%|██████████| 509/509 [00:30<00:00, 16.74it/s, v_num=33, train_loss_step=2.870, train_acc_step=0.000, val_loss=1.950, val_acc=0.463, train_loss_epoch=2.000, train_acc_epoch=0.462]

Metric val_loss improved by 0.023 >= min_delta = 0.01. New best score: 1.954


Epoch 51: 100%|██████████| 509/509 [00:28<00:00, 17.76it/s, v_num=33, train_loss_step=0.714, train_acc_step=1.000, val_loss=1.910, val_acc=0.478, train_loss_epoch=1.940, train_acc_epoch=0.453]

Metric val_loss improved by 0.048 >= min_delta = 0.01. New best score: 1.905


Epoch 53: 100%|██████████| 509/509 [00:29<00:00, 17.12it/s, v_num=33, train_loss_step=2.190, train_acc_step=0.000, val_loss=1.880, val_acc=0.472, train_loss_epoch=1.920, train_acc_epoch=0.467]

Metric val_loss improved by 0.021 >= min_delta = 0.01. New best score: 1.884


Epoch 55: 100%|██████████| 509/509 [00:27<00:00, 18.58it/s, v_num=33, train_loss_step=2.880, train_acc_step=0.000, val_loss=1.830, val_acc=0.501, train_loss_epoch=1.850, train_acc_epoch=0.488] 

Metric val_loss improved by 0.051 >= min_delta = 0.01. New best score: 1.833


Epoch 56: 100%|██████████| 509/509 [00:28<00:00, 17.63it/s, v_num=33, train_loss_step=1.370, train_acc_step=1.000, val_loss=1.820, val_acc=0.501, train_loss_epoch=1.860, train_acc_epoch=0.477] 

Metric val_loss improved by 0.012 >= min_delta = 0.01. New best score: 1.822


Epoch 57: 100%|██████████| 509/509 [00:29<00:00, 17.05it/s, v_num=33, train_loss_step=2.020, train_acc_step=0.000, val_loss=1.800, val_acc=0.516, train_loss_epoch=1.830, train_acc_epoch=0.488]

Metric val_loss improved by 0.024 >= min_delta = 0.01. New best score: 1.797


Epoch 58: 100%|██████████| 509/509 [00:28<00:00, 17.56it/s, v_num=33, train_loss_step=3.910, train_acc_step=0.000, val_loss=1.770, val_acc=0.522, train_loss_epoch=1.810, train_acc_epoch=0.505]

Metric val_loss improved by 0.030 >= min_delta = 0.01. New best score: 1.767


Epoch 60: 100%|██████████| 509/509 [00:29<00:00, 17.02it/s, v_num=33, train_loss_step=2.760, train_acc_step=0.000, val_loss=1.730, val_acc=0.528, train_loss_epoch=1.750, train_acc_epoch=0.511] 

Metric val_loss improved by 0.039 >= min_delta = 0.01. New best score: 1.728


Epoch 62: 100%|██████████| 509/509 [00:29<00:00, 17.15it/s, v_num=33, train_loss_step=0.627, train_acc_step=1.000, val_loss=1.710, val_acc=0.528, train_loss_epoch=1.680, train_acc_epoch=0.543] 

Metric val_loss improved by 0.016 >= min_delta = 0.01. New best score: 1.712


Epoch 63: 100%|██████████| 509/509 [00:28<00:00, 17.83it/s, v_num=33, train_loss_step=3.270, train_acc_step=0.000, val_loss=1.680, val_acc=0.534, train_loss_epoch=1.680, train_acc_epoch=0.527]

Metric val_loss improved by 0.028 >= min_delta = 0.01. New best score: 1.685


Epoch 65: 100%|██████████| 509/509 [00:29<00:00, 17.08it/s, v_num=33, train_loss_step=2.070, train_acc_step=1.000, val_loss=1.650, val_acc=0.543, train_loss_epoch=1.650, train_acc_epoch=0.551] 

Metric val_loss improved by 0.033 >= min_delta = 0.01. New best score: 1.652


Epoch 66: 100%|██████████| 509/509 [00:30<00:00, 16.95it/s, v_num=33, train_loss_step=2.150, train_acc_step=0.000, val_loss=1.620, val_acc=0.558, train_loss_epoch=1.640, train_acc_epoch=0.550]

Metric val_loss improved by 0.029 >= min_delta = 0.01. New best score: 1.624


Epoch 68: 100%|██████████| 509/509 [00:29<00:00, 17.38it/s, v_num=33, train_loss_step=1.350, train_acc_step=1.000, val_loss=1.610, val_acc=0.558, train_loss_epoch=1.600, train_acc_epoch=0.566]

Metric val_loss improved by 0.016 >= min_delta = 0.01. New best score: 1.607


Epoch 70: 100%|██████████| 509/509 [00:55<00:00,  9.18it/s, v_num=33, train_loss_step=1.400, train_acc_step=0.000, val_loss=1.570, val_acc=0.575, train_loss_epoch=1.540, train_acc_epoch=0.572] 

Metric val_loss improved by 0.039 >= min_delta = 0.01. New best score: 1.568


Epoch 71: 100%|██████████| 509/509 [01:11<00:00,  7.10it/s, v_num=33, train_loss_step=0.778, train_acc_step=1.000, val_loss=1.550, val_acc=0.572, train_loss_epoch=1.510, train_acc_epoch=0.589] 

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 1.553


Epoch 74: 100%|██████████| 509/509 [00:49<00:00, 10.20it/s, v_num=33, train_loss_step=0.735, train_acc_step=1.000, val_loss=1.530, val_acc=0.572, train_loss_epoch=1.490, train_acc_epoch=0.592] 

Metric val_loss improved by 0.021 >= min_delta = 0.01. New best score: 1.533


Epoch 75: 100%|██████████| 509/509 [00:45<00:00, 11.08it/s, v_num=33, train_loss_step=0.112, train_acc_step=1.000, val_loss=1.520, val_acc=0.584, train_loss_epoch=1.460, train_acc_epoch=0.591]

Metric val_loss improved by 0.012 >= min_delta = 0.01. New best score: 1.521


Epoch 76: 100%|██████████| 509/509 [00:29<00:00, 17.22it/s, v_num=33, train_loss_step=0.294, train_acc_step=1.000, val_loss=1.490, val_acc=0.578, train_loss_epoch=1.440, train_acc_epoch=0.592]

Metric val_loss improved by 0.031 >= min_delta = 0.01. New best score: 1.490


Epoch 77: 100%|██████████| 509/509 [00:29<00:00, 17.48it/s, v_num=33, train_loss_step=0.585, train_acc_step=1.000, val_loss=1.470, val_acc=0.593, train_loss_epoch=1.410, train_acc_epoch=0.605]

Metric val_loss improved by 0.019 >= min_delta = 0.01. New best score: 1.471


Epoch 80: 100%|██████████| 509/509 [00:29<00:00, 17.21it/s, v_num=33, train_loss_step=2.460, train_acc_step=0.000, val_loss=1.430, val_acc=0.593, train_loss_epoch=1.360, train_acc_epoch=0.624] 

Metric val_loss improved by 0.043 >= min_delta = 0.01. New best score: 1.428


Epoch 83: 100%|██████████| 509/509 [00:30<00:00, 16.71it/s, v_num=33, train_loss_step=0.442, train_acc_step=1.000, val_loss=1.410, val_acc=0.599, train_loss_epoch=1.330, train_acc_epoch=0.641] 

Metric val_loss improved by 0.014 >= min_delta = 0.01. New best score: 1.414


Epoch 84: 100%|██████████| 509/509 [00:28<00:00, 17.61it/s, v_num=33, train_loss_step=1.440, train_acc_step=1.000, val_loss=1.400, val_acc=0.599, train_loss_epoch=1.290, train_acc_epoch=0.654] 

Metric val_loss improved by 0.016 >= min_delta = 0.01. New best score: 1.398


Epoch 87: 100%|██████████| 509/509 [00:29<00:00, 17.06it/s, v_num=33, train_loss_step=0.717, train_acc_step=1.000, val_loss=1.370, val_acc=0.590, train_loss_epoch=1.230, train_acc_epoch=0.661] 

Metric val_loss improved by 0.025 >= min_delta = 0.01. New best score: 1.374


Epoch 89: 100%|██████████| 509/509 [00:29<00:00, 17.35it/s, v_num=33, train_loss_step=0.851, train_acc_step=1.000, val_loss=1.340, val_acc=0.619, train_loss_epoch=1.210, train_acc_epoch=0.646] 

Metric val_loss improved by 0.032 >= min_delta = 0.01. New best score: 1.342


Epoch 90: 100%|██████████| 509/509 [00:40<00:00, 12.71it/s, v_num=33, train_loss_step=2.330, train_acc_step=0.000, val_loss=1.320, val_acc=0.619, train_loss_epoch=1.220, train_acc_epoch=0.660] 

Metric val_loss improved by 0.020 >= min_delta = 0.01. New best score: 1.322


Epoch 91: 100%|██████████| 509/509 [00:49<00:00, 10.38it/s, v_num=33, train_loss_step=0.953, train_acc_step=1.000, val_loss=1.300, val_acc=0.625, train_loss_epoch=1.180, train_acc_epoch=0.677]

Metric val_loss improved by 0.024 >= min_delta = 0.01. New best score: 1.298


Epoch 95: 100%|██████████| 509/509 [00:30<00:00, 16.85it/s, v_num=33, train_loss_step=1.310, train_acc_step=1.000, val_loss=1.280, val_acc=0.640, train_loss_epoch=1.130, train_acc_epoch=0.667] 

Metric val_loss improved by 0.018 >= min_delta = 0.01. New best score: 1.280


Epoch 96: 100%|██████████| 509/509 [00:47<00:00, 10.82it/s, v_num=33, train_loss_step=2.080, train_acc_step=0.000, val_loss=1.270, val_acc=0.652, train_loss_epoch=1.120, train_acc_epoch=0.695] 

Metric val_loss improved by 0.013 >= min_delta = 0.01. New best score: 1.267


Epoch 98: 100%|██████████| 509/509 [00:55<00:00,  9.18it/s, v_num=33, train_loss_step=0.840, train_acc_step=1.000, val_loss=1.250, val_acc=0.646, train_loss_epoch=1.100, train_acc_epoch=0.684] 

Metric val_loss improved by 0.016 >= min_delta = 0.01. New best score: 1.251


Epoch 101: 100%|██████████| 509/509 [00:28<00:00, 17.55it/s, v_num=33, train_loss_step=0.958, train_acc_step=1.000, val_loss=1.230, val_acc=0.652, train_loss_epoch=1.030, train_acc_epoch=0.712] 

Metric val_loss improved by 0.017 >= min_delta = 0.01. New best score: 1.234


Epoch 102: 100%|██████████| 509/509 [00:29<00:00, 17.42it/s, v_num=33, train_loss_step=0.184, train_acc_step=1.000, val_loss=1.220, val_acc=0.664, train_loss_epoch=1.020, train_acc_epoch=0.713] 

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 1.219


Epoch 106: 100%|██████████| 509/509 [00:30<00:00, 16.71it/s, v_num=33, train_loss_step=0.115, train_acc_step=1.000, val_loss=1.200, val_acc=0.652, train_loss_epoch=0.978, train_acc_epoch=0.730] 

Metric val_loss improved by 0.018 >= min_delta = 0.01. New best score: 1.201


Epoch 108: 100%|██████████| 509/509 [00:32<00:00, 15.64it/s, v_num=33, train_loss_step=0.515, train_acc_step=1.000, val_loss=1.190, val_acc=0.661, train_loss_epoch=0.962, train_acc_epoch=0.735] 

Metric val_loss improved by 0.011 >= min_delta = 0.01. New best score: 1.190


Epoch 109: 100%|██████████| 509/509 [00:31<00:00, 16.11it/s, v_num=33, train_loss_step=0.913, train_acc_step=1.000, val_loss=1.170, val_acc=0.655, train_loss_epoch=0.953, train_acc_epoch=0.727] 

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 1.175


Epoch 111: 100%|██████████| 509/509 [00:32<00:00, 15.75it/s, v_num=33, train_loss_step=0.313, train_acc_step=1.000, val_loss=1.160, val_acc=0.684, train_loss_epoch=0.936, train_acc_epoch=0.744] 

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 1.160


Epoch 116: 100%|██████████| 509/509 [00:31<00:00, 16.25it/s, v_num=33, train_loss_step=0.230, train_acc_step=1.000, val_loss=1.140, val_acc=0.690, train_loss_epoch=0.853, train_acc_epoch=0.746]  

Metric val_loss improved by 0.019 >= min_delta = 0.01. New best score: 1.141


Epoch 117: 100%|██████████| 509/509 [00:36<00:00, 13.87it/s, v_num=33, train_loss_step=0.134, train_acc_step=1.000, val_loss=1.130, val_acc=0.702, train_loss_epoch=0.853, train_acc_epoch=0.759] 

Metric val_loss improved by 0.013 >= min_delta = 0.01. New best score: 1.128


Epoch 122: 100%|██████████| 509/509 [00:32<00:00, 15.45it/s, v_num=33, train_loss_step=0.227, train_acc_step=1.000, val_loss=1.110, val_acc=0.717, train_loss_epoch=0.830, train_acc_epoch=0.766] 

Metric val_loss improved by 0.019 >= min_delta = 0.01. New best score: 1.109


Epoch 125: 100%|██████████| 509/509 [00:34<00:00, 14.56it/s, v_num=33, train_loss_step=0.993, train_acc_step=1.000, val_loss=1.090, val_acc=0.714, train_loss_epoch=0.767, train_acc_epoch=0.794]  

Metric val_loss improved by 0.018 >= min_delta = 0.01. New best score: 1.090


Epoch 128: 100%|██████████| 509/509 [00:30<00:00, 16.57it/s, v_num=33, train_loss_step=0.416, train_acc_step=1.000, val_loss=1.080, val_acc=0.723, train_loss_epoch=0.760, train_acc_epoch=0.794] 

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 1.076


Epoch 133: 100%|██████████| 509/509 [00:37<00:00, 13.72it/s, v_num=33, train_loss_step=0.695, train_acc_step=1.000, val_loss=1.060, val_acc=0.732, train_loss_epoch=0.699, train_acc_epoch=0.809] 

Metric val_loss improved by 0.019 >= min_delta = 0.01. New best score: 1.056


Epoch 136: 100%|██████████| 509/509 [00:32<00:00, 15.53it/s, v_num=33, train_loss_step=0.0896, train_acc_step=1.000, val_loss=1.040, val_acc=0.735, train_loss_epoch=0.701, train_acc_epoch=0.805]

Metric val_loss improved by 0.014 >= min_delta = 0.01. New best score: 1.042


Epoch 140: 100%|██████████| 509/509 [00:34<00:00, 14.96it/s, v_num=33, train_loss_step=0.905, train_acc_step=1.000, val_loss=1.030, val_acc=0.720, train_loss_epoch=0.648, train_acc_epoch=0.818] 

Metric val_loss improved by 0.011 >= min_delta = 0.01. New best score: 1.031


Epoch 142: 100%|██████████| 509/509 [00:35<00:00, 14.20it/s, v_num=33, train_loss_step=0.574, train_acc_step=1.000, val_loss=1.010, val_acc=0.743, train_loss_epoch=0.642, train_acc_epoch=0.819] 

Metric val_loss improved by 0.016 >= min_delta = 0.01. New best score: 1.015


Epoch 147: 100%|██████████| 509/509 [00:30<00:00, 16.95it/s, v_num=33, train_loss_step=0.0457, train_acc_step=1.000, val_loss=1.000, val_acc=0.746, train_loss_epoch=0.616, train_acc_epoch=0.829]

Metric val_loss improved by 0.014 >= min_delta = 0.01. New best score: 1.001


Epoch 151: 100%|██████████| 509/509 [00:32<00:00, 15.55it/s, v_num=33, train_loss_step=0.0524, train_acc_step=1.000, val_loss=0.987, val_acc=0.743, train_loss_epoch=0.603, train_acc_epoch=0.839]

Metric val_loss improved by 0.014 >= min_delta = 0.01. New best score: 0.987


Epoch 157: 100%|██████████| 509/509 [00:31<00:00, 16.03it/s, v_num=33, train_loss_step=0.999, train_acc_step=1.000, val_loss=0.972, val_acc=0.752, train_loss_epoch=0.537, train_acc_epoch=0.852]  

Metric val_loss improved by 0.015 >= min_delta = 0.01. New best score: 0.972


Epoch 166: 100%|██████████| 509/509 [00:31<00:00, 15.97it/s, v_num=33, train_loss_step=3.010, train_acc_step=0.000, val_loss=0.961, val_acc=0.746, train_loss_epoch=0.488, train_acc_epoch=0.871]  

Metric val_loss improved by 0.011 >= min_delta = 0.01. New best score: 0.961


Epoch 176: 100%|██████████| 509/509 [00:34<00:00, 14.92it/s, v_num=33, train_loss_step=0.082, train_acc_step=1.000, val_loss=0.990, val_acc=0.735, train_loss_epoch=0.425, train_acc_epoch=0.892]  

Monitored metric val_loss did not improve in the last 10 records. Best score: 0.961. Signaling Trainer to stop.


Epoch 176: 100%|██████████| 509/509 [00:35<00:00, 14.15it/s, v_num=33, train_loss_step=0.082, train_acc_step=1.000, val_loss=0.990, val_acc=0.735, train_loss_epoch=0.425, train_acc_epoch=0.892]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\yohanes.setiawan\AppData\Local\miniconda3\envs\wmmd_env\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=21` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 170/170 [00:10<00:00, 15.80it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.7109144330024719
         test_f1            0.6509345173835754
        test_loss           1.0668379068374634
     test_precision         0.7045230865478516
       test_recall          0.6248770952224731
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Artifacts saved to model\mae-ast_imbalance\20250702_170155/
Completed mae-ast (FT), artifacts in model\mae-ast_imbalance\20250702_170155
