# AASIST: Audio Anti-Spoofing Detection with Custom Dataset

This Jupyter Notebook trains and evaluates the AASIST model for audio spoofing detection using a custom dataset. The dataset is organized with `train`, `val`, and `test` folders, each containing `real` and `fake` subfolders with `.flac` audio files. Each audio clip is processed to a fixed length of 2 seconds (32,000 samples at 16kHz).

## Setup Instructions
- Ensure the dataset is available at `./dataset/` (or update `database_path` in the config).
- The dataset should have the following structure:
  ```
  dataset/
  ├── train/
  │   ├── real/
  │   │   ├── audio1.flac
  │   │   └── ...
  │   └── fake/
  │       ├── audio1.flac
  │       └── ...
  ├── val/
  │   ├── real/
  │   │   ├── audio1.flac
  │   │   └── ...
  │   └── fake/
  │       ├── audio1.flac
  │       └── ...
  ├── test/
  │   ├── real/
  │   │   ├── audio1.flac
  │   │   └── ...
  │   └── fake/
  │       ├── audio1.flac
  │       └── ...
  ```
- Install required dependencies (see next cell).
- Modify `output_dir` or other paths as needed.
- Run cells sequentially to train or evaluate the model.

## Dependencies
The following cell installs required Python packages.

In [1]:
!pip install soundfile torchcontrib

Collecting torchcontrib
  Downloading torchcontrib-0.0.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torchcontrib
  Building wheel for torchcontrib (setup.py) ... [?25l[?25hdone
  Created wheel for torchcontrib: filename=torchcontrib-0.0.2-py3-none-any.whl size=7516 sha256=84361530486d758e706887165c7fe4c5a1e148463c894a5527819e3c12386936
  Stored in directory: /root/.cache/pip/wheels/f1/87/f6/b3c995670297d282da49c39ea210c39fc8089c27f453bc1c42
Successfully built torchcontrib
Installing collected packages: torchcontrib
Successfully installed torchcontrib-0.0.2


## Imports
Import all necessary libraries.

In [2]:
import os
import random
import sys
import json
from pathlib import Path
from shutil import copy
from typing import Dict, List, Union
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchcontrib.optim import SWA
from tqdm.notebook import tqdm
import librosa
import soundfile as sf

warnings.filterwarnings("ignore", category=FutureWarning)

2025-12-09 06:01:17.629184: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765260077.809635      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765260077.862303      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Configuration
Define the configuration for the experiment. This is based on the provided JSON config, updated for the custom dataset and 2-second input length.

In [3]:
config = {
    "database_path": "/kaggle/input/asv-19-aa/ASV19",  # Path to the main dataset folder
    "model_path": "/kaggle/working/exp_result/LA_custom_ep50_bs12/weights/best.pth",  # Path to save/load model weights
    "batch_size": 12,
    "num_epochs": 50,
    "target_sr": 16000,
    "loss": "CCE",
    "track": "LA",  # Kept for compatibility, though not used for custom dataset
    "eval_all_best": "True",
    "eval_output": "eval_scores_using_best_dev_model.txt",
    "cudnn_deterministic_toggle": "True",
    "cudnn_benchmark_toggle": "False",
    "model_config": {
        "architecture": "AASIST",
        "nb_samp": 32000,  # 2 seconds at 16kHz
        "first_conv": 128,
        "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
        "gat_dims": [64, 32],
        "pool_ratios": [0.5, 0.7, 0.5, 0.5],
        "temperatures": [2.0, 2.0, 100.0, 100.0]
    },
    "optim_config": {
        "optimizer": "adam",
        "amsgrad": "False",
        "base_lr": 0.0001,
        "lr_min": 0.000005,
        "betas": [0.9, 0.999],
        "weight_decay": 0.0001,
        "scheduler": "cosine",
        "epochs": 50 
    }
}

## Utility Functions
Functions for optimization, scheduling, and reproducibility (from `utils.py`).

In [4]:
# def str_to_bool(val):
#     """Convert a string representation of truth to true (1) or false (0)."""
#     val = val.lower()
#     if val in ('y', 'yes', 't', 'true', 'on', '1'):
#         return True
#     if val in ('n', 'no', 'f', 'false', 'off', '0'):
#         return False
#     raise ValueError('invalid truth value {}'.format(val))

# def cosine_annealing(step, total_steps, lr_max, lr_min):
#     """Cosine Annealing for learning rate decay scheduler"""
#     return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))

# def keras_decay(step, decay=0.0001):
#     """Learning rate decay in Keras-style"""
#     return 1. / (1. + decay * step)

# class SGDRScheduler(torch.optim.lr_scheduler._LRScheduler):
#     """SGD with restarts scheduler"""
#     def __init__(self, optimizer, T0, T_mul, eta_min, last_epoch=-1):
#         self.Ti = T0
#         self.T_mul = T_mul
#         self.eta_min = eta_min
#         self.last_restart = 0
#         super().__init__(optimizer, last_epoch)

#     def get_lr(self):
#         T_cur = self.last_epoch - self.last_restart
#         if T_cur >= self.Ti:
#             self.last_restart = self.last_epoch
#             self.Ti = self.Ti * self.T_mul
#             T_cur = 0
#         return [self.eta_min + (base_lr - self.eta_min) * (1 + np.cos(np.pi * T_cur / self.Ti)) / 2 for base_lr in self.base_lrs]

# def _get_optimizer(model_parameters, optim_config):
#     """Defines optimizer according to the given config"""
#     optimizer_name = optim_config['optimizer']
#     if optimizer_name == 'sgd':
#         optimizer = torch.optim.SGD(model_parameters, lr=optim_config['base_lr'], momentum=optim_config['momentum'],
#                                     weight_decay=optim_config['weight_decay'], nesterov=optim_config['nesterov'])
#     elif optimizer_name == 'adam':
#         optimizer = torch.optim.Adam(model_parameters, lr=optim_config['base_lr'], betas=optim_config['betas'],
#                                      weight_decay=optim_config['weight_decay'], amsgrad=str_to_bool(optim_config['amsgrad']))
#     else:
#         print('Un-known optimizer', optimizer_name)
#         sys.exit()
#     return optimizer

# def _get_scheduler(optimizer, optim_config):
#     """Defines learning rate scheduler according to the given config"""
#     if optim_config['scheduler'] == 'multistep':
#         scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=optim_config['milestones'], gamma=optim_config['lr_decay'])
#     elif optim_config['scheduler'] == 'sgdr':
#         scheduler = SGDRScheduler(optimizer, optim_config['T0'],
#                                   optim_config['Tmult'], optim_config['lr_min'])
#     elif optim_config['scheduler'] == 'cosine':
#         total_steps = optim_config['epochs'] * optim_config['steps_per_epoch']
#         scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: cosine_annealing(step, total_steps, 1, optim_config['lr_min'] / optim_config['base_lr']))
#     elif optim_config['scheduler'] == 'keras_decay':
#         scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: keras_decay(step))
#     else:
#         scheduler = None
#     return scheduler

# def create_optimizer(model_parameters, optim_config):
#     """Defines an optimizer and a scheduler"""
#     optimizer = _get_optimizer(model_parameters, optim_config)
#     scheduler = _get_scheduler(optimizer, optim_config)
#     return optimizer, scheduler

# def seed_worker(worker_id):
#     """Used in generating seed for the worker of torch.utils.data.Dataloader"""
#     worker_seed = torch.initial_seed() % 2**32
#     np.random.seed(worker_seed)
#     random.seed(worker_seed)

# def set_seed(seed, config=None):
#     """Set initial seed for reproduction"""
#     if config is None:
#         raise ValueError("config should not be None")
#     random.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed_all(seed)
#         torch.backends.cudnn.deterministic = str_to_bool(config["cudnn_deterministic_toggle"])
#         torch.backends.cudnn.benchmark = str_to_bool(config["cudnn_benchmark_toggle"])

In [5]:
import torch
import torch.nn as nn
import numpy as np
import random
import sys

def str_to_bool(val):
    """Convert a string representation of truth to true (1) or false (0)."""
    val = val.lower()
    if val in ('y', 'yes', 't', 'true', 'on', '1'):
        return True
    if val in ('n', 'no', 'f', 'false', 'off', '0'):
        return False
    raise ValueError('invalid truth value {}'.format(val))

def cosine_annealing(step, total_steps, lr_max, lr_min):
    """Cosine Annealing for learning rate decay scheduler"""
    return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))

def keras_decay(step, decay=0.0001):
    """Learning rate decay in Keras-style"""
    return 1. / (1. + decay * step)

class SGDRScheduler(torch.optim.lr_scheduler._LRScheduler):
    """SGD with restarts scheduler"""
    def __init__(self, optimizer, T0, T_mul, eta_min, last_epoch=-1):
        self.Ti = T0
        self.T_mul = T_mul
        self.eta_min = eta_min
        self.last_restart = 0
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        T_cur = self.last_epoch - self.last_restart
        if T_cur >= self.Ti:
            self.last_restart = self.last_epoch
            self.Ti = self.Ti * self.T_mul
            T_cur = 0
        return [self.eta_min + (base_lr - self.eta_min) * (1 + np.cos(np.pi * T_cur / self.Ti)) / 2 for base_lr in self.base_lrs]

def _get_optimizer(model_parameters, optim_config):
    """Defines optimizer according to the given config"""
    optimizer_name = optim_config['optimizer']
    if optimizer_name == 'sgd':
        optimizer = torch.optim.SGD(model_parameters, lr=optim_config['base_lr'], momentum=optim_config['momentum'],
                                    weight_decay=optim_config['weight_decay'], nesterov=optim_config['nesterov'])
    elif optimizer_name == 'adam':
        optimizer = torch.optim.Adam(model_parameters, lr=optim_config['base_lr'], betas=optim_config['betas'],
                                     weight_decay=optim_config['weight_decay'], amsgrad=str_to_bool(optim_config['amsgrad']))
    else:
        print('Un-known optimizer', optimizer_name)
        sys.exit()
    return optimizer

def _get_scheduler(optimizer, optim_config):
    """Defines learning rate scheduler according to the given config"""
    if optim_config['scheduler'] == 'multistep':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=optim_config['milestones'], gamma=optim_config['lr_decay'])
    elif optim_config['scheduler'] == 'sgdr':
        scheduler = SGDRScheduler(optimizer, optim_config['T0'],
                                  optim_config['Tmult'], optim_config['lr_min'])
    elif optim_config['scheduler'] == 'cosine':
        total_steps = optim_config['epochs'] * optim_config['steps_per_epoch']
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: cosine_annealing(step, total_steps, 1, optim_config['lr_min'] / optim_config['base_lr']))
    elif optim_config['scheduler'] == 'keras_decay':
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: keras_decay(step))
    else:
        scheduler = None
    return scheduler

def create_optimizer(model_parameters, optim_config):
    """Defines an optimizer and a scheduler"""
    optimizer = _get_optimizer(model_parameters, optim_config)
    scheduler = _get_scheduler(optimizer, optim_config)
    return optimizer, scheduler

def seed_worker(worker_id):
    """Used in generating seed for the worker of torch.utils.data.Dataloader"""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def set_seed(seed, config=None):
    """Set initial seed for reproduction"""
    if config is None:
        raise ValueError("config should not be None")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = str_to_bool(config["cudnn_deterministic_toggle"])
        torch.backends.cudnn.benchmark = str_to_bool(config["cudnn_benchmark_toggle"])

