In [None]:
import torch
import torch.nn as nn

class RawNet2(nn.Module):
    """
    A lightweight version of RawNet2 for binary classification (bonafide/spoof).
    """
    def __init__(self):
        super(RawNet2, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(1, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
        )
        self.gru = nn.GRU(input_size=32, hidden_size=64, batch_first=True)
        self.fc = nn.Linear(64, 2)

    def forward(self, x):
        x = self.cnn(x)
        x = x.permute(0, 2, 1)  # Reshape for GRU: (batch, time, features)
        _, h = self.gru(x)
        return self.fc(h[-1])
