In [109]:
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import pyaudio
import wave
import sys
import os
import pickle
import torch
import librosa
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchvision.models as models
from transformers import HubertModel, AutoProcessor
import torchaudio

In [41]:
class LayerFeatureDataset(Dataset):
    def __init__(self, pkl_path):
        data = pickle.load(open(pkl_path, 'rb'))
        self.X = torch.tensor(data['features'], dtype=torch.float32)  
        self.y = torch.tensor(data['labels'],   dtype=torch.long)    
    def __len__(self):
        return len(self.y)
    def __getitem__(self, i):
        return self.X[i], self.y[i]

class LayerWeightedAggregator(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.w = nn.Parameter(torch.ones(num_layers) / num_layers)
    def forward(self, x):
        w = torch.softmax(self.w, dim=0)             
        return (x * w[None, :, None]).sum(dim=1)

class ResNetClassifier(nn.Module):
    def __init__(self, num_layers, hidden_dim, num_classes):
        super().__init__()
        self.agg = LayerWeightedAggregator(num_layers)
        H = W = int(np.sqrt(hidden_dim))
        assert H*W == hidden_dim, "hidden_dim must be square"
        self.H, self.W = H, W

        self.resnet = models.resnet18(pretrained=True)
        self.resnet.conv1 = nn.Conv2d(1,
            self.resnet.conv1.out_channels,
            kernel_size=self.resnet.conv1.kernel_size,
            stride=self.resnet.conv1.stride,
            padding=self.resnet.conv1.padding,
            bias=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        x = self.agg(x)               
        b = x.size(0)
        x = x.view(b, 1, self.H, self.W)
        return self.resnet(x)


In [94]:
pkl_path = "../Features/Opt_Features/ASV_Opt_Dataset.pkl"
dataset = LayerFeatureDataset(pkl_path)
processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNetClassifier(num_layers, hidden_dim, num_classes)
model.load_state_dict(torch.load("model.pth"))
model.eval()
hub = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft", output_hidden_states=True).to(device)
hub.eval()
print("Model loaded successfully")

num_layers  = dataset.X.size(1)
hidden_dim  = dataset.X.size(2)
num_classes = len(torch.unique(dataset.y))

with open("ASV_sample.pkl", "rb") as f:
    data = pickle.load(f)

print("Sampled Files:", data["files"])
print("Sampled Labels:", data["labels"])

audio_files = data["files"]
audio_labels = data["labels"]

Model loaded successfully
Sampled Files: ['../ASV/Artifact/LA_E_9464037_clipping.wav', '../ASV/Artifact/LA_E_8206846_filter.wav', '../ASV/Artifact/LA_E_4967785_noise.wav', '../ASV/Artifact/LA_E_6782766_reverb.wav', '../ASV/Data/LA_E_5324584.flac', '../ASV/Artifact/LA_E_5590452_filter.wav', '../ASV/Artifact/LA_E_8964992_clipping.wav', '../ASV/Artifact/LA_E_5986045_compression.wav', '../ASV/Artifact/LA_E_9835790_reverb.wav', '../ASV/Artifact/LA_E_9715080_filter.wav']
Sampled Labels: [1, 0, 0, 0, 1, 1, 1, 1, 0, 0]


In [107]:
def load_audio(filename):
    waveform, sample_rate = torchaudio.load(filename)
    waveform = waveform.to(device)
    if sample_rate != 16000:
        waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
    if waveform.size(0) > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    return waveform.squeeze(0)


def extract_features(waveform):
    with torch.no_grad():
        inputs = processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
        outputs = hub(**inputs)
        hidden_states = outputs.hidden_states[2:6]  # layers 3rd to 6th
        pooled = [hs.mean(dim=1).squeeze(0).cpu().numpy() for hs in hidden_states]
        features = np.stack(pooled, axis=0)  
        features = torch.tensor(features, dtype=torch.float32)
    return features

sys.stderr = open(os.devnull, 'w')
def play_audio(file_path):
    try:
        chunk = 1024
        wf = wave.open(file_path, 'rb')
        p = pyaudio.PyAudio()

        # Open stream
        stream = p.open(format=p.get_format_from_width(wf.getsampwidth()),
                        channels=wf.getnchannels(),
                        rate=wf.getframerate(),
                        output=True)

        # Read data
        data = wf.readframes(chunk)

        # Play stream
        while data:
            stream.write(data)
            data = wf.readframes(chunk)

        # Close everything
        stream.stop_stream()
        stream.close()
        p.terminate()

    except Exception as e:
        print(f"Error playing audio: {e}")

In [110]:
class FakeAudioApp:
    def __init__(self, master):
        self.master = master
        master.title("Fake Audio Detector")
        master.geometry("700x450")              # Larger window
        master.configure(bg="#f0f0f0")          # Light gray background

        # Use a ttk style
        style = ttk.Style(master)
        style.theme_use('clam')
        style.configure('TLabel', background='#f0f0f0', font=('Helvetica', 12))
        style.configure('Header.TLabel', font=('Helvetica', 16, 'bold'))
        style.configure('TButton', font=('Helvetica', 12), padding=6)

        # Header
        header = ttk.Label(master, text="Fake Audio Detector", style='Header.TLabel')
        header.pack(pady=(20, 10))

        # Frame for controls
        controls = ttk.Frame(master, padding=20, style='TFrame')
        controls.pack(fill='x', expand=False)

        ttk.Label(controls, text="Choose a number (1-10):").grid(row=0, column=0, sticky='w')
        self.number_var = tk.StringVar(value="1")
        number_menu = ttk.Combobox(controls, textvariable=self.number_var,
                                   values=[str(i) for i in range(1,11)],
                                   state='readonly', width=5, font=('Helvetica',12))
        number_menu.grid(row=0, column=1, padx=(10,0), sticky='w')

        btn_frame = ttk.Frame(master, padding=(20,10))
        btn_frame.pack(fill='x', expand=False)

        self.play_btn = ttk.Button(btn_frame, text="► Play", command=self.play_selected, width=15)
        self.play_btn.grid(row=0, column=0, padx=10)

        self.classify_btn = ttk.Button(btn_frame, text="✔ Classify", command=self.classify_selected, width=15)
        self.classify_btn.grid(row=0, column=1, padx=10)

        sep = ttk.Separator(master, orient='horizontal')
        sep.pack(fill='x', pady=10)

        result_frame = ttk.Frame(master, padding=20)
        result_frame.pack(fill='both', expand=True)

        ttk.Label(result_frame, text="Result:", font=('Helvetica', 14)).grid(row=0, column=0, sticky='nw')
        self.result_lbl = ttk.Label(result_frame, text="–", font=('Helvetica', 24, 'bold'), foreground='#007acc')
        self.result_lbl.grid(row=1, column=0, sticky='n')

        for child in controls.winfo_children():
            child.grid_configure(pady=5)
        btn_frame.grid_columnconfigure(0, weight=1)
        btn_frame.grid_columnconfigure(1, weight=1)

        self.selected_file = None

    def play_selected(self):
        try:
            selected_number = int(self.number_var.get())
            self.selected_file = audio_files[selected_number - 1]

            play_audio(self.selected_file)
        except Exception as e:
            print("Error", f"An error occurred during playback:\n{e}")

    def classify_selected(self):
        try:
            selected_number = int(self.number_var.get())
            print("Selected number:", selected_number)
            self.selected_file = audio_files[selected_number - 1]
            print("Selected file:", self.selected_file)
            
            waveform = load_audio(self.selected_file)
            if waveform is None:
                return

            features = extract_features(waveform)

            with torch.no_grad():
                outputs = model(features)
                preds = torch.argmax(outputs, dim=1).item()
                label = "Fake" if preds == 1 else "Real"

            self.result_lbl.config(text=f"Result: {label}")

        except Exception as e:
            print("Error", f"An error occurred during classification:\n{e}")

In [111]:
if __name__ == "__main__":
    root = tk.Tk()
    app = FakeAudioApp(root)
    root.mainloop()

Selected number: 1
Selected file: ../ASV/Artifact/LA_E_9464037_clipping.wav
Selected number: 7
Selected file: ../ASV/Artifact/LA_E_8964992_clipping.wav
Selected number: 10
Selected file: ../ASV/Artifact/LA_E_9715080_filter.wav
Selected number: 2
Selected file: ../ASV/Artifact/LA_E_8206846_filter.wav
Selected number: 7
Selected file: ../ASV/Artifact/LA_E_8964992_clipping.wav