def prepare_model(model):
    """Moves model to CUDA and wraps in DataParallel if multiple GPUs available"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    if torch.cuda.device_count() > 1:
        print(f"[INFO] Using {torch.cuda.device_count()} GPUs with DataParallel")
        model = torch.nn.DataParallel(model)

    return model, device


In [6]:
print(torch.cuda.device_count())  # Should return 2
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_device_name(1))


2
Tesla T4
Tesla T4


## Data Utilities
Functions and classes for loading and preprocessing the custom dataset (from `data_utils.py`, adapted for `.flac` files and 2-second inputs).

In [7]:
def pad(x: np.ndarray, max_len: int = 32000) -> np.ndarray:
    x_len = x.shape[0]
    if x_len >= max_len:
        return x[:max_len]
    num_repeats = (max_len // x_len) + 1
    padded_x = np.tile(x, num_repeats)[:max_len]
    return padded_x

def pad_random(x: np.ndarray, max_len: int = 32000) -> np.ndarray:
    x_len = x.shape[0]
    if x_len >= max_len:
        if x_len == max_len:
            return x  # No need to select a random segment if lengths are equal
        stt = np.random.randint(0, x_len - max_len + 1)  # Ensure valid range
        return x[stt:stt + max_len]
    num_repeats = (max_len // x_len) + 1
    padded_x = np.tile(x, num_repeats)[:max_len]
    return padded_x

class Dataset_Custom(Dataset):
    def __init__(self, base_dir, split='train', target_sr=16000):
        """Custom dataset for loading .flac files from train/val/test folders.
        Args:
            base_dir (str): Path to the main dataset folder (containing train/val/test).
            split (str): One of 'train', 'val', or 'test'.
            target_sr (int): Target sample rate (default: 16000).
        """
        self.base_dir = Path(base_dir) / split
        self.target_sr = target_sr
        self.cut = 32000  # 2 seconds at 16kHz

        # Initialize file lists and labels
        self.file_list = []
        self.labels = {}

        # Load real files (label: 1)
        real_dir = self.base_dir / 'bonafide'
        if real_dir.exists():
            for file in real_dir.glob('*.flac'):
                key = file.stem
                self.file_list.append((key, 'bonafide'))
                self.labels[key] = 1

        # Load fake files (label: 0)
        fake_dir = self.base_dir / 'spoof'
        if fake_dir.exists():
            for file in fake_dir.glob('*.flac'):
                key = file.stem
                self.file_list.append((key, 'spoof'))
                self.labels[key] = 0

        if not self.file_list:
            raise ValueError(f"No .flac files found in {self.base_dir}")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        max_retries = 5
        retries = 0
        while retries < max_retries:
            key, folder = self.file_list[index]
            file_path = self.base_dir / folder / f"{key}.flac"
            try:
                X, _ = sf.read(str(file_path))
                if len(X.shape) > 1:  # If multi-channel, take first channel
                    X = X[:, 0]
                X_pad = pad_random(X, self.cut)
                if np.max(np.abs(X_pad)) == 0:  # Handle zero arrays
                    raise ValueError(f"Audio file {file_path} is silent or invalid")
                X_pad = X_pad / np.max(np.abs(X_pad))  # Normalize
                x_inp = Tensor(X_pad)
                y = self.labels[key]
                return x_inp, y
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
                retries += 1
                index = np.random.randint(len(self.file_list))  # Try a new random index
        raise RuntimeError(f"Failed to load valid data after {max_retries} attempts")
        
    def get_key(self, index):
        key, _ = self.file_list[index]
        return key

## AASIST Model
The AASIST model architecture (from `AASIST.py`), which processes raw audio flaceforms.

- Classic AASIST

In [8]:
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_weight = self._init_new_params(out_dim, 1)
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)
        self.bn = nn.BatchNorm1d(out_dim)
        self.input_drop = nn.Dropout(p=0.2)
        self.act = nn.SELU(inplace=True)
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x):
        x = self.input_drop(x)
        att_map = self._derive_att_map(x)
        x = self._project(x, att_map)
        x = self._apply_BN(x)
        x = self.act(x)
        return x

    def _pairwise_mul_nodes(self, x):
        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)
        return x * x_mirror

    def _derive_att_map(self, x):
        att_map = self._pairwise_mul_nodes(x)
        att_map = torch.tanh(self.att_proj(att_map))
        att_map = torch.matmul(att_map, self.att_weight)
        att_map = att_map / self.temp
        att_map = F.softmax(att_map, dim=-2)
        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)
        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)
        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out

class HtrgGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()
        self.proj_type1 = nn.Linear(in_dim, in_dim)
        self.proj_type2 = nn.Linear(in_dim, in_dim)
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_projM = nn.Linear(in_dim, out_dim)
        self.att_weight11 = self._init_new_params(out_dim, 1)
        self.att_weight22 = self._init_new_params(out_dim, 1)
        self.att_weight12 = self._init_new_params(out_dim, 1)
        self.att_weightM = self._init_new_params(out_dim, 1)
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)
        self.proj_with_attM = nn.Linear(in_dim, out_dim)
        self.proj_without_attM = nn.Linear(in_dim, out_dim)
        self.bn = nn.BatchNorm1d(out_dim)
        self.input_drop = nn.Dropout(p=0.2)
        self.act = nn.SELU(inplace=True)
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x1, x2, master=None):
        num_type1 = x1.size(1)
        num_type2 = x2.size(1)
        x1 = self.proj_type1(x1)
        x2 = self.proj_type2(x2)
        x = torch.cat([x1, x2], dim=1)
        if master is None:
            master = torch.mean(x, dim=1, keepdim=True)
        x = self.input_drop(x)
        att_map = self._derive_att_map(x, num_type1, num_type2)
        master = self._update_master(x, master)
        x = self._project(x, att_map)
        x = self._apply_BN(x)
        x = self.act(x)
        x1 = x.narrow(1, 0, num_type1)
        x2 = x.narrow(1, num_type1, num_type2)
        return x1, x2, master

    def _update_master(self, x, master):
        att_map = self._derive_att_map_master(x, master)
        master = self._project_master(x, master, att_map)
        return master

    def _pairwise_mul_nodes(self, x):
        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)
        return x * x_mirror

    def _derive_att_map_master(self, x, master):
        att_map = x * master
        att_map = torch.tanh(self.att_projM(att_map))
        att_map = torch.matmul(att_map, self.att_weightM)
        att_map = att_map / self.temp
        att_map = F.softmax(att_map, dim=-2)
        return att_map

    def _derive_att_map(self, x, num_type1, num_type2):
        att_map = self._pairwise_mul_nodes(x)
        att_map = torch.tanh(self.att_proj(att_map))
        att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
        att_board[:, :num_type1, :num_type1, :] = torch.matmul(att_map[:, :num_type1, :num_type1, :], self.att_weight11)
        att_board[:, num_type1:, num_type1:, :] = torch.matmul(att_map[:, num_type1:, num_type1:, :], self.att_weight22)
        att_board[:, :num_type1, num_type1:, :] = torch.matmul(att_map[:, :num_type1, num_type1:, :], self.att_weight12)
        att_board[:, num_type1:, :num_type1, :] = torch.matmul(att_map[:, num_type1:, :num_type1, :], self.att_weight12)
        att_map = att_board
        att_map = att_map / self.temp
        att_map = F.softmax(att_map, dim=-2)
        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)
        return x1 + x2

    def _project_master(self, x, master, att_map):
        x1 = self.proj_with_attM(torch.matmul(att_map.squeeze(-1).unsqueeze(1), x))
        x2 = self.proj_without_attM(master)
        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)
        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out

class GraphPool(nn.Module):
    def __init__(self, k: float, in_dim: int, p: Union[float, int]):
        super().__init__()
        self.k = k
        self.sigmoid = nn.Sigmoid()
        self.proj = nn.Linear(in_dim, 1)
        self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
        self.in_dim = in_dim

    def forward(self, h):
        Z = self.drop(h)
        weights = self.proj(Z)
        scores = self.sigmoid(weights)
        new_h = self.top_k_graph(scores, h, self.k)
        return new_h

    def top_k_graph(self, scores, h, k):
        _, n_nodes, n_feat = h.size()
        n_nodes = max(int(n_nodes * k), 1)
        _, idx = torch.topk(scores, n_nodes, dim=1)
        idx = idx.expand(-1, -1, n_feat)
        h = h * scores
        h = torch.gather(h, 1, idx)
        return h

class CONV(nn.Module):
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10**(mel / 2595) - 1)

    def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, stride=1, padding=0, dilation=1, bias=False, groups=1, mask=False):
        super().__init__()
        if in_channels != 1:
            msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
            raise ValueError(msg)
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate
        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.mask = mask
        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')
        NFFT = 512
        f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        fmel = self.to_mel(f)
        fmelmax = np.max(fmel)
        fmelmin = np.min(fmel)
        filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
        filbandwidthsf = self.to_hz(filbandwidthsmel)
        self.mel = filbandwidthsf
        self.hsupp = torch.arange(-(self.kernel_size - 1) / 2, (self.kernel_size - 1) / 2 + 1)
        self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
        for i in range(len(self.mel) - 1):
            fmin = self.mel[i]
            fmax = self.mel[i + 1]
            hHigh = (2*fmax/self.sample_rate) * np.sinc(2*fmax*self.hsupp/self.sample_rate)
            hLow = (2*fmin/self.sample_rate) * np.sinc(2*fmin*self.hsupp/self.sample_rate)
            hideal = hHigh - hLow
            self.band_pass[i, :] = Tensor(np.hamming(self.kernel_size)) * Tensor(hideal)

    def forward(self, x, mask=False):
        band_pass_filter = self.band_pass.clone().to(x.device)
        if mask:
            A = np.random.uniform(0, 20)
            A = int(A)
            A0 = random.randint(0, band_pass_filter.shape[0] - A)
            band_pass_filter[A0:A0 + A, :] = 0
        else:
            band_pass_filter = band_pass_filter
        self.filters = (band_pass_filter).view(self.out_channels, 1, self.kernel_size)
        return F.conv1d(x, self.filters, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=None, groups=1)

class Residual_block(nn.Module):
    def __init__(self, nb_filts, first=False):
        super().__init__()
        self.first = first
        if not self.first:
            self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
        self.conv1 = nn.Conv2d(in_channels=nb_filts[0], out_channels=nb_filts[1], kernel_size=(2, 3), padding=(1, 1), stride=1)
        self.selu = nn.SELU(inplace=False)
        self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
        self.conv2 = nn.Conv2d(in_channels=nb_filts[1], out_channels=nb_filts[1], kernel_size=(2, 3), padding=(0, 1), stride=1)
        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], out_channels=nb_filts[1], padding=(0, 1), kernel_size=(1, 3), stride=1)
        else:
            self.downsample = False
        self.mp = nn.MaxPool2d((1, 3))

    def forward(self, x):
        identity = x
        if not self.first:
            out = self.bn1(x)
            out = self.selu(out)
        else:
            out = x
        out = self.conv1(x)
        out = self.bn2(out)
        out = self.selu(out)
        out = self.conv2(out)
        if self.downsample:
            identity = self.conv_downsample(identity)
        out += identity
        out = self.mp(out)
        return out

class Model(nn.Module):
    def __init__(self, d_args):
        super().__init__()
        self.d_args = d_args
        filts = d_args["filts"]
        gat_dims = d_args["gat_dims"]
        pool_ratios = d_args["pool_ratios"]
        temperatures = d_args["temperatures"]
        self.conv_time = CONV(out_channels=filts[0], kernel_size=d_args["first_conv"], in_channels=1)
        self.first_bn = nn.BatchNorm2d(num_features=1)
        self.drop = nn.Dropout(0.5, inplace=True)
        self.drop_way = nn.Dropout(0.2, inplace=True)
        self.selu = nn.SELU(inplace=False)
        self.encoder = nn.Sequential(
            nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
            nn.Sequential(Residual_block(nb_filts=filts[2])),
            nn.Sequential(Residual_block(nb_filts=filts[3])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])))
        self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
        self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], gat_dims[0], temperature=temperatures[0])
        self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], gat_dims[0], temperature=temperatures[1])
        self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(gat_dims[1], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(gat_dims[1], gat_dims[1], temperature=temperatures[2])
        self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
        self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
        self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.out_layer = nn.Linear(5 * gat_dims[1], 2)

    def forward(self, x, Freq_aug=False):
        x = x.unsqueeze(1)
        x = self.conv_time(x, mask=Freq_aug)
        x = x.unsqueeze(dim=1)
        x = F.max_pool2d(torch.abs(x), (3, 3))
        x = self.first_bn(x)
        x = self.selu(x)
        e = self.encoder(x)
        e_S, _ = torch.max(torch.abs(e), dim=3)
        e_S = e_S.transpose(1, 2) + self.pos_S
        gat_S = self.GAT_layer_S(e_S)
        out_S = self.pool_S(gat_S)
        e_T, _ = torch.max(torch.abs(e), dim=2)
        e_T = e_T.transpose(1, 2)
        gat_T = self.GAT_layer_T(e_T)
        out_T = self.pool_T(gat_T)
        master1 = self.master1.expand(x.size(0), -1, -1)
        master2 = self.master2.expand(x.size(0), -1, -1)
        out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(out_T, out_S, master=self.master1)
        out_S1 = self.pool_hS1(out_S1)
        out_T1 = self.pool_hT1(out_T1)
        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(out_T1, out_S1, master=master1)
        out_T1 = out_T1 + out_T_aug
        out_S1 = out_S1 + out_S_aug
        master1 = master1 + master_aug
        out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(out_T, out_S, master=self.master2)
        out_S2 = self.pool_hS2(out_S2)
        out_T2 = self.pool_hT2(out_T2)
        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(out_T2, out_S2, master=master2)
        out_T2 = out_T2 + out_T_aug
        out_S2 = out_S2 + out_S_aug
        master2 = master2 + master_aug
        out_T1 = self.drop_way(out_T1)
        out_T2 = self.drop_way(out_T2)
        out_S1 = self.drop_way(out_S1)
        out_S2 = self.drop_way(out_S2)
        master1 = self.drop_way(master1)
        master2 = self.drop_way(master2)
        out_T = torch.max(out_T1, out_T2)
        out_S = torch.max(out_S1, out_S2)
        master = torch.max(master1, master2)
        T_max, _ = torch.max(torch.abs(out_T), dim=1)
        T_avg = torch.mean(out_T, dim=1)
        S_max, _ = torch.max(torch.abs(out_S), dim=1)
        S_avg = torch.mean(out_S, dim=1)
        last_hidden = torch.cat([T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
        last_hidden = self.drop(last_hidden)
        output = self.out_layer(last_hidden)
        return output

- Bi-Mamba AASIST

In [9]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
# import random
# from typing import Union


# # --- Mamba Blocks ---
# class MambaBlock(nn.Module):
#     def __init__(self, d_model, d_state=64, d_conv=3, d_expand=2):
#         super().__init__()
#         self.inner_dim = d_model * d_expand
#         self.linear_proj = nn.Linear(d_model, self.inner_dim * 2)
#         self.conv = nn.Conv1d(self.inner_dim, self.inner_dim, kernel_size=d_conv, padding=d_conv // 2, groups=d_expand)
#         self.linear_out = nn.Linear(self.inner_dim, d_model)
#         self.norm = nn.LayerNorm(d_model)

#     def forward(self, x):
#         x_proj = self.linear_proj(x)
#         u, v = x_proj.chunk(2, dim=-1)
#         u = u.transpose(1, 2)
#         u = self.conv(u)
#         u = u.transpose(1, 2)
#         x = self.linear_out(F.silu(u * v))
#         return self.norm(x)

# class BidirectionalMamba(nn.Module):
#     def __init__(self, d_model):
#         super().__init__()
#         self.mamba_forward = MambaBlock(d_model)
#         self.mamba_backward = MambaBlock(d_model)
#         self.fuse = nn.Sequential(
#             nn.Linear(2 * d_model, d_model),
#             nn.SiLU(),
#             nn.Linear(d_model, d_model)
#         )

#     def forward(self, x):
#         x_fwd = self.mamba_forward(x)
#         x_bwd = self.mamba_backward(torch.flip(x, dims=[1]))
#         x_bwd = torch.flip(x_bwd, dims=[1])
#         x = torch.cat([x_fwd, x_bwd], dim=-1)
#         return self.fuse(x)

# # class HtrgMambaBlock(nn.Module):
# #     def __init__(self, in_dim, out_dim):
# #         super().__init__()
# #         self.mamba_1 = BidirectionalMamba(in_dim)
# #         self.mamba_2 = BidirectionalMamba(in_dim)
# #         self.master_update = nn.Sequential(
# #             nn.Linear(in_dim, out_dim),
# #             nn.SiLU(),
# #             nn.Linear(out_dim, out_dim)
# #         )
# #         self.out_proj_1 = nn.Linear(in_dim, out_dim)
# #         self.out_proj_2 = nn.Linear(in_dim, out_dim)
# #         self.norm = nn.LayerNorm(out_dim)

# #     def forward(self, x1, x2, master):
# #         x1_mamba = self.mamba_1(x1)
# #         x2_mamba = self.mamba_2(x2)
# #         x1_out = self.out_proj_1(x1_mamba)
# #         x2_out = self.out_proj_2(x2_mamba)
# #         self.master_proj = nn.Linear(64, out_dim)
# #         master_update = self.master_update((x1_out.mean(1) + x2_out.mean(1)) / 2 + master.squeeze(1))
# #         return self.norm(x1_out), self.norm(x2_out), master_update.unsqueeze(1)
# class HtrgMambaBlock(nn.Module):
#     def __init__(self, in_dim, out_dim, master_dim=None):
#         super().__init__()
#         self.in_dim = in_dim
#         self.out_dim = out_dim
#         self.master_dim = master_dim if master_dim is not None else out_dim

#         self.mamba_1 = BidirectionalMamba(in_dim)
#         self.mamba_2 = BidirectionalMamba(in_dim)

#         self.out_proj_1 = nn.Linear(in_dim, out_dim)
#         self.out_proj_2 = nn.Linear(in_dim, out_dim)

#         # Always project master to out_dim for guaranteed shape match
#         self.master_proj = nn.Linear(self.master_dim, out_dim)

#         self.master_update = nn.Sequential(
#             nn.Linear(out_dim, out_dim),
#             nn.SiLU(),
#             nn.Linear(out_dim, out_dim)
#         )

#         self.norm = nn.LayerNorm(out_dim)

#     def forward(self, x1, x2, master):
#         # x1, x2: [B, T, in_dim]; master: [B, 1, master_dim]
#         x1_mamba = self.mamba_1(x1)
#         x2_mamba = self.mamba_2(x2)

#         x1_out = self.out_proj_1(x1_mamba)  # [B, T, out_dim]
#         x2_out = self.out_proj_2(x2_mamba)

#         branch_avg = (x1_out.mean(1) + x2_out.mean(1)) / 2  # [B, out_dim]
#         master_proj = self.master_proj(master.squeeze(1))   # [B, out_dim]

#         fused = branch_avg + master_proj                    # ✅ Both are [B, out_dim]
#         master_update = self.master_update(fused)           # [B, out_dim]

#         return self.norm(x1_out), self.norm(x2_out), master_update.unsqueeze(1)  # [B, T, out_dim], [B, T, out_dim], [B, 1, out_dim]



# # --- GraphPool ---
# class GraphPool(nn.Module):
#     def __init__(self, k: float, in_dim: int, p: Union[float, int]):
#         super().__init__()
#         self.k = k
#         self.sigmoid = nn.Sigmoid()
#         self.proj = nn.Linear(in_dim, 1)
#         self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
#         self.in_dim = in_dim

#     def forward(self, h):
#         Z = self.drop(h)
#         weights = self.proj(Z)
#         scores = self.sigmoid(weights)
#         return self.top_k_graph(scores, h, self.k)

#     def top_k_graph(self, scores, h, k):
#         _, n_nodes, n_feat = h.size()
#         n_nodes = max(int(n_nodes * k), 1)
#         _, idx = torch.topk(scores, n_nodes, dim=1)
#         idx = idx.expand(-1, -1, n_feat)
#         h = h * scores
#         h = torch.gather(h, 1, idx)
#         return h


# # --- CONV: SincConv Filter for Raw Audio ---
# class CONV(nn.Module):
#     @staticmethod
#     def to_mel(hz):
#         return 2595 * np.log10(1 + hz / 700)

#     @staticmethod
#     def to_hz(mel):
#         return 700 * (10**(mel / 2595) - 1)

#     def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, stride=1, padding=0, dilation=1, bias=False, groups=1, mask=False):
#         super().__init__()
#         if in_channels != 1:
#             raise ValueError("SincConv only supports one input channel.")
#         self.out_channels = out_channels
#         self.kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
#         self.sample_rate = sample_rate
#         self.stride = stride
#         self.padding = padding
#         self.dilation = dilation
#         self.mask = mask
#         self.hsupp = torch.arange(-(self.kernel_size - 1) / 2, (self.kernel_size - 1) / 2 + 1)
#         NFFT = 512
#         f = int(sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
#         fmel = self.to_mel(f)
#         filbandwidthsmel = np.linspace(np.min(fmel), np.max(fmel), out_channels + 1)
#         filbandwidthsf = self.to_hz(filbandwidthsmel)
#         self.mel = filbandwidthsf
#         self.band_pass = torch.zeros(out_channels, self.kernel_size)
#         for i in range(len(self.mel) - 1):
#             fmin = self.mel[i]
#             fmax = self.mel[i + 1]
#             hHigh = (2 * fmax / sample_rate) * np.sinc(2 * fmax * self.hsupp / sample_rate)
#             hLow = (2 * fmin / sample_rate) * np.sinc(2 * fmin * self.hsupp / sample_rate)
#             hideal = hHigh - hLow
#             self.band_pass[i, :] = torch.Tensor(np.hamming(self.kernel_size)) * torch.Tensor(hideal)

#     def forward(self, x, mask=False):
#         band_pass_filter = self.band_pass.clone().to(x.device)
#         if mask:
#             A = int(np.random.uniform(0, 20))
#             A0 = random.randint(0, band_pass_filter.shape[0] - A)
#             band_pass_filter[A0:A0 + A, :] = 0
#         filters = band_pass_filter.view(self.out_channels, 1, self.kernel_size)
#         return F.conv1d(x, filters, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=None)


# # --- Residual CNN Block ---
# class Residual_block(nn.Module):
#     def __init__(self, nb_filts, first=False):
#         super().__init__()
#         self.first = first
#         self.bn1 = nn.BatchNorm2d(nb_filts[0]) if not first else nn.Identity()
#         self.conv1 = nn.Conv2d(nb_filts[0], nb_filts[1], kernel_size=(2, 3), padding=(1, 1))
#         self.bn2 = nn.BatchNorm2d(nb_filts[1])
#         self.selu = nn.SELU(inplace=False)
#         self.conv2 = nn.Conv2d(nb_filts[1], nb_filts[1], kernel_size=(2, 3), padding=(0, 1))
#         self.downsample = (nb_filts[0] != nb_filts[1])
#         self.conv_downsample = nn.Conv2d(nb_filts[0], nb_filts[1], kernel_size=(1, 3), padding=(0, 1)) if self.downsample else nn.Identity()
#         self.mp = nn.MaxPool2d((1, 3))

#     def forward(self, x):
#         identity = x
#         out = self.bn1(x)
#         out = self.selu(out)
#         out = self.conv1(out)
#         out = self.bn2(out)
#         out = self.selu(out)
#         out = self.conv2(out)
#         if self.downsample:
#             identity = self.conv_downsample(identity)
#         out += identity
#         return self.mp(out)


# # --- Full BiMambaAASIST Model ---
# class Model(nn.Module):
#     def __init__(self, d_args):
#         super().__init__()
#         filts = d_args["filts"]
#         gat_dims = d_args["gat_dims"]
#         pool_ratios = d_args["pool_ratios"]
#         self.conv_time = CONV(out_channels=filts[0], kernel_size=d_args["first_conv"], in_channels=1)
#         self.first_bn = nn.BatchNorm2d(1)
#         self.drop = nn.Dropout(0.5, inplace=True)
#         self.drop_way = nn.Dropout(0.2, inplace=True)
#         self.selu = nn.SELU(inplace=False)
#         self.encoder = nn.Sequential(*[Residual_block(nb_filts=f, first=(i == 0)) for i, f in enumerate(filts[1:])])
#         self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
#         self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
#         self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
#         self.GAT_layer_S = BidirectionalMamba(filts[-1][-1])
#         self.GAT_layer_T = BidirectionalMamba(filts[-1][-1])
#         self.HtrgGAT_layer_ST11 = HtrgMambaBlock(gat_dims[0], gat_dims[1], master_dim=gat_dims[0])
#         self.HtrgGAT_layer_ST12 = HtrgMambaBlock(gat_dims[1], gat_dims[1], master_dim=gat_dims[1])
#         self.HtrgGAT_layer_ST21 = HtrgMambaBlock(gat_dims[0], gat_dims[1], master_dim=gat_dims[0])
#         self.HtrgGAT_layer_ST22 = HtrgMambaBlock(gat_dims[1], gat_dims[1], master_dim=gat_dims[1])
#         self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
#         self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
#         self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
#         self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
#         self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
#         self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
#         self.out_layer = nn.Linear(5 * gat_dims[1], 2)

#     def forward(self, x, Freq_aug=False):
#         x = x.unsqueeze(1)
#         x = self.conv_time(x, mask=Freq_aug)
#         x = x.unsqueeze(1)
#         x = F.max_pool2d(torch.abs(x), (3, 3))
#         x = self.first_bn(x)
#         x = self.selu(x)
#         e = self.encoder(x)
#         e_S, _ = torch.max(torch.abs(e), dim=3)
#         e_S = e_S.transpose(1, 2) + self.pos_S
#         gat_S = self.GAT_layer_S(e_S)
#         out_S = self.pool_S(gat_S)
#         e_T, _ = torch.max(torch.abs(e), dim=2)
#         e_T = e_T.transpose(1, 2)
#         gat_T = self.GAT_layer_T(e_T)
#         out_T = self.pool_T(gat_T)
#         master1 = self.master1.expand(x.size(0), -1, -1)
#         master2 = self.master2.expand(x.size(0), -1, -1)
#         out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(out_T, out_S, master1)
#         out_S1 = self.pool_hS1(out_S1)
#         out_T1 = self.pool_hT1(out_T1)
#         out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(out_T1, out_S1, master1)
#         out_T1 = out_T1 + out_T_aug
#         out_S1 = out_S1 + out_S_aug
#         master1 = master1 + master_aug
#         out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(out_T, out_S, master2)
#         out_S2 = self.pool_hS2(out_S2)
#         out_T2 = self.pool_hT2(out_T2)
#         out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(out_T2, out_S2, master2)
#         out_T2 = out_T2 + out_T_aug
#         out_S2 = out_S2 + out_S_aug
#         master2 = master2 + master_aug
#         out_T = torch.max(self.drop_way(out_T1), self.drop_way(out_T2))
#         out_S = torch.max(self.drop_way(out_S1), self.drop_way(out_S2))
#         master = torch.max(self.drop_way(master1), self.drop_way(master2))
#         T_max, _ = torch.max(torch.abs(out_T), dim=1)
#         T_avg = torch.mean(out_T, dim=1)
#         S_max, _ = torch.max(torch.abs(out_S), dim=1)
#         S_avg = torch.mean(out_S, dim=1)
#         last_hidden = torch.cat([T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
#         last_hidden = self.drop(last_hidden)
#         return self.out_layer(last_hidden)


- FAN-AASIST

In [10]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
# import random

# # ---------- FAN Blocks ----------

# class FANBlock(nn.Module):
#     def __init__(self, in_dim, out_dim):
#         super().__init__()
#         self.linear_p = nn.Linear(in_dim, out_dim // 4)
#         self.linear_g = nn.Linear(in_dim, out_dim - (out_dim // 2))
#         self.act = nn.GELU()
#         self.gate = nn.Parameter(torch.randn(1))

#     def forward(self, x):
#         p = self.linear_p(x)
#         g = self.act(self.linear_g(x))
#         gate = torch.sigmoid(self.gate)
#         return torch.cat([gate * torch.cos(p), gate * torch.sin(p), (1 - gate) * g], dim=-1)

# # class HtrgFANBlock(nn.Module):
# #     def __init__(self, in_dim, out_dim):
# #         super().__init__()
# #         self.fan1 = FANBlock(in_dim, out_dim)
# #         self.fan2 = FANBlock(in_dim, out_dim)
# #         self.master_proj = nn.Linear(out_dim, out_dim)
# #         self.bn = nn.BatchNorm1d(out_dim)
# #         self.act = nn.SELU(inplace=True)

# class HtrgFANBlock(nn.Module):
#     def __init__(self, in_dim, out_dim):
#         super().__init__()
#         self.fan1 = FANBlock(in_dim, out_dim)
#         self.fan2 = FANBlock(in_dim, out_dim)
#         self.master_proj = nn.Linear(in_dim, out_dim)  
#         self.bn = nn.BatchNorm1d(out_dim)
#         self.act = nn.SELU(inplace=True)
        
#     def forward(self, x1, x2, master=None):
#         x1 = self.fan1(x1)
#         x2 = self.fan2(x2)
#         x = torch.cat([x1, x2], dim=1)

#         if master is None:
#             master = torch.mean(x, dim=1, keepdim=True)
#         master = self.master_proj(master)
#         out = self._apply_BN(x)
#         return x1, x2, master

#     def _apply_BN(self, x):
#         B, N, C = x.shape
#         x = self.bn(x.view(-1, C))
#         return self.act(x.view(B, N, C))

# # ---------- GraphPool and SincConv ----------

# class GraphPool(nn.Module):
#     def __init__(self, k: float, in_dim: int, p: float = 0.3):
#         super().__init__()
#         self.k = k
#         self.sigmoid = nn.Sigmoid()
#         self.proj = nn.Linear(in_dim, 1)
#         self.drop = nn.Dropout(p=p)
#         self.in_dim = in_dim

#     def forward(self, h):
#         Z = self.drop(h)
#         scores = self.sigmoid(self.proj(Z))
#         return self.top_k_graph(scores, h, self.k)

#     def top_k_graph(self, scores, h, k):
#         B, N, C = h.shape
#         K = max(int(N * k), 1)
#         _, idx = torch.topk(scores, K, dim=1)
#         idx = idx.expand(-1, -1, C)
#         h = h * scores
#         return torch.gather(h, 1, idx)

# class CONV(nn.Module):
#     def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1):
#         super().__init__()
#         if in_channels != 1:
#             raise ValueError("SincConv only supports 1 input channel")
#         if kernel_size % 2 == 0:
#             kernel_size += 1
#         self.out_channels = out_channels
#         self.kernel_size = kernel_size
#         self.sample_rate = sample_rate
#         self.hsupp = torch.arange(-(kernel_size - 1) / 2, (kernel_size - 1) / 2 + 1)
#         self.band_pass = self._create_filters()

#     def _create_filters(self):
#         NFFT = 512
#         f = np.linspace(0, self.sample_rate / 2, int(NFFT / 2) + 1)
#         mel = 2595 * np.log10(1 + f / 700)
#         mel_min, mel_max = mel.min(), mel.max()
#         mel_bands = np.linspace(mel_min, mel_max, self.out_channels + 1)
#         hz_bands = 700 * (10**(mel_bands / 2595) - 1)

#         filters = torch.zeros(self.out_channels, self.kernel_size)
#         for i in range(self.out_channels):
#             fmin, fmax = hz_bands[i], hz_bands[i + 1]
#             hHigh = 2*fmax/self.sample_rate * np.sinc(2*fmax*self.hsupp/self.sample_rate)
#             hLow = 2*fmin/self.sample_rate * np.sinc(2*fmin*self.hsupp/self.sample_rate)
#             filters[i] = torch.tensor(np.hamming(self.kernel_size) * (hHigh - hLow))
#         return filters

#     def forward(self, x, mask=False):
#         filt = self.band_pass.clone().to(x.device)
#         if mask:
#             A = random.randint(0, 20)
#             A0 = random.randint(0, filt.shape[0] - A)
#             filt[A0:A0+A] = 0
#         filters = filt.view(self.out_channels, 1, self.kernel_size)
#         return F.conv1d(x, filters, stride=1, padding=0)

# # ---------- Residual Block ----------

# class Residual_block(nn.Module):
#     def __init__(self, nb_filts, first=False):
#         super().__init__()
#         self.first = first
#         self.conv1 = nn.Conv2d(nb_filts[0], nb_filts[1], (2, 3), padding=(1, 1))
#         self.bn2 = nn.BatchNorm2d(nb_filts[1])
#         self.selu = nn.SELU(inplace=True)
#         self.conv2 = nn.Conv2d(nb_filts[1], nb_filts[1], (2, 3), padding=(0, 1))
#         self.mp = nn.MaxPool2d((1, 3))
#         self.downsample = nb_filts[0] != nb_filts[1]
#         if self.downsample:
#             self.conv_downsample = nn.Conv2d(nb_filts[0], nb_filts[1], (1, 3), padding=(0, 1))

#     def forward(self, x):
#         identity = x
#         out = self.conv1(x if self.first else self.selu(x))
#         out = self.selu(self.bn2(out))
#         out = self.conv2(out)
#         if self.downsample:
#             identity = self.conv_downsample(identity)
#         return self.mp(out + identity)

# # ---------- FAN-AASIST Model ----------

# class Model(nn.Module):
#     def __init__(self, d_args):
#         super().__init__()
#         filts, gat_dims = d_args["filts"], d_args["gat_dims"]
#         pool_ratios = d_args["pool_ratios"]
#         self.conv_time = CONV(filts[0], d_args["first_conv"])
#         self.first_bn = nn.BatchNorm2d(1)
#         self.selu = nn.SELU(inplace=True)
#         self.encoder = nn.Sequential(
#             *[Residual_block(f, first=(i==0)) for i, f in enumerate(filts[1:])]
#         )
#         self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
#         self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
#         self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
#         self.GAT_layer_S = FANBlock(filts[-1][-1], gat_dims[0])
#         self.GAT_layer_T = FANBlock(filts[-1][-1], gat_dims[0])
#         self.HtrgGAT_layers = nn.ModuleList([
#             HtrgFANBlock(gat_dims[0], gat_dims[1]),
#             HtrgFANBlock(gat_dims[1], gat_dims[1]),
#             HtrgFANBlock(gat_dims[0], gat_dims[1]),
#             HtrgFANBlock(gat_dims[1], gat_dims[1])
#         ])
#         self.pool_S = GraphPool(pool_ratios[0], gat_dims[0])
#         self.pool_T = GraphPool(pool_ratios[1], gat_dims[0])
#         self.pool_h = nn.ModuleList([
#             GraphPool(pool_ratios[2], gat_dims[1]),
#             GraphPool(pool_ratios[2], gat_dims[1]),
#             GraphPool(pool_ratios[2], gat_dims[1]),
#             GraphPool(pool_ratios[2], gat_dims[1])
#         ])
#         self.out_layer = nn.Linear(5 * gat_dims[1], 2)
#         self.drop = nn.Dropout(0.5)
#         self.drop_way = nn.Dropout(0.2)

#     def forward(self, x, Freq_aug=False):
#         x = x.unsqueeze(1)
#         x = self.conv_time(x, mask=Freq_aug).unsqueeze(1)
#         x = F.max_pool2d(torch.abs(x), (3, 3))
#         x = self.selu(self.first_bn(x))
#         e = self.encoder(x)
#         e_S = torch.max(torch.abs(e), dim=3)[0].transpose(1, 2) + self.pos_S
#         gat_S = self.GAT_layer_S(e_S)
#         out_S = self.pool_S(gat_S)
#         e_T = torch.max(torch.abs(e), dim=2)[0].transpose(1, 2)
#         gat_T = self.GAT_layer_T(e_T)
#         out_T = self.pool_T(gat_T)

#         master1 = self.master1.expand(x.size(0), -1, -1)
#         out_T1, out_S1, master1 = self.HtrgGAT_layers[0](out_T, out_S, master1)
#         out_S1, out_T1 = self.pool_h[0](out_S1), self.pool_h[1](out_T1)
#         T_aug, S_aug, m_aug = self.HtrgGAT_layers[1](out_T1, out_S1, master1)
#         out_T1, out_S1, master1 = out_T1 + T_aug, out_S1 + S_aug, master1 + m_aug

#         master2 = self.master2.expand(x.size(0), -1, -1)
#         out_T2, out_S2, master2 = self.HtrgGAT_layers[2](out_T, out_S, master2)
#         out_S2, out_T2 = self.pool_h[2](out_S2), self.pool_h[3](out_T2)
#         T_aug, S_aug, m_aug = self.HtrgGAT_layers[3](out_T2, out_S2, master2)
#         out_T2, out_S2, master2 = out_T2 + T_aug, out_S2 + S_aug, master2 + m_aug

#         out_T = torch.max(self.drop_way(out_T1), self.drop_way(out_T2))
#         out_S = torch.max(self.drop_way(out_S1), self.drop_way(out_S2))
#         master = torch.max(self.drop_way(master1), self.drop_way(master2))
#         T_max, T_avg = out_T.abs().max(dim=1)[0], out_T.mean(dim=1)
#         S_max, S_avg = out_S.abs().max(dim=1)[0], out_S.mean(dim=1)
#         last_hidden = self.drop(torch.cat([T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1))
#         return self.out_layer(last_hidden)


## Evaluation Function
Function to compute Equal Error Rate (EER) using a placeholder `compute_eer` function.

In [11]:
# Placeholder for compute_eer (replace with actual implementation if available)
def compute_eer(target_scores, nontarget_scores):
    """Placeholder for EER computation. Replace with actual implementation."""
    from sklearn.metrics import roc_curve
    import numpy as np
    scores = np.concatenate([target_scores, nontarget_scores])
    labels = np.concatenate([np.ones(len(target_scores)), np.zeros(len(nontarget_scores))])
    fpr, tpr, thresholds = roc_curve(labels, scores)
    fnr = 1 - tpr
    eer = fpr[np.argmin(np.abs(fpr - fnr))]
    return eer, thresholds[np.argmin(np.abs(fpr - fnr))]

def evaluate(loader, model, device: torch.device):
    """Evaluate the model on the given loader, then return EER."""
    model.eval()
    target_scores = []
    nontarget_scores = []
    with torch.no_grad():
        for batch_x, batch_y in tqdm(loader, total=len(loader)):
            batch_x = batch_x.to(device)
            batch_out = model(batch_x)
            batch_score = batch_out[:, 1].data.cpu().numpy().ravel()
            batch_y = batch_y.data.cpu().numpy().ravel()
            for i in range(len(batch_y)):
                if batch_y[i] == 1:
                    target_scores.append(batch_score[i])
                else:
                    nontarget_scores.append(batch_score[i])
    eer, _ = compute_eer(target_scores, nontarget_scores)
    return eer

## DataLoader
Function to create DataLoaders for the custom dataset.

In [12]:
def get_loader(database_path: str, seed: int, config: dict) -> List[torch.utils.data.DataLoader]:
    """Create DataLoaders for train, validation, and test sets."""
    target_sr = float(config["target_sr"])
    train_set = Dataset_Custom(base_dir=database_path, split='train', target_sr=target_sr)
    val_set = Dataset_Custom(base_dir=database_path, split='dev', target_sr=target_sr)
    test_set = Dataset_Custom(base_dir=database_path, split='test', target_sr=target_sr)

    print(f"No. training files: {len(train_set)}")
    print(f"No. validation files: {len(val_set)}")
    print(f"No. test files: {len(test_set)}")

    gen = torch.Generator()
    gen.manual_seed(seed)
    trn_loader = DataLoader(train_set, batch_size=config["batch_size"], shuffle=True, drop_last=True,
                            pin_memory=True, worker_init_fn=seed_worker, generator=gen, num_workers=4)
    val_loader = DataLoader(val_set, batch_size=config["batch_size"], shuffle=False, drop_last=False,
                            pin_memory=True, num_workers=4)
    test_loader = DataLoader(test_set, batch_size=config["batch_size"], shuffle=False, drop_last=False,
                             pin_memory=True, num_workers=4)
    return trn_loader, val_loader, test_loader

## Training Function
Function to train the model for one epoch.

In [13]:
def train_epoch(trn_loader: DataLoader, model, optim: Union[torch.optim.SGD, torch.optim.Adam],
                device: torch.device, scheduler: torch.optim.lr_scheduler, config: dict):
    """Train the model for one epoch."""
    running_loss = 0
    num_total = 0.0
    model.train()
    weight = torch.FloatTensor([0.1, 0.9]).to(device)
    criterion = nn.CrossEntropyLoss(weight=weight)
    pbar = tqdm(trn_loader, total=len(trn_loader))
    for batch_x, batch_y in pbar:
        batch_size = batch_x.size(0)
        num_total += batch_size
        batch_x = batch_x.to(device)
        batch_y = batch_y.view(-1).type(torch.int64).to(device)
        batch_out = model(batch_x)
        batch_loss = criterion(batch_out, batch_y)
        running_loss += batch_loss.item() * batch_size
        pbar.set_description(f"loss: {batch_loss.item():.5f}, running loss: {running_loss / num_total:.5f}")
        optim.zero_grad()
        batch_loss.backward()
        optim.step()
        if config["optim_config"]["scheduler"] in ["cosine", "keras_decay"]:
            scheduler.step()
        elif scheduler is None:
            pass
        else:
            raise ValueError(f"scheduler error, got: {scheduler}")
    running_loss /= num_total
    return running_loss
    

## Main Training and Evaluation Loop
Set up the model, train, and evaluate it.

In [14]:
# # Set random seed
# seed = 1234
# set_seed(seed, config)

# # Set device
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(f"Device: {device}")
# if device == "cpu":
#     print("Warning: Running on CPU, which may be slow!")

# # Define model
# model = Model(config["model_config"]).to(device)
# nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
# print(f"No. model params: {nb_params}")

# # Create DataLoaders
# output_dir = Path("./exp_result")
# model_tag = f"LA_custom_ep{config['num_epochs']}_bs{config['batch_size']}"
# model_tag = output_dir / model_tag
# model_save_path = model_tag / "weights"
# eval_score_path = model_tag / config["eval_output"]
# writer = SummaryWriter(model_tag)
# os.makedirs(model_save_path, exist_ok=True)

# trn_loader, val_loader, test_loader = get_loader(config["database_path"], seed, config)

# # Get optimizer and scheduler
# config["optim_config"]["steps_per_epoch"] = len(trn_loader)
# optimizer, scheduler = create_optimizer(model.parameters(), config["optim_config"])
# optimizer_swa = SWA(optimizer)

# # Training loop
# best_val_eer = 1.0
# best_test_eer = 100.0
# n_swa_update = 0
# f_log = open(model_tag / "metric_log.txt", "a")
# f_log.write("=" * 5 + "\n")

# for epoch in range(config["num_epochs"]):
#     print(f"Start training epoch {epoch:03d}")
#     running_loss = train_epoch(trn_loader, model, optimizer, device, scheduler, config)
#     train_eer = evaluate(trn_loader, model, device)
#     val_eer = evaluate(val_loader, model, device)
#     test_eer = evaluate(test_loader, model, device)

#     print(f"DONE.\nLoss: {running_loss:.5f}, train_eer: {train_eer*100:.2f}%, val_eer: {val_eer*100:.2f}%, test_eer: {test_eer*100:.2f}%")
#     writer.add_scalar("loss", running_loss, epoch)
#     writer.add_scalar("train_eer", train_eer, epoch)
#     writer.add_scalar("val_eer", val_eer, epoch)
#     writer.add_scalar("test_eer", test_eer, epoch)

#     if val_eer <= best_val_eer:
#         print(f"Best model found at epoch {epoch}")
#         best_val_eer = val_eer
#         torch.save(model.state_dict(), model_save_path / f"epoch_{epoch}_{val_eer:.3f}.pth")
#         if str_to_bool(config["eval_all_best"]):
#             test_eer = evaluate(test_loader, model, device)
#             log_text = f"epoch{epoch:03d}, "
#             if test_eer < best_test_eer:
#                 log_text += f"best eer, {test_eer:.4f}%"
#                 best_test_eer = test_eer
#                 torch.save(model.state_dict(), model_save_path / "best.pth")
#             if len(log_text) > 0:
#                 print(log_text)
#                 f_log.write(log_text + "\n")
#         print(f"Saving epoch {epoch} for SWA")
#         optimizer_swa.update_swa()
#         n_swa_update += 1
#     writer.add_scalar("best_val_eer", best_val_eer, epoch)
# # Final evaluation
# print("Start final evaluation")
# if n_swa_update > 0:
#     optimizer_swa.swap_swa_sgd()
#     optimizer_swa.bn_update(trn_loader, model, device=device)
# test_eer = evaluate(test_loader, model, device)
# f_log.write(f"EER: {test_eer*100:.3f}%\n")
# f_log.close()

# torch.save(model.state_dict(), model_save_path / "swa.pth")
# if test_eer <= best_test_eer:
#     best_test_eer = test_eer
#     torch.save(model.state_dict(), model_save_path / "best.pth")

# print(f"Experiment finished. Best EER: {best_test_eer*100:.3f}%")

In [15]:
# from pathlib import Path
# from torch.utils.tensorboard import SummaryWriter
# import os

# # --- Utility: Enable Multi-GPU ---
# def prepare_model(model):
#     """Moves model to CUDA and wraps in DataParallel if multiple GPUs available"""
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model = model.to(device)

#     if torch.cuda.device_count() > 1:
#         print(f"[INFO] Using {torch.cuda.device_count()} GPUs with DataParallel")
#         model = torch.nn.DataParallel(model)

#     return model, device

# # Set random seed
# seed = 1234
# set_seed(seed, config)

# # Define model and move to device
# model = Model(config["model_config"])
# model, device = prepare_model(model)

# nb_params = sum([param.view(-1).size(0) for param in model.parameters()])
# print(f"No. model params: {nb_params}")
# print(f"Device: {device}")
# if str(device) == "cpu":
#     print("Warning: Running on CPU, which may be slow!")

# # Create directories
# output_dir = Path("./exp_result")
# model_tag = f"LA_custom_ep{config['num_epochs']}_bs{config['batch_size']}"
# model_tag = output_dir / model_tag
# model_save_path = model_tag / "weights"
# eval_score_path = model_tag / config["eval_output"]
# writer = SummaryWriter(model_tag)
# os.makedirs(model_save_path, exist_ok=True)

# # DataLoaders
# trn_loader, val_loader, test_loader = get_loader(config["database_path"], seed, config)

# # Optimizer and Scheduler
# config["optim_config"]["steps_per_epoch"] = len(trn_loader)
# optimizer, scheduler = create_optimizer(model.parameters(), config["optim_config"])
# optimizer_swa = SWA(optimizer)

# # Training loop
# best_val_eer = 1.0
# best_test_eer = 100.0
# n_swa_update = 0
# f_log = open(model_tag / "metric_log.txt", "a")
# f_log.write("=" * 5 + "\n")

# for epoch in range(config["num_epochs"]):
#     print(f"Start training epoch {epoch:03d}")
#     running_loss = train_epoch(trn_loader, model, optimizer, device, scheduler, config)

#     train_eer = evaluate(trn_loader, model, device)
#     val_eer = evaluate(val_loader, model, device)
#     test_eer = evaluate(test_loader, model, device)

#     print(f"DONE.\nLoss: {running_loss:.5f}, train_eer: {train_eer*100:.2f}%, val_eer: {val_eer*100:.2f}%, test_eer: {test_eer*100:.2f}%")
#     writer.add_scalar("loss", running_loss, epoch)
#     writer.add_scalar("train_eer", train_eer, epoch)
#     writer.add_scalar("val_eer", val_eer, epoch)
#     writer.add_scalar("test_eer", test_eer, epoch)

#     if val_eer <= best_val_eer:
#         print(f"Best model found at epoch {epoch}")
#         best_val_eer = val_eer

#         # Save model correctly depending on DataParallel
#         state_dict = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict()
#         torch.save(state_dict, model_save_path / f"epoch_{epoch}_{val_eer:.3f}.pth")

#         if str_to_bool(config["eval_all_best"]):
#             test_eer = evaluate(test_loader, model, device)
#             log_text = f"epoch{epoch:03d}, "
#             if test_eer < best_test_eer:
#                 log_text += f"best eer, {test_eer:.4f}%"
#                 best_test_eer = test_eer
#                 torch.save(state_dict, model_save_path / "best.pth")
#             if len(log_text) > 0:
#                 print(log_text)
#                 f_log.write(log_text + "\n")

#         print(f"Saving epoch {epoch} for SWA")
#         optimizer_swa.update_swa()
#         n_swa_update += 1

#     writer.add_scalar("best_val_eer", best_val_eer, epoch)

# # Final Evaluation
# print("Start final evaluation")
# if n_swa_update > 0:
#     optimizer_swa.swap_swa_sgd()
#     optimizer_swa.bn_update(trn_loader, model, device=device)

# test_eer = evaluate(test_loader, model, device)
# f_log.write(f"EER: {test_eer*100:.3f}%\n")
# f_log.close()

# # Save final SWA model
# state_dict = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict()
# torch.save(state_dict, model_save_path / "swa.pth")

# if test_eer <= best_test_eer:
#     best_test_eer = test_eer
#     torch.save(state_dict, model_save_path / "best.pth")

# print(f"Experiment finished. Best EER: {best_test_eer*100:.3f}%")


In [16]:
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
import torch
import os

# --- Utility: Enable Multi-GPU ---
def prepare_model(model):
    """Moves model to CUDA and wraps in DataParallel if multiple GPUs available"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    if torch.cuda.device_count() > 1:
        print(f"[INFO] Using {torch.cuda.device_count()} GPUs with DataParallel")
        model = torch.nn.DataParallel(model)

    return model, device


