In [None]:
# # RawNet2 Implementation for Audio Deepfake Detection
# Part 2 of Momenta Assessment
# Integrates provided RawNet model and ASVDataset for ASVspoof 5

# ## 1. Setup
# Install dependencies and import libraries

!pip install torch torchvision torchaudio librosa numpy matplotlib tensorboardX pyyaml soundfile joblib
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
import math
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import yaml
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import random
import collections
import soundfile as sf
from joblib import Parallel, delayed

# ## 2. Model Definition
# RawNet (RawNet2 variant) from provided code

class SincConv(nn.Module):
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10 ** (mel / 2595) - 1)

    def __init__(self, device, out_channels, kernel_size, in_channels=1, sample_rate=16000,
                 stride=1, padding=0, dilation=1, bias=False, groups=1, freq_scale='Mel'):
        super(SincConv, self).__init__()
        if in_channels != 1:
            raise ValueError(f"SincConv only supports one input channel (here, in_channels = {in_channels})")
        
        self.out_channels = out_channels + 1
        self.kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
        self.sample_rate = sample_rate
        self.device = device
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        
        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')
        
        NFFT = 512
        f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        if freq_scale == 'Mel':
            fmel = self.to_mel(f)
            fmelmax, fmelmin = np.max(fmel), np.min(fmel)
            filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 2)
            filbandwidthsf = self.to_hz(filbandwidthsmel)
            self.freq = filbandwidthsf[:self.out_channels]
        
        self.hsupp = torch.arange(-(self.kernel_size - 1) / 2, (self.kernel_size - 1) / 2 + 1)
        self.band_pass = torch.zeros(self.out_channels - 1, self.kernel_size)

    def forward(self, x):
        for i in range(len(self.freq) - 1):
            fmin, fmax = self.freq[i], self.freq[i + 1]
            hHigh = (2 * fmax / self.sample_rate) * np.sinc(2 * fmax * self.hsupp.numpy() / self.sample_rate)
            hLow = (2 * fmin / self.sample_rate) * np.sinc(2 * fmin * self.hsupp.numpy() / self.sample_rate)
            hideal = hHigh - hLow
            self.band_pass[i, :] = Tensor(np.hamming(self.kernel_size)) * Tensor(hideal)
        
        band_pass_filter = self.band_pass.to(self.device)
        filters = band_pass_filter.view(self.out_channels - 1, 1, self.kernel_size)
        return F.conv1d(x, filters, stride=self.stride, padding=self.padding, dilation=self.dilation)

class Residual_block(nn.Module):
    def __init__(self, nb_filts, first=False):
        super(Residual_block, self).__init__()
        self.first = first
        if not self.first:
            self.bn1 = nn.BatchNorm1d(num_features=nb_filts[0])
        self.lrelu = nn.LeakyReLU(negative_slope=0.3)
        self.conv1 = nn.Conv1d(nb_filts[0], nb_filts[1], kernel_size=3, padding=1, stride=1)
        self.bn2 = nn.BatchNorm1d(nb_filts[1])
        self.conv2 = nn.Conv1d(nb_filts[1], nb_filts[1], kernel_size=3, padding=1, stride=1)
        self.downsample = nb_filts[0] != nb_filts[1]
        if self.downsample:
            self.conv_downsample = nn.Conv1d(nb_filts[0], nb_filts[1], kernel_size=1, stride=1)
        self.mp = nn.MaxPool1d(3)

    def forward(self, x):
        identity = x
        out = self.bn1(x) if not self.first else x
        out = self.lrelu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.lrelu(out)
        out = self.conv2(out)
        if self.downsample:
            identity = self.conv_downsample(identity)
        out += identity
        return self.mp(out)

