In [None]:
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import torchaudio
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from transformers import ASTFeatureExtractor, ASTForAudioClassification
from sklearn.metrics import roc_curve, roc_auc_score, f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA available:", torch.cuda.is_available())
print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")


In [None]:
LABEL_MAP = {"bonafide": 1, "spoof": 0}
feature_extractor = ASTFeatureExtractor.from_pretrained(
    "..."
)
sampling_rate = 16000


class ASTDataset(Dataset):
    def __init__(self, csv_path, audio_dir, nrows=None):
        self.df = pd.read_csv(csv_path, sep=" ", header=None, nrows=nrows)
        self.audio_dir = audio_dir
        self.labels = [LABEL_MAP[label] for label in self.df[4]]  # precompute labels

    def __len__(self):
        return len(self.df)

    def get_labels(self):
        return self.labels

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        audio_filename = row[1] + ".flac"
        label_str = row[4]
        label = LABEL_MAP[label_str]

        flac_path = os.path.join(self.audio_dir, audio_filename)
        waveform, sr = torchaudio.load(flac_path)

        # Convert to mono
        waveform = waveform.mean(dim=0)

        # Resample to 16kHz if needed
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
            waveform = resampler(waveform)

        return {
            "waveform": waveform.numpy(),  # to use in collate_fn
            "labels": torch.tensor(label, dtype=torch.long)
        }

def ast_collate_fn(batch):
    waveforms = [item["waveform"] for item in batch]
    labels = [item["labels"] for item in batch]

    inputs = feature_extractor(
        waveforms,
        sampling_rate=16000,
        padding=True,
        return_tensors="pt"
    )

    return {
        "input_values": inputs["input_values"],
        "labels": torch.stack(labels)
    }


In [None]:

from torch.utils.data import WeightedRandomSampler

batch_size = 8

train_dataset = ASTDataset(
    csv_path="/root/autodl-fs/ASVspoof2019.LA.cm.train.trn.txt",
    audio_dir="/root/autodl-fs/ASVspoof2019_LA_train/flac",
)

count_0 = 22800
count_1 = 2580
total = count_0 + count_1

# Inverse frequency
class_weights = {
    0: total / count_0, 
    1: total / count_1   
}

# Get all labels from the dataset
labels = train_dataset.get_labels()

# Assign sample weights
sample_weights = [class_weights[label] for label in labels]


# Create the sampler
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=45600,
    replacement=True
)

# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=ast_collate_fn)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=8,
    sampler=sampler,
    collate_fn=ast_collate_fn
)

for batch in train_dataloader:
    print(batch["input_values"].shape)  
    print(batch["labels"])
    break



In [None]:
dev_dataset = ASTDataset(
    csv_path="/root/autodl-fs/ASVspoof2019.LA.cm.dev.trl.txt",
    audio_dir="/root/autodl-fs/ASVspoof2019_LA_dev/flac",
)

dev_dataloader = DataLoader(dev_dataset, batch_size=8, shuffle=True, collate_fn=ast_collate_fn)

for batch in dev_dataloader:
    print(batch["input_values"].shape) 
    print(batch["labels"])
    break


In [None]:
eval_dataset = ASTDataset(
    csv_path="/root/autodl-fs/ASVspoof2019.LA.cm.eval.trl.txt",
    audio_dir="/root/autodl-tmp/ASVspoof2019_LA_eval/flac",
)

eval_dataloader = DataLoader(eval_dataset, batch_size=8, shuffle=True, collate_fn=ast_collate_fn)

for batch in eval_dataloader:
    print(batch["input_values"].shape)  
    print(batch["labels"])
    break

In [None]:
model_path = "..."
model = ASTForAudioClassification.from_pretrained(
    model_path,
    num_labels=2,                      # adjust for your task
    ignore_mismatched_sizes=True      # useful when num_labels differs
).to(device)

optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)


Some weights of ASTForAudioClassification were not initialized from the model checkpoint at /root/autodl-fs/ast-audioset and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Freeze layers (Linear Probing)

for param in model.parameters():
    param.requires_grad = False
for name, param in model.named_parameters():
    if 'classifier' in name:
        param.requires_grad = True