# Set random seed
seed = 1234
set_seed(seed, config)

# Define model and move to device
model = Model(config["model_config"])
model, device = prepare_model(model)

nb_params = sum([param.view(-1).size(0) for param in model.parameters()])
print(f"No. model params: {nb_params}")
print(f"Device: {device}")
if str(device) == "cpu":
    print("Warning: Running on CPU, which may be slow!")

# Create directories
output_dir = Path("./exp_result")
model_tag = f"LA_custom_ep{config['num_epochs']}_bs{config['batch_size']}"
model_tag = output_dir / model_tag
model_save_path = model_tag / "weights"
eval_score_path = model_tag / config["eval_output"]
writer = SummaryWriter(model_tag)
os.makedirs(model_save_path, exist_ok=True)

# DataLoaders
trn_loader, val_loader, _ = get_loader(config["database_path"], seed, config)

# Optimizer and Scheduler
config["optim_config"]["steps_per_epoch"] = len(trn_loader)
optimizer, scheduler = create_optimizer(model.parameters(), config["optim_config"])
optimizer_swa = SWA(optimizer)

# Training loop
best_val_eer = 1.0
n_swa_update = 0
f_log = open(model_tag / "metric_log.txt", "a")
f_log.write("=" * 5 + "\n")

for epoch in range(config["num_epochs"]):
    print(f"Start training epoch {epoch:03d}")
    running_loss = train_epoch(trn_loader, model, optimizer, device, scheduler, config)

    # Only train + val EER during training
    # train_eer = evaluate(trn_loader, model, device)
    val_eer = evaluate(val_loader, model, device)

    print(f"DONE.\nLoss: {running_loss:.5f}, "
          # f"train_eer: {train_eer*100:.2f}%, "
          f"val_eer: {val_eer*100:.2f}%")

    writer.add_scalar("loss", running_loss, epoch)
    # writer.add_scalar("train_eer", train_eer, epoch)
    writer.add_scalar("val_eer", val_eer, epoch)

    # Save best models based only on val EER
    if val_eer <= best_val_eer:
        print(f"Best model found at epoch {epoch}")
        best_val_eer = val_eer

        # Save model (supporting DataParallel)
        state_dict = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict()
        torch.save(state_dict, model_save_path / f"epoch_{epoch}_{val_eer:.3f}.pth")
        torch.save(state_dict, model_save_path / "best.pth")

        # Update SWA
        print(f"Saving epoch {epoch} for SWA")
        optimizer_swa.update_swa()
        n_swa_update += 1

    writer.add_scalar("best_val_eer", best_val_eer, epoch)