class RawNet(nn.Module):
    def __init__(self, d_args, device):
        super(RawNet, self).__init__()
        self.device = device
        self.Sinc_conv = SincConv(device=self.device, out_channels=d_args['filts'][0],
                                  kernel_size=d_args['first_conv'], in_channels=d_args['in_channels'])
        self.first_bn = nn.BatchNorm1d(d_args['filts'][0])
        self.selu = nn.SELU(inplace=True)
        self.block0 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][1], first=True))
        self.block1 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][1]))
        self.block2 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][2]))
        d_args['filts'][2][0] = d_args['filts'][2][1]
        self.block3 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][2]))
        self.block4 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][2]))
        self.block5 = nn.Sequential(Residual_block(nb_filts=d_args['filts'][2]))
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc_attention0 = self._make_attention_fc(d_args['filts'][1][-1], d_args['filts'][1][-1])
        self.fc_attention1 = self._make_attention_fc(d_args['filts'][1][-1], d_args['filts'][1][-1])
        self.fc_attention2 = self._make_attention_fc(d_args['filts'][2][-1], d_args['filts'][2][-1])
        self.fc_attention3 = self._make_attention_fc(d_args['filts'][2][-1], d_args['filts'][2][-1])
        self.fc_attention4 = self._make_attention_fc(d_args['filts'][2][-1], d_args['filts'][2][-1])
        self.fc_attention5 = self._make_attention_fc(d_args['filts'][2][-1], d_args['filts'][2][-1])
        self.bn_before_gru = nn.BatchNorm1d(d_args['filts'][2][-1])
        self.gru = nn.GRU(d_args['filts'][2][-1], d_args['gru_node'], num_layers=d_args['nb_gru_layer'], batch_first=True)
        self.fc1_gru = nn.Linear(d_args['gru_node'], d_args['nb_fc_node'])
        self.fc2_gru = nn.Linear(d_args['nb_fc_node'], d_args['nb_classes'], bias=True)
        self.sig = nn.Sigmoid()

    def forward(self, x, y=None, is_test=False):
        nb_samp, len_seq = x.shape[0], x.shape[1]
        x = x.view(nb_samp, 1, len_seq)
        x = self.Sinc_conv(x)
        x = F.max_pool1d(torch.abs(x), 3)
        x = self.first_bn(x)
        x = self.selu(x)
        
        x0 = self.block0(x)
        y0 = self.avgpool(x0).view(x0.size(0), -1)
        y0 = self.fc_attention0(y0)
        y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1)
        x = x0 * y0 + y0

        x1 = self.block1(x)
        y1 = self.avgpool(x1).view(x1.size(0), -1)
        y1 = self.fc_attention1(y1)
        y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1)
        x = x1 * y1 + y1

        x2 = self.block2(x)
        y2 = self.avgpool(x2).view(x2.size(0), -1)
        y2 = self.fc_attention2(y2)
        y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1)
        x = x2 * y2 + y2

        x3 = self.block3(x)
        y3 = self.avgpool(x3).view(x3.size(0), -1)
        y3 = self.fc_attention3(y3)
        y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1)
        x = x3 * y3 + y3

        x4 = self.block4(x)
        y4 = self.avgpool(x4).view(x4.size(0), -1)
        y4 = self.fc_attention4(y4)
        y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1)
        x = x4 * y4 + y4

        x5 = self.block5(x)
        y5 = self.avgpool(x5).view(x5.size(0), -1)
        y5 = self.fc_attention5(y5)
        y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1)
        x = x5 * y5 + y5

        x = self.bn_before_gru(x)
        x = self.selu(x)
        x = x.permute(0, 2, 1)
        self.gru.flatten_parameters()
        x, _ = self.gru(x)
        x = x[:, -1, :]
        x = self.fc1_gru(x)
        x = self.fc2_gru(x)

        return x if not is_test else F.softmax(x, dim=1)

    def _make_attention_fc(self, in_features, l_out_features):
        return nn.Sequential(nn.Linear(in_features, l_out_features))

# ## 3. Dataset Definition
# ASVDataset from data_utils_LA.py with augmentations

ASVFile = collections.namedtuple('ASVFile', ['speaker_id', 'file_name', 'path', 'sys_id', 'key'])