for name, param in model.named_parameters():
    print(name, param.requires_grad)
    
# Partial tuning
# Freeze first 6 transformer encoder layers
for name, param in model.named_parameters():
    if any(f"encoder.layer.{i}." in name for i in range(6)):
        param.requires_grad = False
    else:
        param.requires_grad = True  # Unfreeze the rest
for name, param in model.named_parameters():
    print(name, param.requires_grad)

# Full tuning

audio_spectrogram_transformer.embeddings.cls_token False
audio_spectrogram_transformer.embeddings.distillation_token False
audio_spectrogram_transformer.embeddings.position_embeddings False
audio_spectrogram_transformer.embeddings.patch_embeddings.projection.weight False
audio_spectrogram_transformer.embeddings.patch_embeddings.projection.bias False
audio_spectrogram_transformer.encoder.layer.0.attention.attention.query.weight False
audio_spectrogram_transformer.encoder.layer.0.attention.attention.query.bias False
audio_spectrogram_transformer.encoder.layer.0.attention.attention.key.weight False
audio_spectrogram_transformer.encoder.layer.0.attention.attention.key.bias False
audio_spectrogram_transformer.encoder.layer.0.attention.attention.value.weight False
audio_spectrogram_transformer.encoder.layer.0.attention.attention.value.bias False
audio_spectrogram_transformer.encoder.layer.0.attention.output.dense.weight False
audio_spectrogram_transformer.encoder.layer.0.attention.output.den

In [None]:
model_path = "/root/autodl-fs/ast-audioset"
model = ASTForAudioClassification.from_pretrained(
    model_path,
    num_labels=2,                      # adjust for your task
    ignore_mismatched_sizes=True      # useful when num_labels differs
).to(device)

optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
# criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)


# weighted loss
# comput weight


count_0 = 22800  # spoof
count_1 = 2580   # bonafide
total = count_0 + count_1

# # Inverse frequency weights
# weight_0 = total / count_0  # ≈ 1.1877
# weight_1 = total / count_1  # ≈ 6.3096

import math
weight_0 = 1.0 / math.sqrt(count_0)
weight_1 = 1.0 / math.sqrt(count_1)

# Convert to tensor
class_weights = torch.tensor([weight_0, weight_1], dtype=torch.float).to(device)

# Define weighted loss
criterion = nn.CrossEntropyLoss(weight=class_weights)


In [None]:
def compute_eer(y_true, y_score):
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    fnr = 1 - tpr
    eer_idx = np.nanargmin(np.abs(fnr - fpr))
    eer = (fpr[eer_idx] + fnr[eer_idx]) / 2
    return eer

def train(model, dataloader):
    model.train()
    total_loss, correct, total = 0, 0, 0
    y_true, y_score, y_pred = [], [], []

    for batch in tqdm(dataloader, desc="Training"):
        x = batch["input_values"].to(device)
        y = batch["labels"].to(device)

        out = model(input_values=x)  # ✅ AST input
        logits = out.logits
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

        probs = torch.softmax(logits, dim=1)[:, 1]  # class 1 (bonafide)
        y_true.extend(y.cpu().numpy())
        y_score.extend(probs.detach().cpu().numpy())
        y_pred.extend(preds.detach().cpu().numpy())

    scheduler.step()
    acc = correct / total
    eer = compute_eer(y_true, y_score)
    auc = roc_auc_score(y_true, y_score)
    f1 = f1_score(y_true, y_pred)
    return total_loss / total, acc, eer, auc, f1