# -------- Final SWA Save (No Test EER) --------
print("Start final SWA update")
if n_swa_update > 0:
    optimizer_swa.swap_swa_sgd()
    optimizer_swa.bn_update(trn_loader, model, device=device)

# Save final SWA model
state_dict = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict()
torch.save(state_dict, model_save_path / "swa.pth")

f_log.write(f"Best Val EER: {best_val_eer*100:.3f}%\n")
f_log.close()

print(f"Experiment finished. Best Val EER: {best_val_eer*100:.3f}%")


[INFO] Using 2 GPUs with DataParallel
No. model params: 297866
Device: cuda
No. training files: 25380
No. validation files: 24844
No. test files: 71237
Start training epoch 000


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.70944, val_eer: 23.35%
Best model found at epoch 0
Saving epoch 0 for SWA
Start training epoch 001


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.49058, val_eer: 16.92%
Best model found at epoch 1
Saving epoch 1 for SWA
Start training epoch 002


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.38511, val_eer: 12.23%
Best model found at epoch 2
Saving epoch 2 for SWA
Start training epoch 003


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.30983, val_eer: 8.30%
Best model found at epoch 3
Saving epoch 3 for SWA
Start training epoch 004


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.25078, val_eer: 6.71%
Best model found at epoch 4
Saving epoch 4 for SWA
Start training epoch 005


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.20322, val_eer: 5.24%
Best model found at epoch 5
Saving epoch 5 for SWA
Start training epoch 006


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.19361, val_eer: 5.89%
Start training epoch 007


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.16057, val_eer: 5.80%
Start training epoch 008


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.14473, val_eer: 3.65%
Best model found at epoch 8
Saving epoch 8 for SWA
Start training epoch 009


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.13268, val_eer: 3.65%
Best model found at epoch 9
Saving epoch 9 for SWA
Start training epoch 010


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.11986, val_eer: 3.75%
Start training epoch 011


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.11364, val_eer: 3.84%
Start training epoch 012


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.10852, val_eer: 3.89%
Start training epoch 013


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.09187, val_eer: 2.76%
Best model found at epoch 13
Saving epoch 13 for SWA
Start training epoch 014


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.08823, val_eer: 2.69%
Best model found at epoch 14
Saving epoch 14 for SWA
Start training epoch 015


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.08581, val_eer: 2.83%
Start training epoch 016


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.07639, val_eer: 2.73%
Start training epoch 017


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.07894, val_eer: 3.04%
Start training epoch 018


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.06847, val_eer: 2.69%
Best model found at epoch 18
Saving epoch 18 for SWA
Start training epoch 019


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.06635, val_eer: 2.61%
Best model found at epoch 19
Saving epoch 19 for SWA
Start training epoch 020


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.06388, val_eer: 2.83%
Start training epoch 021


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.05191, val_eer: 2.14%
Best model found at epoch 21
Saving epoch 21 for SWA
Start training epoch 022


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.05120, val_eer: 2.09%
Best model found at epoch 22
Saving epoch 22 for SWA
Start training epoch 023


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.05354, val_eer: 2.24%
Start training epoch 024


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.05210, val_eer: 1.59%
Best model found at epoch 24
Saving epoch 24 for SWA
Start training epoch 025


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.04640, val_eer: 1.46%
Best model found at epoch 25
Saving epoch 25 for SWA
Start training epoch 026


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.04481, val_eer: 2.02%
Start training epoch 027


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.03905, val_eer: 1.58%
Start training epoch 028


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.03379, val_eer: 1.40%
Best model found at epoch 28
Saving epoch 28 for SWA
Start training epoch 029


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.03134, val_eer: 1.42%
Start training epoch 030


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.03763, val_eer: 1.27%
Best model found at epoch 30
Saving epoch 30 for SWA
Start training epoch 031


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02488, val_eer: 1.78%
Start training epoch 032


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.03208, val_eer: 1.25%
Best model found at epoch 32
Saving epoch 32 for SWA
Start training epoch 033


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02876, val_eer: 1.27%
Start training epoch 034


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02708, val_eer: 1.19%
Best model found at epoch 34
Saving epoch 34 for SWA
Start training epoch 035


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02652, val_eer: 1.14%
Best model found at epoch 35
Saving epoch 35 for SWA
Start training epoch 036


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02577, val_eer: 2.05%
Start training epoch 037


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02370, val_eer: 1.01%
Best model found at epoch 37
Saving epoch 37 for SWA
Start training epoch 038


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02389, val_eer: 1.09%
Start training epoch 039


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02078, val_eer: 0.96%
Best model found at epoch 39
Saving epoch 39 for SWA
Start training epoch 040


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02329, val_eer: 1.01%
Start training epoch 041


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.01946, val_eer: 0.91%
Best model found at epoch 41
Saving epoch 41 for SWA
Start training epoch 042


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02159, val_eer: 1.14%
Start training epoch 043


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.01557, val_eer: 1.09%
Start training epoch 044


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.01936, val_eer: 1.04%
Start training epoch 045


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.01796, val_eer: 1.11%
Start training epoch 046


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.02006, val_eer: 1.12%
Start training epoch 047


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.01626, val_eer: 1.19%
Start training epoch 048


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.01563, val_eer: 1.05%
Start training epoch 049


  0%|          | 0/2115 [00:00<?, ?it/s]

  0%|          | 0/2071 [00:00<?, ?it/s]

