In [39]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchaudio
import numpy as np
from tqdm.auto import tqdm
import tqdm as notebook_tqdm
import wandb
import kagglehub
import os
from kaggle_secrets import UserSecretsClient

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'

import seaborn as sns
sns.set_style("whitegrid") # 

LCNN model

In [45]:
# LCNN model

import torchaudio
from torch.utils.data import Dataset
import os
from torchaudio import transforms as T
from torch.nn import Sequential

class mfm(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, type=1):
        super(mfm, self).__init__()
        self.out_channels = out_channels
        if type == 1:
            self.filter = nn.Conv2d(in_channels, 2*out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        else:
            self.filter = nn.Linear(in_channels, 2*out_channels)
    
    def forward(self, x):
        x = self.filter(x)
        out = torch.split(x, self.out_channels, 1)
        return torch.maximum(out[0], out[1])


class LCNN(nn.Module):

    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            mfm(1, 32, 5, 1, 2), # 2

            nn.MaxPool2d(2, 2),

            mfm(32, 32, 1, 1, 0), # 5
            nn.BatchNorm2d(32),
            mfm(32, 48), # 8

            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(48),

            mfm(48, 48, 1, 1, 0), # 12
            nn.BatchNorm2d(48),
            mfm(48, 64), # 15

            nn.MaxPool2d(2, 2),

            mfm(64, 64, 1, 1, 0), # 18
            nn.BatchNorm2d(64),
            mfm(64, 32), # 21
            nn.BatchNorm2d(32),
            mfm(32, 32, 1, 1, 0), # 24
            nn.BatchNorm2d(32),
            mfm(32, 32), # 27

            nn.MaxPool2d(2, 2)
        )

        self.fc = nn.Sequential(
            mfm(32 * 53 * 37, 80, type=0),  # FC-29 + MFM-30
            nn.BatchNorm1d(80),
            nn.Dropout(p=0.3),
            nn.Linear(80, 2)  # FC-32
        )
        
        self._inilialize_weights()
    
    def _inilialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)


    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


Dataset

In [41]:
# dataset v2

import os
import logging
import random
from typing import List

import torch
from torch.utils.data import Dataset
import torchaudio
import torchaudio.transforms as T
import torch.nn.functional as F

def get_index_from_path(path_of_dir : str, path_of_protocol : str) -> list[dict]:
    _index = list()
    with open(path_of_protocol, 'r') as f:
        for line in f:
            protocol = dict()
            parts = line.strip().split()
            filename = os.path.join(path_of_dir, parts[1] + '.flac')
            label = 1 if parts[-1] == 'bonafide' else 0
            protocol["path"] = filename
            protocol["label"] = label
            _index.append(protocol)
    return _index