def evaluate(model, dataloader):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    y_true, y_score, y_pred = [], [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            x = batch["input_values"].to(device)
            y = batch["labels"].to(device)

            out = model(input_values=x)  # ✅ AST input
            logits = out.logits
            loss = criterion(logits, y)

            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

            probs = torch.softmax(logits, dim=1)[:, 1]
            y_true.extend(y.cpu().numpy())
            y_score.extend(probs.cpu().numpy())
            y_pred.extend(preds.detach().cpu().numpy())

    acc = correct / total
    eer = compute_eer(y_true, y_score)
    auc = roc_auc_score(y_true, y_score)
    f1 = f1_score(y_true, y_pred)
    return total_loss / total, acc, eer, auc, f1


In [None]:
# training example
model_path = "/root/autodl-tmp/ast_ws_lp/epoch_15.pt"
model.load_state_dict(torch.load(model_path))

save_dir = "/root/autodl-tmp/ast_ws_lp/"

epochs = 20
for epoch in range(15, epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    train_loss, train_acc, train_eer, train_auc, train_f1 = train(model, train_dataloader)
    print(f"Train Loss: {train_loss:.6f} | Train Acc: {train_acc:.6f} | Train EER: {train_eer:.6f} | Train AUC: {train_auc:.6f} | Train F1: {train_f1:.6f}" )

    save_path = os.path.join(save_dir, f"epoch_{epoch+1}.pt")
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

    dev_loss, dev_acc, dev_eer, dev_auc, dev_f1 = evaluate(model, dev_dataloader)
    print(f"Dev   Loss: {dev_loss:.6f} | Dev   Acc: {dev_acc:.6f} | Dev EER: {dev_eer:.6f} | Dev AUC: {dev_auc:.6f} | Dev F1: {dev_f1:.6f}")




Epoch 16/20


Training: 100%|██████████| 5700/5700 [15:45<00:00,  6.03it/s]


Train Loss: 0.067523 | Train Acc: 0.980768 | Train EER: 0.020526 | Train AUC: 0.997945 | Train F1: 0.980893
Model saved to /root/autodl-tmp/ast_ws_lp/epoch_16.pt


Evaluating: 100%|██████████| 3106/3106 [09:20<00:00,  5.54it/s]


Dev   Loss: 0.084035 | Dev   Acc: 0.967839 | Dev EER: 0.042542 | Dev AUC: 0.992590 | Dev F1: 0.856938

Epoch 17/20


Training: 100%|██████████| 5700/5700 [17:27<00:00,  5.44it/s]


Train Loss: 0.062891 | Train Acc: 0.981798 | Train EER: 0.017873 | Train AUC: 0.998304 | Train F1: 0.981857
Model saved to /root/autodl-tmp/ast_ws_lp/epoch_17.pt


Evaluating: 100%|██████████| 3106/3106 [09:29<00:00,  5.45it/s]


Dev   Loss: 0.074224 | Dev   Acc: 0.971824 | Dev EER: 0.041236 | Dev AUC: 0.992836 | Dev F1: 0.870466

Epoch 18/20


Training: 100%|██████████| 5700/5700 [17:25<00:00,  5.45it/s]


Train Loss: 0.062177 | Train Acc: 0.982083 | Train EER: 0.017259 | Train AUC: 0.998260 | Train F1: 0.982330
Model saved to /root/autodl-tmp/ast_ws_lp/epoch_18.pt


Evaluating: 100%|██████████| 3106/3106 [09:31<00:00,  5.44it/s]


Dev   Loss: 0.076568 | Dev   Acc: 0.970576 | Dev EER: 0.040440 | Dev AUC: 0.993001 | Dev F1: 0.866338

Epoch 19/20


Training: 100%|██████████| 5700/5700 [17:17<00:00,  5.49it/s]


Train Loss: 0.060431 | Train Acc: 0.983048 | Train EER: 0.016075 | Train AUC: 0.998370 | Train F1: 0.983317
Model saved to /root/autodl-tmp/ast_ws_lp/epoch_19.pt


Evaluating: 100%|██████████| 3106/3106 [05:21<00:00,  9.67it/s]


Dev   Loss: 0.083999 | Dev   Acc: 0.967678 | Dev EER: 0.041146 | Dev AUC: 0.992900 | Dev F1: 0.856633

Epoch 20/20


Training: 100%|██████████| 5700/5700 [09:46<00:00,  9.71it/s]


Train Loss: 0.061695 | Train Acc: 0.981952 | Train EER: 0.017873 | Train AUC: 0.998300 | Train F1: 0.982169
Model saved to /root/autodl-tmp/ast_ws_lp/epoch_20.pt


Evaluating: 100%|██████████| 3106/3106 [05:21<00:00,  9.67it/s]


Dev   Loss: 0.080307 | Dev   Acc: 0.968926 | Dev EER: 0.040748 | Dev AUC: 0.993125 | Dev F1: 0.861101
