### MODELS DEFINED

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# =========================
# -------- RawTFNet --------
# =========================

class TFAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // 4)
        self.fc2 = nn.Linear(channels // 4, channels)

    def forward(self, x):
        w = x.mean(dim=2)
        w = F.relu(self.fc1(w))
        w = torch.sigmoid(self.fc2(w)).unsqueeze(-1)
        return x * w


class TFResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()

        self.conv1 = nn.Conv1d(in_ch, out_ch, 3, padding=1)
        self.bn1 = nn.BatchNorm1d(out_ch)

        self.conv2 = nn.Conv1d(out_ch, out_ch, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(out_ch)

        self.att = TFAttention(out_ch)
        self.act = nn.LeakyReLU(0.3)

        self.skip = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.pool = nn.MaxPool1d(3)

    def forward(self, x):
        identity = self.skip(x)
        x = self.act(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x = self.att(x)
        x = self.act(x + identity)
        return self.pool(x)


class RawTFNet(nn.Module):
    def __init__(self, base_channels=32):
        super().__init__()

        self.frontend = nn.Sequential(
            nn.Conv1d(1, base_channels, 251, padding=125),
            nn.BatchNorm1d(base_channels),
            nn.LeakyReLU(0.3),
            nn.MaxPool1d(3)
        )

        self.block1 = TFResidualBlock(base_channels, base_channels * 2)
        self.block2 = TFResidualBlock(base_channels * 2, base_channels * 4)
        self.block3 = TFResidualBlock(base_channels * 4, base_channels * 8)

        self.gru = nn.GRU(base_channels * 8, 512, batch_first=True)
        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, 2)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.frontend(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = x.permute(0, 2, 1)
        x, _ = self.gru(x)
        x = F.normalize(self.fc1(x[:, -1]), dim=1)
        return self.fc2(x)


# =========================
# -------- RawNet2 --------
# =========================
class FixedSincConv(nn.Module):
    def __init__(self, out_channels=128, kernel_size=129, sample_rate=16000):
        super().__init__()

        if kernel_size % 2 == 0:
            kernel_size += 1

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate

        # Linear-scale fixed filterbank (S3 setup in paper)
        low_hz = 0
        high_hz = sample_rate // 2
        hz = torch.linspace(low_hz, high_hz, out_channels + 1)

        self.low = hz[:-1]
        self.high = hz[1:]

        n = torch.arange(-(kernel_size // 2), kernel_size // 2 + 1)
        self.register_buffer("n", n)
        self.register_buffer("window", torch.hamming_window(kernel_size))

    def forward(self, x):
        device = x.device
        filters = []

        for low, high in zip(self.low, self.high):
            f1 = low / self.sample_rate
            f2 = high / self.sample_rate
            n = self.n.to(device)

            sinc1 = torch.sin(2 * math.pi * f1 * n) / (n + 1e-9)
            sinc2 = torch.sin(2 * math.pi * f2 * n) / (n + 1e-9)
            band = (sinc2 - sinc1) * self.window.to(device)
            filters.append(band)

        filters = torch.stack(filters).view(self.out_channels, 1, self.kernel_size)
        return F.conv1d(x, filters, stride=1, padding=self.kernel_size // 2)

class FMS(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.fc = nn.Linear(channels, channels)

    def forward(self, x):
        s = x.mean(dim=2)
        s = torch.sigmoid(self.fc(s))
        s = s.unsqueeze(2)
        return x * s + s


class RawNetResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()

        self.conv1 = nn.Conv1d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm1d(channels)
        self.conv2 = nn.Conv1d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(channels)

        self.fms = FMS(channels)
        self.pool = nn.MaxPool1d(3)

    def forward(self, x):
        residual = x
        x = F.leaky_relu(self.bn1(self.conv1(x)), 0.3)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.3)
        x = x + residual
        x = self.pool(x)
        x = self.fms(x)
        return x


class RawNet2_AntiSpoof(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()

        self.sinc = FixedSincConv(out_channels=128, kernel_size=129)
        self.bn0 = nn.BatchNorm1d(128)

        self.block1 = nn.Sequential(
            RawNetResidualBlock(128),
            RawNetResidualBlock(128)
        )

        self.conv_expand = nn.Conv1d(128, 512, 1)

        self.block2 = nn.Sequential(
            RawNetResidualBlock(512),
            RawNetResidualBlock(512),
            RawNetResidualBlock(512),
            RawNetResidualBlock(512)
        )

        self.gru = nn.GRU(512, 1024, batch_first=True)
        self.fc1 = nn.Linear(1024, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.sinc(x)
        x = F.leaky_relu(self.bn0(x), 0.3)

        x = self.block1(x)
        x = self.conv_expand(x)
        x = self.block2(x)

        x = x.permute(0, 2, 1)

        self.gru.flatten_parameters()
        x, _ = self.gru(x)

        x = x[:, -1, :]
        x = F.leaky_relu(self.fc1(x), 0.3)
        return self.fc2(x)
    
    
# =========================
# -------- AASIST --------
# =========================
class AASIST(nn.Module):
    def __init__(self):
        super().__init__()

        self.frontend = nn.Sequential(
            nn.Conv1d(1, 64, 251, padding=125),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.3),
            nn.MaxPool1d(3)
        )

        self.block1 = STBlock(64, 128)
        self.block2 = STBlock(128, 256)
        self.block3 = STBlock(256, 512)

        self.gru = nn.GRU(512, 512, batch_first=True)
        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, 2)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.frontend(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = x.permute(0, 2, 1)
        x, _ = self.gru(x)

        x = F.normalize(self.fc1(x[:, -1]), dim=1)
        return self.fc2(x)

class STBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm1d(out_ch),
            nn.LeakyReLU(0.3),

            # ðŸ”¥ extra downsampling BEFORE attention
            nn.MaxPool1d(4)
        )

        self.attn = TemporalSelfAttention(out_ch, heads=2)

    def forward(self, x):
        x = self.conv(x)          # (B, C, Tâ†“)
        x = x.permute(0, 2, 1)    # (B, Tâ†“, C)
        x = self.attn(x)
        return x.permute(0, 2, 1)


class TemporalSelfAttention(nn.Module):
    def __init__(self, dim, heads=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=heads,
            batch_first=True
        )

    def forward(self, x):
        # x: (B, T, C)
        out, _ = self.attn(x, x, x)
        return out


# =========================
# -------- RawformerL --------
# =========================
class RawFormerL(nn.Module):
    def __init__(self):
        super().__init__()

        self.frontend = nn.Sequential(
            nn.Conv1d(1, 128, kernel_size=251, stride=5, padding=125),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.3),
            nn.MaxPool1d(3),

            nn.Conv1d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.3),
            nn.MaxPool1d(3)
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=256,
            nhead=8,
            dim_feedforward=1024,
            dropout=0.1,
            batch_first=True
        )

        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=4
        )

        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

    def forward(self, x):
        x = x.unsqueeze(1)          # (B,1,T)
        x = self.frontend(x)        # (B,256,T')
        x = x.permute(0, 2, 1)      # (B,T',256)

        x = self.transformer(x)

        mean = x.mean(dim=1)
        std = x.std(dim=1)
        stats = torch.cat([mean, std], dim=1)  # (B,512)

        return self.fc(stats)


In [11]:
import torch
import numpy as np
import librosa

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

SAMPLE_RATE = 16000
SEGMENT_LEN = 64000
FRAME_SIZE = 160
COMPRESSION_RATIOS = [0.5, 0.25, 0.75]

audio_path = "Test_Audio.flac"

audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE)

if len(audio) < SEGMENT_LEN:
    audio = np.pad(audio, (0, SEGMENT_LEN - len(audio)))
else:
    audio = audio[:SEGMENT_LEN]

print("===== ORIGINAL AUDIO =====")
print("Original length:", len(audio))

# -----------------------------
# TRUE FRAME-WISE CS
# -----------------------------

def generate_cs_matrix(N, compression_ratio):
    M = int(N * compression_ratio)
    Phi = np.random.randn(M, N)
    Phi = Phi / np.sqrt(M)
    return Phi


def apply_cs(audio, frame_size, compression_ratio):
    N = frame_size
    M = int(N * compression_ratio)

    Phi = generate_cs_matrix(N, compression_ratio)

    n_frames = len(audio) // frame_size
    compressed_frames = []

    for i in range(n_frames):
        frame = audio[i*frame_size:(i+1)*frame_size]
        compressed = Phi @ frame
        compressed_frames.append(compressed)

    return np.concatenate(compressed_frames)


# Prepare original tensor (B, T)
original_tensor = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(DEVICE)

print("Original tensor shape:", original_tensor.shape)

# -----------------------------
# LOAD MODELS
# -----------------------------

models = {
    "AASIST": AASIST().to(DEVICE),
    "RawNet2": RawNet2_AntiSpoof().to(DEVICE),
    "RawTFNet": RawTFNet(base_channels=32).to(DEVICE),
    "RawFormerL": RawFormerL().to(DEVICE),
}

# -----------------------------
# TEST ORIGINAL INPUT
# -----------------------------

print("\n====== TESTING ORIGINAL INPUT ======")

for name, model in models.items():
    model.eval()
    print(f"\nModel: {name}")

    if name == "RawNet2":
        inp = original_tensor.unsqueeze(1)  # RawNet2 expects (B,1,T)
    else:
        inp = original_tensor  # others expect (B,T)

    out = model(inp)
    print("Input shape:", inp.shape)
    print("Output shape:", out.shape)

# -----------------------------
# TEST COMPRESSED INPUTS
# -----------------------------

for ratio in COMPRESSION_RATIOS:

    print(f"\n\n====== TESTING CS RATIO = {ratio} ======")

    compressed_audio = apply_cs(audio, FRAME_SIZE, ratio)
    compressed_tensor = torch.tensor(
        compressed_audio, dtype=torch.float32
    ).unsqueeze(0).to(DEVICE)

    print("Compressed length:", len(compressed_audio))
    print("Compressed tensor shape:", compressed_tensor.shape)
    print("Reduction %:", 100 * (1 - len(compressed_audio)/len(audio)))

    for name, model in models.items():
        model.eval()
        print(f"\nModel: {name}")

        if name == "RawNet2":
            inp = compressed_tensor.unsqueeze(1)
        else:
            inp = compressed_tensor

        out = model(inp)
        print("Input shape:", inp.shape)
        print("Output shape:", out.shape)

    print("-" * 60)


===== ORIGINAL AUDIO =====
Original length: 64000
Original tensor shape: torch.Size([1, 64000])


Model: AASIST
Input shape: torch.Size([1, 64000])
Output shape: torch.Size([1, 2])

Model: RawNet2
Input shape: torch.Size([1, 1, 64000])
Output shape: torch.Size([1, 2])

Model: RawTFNet
Input shape: torch.Size([1, 64000])
Output shape: torch.Size([1, 2])

Model: RawFormerL
Input shape: torch.Size([1, 64000])
Output shape: torch.Size([1, 2])


Compressed length: 32000
Compressed tensor shape: torch.Size([1, 32000])
Reduction %: 50.0

Model: AASIST
Input shape: torch.Size([1, 32000])
Output shape: torch.Size([1, 2])

Model: RawNet2
Input shape: torch.Size([1, 1, 32000])
Output shape: torch.Size([1, 2])

Model: RawTFNet
Input shape: torch.Size([1, 32000])
Output shape: torch.Size([1, 2])

Model: RawFormerL
Input shape: torch.Size([1, 32000])
Output shape: torch.Size([1, 2])
------------------------------------------------------------


Compressed length: 16000
Compressed tensor shape: torch