In [1]:
import os
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
from tqdm.notebook import tqdm

In [2]:
class ASVspoofDataset(Dataset):
    def __init__(self, data_dir, protocol_file, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.data = []

        with open(protocol_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                filename = parts[1]
                label = 1 if parts[-1] == 'spoof' else 0
                self.data.append((filename, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        filename, label = self.data[idx]
        file_path = os.path.join(self.data_dir, filename + '.flac')

        try:
            waveform, sr = torchaudio.load(file_path)
        except:
            # Corrupt or unreadable file, return dummy tensor
            return torch.zeros(1, 128, 128), label

        if self.transform:
            features = self.transform(waveform)
        else:
            features = waveform

        return features, label

In [3]:
mel_transform = nn.Sequential(
    MelSpectrogram(sample_rate=16000, n_fft=400, hop_length=160, n_mels=128),
    AmplitudeToDB()
)




In [4]:
def collate_fn(batch):
    specs, labels = zip(*batch)
    # Resize or pad all to 128x128
    processed = []
    for spec in specs:
        if spec.shape[-1] < 128:
            pad_amt = 128 - spec.shape[-1]
            spec = F.pad(spec, (0, pad_amt))
        elif spec.shape[-1] > 128:
            spec = spec[:, :, :128]
        processed.append(spec)
    specs_tensor = torch.stack(processed)
    labels_tensor = torch.tensor(labels)
    return specs_tensor, labels_tensor


In [5]:
class ConvAttentionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        self.attn = nn.Sequential(
            nn.Conv2d(64, 32, 1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 1),
            nn.Softmax(dim=-1)
        )

        self.fc = nn.Sequential(
            nn.Linear(64 * 64 * 64, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        x = F.relu(self.conv1(x))  # (B, 32, 128, 128)
        x = self.pool(F.relu(self.conv2(x)))  # (B, 64, 64, 64)

        attn_weights = self.attn(x)  # (B, 1, 64, 64)
        x = x * attn_weights  # Apply attention
        x = x.view(x.size(0), -1)  # Flatten
        return self.fc(x)

In [6]:
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss, correct = 0.0, 0
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        acc = correct / len(train_loader.dataset)
        print(f"Train Loss: {total_loss:.4f}, Accuracy: {acc:.4f}")

        evaluate_model(model, val_loader)

In [7]:
def evaluate_model(model, val_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            correct += (outputs.argmax(dim=1) == labels).sum().item()
    acc = correct / len(val_loader.dataset)
    print(f"Validation Accuracy: {acc:.4f}")

In [None]:
from torch.utils.data import Subset

data_root = "./PA/ASVspoof2019_PA_train/flac"
protocol_train = "./PA/ASVspoof2019_PA_cm_protocols/ASVspoof2019.PA.cm.train.trn.txt"
protocol_dev = "./PA/ASVspoof2019_PA_cm_protocols/ASVspoof2019.PA.cm.dev.trl.txt"
data_dev_root = "./PA/ASVspoof2019_PA_dev/flac"

train_dataset = ASVspoofDataset(data_root, protocol_train, transform=mel_transform)
dev_dataset = ASVspoofDataset(data_dev_root, protocol_dev, transform=mel_transform)


#i am using a subset because of resource constraints, if you wish to train on the entire dataset, remove the train_subset and dev_subset, and uncomment the train_loader and val_loader that are commented, remove the other ones.
train_subset = Subset(train_dataset, indices=range(5000))
dev_subset = Subset(dev_dataset, indices=range(1000))


# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
# val_loader = DataLoader(dev_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

train_loader = DataLoader(train_subset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(dev_subset, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [14]:
model = ConvAttentionModel()
train_model(model, train_loader, val_loader, epochs=3, lr=0.001)


Epoch 1:   0%|          | 0/313 [00:00<?, ?it/s]

Train Loss: 1.0432, Accuracy: 0.9994
Validation Accuracy: 1.0000


Epoch 2:   0%|          | 0/313 [00:00<?, ?it/s]

Train Loss: 0.0000, Accuracy: 1.0000
Validation Accuracy: 1.0000


Epoch 3:   0%|          | 0/313 [00:00<?, ?it/s]

Train Loss: 0.0000, Accuracy: 1.0000
Validation Accuracy: 1.0000