class BaseDataset(Dataset):

    def __init__(
        self, index, limit=None, shuffle_index=False, instance_transforms=None
    ):

        self._assert_index_is_valid(index)

        index = self._shuffle_and_limit_index(index, limit, shuffle_index)
        self._index: List[dict] = index

        self.instance_transforms = instance_transforms
        self.max_frames = 600

        self.fft_transform = T.Spectrogram(
            n_fft=1724,
            hop_length=130,
            win_length=1724,
            window_fn=torch.blackman_window,
            power=2,
            center=False
        )
        self.db_transform = T.AmplitudeToDB(stype='power')

    def __getitem__(self, ind):
        """
        Get element from the index, preprocess it, and combine it
        into a dict.

        Notice that the choice of key names is defined by the template user.
        However, they should be consistent across dataset getitem, collate_fn,
        loss_function forward method, and model forward method.

        Args:
            ind (int): index in the self.index list.
        Returns:
            instance_data (dict): dict, containing instance
                (a single dataset element).
        """
        data_dict = self._index[ind]
        data_path = data_dict["path"]
        waveform, sr = self._load_object(data_path)
        data_label = data_dict["label"]
        log_spectrogram = self.preprocess_data(waveform)

        return log_spectrogram, data_label

    def __len__(self):
        """
        Get length of the dataset (length of the index).
        """
        return len(self._index)

    def _load_object(self, path):
        """
        Load object from disk.

        Args:
            path (str): path to the object.
        Returns:
            data_object (Tensor):
        """
        wave, sr = torchaudio.load(path)
        return wave, sr

    def preprocess_data(self, waveform):
        """
        Preprocess data with instance transforms.

        Each tensor in a dict undergoes its own transform defined by the key.

        Args:
            instance_data (dict): dict, containing instance
                (a single dataset element).
        Returns:
            instance_data (Spectogram): Spectogram, containing waveform
                (a single dataset element) (possibly transformed via
                instance transform).
        """
        power_spectrum = self.fft_transform(waveform)
        log_spectrogram = self.db_transform(power_spectrum).squeeze(0)

        num_frames = log_spectrogram.shape[1]
        if num_frames < self.max_frames:
            pad_amount = self.max_frames - num_frames
            log_spectrogram = F.pad(
                log_spectrogram, 
                (0, pad_amount), 
                mode='constant', 
                value=-80
            )
        else:
            log_spectrogram = log_spectrogram[:, :self.max_frames]
        
        return log_spectrogram.unsqueeze(0)

    @staticmethod
    def _filter_records_from_dataset(
        index: list,
    ) -> list:
        """
        Filter some of the elements from the dataset depending on
        some condition.

        This is not used in the example. The method should be called in
        the __init__ before shuffling and limiting.

        Args:
            index (list[dict]): list, containing dict for each element of
                the dataset. The dict has required metadata information,
                such as label and object path.
        Returns:
            index (list[dict]): list, containing dict for each element of
                the dataset that satisfied the condition. The dict has
                required metadata information, such as label and object path.
        """
        # Filter logic
        pass

    @staticmethod
    def _assert_index_is_valid(index):
        """
        Check the structure of the index and ensure it satisfies the desired
        conditions.

        Args:
            index (list[dict]): list, containing dict for each element of
                the dataset. The dict has required metadata information,
                such as label and object path.
        """
        for entry in index:
            assert "path" in entry, (
                "Each dataset item should include field 'path'" " - path to audio file."
            )
            assert "label" in entry, (
                "Each dataset item should include field 'label'"
                " - object ground-truth label."
            )

    @staticmethod
    def _sort_index(index):
        """
        Sort index via some rules.

        This is not used in the example. The method should be called in
        the __init__ before shuffling and limiting and after filtering.

        Args:
            index (list[dict]): list, containing dict for each element of
                the dataset. The dict has required metadata information,
                such as label and object path.
        Returns:
            index (list[dict]): sorted list, containing dict for each element
                of the dataset. The dict has required metadata information,
                such as label and object path.
        """
        return sorted(index, key=lambda x: x["KEY_FOR_SORTING"])

    @staticmethod
    def _shuffle_and_limit_index(index, limit, shuffle_index):
        """
        Shuffle elements in index and limit the total number of elements.

        Args:
            index (list[dict]): list, containing dict for each element of
                the dataset. The dict has required metadata information,
                such as label and object path.
            limit (int | None): if not None, limit the total number of elements
                in the dataset to 'limit' elements.
            shuffle_index (bool): if True, shuffle the index. Uses python
                random package with seed 42.
        """
        if shuffle_index:
            random.seed(42)
            random.shuffle(index)

        if limit is not None:
            index = index[:limit]
        return index

Compute err

In [None]:
def compute_det_curve(target_scores, nontarget_scores):

    n_scores = target_scores.size + nontarget_scores.size
    all_scores = np.concatenate((target_scores, nontarget_scores))
    labels = np.concatenate(
        (np.ones(target_scores.size), np.zeros(nontarget_scores.size)))

    # Sort labels based on scores
    indices = np.argsort(all_scores, kind='mergesort')
    labels = labels[indices]

    # Compute false rejection and false acceptance rates
    tar_trial_sums = np.cumsum(labels)
    nontarget_trial_sums = nontarget_scores.size - \
        (np.arange(1, n_scores + 1) - tar_trial_sums)

    # false rejection rates
    frr = np.concatenate(
        (np.atleast_1d(0), tar_trial_sums / target_scores.size))
    far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums /
                          nontarget_scores.size))  # false acceptance rates
    # Thresholds are the sorted scores
    thresholds = np.concatenate(
        (np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))

    return frr, far, thresholds


def compute_eer(bonafide_scores, other_scores):
    """ 
    Returns equal error rate (EER) and the corresponding threshold.
    """
    frr, far, thresholds = compute_det_curve(bonafide_scores, other_scores)
    abs_diffs = np.abs(frr - far)
    min_index = np.argmin(abs_diffs)
    eer = np.mean((frr[min_index], far[min_index]))
    return eer, thresholds[min_index]

Training pipeline


In [None]:
import torch
import numpy as np

