In [16]:
%cd /content/
!git clone https://github.com/clovaai/aasist.git

/content
fatal: destination path 'aasist' already exists and is not an empty directory.


In [17]:
%cd /content/aasist/
!pip install -r requirements.txt

/content/aasist


In [18]:
!python ./download_dataset.py

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_7787040.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_2924301.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_9249366.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_3442936.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_7772915.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_5569336.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_7773607.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_7813281.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_9705954.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_2427464.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_1000273.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_5263550.flac  
  ...  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_4492957.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_6105590.flac  
  inflating: LA/ASVspoof2019_LA_eval/flac/LA_E_9008117.flac

In [1]:
# Notebook parameters – edit these paths before running
config_path     = "/content/aasist/config/AASIST.conf"
checkpoint_path = "/content/aasist/models/weights/AASIST.pth"
audio_dir       = "/content/aasist/LA/ASVspoof2019_LA_dev/flac"
n_iter          = 20
batch_size      = 1

In [2]:
# 1) Install required packages (uncomment if needed)
!pip install -q torch torchvision onnx tqdm model_compression_toolkit soundfile scipy


In [3]:
%cd /content/aasist/
# 2) Imports
import os
import json
import torch
import soundfile as sf
import numpy as np
from scipy.signal import resample
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import model_compression_toolkit as mct
from AASIST import Model
from utils import str_to_bool

/content/aasist


In [4]:
# 3) Audio loading & dataset
def load_audio(path: str, target_sr: int = 16000) -> np.ndarray:
    wav, sr = sf.read(path)
    if sr != target_sr:
        num = int(len(wav) * target_sr / sr)
        wav = resample(wav, num)
    return wav.astype(np.float32)

class AudioFolderDataset(Dataset):
    """All .flac files under audio_dir."""
    def __init__(self, root_dir, target_sr=16000):
        self.paths = sorted([
            os.path.join(root_dir, fn)
            for fn in os.listdir(root_dir)
            if fn.lower().endswith('.flac')
        ])
        self.target_sr = target_sr

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        wav = load_audio(self.paths[idx], self.target_sr)
        return torch.from_numpy(wav), self.paths[idx]

def representative_dataset_gen(dataset, batch_size=1, n_iter=20):
    """Yield lists of numpy arrays for MCT."""
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    it = iter(loader)
    for _ in range(n_iter):
        try:
            batch, _ = next(it)
        except StopIteration:
            it = iter(loader)
            batch, _ = next(it)
        yield [batch.cpu().numpy()]

def evaluate(model, dataloader, device):
    """Run inference and print per‐file predictions."""
    model.to(device).eval()
    with torch.no_grad():
        for x, path in tqdm(dataloader, desc="Eval"):
            x = x.to(device)
            _, logits = model(x, Freq_aug=False)
            probs = torch.softmax(logits, dim=-1)
            pred = probs.argmax(dim=-1).item()
            label = {0: 'Spoof', 1: 'Bona fide'}.get(pred, 'Unknown')
            print(f"{os.path.basename(path[0])}: {label} ({probs[0,pred]:.3f})")

In [5]:
# 4) Load config, instantiate & load checkpoint
with open(config_path, 'r') as f:
    cfg = json.load(f)
d_args = cfg['model_config']

if not isinstance(d_args['filts'][0], int):
    raise ValueError("d_args['filts'][0] must be int")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

model = Model(d_args).to(device)
ckpt = torch.load(checkpoint_path, map_location=device)
sd = ckpt.get('state_dict', ckpt)
res = model.load_state_dict(sd, strict=False)
if res.unexpected_keys: print("Unexpected keys:", res.unexpected_keys)
if res.missing_keys:    print("Missing keys:",    res.missing_keys)

model.eval()
torch.set_grad_enabled(False)

Device: cpu


<torch.autograd.grad_mode.set_grad_enabled at 0x7ed01e8b9310>

In [6]:
# 5) Prepare datasets
dataset     = AudioFolderDataset(audio_dir, target_sr=16000)
rep_gen     = lambda: representative_dataset_gen(dataset, batch_size, n_iter)
eval_loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [7]:
# 6) GPTQ config & target platform
gptq_config = mct.gptq.get_pytorch_gptq_config(n_epochs=50)
tpc = mct.get_target_platform_capabilities(
    'pytorch',
    'tflite',
    target_platform_version="v1"
)

In [8]:
# 7) Run quantization
print("Starting GPTQ quantization…")
q_model, q_info = mct.gptq.pytorch_gradient_post_training_quantization(
    model,
    rep_gen,
    gptq_config=gptq_config,
    target_platform_capabilities=tpc
)
print("Quantization done.")

CRITICAL:Model Compression Toolkit:Error parsing model with torch.fx
fx error: symbolically traced variables cannot be used as inputs to control flow


Starting GPTQ quantization…


Exception: Error parsing model with torch.fx
fx error: symbolically traced variables cannot be used as inputs to control flow

In [None]:
# 8) Evaluate original vs. quantized
print("\n■ Floating-point model:")
evaluate(model,     eval_loader, device)
print("\n■ Quantized model:")
evaluate(q_model, eval_loader, device)

In [None]:
# 9) Export quantized model to ONNX
print("\nExporting to qmodel.onnx …")
mct.exporter.pytorch_export_model(
    q_model,
    save_model_path='qmodel.onnx',
    repr_dataset=rep_gen
)
print("Export complete.")