In [1]:
import torch
import json
import sys

sys.path.append("D:/Work/JupytorWorkspace/AASIST")

from aasist.models.AASIST import Model

with open("D:/Work/JupytorWorkspace/AASIST/aasist/config/AASIST.conf", "r") as f:
    full_config = json.load(f)

model_config = full_config["model_config"]
model_path = full_config["model_path"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model(model_config)

state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model = model.to(device).eval()

print("AASIST model loaded on", device)

AASIST model loaded on cuda


In [2]:
import os
import librosa
import torch
from torch.utils.data import Dataset

class ASVspoof2021LAEvalDataset(Dataset):
    def __init__(self, protocol_path, flac_dir, nb_samples=64600):
        self.nb_samples = nb_samples
        self.flac_dir = flac_dir
        self.entries = []

        with open(protocol_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 1:
                    utt_id = parts[0]
                    flac_path = os.path.join(self.flac_dir, f"{utt_id}.flac")
                    if os.path.exists(flac_path):
                        self.entries.append(utt_id)
                    else:
                        print(f"Missing file: {flac_path}")

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, index):
        utt_id = self.entries[index]
        flac_path = os.path.join(self.flac_dir, f"{utt_id}.flac")

        try:
            waveform, sr = librosa.load(flac_path, sr=16000)
            waveform = torch.tensor(waveform, dtype=torch.float32)

            if waveform.dim() > 1:
                waveform = waveform.mean(dim=1)

            if waveform.size(0) < self.nb_samples:
                padding = self.nb_samples - waveform.size(0)
                waveform = torch.nn.functional.pad(waveform, (0, padding))
            else:
                waveform = waveform[:self.nb_samples]

            waveform = waveform.unsqueeze(0)  # [1, T]
            return waveform, utt_id

        except Exception as e:
            print(f"Error loading {flac_path}: {e}")
            return None

In [3]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import warnings
warnings.filterwarnings("ignore")

protocol_path = r"C:\Users\DaysPC\Documents\Datasets\ASVspoof2021\ASVspoof2021_LA_eval\ASVspoof2021_LA_eval\ASVspoof2021.LA.cm.eval.trl.txt"
flac_dir = r"C:\Users\DaysPC\Documents\Datasets\ASVspoof2021\ASVspoof2021_LA_eval\ASVspoof2021_LA_eval\flac"

dataset = ASVspoof2021LAEvalDataset(protocol_path, flac_dir)
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

results = {}

for batch in tqdm(loader, desc="Evaluating", unit="file"):
    if batch is None or batch[0] is None:
        continue

    waveform, utt_id = batch
    waveform = waveform.squeeze(0).to(device)

    with torch.no_grad():
        out = model(waveform)
        score = out[0].mean().item()
        results[utt_id[0]] = score

with open("aasist_eval_results.json", "w") as f:
    json.dump(results, f, indent=2)

print(f"Saved {len(results)} scores to aasist_eval_results.json")

Evaluating: 100%|██████████| 181566/181566 [5:27:08<00:00,  9.25file/s]  


Saved 181566 scores to aasist_eval_results.json