def train_one_epoch(model, dataloader, criterion, optimizer, scheduler, device):
    # some layers may have different behavior for train and inference
    # this is why it is important to set model to train mode during training
    model.train()

    avg_loss = 0
    for batch_idx, (image, label) in tqdm(enumerate(dataloader), total=len(dataloader)):
        # if we want to use GPU for model forward, we need to put inputs and desired outputs on device 
        image, label = image.to(device), label.to(device)

        output = model(image) # calculate model output
        loss = criterion(output, label) # calculate loss

        loss.backward() # calculate gradients
        optimizer.step() # update weights
        optimizer.zero_grad() # zero gradients for the next step
        # scheduler.step() # update learning rate

        avg_loss += loss.item() # item to detach loss and get element on CPU
        # item is used for tensors containing single scalar

    # batch-idx + 1 == the total number of batches == len(dataloader)
    avg_loss = avg_loss / (batch_idx + 1)
    return avg_loss

def evaluate(model, dataloader, criterion, device):
    model.eval()

    avg_loss = 0
    accuracy = 0
    total_elements = 0
    
    # Для вычисления EER сохраняем все предсказанные оценки и метки
    all_scores = []
    all_labels = []

    for batch_idx, (image, label) in enumerate(dataloader):
        image, label = image.to(device), label.to(device)

        with torch.no_grad():
            output = model(image)
            loss = criterion(output, label)

        # Сохраняем оценки и метки для EER
        scores = F.softmax(output, dim=1)[:, 1]  # Вероятности класса "bonafide" (предполагая, что класс 1 = bonafide)
        all_scores.extend(scores.cpu().numpy())
        all_labels.extend(label.cpu().numpy())

        accuracy += (output.argmax(-1) == label).sum().item()
        total_elements += output.shape[0]
        avg_loss += loss.item()

    avg_loss = avg_loss / (batch_idx + 1)
    accuracy = 100 * accuracy / total_elements

    # Вычисляем EER
    all_scores = np.array(all_scores)
    all_labels = np.array(all_labels)
    
    bonafide_scores = all_scores[all_labels == 1]  # Оценки для класса "bonafide"
    other_scores = all_scores[all_labels == 0]     # Оценки для класса "spoof"
    
    eer, _ = compute_eer(bonafide_scores, other_scores)  # Используем функцию compute_eer из предыдущего примера

    return avg_loss, accuracy, eer * 100

def train(model, train_dataloader, val_dataloader, test_dataloader, criterion, optimizer, scheduler, device, n_epochs):

    train_avg_losses = []
    val_avg_losses = []
    val_accuracy_list = []

    for epoch in range(n_epochs):
        print(f"\n--- Epoch {epoch+1}/{n_epochs} ---")     
        train_avg_loss = train_one_epoch(model, train_dataloader, criterion, optimizer, scheduler, device)
        val_avg_loss, val_accuracy, err_val = evaluate(model, val_dataloader, criterion, device)
        print("val_avg_loss:", val_avg_loss)
        print("val_accuracy:", val_accuracy)
        print("err_val:", err_val)
        scheduler.step()

        wandb.log({
            "test_avg_loss": val_avg_loss,
            "test_accuracy": val_accuracy,
            "err": err_val
        }, step=8)
    
    test_avg_loss, test_accuracy, err_t = evaluate(model, test_dataloader, criterion, device)
    print("test_avg_loss:", test_avg_loss)
    print("test_accuracy:", test_accuracy)
    print("test_val:", err_t)
    


Train

In [None]:
criterion = nn.CrossEntropyLoss()
NUM_EPOCHS = 8

train_dataset = BaseDataset(get_index_from_path("/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_train/flac", 
                            "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt"))
val_dataset = BaseDataset(get_index_from_path("/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_eval/flac", 
                          "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt"))
dev_dataset = BaseDataset(get_index_from_path("/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_dev/flac", 
                          "/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt"))

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, pin_memory=True)
dev_dataloader = DataLoader(dev_dataset, batch_size=8, shuffle=False, pin_memory=True)

wandb.login(key=UserSecretsClient().get_secret("wandb"))

# Конфигурация
config = {
    "learning_rate": 3e-4,
    "batch_size": 8,
    "epochs": NUM_EPOCHS,
    "architecture": "LCNN"
}

# Инициализация модели
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LCNN()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-4)

try:
    with wandb.init(project="lcnn", name="lcnn_run", config=config, reinit=True):
        train(model, train_dataloader, dev_dataloader, val_dataloader,
              criterion, optimizer, scheduler, device, config["epochs"])

except Exception as e:
    wandb.alert(title="Training Failed", text=str(e))
    raise
finally:
    wandb.finish()