DONE.
Loss: 0.01651, val_eer: 1.13%
Start final SWA update
Experiment finished. Best Val EER: 0.910%


## Evaluate Pretrained Model
Optionally evaluate a pretrained model by loading weights.

In [17]:
# Set to True to evaluate a pretrained model
evaluate_pretrained = True

if evaluate_pretrained:
    model.load_state_dict(torch.load(config["model_path"], map_location=device))
    print(f"Model loaded: {config['model_path']}")
    print("Evaluating on test set...")
    test_eer = evaluate(test_loader, model, device)
    print(f"Test EER: {test_eer*100:.6f}%") 

RuntimeError: Error(s) in loading state_dict for DataParallel:
	Missing key(s) in state_dict: "module.pos_S", "module.master1", "module.master2", "module.first_bn.weight", "module.first_bn.bias", "module.first_bn.running_mean", "module.first_bn.running_var", "module.encoder.0.0.conv1.weight", "module.encoder.0.0.conv1.bias", "module.encoder.0.0.bn2.weight", "module.encoder.0.0.bn2.bias", "module.encoder.0.0.bn2.running_mean", "module.encoder.0.0.bn2.running_var", "module.encoder.0.0.conv2.weight", "module.encoder.0.0.conv2.bias", "module.encoder.0.0.conv_downsample.weight", "module.encoder.0.0.conv_downsample.bias", "module.encoder.1.0.bn1.weight", "module.encoder.1.0.bn1.bias", "module.encoder.1.0.bn1.running_mean", "module.encoder.1.0.bn1.running_var", "module.encoder.1.0.conv1.weight", "module.encoder.1.0.conv1.bias", "module.encoder.1.0.bn2.weight", "module.encoder.1.0.bn2.bias", "module.encoder.1.0.bn2.running_mean", "module.encoder.1.0.bn2.running_var", "module.encoder.1.0.conv2.weight", "module.encoder.1.0.conv2.bias", "module.encoder.2.0.bn1.weight", "module.encoder.2.0.bn1.bias", "module.encoder.2.0.bn1.running_mean", "module.encoder.2.0.bn1.running_var", "module.encoder.2.0.conv1.weight", "module.encoder.2.0.conv1.bias", "module.encoder.2.0.bn2.weight", "module.encoder.2.0.bn2.bias", "module.encoder.2.0.bn2.running_mean", "module.encoder.2.0.bn2.running_var", "module.encoder.2.0.conv2.weight", "module.encoder.2.0.conv2.bias", "module.encoder.2.0.conv_downsample.weight", "module.encoder.2.0.conv_downsample.bias", "module.encoder.3.0.bn1.weight", "module.encoder.3.0.bn1.bias", "module.encoder.3.0.bn1.running_mean", "module.encoder.3.0.bn1.running_var", "module.encoder.3.0.conv1.weight", "module.encoder.3.0.conv1.bias", "module.encoder.3.0.bn2.weight", "module.encoder.3.0.bn2.bias", "module.encoder.3.0.bn2.running_mean", "module.encoder.3.0.bn2.running_var", "module.encoder.3.0.conv2.weight", "module.encoder.3.0.conv2.bias", "module.encoder.4.0.bn1.weight", "module.encoder.4.0.bn1.bias", "module.encoder.4.0.bn1.running_mean", "module.encoder.4.0.bn1.running_var", "module.encoder.4.0.conv1.weight", "module.encoder.4.0.conv1.bias", "module.encoder.4.0.bn2.weight", "module.encoder.4.0.bn2.bias", "module.encoder.4.0.bn2.running_mean", "module.encoder.4.0.bn2.running_var", "module.encoder.4.0.conv2.weight", "module.encoder.4.0.conv2.bias", "module.encoder.5.0.bn1.weight", "module.encoder.5.0.bn1.bias", "module.encoder.5.0.bn1.running_mean", "module.encoder.5.0.bn1.running_var", "module.encoder.5.0.conv1.weight", "module.encoder.5.0.conv1.bias", "module.encoder.5.0.bn2.weight", "module.encoder.5.0.bn2.bias", "module.encoder.5.0.bn2.running_mean", "module.encoder.5.0.bn2.running_var", "module.encoder.5.0.conv2.weight", "module.encoder.5.0.conv2.bias", "module.GAT_layer_S.att_weight", "module.GAT_layer_S.att_proj.weight", "module.GAT_layer_S.att_proj.bias", "module.GAT_layer_S.proj_with_att.weight", "module.GAT_layer_S.proj_with_att.bias", "module.GAT_layer_S.proj_without_att.weight", "module.GAT_layer_S.proj_without_att.bias", "module.GAT_layer_S.bn.weight", "module.GAT_layer_S.bn.bias", "module.GAT_layer_S.bn.running_mean", "module.GAT_layer_S.bn.running_var", "module.GAT_layer_T.att_weight", "module.GAT_layer_T.att_proj.weight", "module.GAT_layer_T.att_proj.bias", "module.GAT_layer_T.proj_with_att.weight", "module.GAT_layer_T.proj_with_att.bias", "module.GAT_layer_T.proj_without_att.weight", "module.GAT_layer_T.proj_without_att.bias", "module.GAT_layer_T.bn.weight", "module.GAT_layer_T.bn.bias", "module.GAT_layer_T.bn.running_mean", "module.GAT_layer_T.bn.running_var", "module.HtrgGAT_layer_ST11.att_weight11", "module.HtrgGAT_layer_ST11.att_weight22", "module.HtrgGAT_layer_ST11.att_weight12", "module.HtrgGAT_layer_ST11.att_weightM", "module.HtrgGAT_layer_ST11.proj_type1.weight", "module.HtrgGAT_layer_ST11.proj_type1.bias", "module.HtrgGAT_layer_ST11.proj_type2.weight", "module.HtrgGAT_layer_ST11.proj_type2.bias", "module.HtrgGAT_layer_ST11.att_proj.weight", "module.HtrgGAT_layer_ST11.att_proj.bias", "module.HtrgGAT_layer_ST11.att_projM.weight", "module.HtrgGAT_layer_ST11.att_projM.bias", "module.HtrgGAT_layer_ST11.proj_with_att.weight", "module.HtrgGAT_layer_ST11.proj_with_att.bias", "module.HtrgGAT_layer_ST11.proj_without_att.weight", "module.HtrgGAT_layer_ST11.proj_without_att.bias", "module.HtrgGAT_layer_ST11.proj_with_attM.weight", "module.HtrgGAT_layer_ST11.proj_with_attM.bias", "module.HtrgGAT_layer_ST11.proj_without_attM.weight", "module.HtrgGAT_layer_ST11.proj_without_attM.bias", "module.HtrgGAT_layer_ST11.bn.weight", "module.HtrgGAT_layer_ST11.bn.bias", "module.HtrgGAT_layer_ST11.bn.running_mean", "module.HtrgGAT_layer_ST11.bn.running_var", "module.HtrgGAT_layer_ST12.att_weight11", "module.HtrgGAT_layer_ST12.att_weight22", "module.HtrgGAT_layer_ST12.att_weight12", "module.HtrgGAT_layer_ST12.att_weightM", "module.HtrgGAT_layer_ST12.proj_type1.weight", "module.HtrgGAT_layer_ST12.proj_type1.bias", "module.HtrgGAT_layer_ST12.proj_type2.weight", "module.HtrgGAT_layer_ST12.proj_type2.bias", "module.HtrgGAT_layer_ST12.att_proj.weight", "module.HtrgGAT_layer_ST12.att_proj.bias", "module.HtrgGAT_layer_ST12.att_projM.weight", "module.HtrgGAT_layer_ST12.att_projM.bias", "module.HtrgGAT_layer_ST12.proj_with_att.weight", "module.HtrgGAT_layer_ST12.proj_with_att.bias", "module.HtrgGAT_layer_ST12.proj_without_att.weight", "module.HtrgGAT_layer_ST12.proj_without_att.bias", "module.HtrgGAT_layer_ST12.proj_with_attM.weight", "module.HtrgGAT_layer_ST12.proj_with_attM.bias", "module.HtrgGAT_layer_ST12.proj_without_attM.weight", "module.HtrgGAT_layer_ST12.proj_without_attM.bias", "module.HtrgGAT_layer_ST12.bn.weight", "module.HtrgGAT_layer_ST12.bn.bias", "module.HtrgGAT_layer_ST12.bn.running_mean", "module.HtrgGAT_layer_ST12.bn.running_var", "module.HtrgGAT_layer_ST21.att_weight11", "module.HtrgGAT_layer_ST21.att_weight22", "module.HtrgGAT_layer_ST21.att_weight12", "module.HtrgGAT_layer_ST21.att_weightM", "module.HtrgGAT_layer_ST21.proj_type1.weight", "module.HtrgGAT_layer_ST21.proj_type1.bias", "module.HtrgGAT_layer_ST21.proj_type2.weight", "module.HtrgGAT_layer_ST21.proj_type2.bias", "module.HtrgGAT_layer_ST21.att_proj.weight", "module.HtrgGAT_layer_ST21.att_proj.bias", "module.HtrgGAT_layer_ST21.att_projM.weight", "module.HtrgGAT_layer_ST21.att_projM.bias", "module.HtrgGAT_layer_ST21.proj_with_att.weight", "module.HtrgGAT_layer_ST21.proj_with_att.bias", "module.HtrgGAT_layer_ST21.proj_without_att.weight", "module.HtrgGAT_layer_ST21.proj_without_att.bias", "module.HtrgGAT_layer_ST21.proj_with_attM.weight", "module.HtrgGAT_layer_ST21.proj_with_attM.bias", "module.HtrgGAT_layer_ST21.proj_without_attM.weight", "module.HtrgGAT_layer_ST21.proj_without_attM.bias", "module.HtrgGAT_layer_ST21.bn.weight", "module.HtrgGAT_layer_ST21.bn.bias", "module.HtrgGAT_layer_ST21.bn.running_mean", "module.HtrgGAT_layer_ST21.bn.running_var", "module.HtrgGAT_layer_ST22.att_weight11", "module.HtrgGAT_layer_ST22.att_weight22", "module.HtrgGAT_layer_ST22.att_weight12", "module.HtrgGAT_layer_ST22.att_weightM", "module.HtrgGAT_layer_ST22.proj_type1.weight", "module.HtrgGAT_layer_ST22.proj_type1.bias", "module.HtrgGAT_layer_ST22.proj_type2.weight", "module.HtrgGAT_layer_ST22.proj_type2.bias", "module.HtrgGAT_layer_ST22.att_proj.weight", "module.HtrgGAT_layer_ST22.att_proj.bias", "module.HtrgGAT_layer_ST22.att_projM.weight", "module.HtrgGAT_layer_ST22.att_projM.bias", "module.HtrgGAT_layer_ST22.proj_with_att.weight", "module.HtrgGAT_layer_ST22.proj_with_att.bias", "module.HtrgGAT_layer_ST22.proj_without_att.weight", "module.HtrgGAT_layer_ST22.proj_without_att.bias", "module.HtrgGAT_layer_ST22.proj_with_attM.weight", "module.HtrgGAT_layer_ST22.proj_with_attM.bias", "module.HtrgGAT_layer_ST22.proj_without_attM.weight", "module.HtrgGAT_layer_ST22.proj_without_attM.bias", "module.HtrgGAT_layer_ST22.bn.weight", "module.HtrgGAT_layer_ST22.bn.bias", "module.HtrgGAT_layer_ST22.bn.running_mean", "module.HtrgGAT_layer_ST22.bn.running_var", "module.pool_S.proj.weight", "module.pool_S.proj.bias", "module.pool_T.proj.weight", "module.pool_T.proj.bias", "module.pool_hS1.proj.weight", "module.pool_hS1.proj.bias", "module.pool_hT1.proj.weight", "module.pool_hT1.proj.bias", "module.pool_hS2.proj.weight", "module.pool_hS2.proj.bias", "module.pool_hT2.proj.weight", "module.pool_hT2.proj.bias", "module.out_layer.weight", "module.out_layer.bias". 
	Unexpected key(s) in state_dict: "pos_S", "master1", "master2", "first_bn.weight", "first_bn.bias", "first_bn.running_mean", "first_bn.running_var", "first_bn.num_batches_tracked", "encoder.0.0.conv1.weight", "encoder.0.0.conv1.bias", "encoder.0.0.bn2.weight", "encoder.0.0.bn2.bias", "encoder.0.0.bn2.running_mean", "encoder.0.0.bn2.running_var", "encoder.0.0.bn2.num_batches_tracked", "encoder.0.0.conv2.weight", "encoder.0.0.conv2.bias", "encoder.0.0.conv_downsample.weight", "encoder.0.0.conv_downsample.bias", "encoder.1.0.bn1.weight", "encoder.1.0.bn1.bias", "encoder.1.0.bn1.running_mean", "encoder.1.0.bn1.running_var", "encoder.1.0.bn1.num_batches_tracked", "encoder.1.0.conv1.weight", "encoder.1.0.conv1.bias", "encoder.1.0.bn2.weight", "encoder.1.0.bn2.bias", "encoder.1.0.bn2.running_mean", "encoder.1.0.bn2.running_var", "encoder.1.0.bn2.num_batches_tracked", "encoder.1.0.conv2.weight", "encoder.1.0.conv2.bias", "encoder.2.0.bn1.weight", "encoder.2.0.bn1.bias", "encoder.2.0.bn1.running_mean", "encoder.2.0.bn1.running_var", "encoder.2.0.bn1.num_batches_tracked", "encoder.2.0.conv1.weight", "encoder.2.0.conv1.bias", "encoder.2.0.bn2.weight", "encoder.2.0.bn2.bias", "encoder.2.0.bn2.running_mean", "encoder.2.0.bn2.running_var", "encoder.2.0.bn2.num_batches_tracked", "encoder.2.0.conv2.weight", "encoder.2.0.conv2.bias", "encoder.2.0.conv_downsample.weight", "encoder.2.0.conv_downsample.bias", "encoder.3.0.bn1.weight", "encoder.3.0.bn1.bias", "encoder.3.0.bn1.running_mean", "encoder.3.0.bn1.running_var", "encoder.3.0.bn1.num_batches_tracked", "encoder.3.0.conv1.weight", "encoder.3.0.conv1.bias", "encoder.3.0.bn2.weight", "encoder.3.0.bn2.bias", "encoder.3.0.bn2.running_mean", "encoder.3.0.bn2.running_var", "encoder.3.0.bn2.num_batches_tracked", "encoder.3.0.conv2.weight", "encoder.3.0.conv2.bias", "encoder.4.0.bn1.weight", "encoder.4.0.bn1.bias", "encoder.4.0.bn1.running_mean", "encoder.4.0.bn1.running_var", "encoder.4.0.bn1.num_batches_tracked", "encoder.4.0.conv1.weight", "encoder.4.0.conv1.bias", "encoder.4.0.bn2.weight", "encoder.4.0.bn2.bias", "encoder.4.0.bn2.running_mean", "encoder.4.0.bn2.running_var", "encoder.4.0.bn2.num_batches_tracked", "encoder.4.0.conv2.weight", "encoder.4.0.conv2.bias", "encoder.5.0.bn1.weight", "encoder.5.0.bn1.bias", "encoder.5.0.bn1.running_mean", "encoder.5.0.bn1.running_var", "encoder.5.0.bn1.num_batches_tracked", "encoder.5.0.conv1.weight", "encoder.5.0.conv1.bias", "encoder.5.0.bn2.weight", "encoder.5.0.bn2.bias", "encoder.5.0.bn2.running_mean", "encoder.5.0.bn2.running_var", "encoder.5.0.bn2.num_batches_tracked", "encoder.5.0.conv2.weight", "encoder.5.0.conv2.bias", "GAT_layer_S.att_weight", "GAT_layer_S.att_proj.weight", "GAT_layer_S.att_proj.bias", "GAT_layer_S.proj_with_att.weight", "GAT_layer_S.proj_with_att.bias", "GAT_layer_S.proj_without_att.weight", "GAT_layer_S.proj_without_att.bias", "GAT_layer_S.bn.weight", "GAT_layer_S.bn.bias", "GAT_layer_S.bn.running_mean", "GAT_layer_S.bn.running_var", "GAT_layer_S.bn.num_batches_tracked", "GAT_layer_T.att_weight", "GAT_layer_T.att_proj.weight", "GAT_layer_T.att_proj.bias", "GAT_layer_T.proj_with_att.weight", "GAT_layer_T.proj_with_att.bias", "GAT_layer_T.proj_without_att.weight", "GAT_layer_T.proj_without_att.bias", "GAT_layer_T.bn.weight", "GAT_layer_T.bn.bias", "GAT_layer_T.bn.running_mean", "GAT_layer_T.bn.running_var", "GAT_layer_T.bn.num_batches_tracked", "HtrgGAT_layer_ST11.att_weight11", "HtrgGAT_layer_ST11.att_weight22", "HtrgGAT_layer_ST11.att_weight12", "HtrgGAT_layer_ST11.att_weightM", "HtrgGAT_layer_ST11.proj_type1.weight", "HtrgGAT_layer_ST11.proj_type1.bias", "HtrgGAT_layer_ST11.proj_type2.weight", "HtrgGAT_layer_ST11.proj_type2.bias", "HtrgGAT_layer_ST11.att_proj.weight", "HtrgGAT_layer_ST11.att_proj.bias", "HtrgGAT_layer_ST11.att_projM.weight", "HtrgGAT_layer_ST11.att_projM.bias", "HtrgGAT_layer_ST11.proj_with_att.weight", "HtrgGAT_layer_ST11.proj_with_att.bias", "HtrgGAT_layer_ST11.proj_without_att.weight", "HtrgGAT_layer_ST11.proj_without_att.bias", "HtrgGAT_layer_ST11.proj_with_attM.weight", "HtrgGAT_layer_ST11.proj_with_attM.bias", "HtrgGAT_layer_ST11.proj_without_attM.weight", "HtrgGAT_layer_ST11.proj_without_attM.bias", "HtrgGAT_layer_ST11.bn.weight", "HtrgGAT_layer_ST11.bn.bias", "HtrgGAT_layer_ST11.bn.running_mean", "HtrgGAT_layer_ST11.bn.running_var", "HtrgGAT_layer_ST11.bn.num_batches_tracked", "HtrgGAT_layer_ST12.att_weight11", "HtrgGAT_layer_ST12.att_weight22", "HtrgGAT_layer_ST12.att_weight12", "HtrgGAT_layer_ST12.att_weightM", "HtrgGAT_layer_ST12.proj_type1.weight", "HtrgGAT_layer_ST12.proj_type1.bias", "HtrgGAT_layer_ST12.proj_type2.weight", "HtrgGAT_layer_ST12.proj_type2.bias", "HtrgGAT_layer_ST12.att_proj.weight", "HtrgGAT_layer_ST12.att_proj.bias", "HtrgGAT_layer_ST12.att_projM.weight", "HtrgGAT_layer_ST12.att_projM.bias", "HtrgGAT_layer_ST12.proj_with_att.weight", "HtrgGAT_layer_ST12.proj_with_att.bias", "HtrgGAT_layer_ST12.proj_without_att.weight", "HtrgGAT_layer_ST12.proj_without_att.bias", "HtrgGAT_layer_ST12.proj_with_attM.weight", "HtrgGAT_layer_ST12.proj_with_attM.bias", "HtrgGAT_layer_ST12.proj_without_attM.weight", "HtrgGAT_layer_ST12.proj_without_attM.bias", "HtrgGAT_layer_ST12.bn.weight", "HtrgGAT_layer_ST12.bn.bias", "HtrgGAT_layer_ST12.bn.running_mean", "HtrgGAT_layer_ST12.bn.running_var", "HtrgGAT_layer_ST12.bn.num_batches_tracked", "HtrgGAT_layer_ST21.att_weight11", "HtrgGAT_layer_ST21.att_weight22", "HtrgGAT_layer_ST21.att_weight12", "HtrgGAT_layer_ST21.att_weightM", "HtrgGAT_layer_ST21.proj_type1.weight", "HtrgGAT_layer_ST21.proj_type1.bias", "HtrgGAT_layer_ST21.proj_type2.weight", "HtrgGAT_layer_ST21.proj_type2.bias", "HtrgGAT_layer_ST21.att_proj.weight", "HtrgGAT_layer_ST21.att_proj.bias", "HtrgGAT_layer_ST21.att_projM.weight", "HtrgGAT_layer_ST21.att_projM.bias", "HtrgGAT_layer_ST21.proj_with_att.weight", "HtrgGAT_layer_ST21.proj_with_att.bias", "HtrgGAT_layer_ST21.proj_without_att.weight", "HtrgGAT_layer_ST21.proj_without_att.bias", "HtrgGAT_layer_ST21.proj_with_attM.weight", "HtrgGAT_layer_ST21.proj_with_attM.bias", "HtrgGAT_layer_ST21.proj_without_attM.weight", "HtrgGAT_layer_ST21.proj_without_attM.bias", "HtrgGAT_layer_ST21.bn.weight", "HtrgGAT_layer_ST21.bn.bias", "HtrgGAT_layer_ST21.bn.running_mean", "HtrgGAT_layer_ST21.bn.running_var", "HtrgGAT_layer_ST21.bn.num_batches_tracked", "HtrgGAT_layer_ST22.att_weight11", "HtrgGAT_layer_ST22.att_weight22", "HtrgGAT_layer_ST22.att_weight12", "HtrgGAT_layer_ST22.att_weightM", "HtrgGAT_layer_ST22.proj_type1.weight", "HtrgGAT_layer_ST22.proj_type1.bias", "HtrgGAT_layer_ST22.proj_type2.weight", "HtrgGAT_layer_ST22.proj_type2.bias", "HtrgGAT_layer_ST22.att_proj.weight", "HtrgGAT_layer_ST22.att_proj.bias", "HtrgGAT_layer_ST22.att_projM.weight", "HtrgGAT_layer_ST22.att_projM.bias", "HtrgGAT_layer_ST22.proj_with_att.weight", "HtrgGAT_layer_ST22.proj_with_att.bias", "HtrgGAT_layer_ST22.proj_without_att.weight", "HtrgGAT_layer_ST22.proj_without_att.bias", "HtrgGAT_layer_ST22.proj_with_attM.weight", "HtrgGAT_layer_ST22.proj_with_attM.bias", "HtrgGAT_layer_ST22.proj_without_attM.weight", "HtrgGAT_layer_ST22.proj_without_attM.bias", "HtrgGAT_layer_ST22.bn.weight", "HtrgGAT_layer_ST22.bn.bias", "HtrgGAT_layer_ST22.bn.running_mean", "HtrgGAT_layer_ST22.bn.running_var", "HtrgGAT_layer_ST22.bn.num_batches_tracked", "pool_S.proj.weight", "pool_S.proj.bias", "pool_T.proj.weight", "pool_T.proj.bias", "pool_hS1.proj.weight", "pool_hS1.proj.bias", "pool_hT1.proj.weight", "pool_hT1.proj.bias", "pool_hS2.proj.weight", "pool_hS2.proj.bias", "pool_hT2.proj.weight", "pool_hT2.proj.bias", "out_layer.weight", "out_layer.bias". 