class ASVDataset(Dataset):
    def __init__(self, database_path=None, protocols_path=None, transform=None, 
                 is_train=True, sample_size=None, is_logical=True, feature_name=None, 
                 is_eval=False, eval_part=0):
        track = 'LA'
        assert feature_name is not None, 'must provide feature name'
        self.track = track
        self.is_logical = is_logical
        self.prefix = 'ASVspoof2019_{}'.format(track)
        self.sysid_dict = {
            '-': 0, 'A01': 1, 'A02': 2, 'A03': 3, 'A04': 4, 'A05': 5, 'A06': 6
        } if not is_eval else {
            '-': 0, 'A07': 1, 'A08': 2, 'A09': 3, 'A10': 4, 'A11': 5, 'A12': 6,
            'A13': 7, 'A14': 8, 'A15': 9, 'A16': 10, 'A17': 11, 'A18': 12, 'A19': 13
        }
        self.data_root_dir = database_path
        self.is_eval = is_eval
        self.sysid_dict_inv = {v: k for k, v in self.sysid_dict.items()}
        self.data_root = protocols_path
        self.dset_name = 'eval' if is_eval else 'train' if is_train else 'dev'
        self.protocols_fname = 'eval.trl' if is_eval else 'train.trn' if is_train else 'dev.trl'
        self.protocols_dir = os.path.join(self.data_root)
        self.files_dir = os.path.join(self.data_root_dir, f'{self.prefix}_{self.dset_name}', 'flac')
        self.protocols_fname = os.path.join(self.protocols_dir, f'ASVspoof2019.{track}.cm.{self.protocols_fname}.txt')
        self.cache_fname = f'cache_{self.dset_name}_{track}_{feature_name}.npy'
        self.transform = transform

        if os.path.exists(self.cache_fname):
            self.data_x, self.data_y, self.data_sysid, self.files_meta = torch.load(self.cache_fname)
            print(f'Dataset loaded from cache {self.cache_fname}')
        else:
            self.files_meta = self.parse_protocols_file(self.protocols_fname)
            data = list(map(self.read_file, self.files_meta))
            self.data_x, self.data_y, self.data_sysid = map(list, zip(*data))
            if self.transform:
                self.data_x = Parallel(n_jobs=4, prefer='threads')(delayed(self.transform)(x) for x in self.data_x)
            torch.save((self.data_x, self.data_y, self.data_sysid, self.files_meta), self.cache_fname)
        
        if sample_size:
            select_idx = np.random.choice(len(self.files_meta), size=(sample_size,), replace=True).astype(np.int32)
            self.files_meta = [self.files_meta[x] for x in select_idx]
            self.data_x = [self.data_x[x] for x in select_idx]
            self.data_y = [self.data_y[x] for x in select_idx]
            self.data_sysid = [self.data_sysid[x] for x in select_idx]
        
        self.length = len(self.data_x)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        x = self.data_x[idx]
        y = self.data_y[idx]
        # Channel augmentation: Simulate stereo
        if np.random.rand() > 0.5:
            x_stereo = np.stack([x, np.roll(x, 10)])
            x = x_stereo.mean(axis=0)
        # Compression augmentation: Add noise
        if np.random.rand() > 0.5:
            x += np.random.randn(*x.shape) * 0.01
        return x, y, self.files_meta[idx]

    def read_file(self, meta):
        data_x, sample_rate = sf.read(meta.path)
        data_y = meta.key
        return data_x, float(data_y), meta.sys_id

    def _parse_line(self, line):
        tokens = line.strip().split(' ')
        return ASVFile(speaker_id=tokens[0], file_name=tokens[1],
                       path=os.path.join(self.files_dir, tokens[1] + '.flac'),
                       sys_id=self.sysid_dict[tokens[3]], key=int(tokens[4] == 'bonafide'))

    def parse_protocols_file(self, protocols_fname):
        lines = open(protocols_fname).readlines()
        return list(map(self._parse_line, lines))

