In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from utils.dataset import ASVspoofDataset
from models.rawnet2 import RawNet2

# Paths (update as needed)
PROTOCOL_FILE = 'data/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'
AUDIO_DIR = 'data/LA/ASVspoof2019_LA_train/flac'

# Dataset and Dataloader
dataset = ASVspoofDataset(PROTOCOL_FILE, AUDIO_DIR)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Model, Loss, Optimizer
model = RawNet2()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
EPOCHS = 5
for epoch in range(EPOCHS):
    model.train()
    total, correct = 0, 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        output = model(x)
        loss = loss_fn(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total += y.size(0)
        correct += (output.argmax(1) == y).sum().item()

    acc = 100 * correct / total
    print(f"Epoch {epoch + 1}/{EPOCHS} | Accuracy: {acc:.2f}%")