# ## 4. Training and Evaluation Functions
# From provided training script

def pad(x, max_len=64600):
    x_len = x.shape[0]
    if x_len >= max_len:
        return x[:max_len]
    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
    return padded_x

def evaluate_accuracy(data_loader, model, device):
    num_correct = 0.0
    num_total = 0.0
    model.eval()
    for batch_x, batch_y, batch_meta in data_loader:
        batch_size = batch_x.size(0)
        num_total += batch_size
        batch_x = batch_x.to(device)
        batch_y = batch_y.view(-1).type(torch.int64).to(device)
        batch_out = model(batch_x, batch_y)
        _, batch_pred = batch_out.max(dim=1)
        num_correct += (batch_pred == batch_y).sum(dim=0).item()
    return 100 * (num_correct / num_total)

def train_epoch(data_loader, model, lr, optim, device):
    running_loss = 0
    num_correct = 0.0
    num_total = 0.0
    model.train()
    weight = torch.FloatTensor([1.0, 9.0]).to(device)
    criterion = nn.CrossEntropyLoss(weight=weight)
    
    for batch_x, batch_y, batch_meta in data_loader:
        batch_size = batch_x.size(0)
        num_total += batch_size
        batch_x = batch_x.to(device)
        batch_y = batch_y.view(-1).type(torch.int64).to(device)
        batch_out = model(batch_x, batch_y)
        batch_loss = criterion(batch_out, batch_y)
        _, batch_pred = batch_out.max(dim=1)
        num_correct += (batch_pred == batch_y).sum(dim=0).item()
        running_loss += (batch_loss.item() * batch_size)
        optim.zero_grad()
        batch_loss.backward()
        optim.step()
    
    running_loss /= num_total
    train_accuracy = (num_correct / num_total) * 100
    return running_loss, train_accuracy

# ## 5. Main Execution
# Model config, data loading, training, and evaluation

# Model configuration (example from RawNet2 literature)
d_args = {
    'filts': [20, [20, 20], [20, 128], [128, 128]],
    'first_conv': 251,
    'in_channels': 1,
    'gru_node': 1024,
    'nb_gru_layer': 1,
    'nb_fc_node': 1024,
    'nb_classes': 2
}

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = RawNet(d_args, device).to(device)

# Data loading (replace paths with your ASVspoof 5 locations)
database_path = '/path/to/ASVspoof5/'
protocols_path = '/path/to/ASVspoof5/protocols/'
transform = transforms.Compose([lambda x: pad(x), lambda x: Tensor(x)])
train_set = ASVDataset(database_path=database_path, protocols_path=protocols_path, 
                       transform=transform, is_train=True, sample_size=1000, 
                       is_logical=True, feature_name='Raw_audio')
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
dev_set = ASVDataset(database_path=database_path, protocols_path=protocols_path, 
                     transform=transform, is_train=False, sample_size=1000, 
                     is_logical=True, feature_name='Raw_audio')
dev_loader = DataLoader(dev_set, batch_size=16, shuffle=False)

# Training
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
num_epochs = 5  # Light fine-tuning
writer = SummaryWriter('logs/rawnet2_momenta')
losses = []
for epoch in range(num_epochs):
    running_loss, train_accuracy = train_epoch(train_loader, model, 0.0001, optimizer, device)
    valid_accuracy = evaluate_accuracy(dev_loader, model, device)
    losses.append(running_loss)
    writer.add_scalar('train_accuracy', train_accuracy, epoch)
    writer.add_scalar('valid_accuracy', valid_accuracy, epoch)
    writer.add_scalar('loss', running_loss, epoch)
    print(f'Epoch {epoch+1}/{num_epochs} - Loss: {running_loss:.4f} - Train Acc: {train_accuracy:.2f}% - Dev Acc: {valid_accuracy:.2f}%')

plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.show()

# Save model
if not os.path.exists('models'):
    os.mkdir('models')
torch.save(model.state_dict(), 'models/rawnet2_epoch_5.pth')
