# Deepfake Audio Detection Demo

This notebook demonstrates a lightweight version of a hybrid deepfake audio detection system using Retrieval-Augmented Detection (RADD), GANs, and VAEs. It’s designed for a quick demo, processing only 10 real and 10 fake `.flac` files from the ASVspoof2019 LA dataset, training briefly, and running real-time detection for 10 seconds.

## Prerequisites
- **Dataset**: ASVspoof2019 LA (`ASVspoof2019_LA_train/flac/` and `ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt`) in the project directory.
- **Environment**: Python 3.9 with dependencies installed (`librosa`, `numpy`, `faiss-cpu`, `transformers`, `torch`, `tensorflow`, `pyaudio`).
- **Hardware**: Microphone for real-time detection.


### Step 1: Import Libraries and Define Constants

In [None]:
# Cell 1: Import Libraries and Define Constants
import os
import librosa
import numpy as np
import glob
import faiss
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torch
from tensorflow.keras import layers, models
import tensorflow as tf
import pyaudio
import mysql.connector
import hashlib
import logging
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
SAMPLE_RATE = 16000
CHUNK_SIZE = 2048  # Increased to match n_fft
LATENT_DIM = 100
SPECTROGRAM_SHAPE = (1025, 87)  # 1-second audio at 16kHz with n_fft=2048
DATA_DIR = 'flac'
DB_CONFIG = {
    "host": "mysql-1af4031e-audiodeepfake.h.aivencloud.com",
    "port": 12094,
    "user": "avnadmin",
    "password": "AVNS_qweb09E825R-UKRWQBR",
    "database": "defaultdb"
}
DEMO_FILES = 100  #for demonstration

### Step 2: Load and Preprocess Audio Data
Load 100 real (bonafide) and 100 fake (spoof) .flac files using the protocol file to classify them.

In [11]:
# Cell 2: Load and Preprocess Audio Data
def load_and_preprocess_audio(directory, label, max_files=DEMO_FILES, protocol_file=None):
    audio_files = glob.glob(os.path.join(directory, '*.flac'))
    spectrograms, labels = [], []
    if protocol_file and os.path.exists(protocol_file):
        with open(protocol_file, 'r') as f:
            protocol = {line.split()[1]: line.split()[4] for line in f.readlines()[1:]}
    else:
        logging.warning("Protocol file not found; assuming all files match the label.")
        protocol = None
    
    file_count = 0
    for file in audio_files:
        if file_count >= max_files:
            break
        try:
            filename = os.path.basename(file).replace('.flac', '')
            if protocol and filename not in protocol:
                continue
            is_real = 0 if protocol and protocol[filename] == 'bonafide' else 1
            if protocol and is_real != label:
                continue
            y, sr = librosa.load(file, sr=SAMPLE_RATE)
            S = librosa.stft(y, n_fft=2048, hop_length=512)
            S_db = librosa.amplitude_to_db(np.abs(S), ref=np.max)
            if S_db.shape != SPECTROGRAM_SHAPE:
                S_db = librosa.util.fix_length(S_db, size=SPECTROGRAM_SHAPE[1], axis=1)[:SPECTROGRAM_SHAPE[0], :]
            spectrograms.append(S_db)
            labels.append(label)
            file_count += 1
        except Exception as e:
            logging.error(f"Error processing {file}: {e}")
    return np.array(spectrograms), np.array(labels)

logging.info("Loading dataset for demo...")
protocol_path = 'ASVspoof2019.LA.cm.train.trn.txt'
real_spectrograms, real_labels = load_and_preprocess_audio(DATA_DIR, 0, protocol_file=protocol_path)
fake_spectrograms, fake_labels = load_and_preprocess_audio(DATA_DIR, 1, protocol_file=protocol_path)
all_spectrograms = np.concatenate([real_spectrograms, fake_spectrograms])
all_labels = np.concatenate([real_labels, fake_labels])
logging.info(f"Loaded {len(real_spectrograms)} real and {len(fake_spectrograms)} fake samples.")
logging.info(f"Real samples: {np.sum(all_labels == 0)}, Fake samples: {np.sum(all_labels == 1)}")

2025-04-05 22:31:13,237 - INFO - Loading dataset for demo...
2025-04-05 22:31:14,927 - INFO - Loaded 100 real and 100 fake samples.
2025-04-05 22:31:14,928 - INFO - Real samples: 100, Fake samples: 100


### Step 3: Set Up Retrieval-Augmented Detection (RADD)
Extract features with Wav2Vec2 and index them with FAISS for similarity retrieval.

In [12]:
# Cell 3: Set Up Retrieval-Augmented Detection (RADD)
processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base')
model = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base')

def extract_features(audio):
    inputs = processor(audio, return_tensors='pt', sampling_rate=SAMPLE_RATE, padding=True)
    with torch.no_grad():
        features = model(inputs.input_values).last_hidden_state
    return features.mean(dim=1).squeeze().numpy()

audio_samples = [librosa.istft(s) for s in all_spectrograms]
features = np.array([extract_features(audio) for audio in audio_samples])
index = faiss.IndexFlatL2(features.shape[1])
index.add(features)

def retrieve_similar(new_audio, k=5):
    new_features = extract_features(new_audio).reshape(1, -1)
    distances, indices = index.search(new_features, k)
    return distances, indices

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Step 4: Train GAN for Synthetic Deepfakes
Train a GAN with 50 epochs to generate synthetic deepfake samples.

In [13]:
# Cell 4: Train GAN for Synthetic Deepfakes
def build_generator():
    model = models.Sequential([
        layers.Input(shape=(LATENT_DIM,)),
        layers.Dense(256),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Dense(512),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Dense(np.prod(SPECTROGRAM_SHAPE), activation='tanh'),
        layers.Reshape(SPECTROGRAM_SHAPE)
    ])
    return model

def build_discriminator():
    model = models.Sequential([
        layers.Input(shape=SPECTROGRAM_SHAPE),
        layers.Flatten(),
        layers.Dense(512),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Dense(256),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002), loss='binary_crossentropy')
discriminator.trainable = False
gan_input = layers.Input(shape=(LATENT_DIM,))
gan_output = discriminator(generator(gan_input))
gan = models.Model(gan_input, gan_output)
gan.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002), loss='binary_crossentropy')

def train_gan(epochs=50, batch_size=8):
    for epoch in range(epochs):
        noise = np.random.normal(0, 1, (batch_size, LATENT_DIM))
        fake_spectrograms = generator.predict(noise, verbose=0)
        real_idx = np.random.randint(0, all_spectrograms.shape[0], batch_size)
        real_spectrograms = all_spectrograms[real_idx]
        X = np.concatenate([real_spectrograms, fake_spectrograms])
        y = np.array([0.9] * batch_size + [0.1] * batch_size)  # Label smoothing
        d_loss = discriminator.train_on_batch(X, y)
        noise = np.random.normal(0, 1, (batch_size, LATENT_DIM))
        y_gen = np.ones(batch_size) * 0.9  # Smooth generator labels
        g_loss = gan.train_on_batch(noise, y_gen)
        if epoch % 10 == 0:
            logging.info(f"GAN Epoch {epoch}: D Loss: {d_loss}, G Loss: {g_loss}")

logging.info("Training GAN for demo...")
train_gan()

2025-04-05 22:32:02,347 - INFO - Training GAN for demo...
2025-04-05 22:32:03,283 - INFO - GAN Epoch 0: D Loss: 1.2657653093338013, G Loss: 0.6925024390220642
2025-04-05 22:32:05,454 - INFO - GAN Epoch 10: D Loss: 2.823486328125, G Loss: 0.4016060531139374
2025-04-05 22:32:07,557 - INFO - GAN Epoch 20: D Loss: 2.843430995941162, G Loss: 0.37007731199264526
2025-04-05 22:32:09,650 - INFO - GAN Epoch 30: D Loss: 2.7539262771606445, G Loss: 0.35752567648887634
2025-04-05 22:32:11,786 - INFO - GAN Epoch 40: D Loss: 2.7970542907714844, G Loss: 0.3510686457157135


### Step 5: Train VAE for Data Augmentation
Train a VAE with 10 epochs to augment real audio data.

In [14]:
# Cell 5: Train VAE for Data Augmentation
import tensorflow.keras.backend as K

class VAE(models.Model):
    def __init__(self, spectrogram_shape, latent_dim, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.spectrogram_shape = spectrogram_shape
        self.latent_dim = latent_dim

        # Encoder
        self.encoder_inputs = layers.Input(shape=spectrogram_shape)
        x = layers.Flatten()(self.encoder_inputs)
        x = layers.Dense(512, activation='relu')(x)
        self.z_mean = layers.Dense(latent_dim, name='z_mean')(x)
        self.z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
        self.z = layers.Lambda(self._sampling, output_shape=(latent_dim,), name='z')([self.z_mean, self.z_log_var])
        self.encoder = models.Model(self.encoder_inputs, [self.z_mean, self.z_log_var, self.z], name='encoder')

        # Decoder
        self.decoder_inputs = layers.Input(shape=(latent_dim,))
        x = layers.Dense(512, activation='relu')(self.decoder_inputs)
        x = layers.Dense(np.prod(spectrogram_shape), activation='tanh')(x)
        self.decoder_outputs = layers.Reshape(spectrogram_shape)(x)
        self.decoder = models.Model(self.decoder_inputs, self.decoder_outputs, name='decoder')

        # VAE outputs
        self.outputs = self.decoder(self.encoder(self.encoder_inputs)[2])

    def _sampling(self, args):
        z_mean, z_log_var = args
        batch = K.shape(z_mean)[0]
        dim = K.shape(z_mean)[1]
        epsilon = K.random_normal(shape=(batch, dim))
        return z_mean + K.exp(0.5 * z_log_var) * epsilon

    def call(self, inputs, training=None):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        if training:
            reconstruction_loss = K.mean(K.square(inputs - reconstructed))
            kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var))
            self.add_loss(reconstruction_loss + kl_loss)
        return reconstructed

def build_vae():
    vae = VAE(spectrogram_shape=SPECTROGRAM_SHAPE, latent_dim=LATENT_DIM)
    vae.compile(optimizer='adam')
    encoder = vae.encoder
    decoder = vae.decoder
    return vae, encoder, decoder

vae, encoder, decoder = build_vae()
logging.info("Training VAE for demo...")
vae.fit(all_spectrograms[all_labels == 0], epochs=10, batch_size=8, verbose=0)

noise = np.random.normal(0, 1, (len(real_spectrograms), LATENT_DIM))
augmented_spectrograms = decoder.predict(noise, verbose=0)

2025-04-05 22:32:13,899 - INFO - Training VAE for demo...


### Step 6: Caching Mechanism
Cache retrieval results for efficiency.

In [15]:
# Cell 6: Caching Mechanism
cache = {}
def get_cached_retrieval(audio_hash, audio):
    if audio_hash not in cache:
        distances, indices = retrieve_similar(audio)
        cache[audio_hash] = (distances, indices)
    return cache[audio_hash]

### Step 7: Model Integration, Training, and Continuous Learning
Combine all data, define and train the detector with 50 epochs, and set up continuous learning.

In [16]:
# Cell 7: Model Integration, Training, and Continuous Learning
# Generate synthetic data
synthetic_spectrograms = generator.predict(np.random.normal(0, 1, (len(real_spectrograms), LATENT_DIM)), verbose=0)
all_data = np.concatenate([all_spectrograms, synthetic_spectrograms, augmented_spectrograms])
all_labels_extended = np.concatenate([all_labels, [1] * len(synthetic_spectrograms), [0] * len(augmented_spectrograms)])

# Split data with stratification
train_data, val_data, train_labels, val_labels = train_test_split(
    all_data, all_labels_extended, test_size=0.2, stratify=all_labels_extended, random_state=42
)
train_data = train_data[..., np.newaxis]
val_data = val_data[..., np.newaxis]

# Define and train detector
detector = models.Sequential([
    layers.Input(shape=SPECTROGRAM_SHAPE + (1,)),
    layers.Conv2D(32, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(1, activation='sigmoid')
])
detector.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

logging.info("Training detector for demo...")
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
detector.fit(train_data, train_labels, epochs=50, batch_size=8, validation_data=(val_data, val_labels), callbacks=[early_stopping], verbose=1)

# Continuous learning setup
conn = mysql.connector.connect(**DB_CONFIG)
cur = conn.cursor()
cur.execute('''
    CREATE TABLE IF NOT EXISTS samples (
        hash TEXT,
        spectrogram BLOB,
        label INT,
        PRIMARY KEY (hash(64))
    )
''')

def add_new_sample(spectrogram, label):
    spectrogram_blob = spectrogram.tobytes()
    hash_value = hashlib.md5(spectrogram_blob).hexdigest()
    cur.execute('INSERT INTO samples (hash, spectrogram, label) VALUES (%s, %s, %s)', 
                (hash_value, spectrogram_blob, label))
    conn.commit()

def retrain_model():
    cur.execute('SELECT spectrogram, label FROM samples')
    rows = cur.fetchall()
    new_spectrograms, new_labels = [], []
    for row in rows:
        spectrogram = np.frombuffer(row[0], dtype=np.float32).reshape(SPECTROGRAM_SHAPE)
        new_spectrograms.append(spectrogram)
        new_labels.append(row[1])
    if new_spectrograms:
        new_data = np.array(new_spectrograms)[..., np.newaxis]
        new_labels = np.array(new_labels)
        detector.fit(new_data, new_labels, epochs=2, batch_size=8, verbose=0)
        logging.info("Model retrained with new samples.")

2025-04-05 22:32:47,649 - INFO - Training detector for demo...


Epoch 1/50
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 228ms/step - accuracy: 0.5422 - loss: 0.9036 - val_accuracy: 0.5000 - val_loss: 0.6932
Epoch 2/50
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 229ms/step - accuracy: 0.4650 - loss: 0.6935 - val_accuracy: 0.5000 - val_loss: 0.6931
Epoch 3/50
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 226ms/step - accuracy: 0.5134 - loss: 0.6933 - val_accuracy: 0.5000 - val_loss: 0.6932
Epoch 4/50
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 222ms/step - accuracy: 0.5477 - loss: 0.6932 - val_accuracy: 0.5000 - val_loss: 0.6932
Epoch 5/50
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 223ms/step - accuracy: 0.4075 - loss: 0.6936 - val_accuracy: 0.5000 - val_loss: 0.6931
Epoch 6/50
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 225ms/step - accuracy: 0.4902 - loss: 0.6934 - val_accuracy: 0.5000 - val_loss: 0.6931
Epoch 7/50
[1m40/40[0m [

### Step 8: Real-Time Detection
Run real-time detection for 10 seconds using the microphone.

In [17]:
# Cell 8: Real-Time Detection
def real_time_detection():
    p = pyaudio.PyAudio()
    stream = p.open(format=pyaudio.paFloat32, channels=1, rate=SAMPLE_RATE, input=True, frames_per_buffer=CHUNK_SIZE)
    logging.info("Starting real-time detection (10 seconds for demo)...")
    try:
        for _ in range(int(10 * SAMPLE_RATE / CHUNK_SIZE)):
            try:
                data = stream.read(CHUNK_SIZE, exception_on_overflow=False)
                audio = np.frombuffer(data, dtype=np.float32)
                S = librosa.stft(audio, n_fft=2048, hop_length=512)
                S_db = librosa.amplitude_to_db(np.abs(S), ref=np.max)
                S_db = librosa.util.fix_length(S_db, size=SPECTROGRAM_SHAPE[1], axis=1)[:SPECTROGRAM_SHAPE[0], :]
                S_db = S_db[np.newaxis, ..., np.newaxis]
                prediction = detector.predict(S_db, verbose=0)[0][0]
                audio_hash = hashlib.md5(audio.tobytes()).hexdigest()
                distances, _ = get_cached_retrieval(audio_hash, audio)
                logging.info(f"Deepfake Probability: {prediction:.2f}, Retrieval Distance: {distances[0][0]:.2f}")
                if prediction > 0.7:  # Adjusted threshold
                    add_new_sample(S_db[0, ..., 0], 1)
            except Exception as e:
                logging.error(f"Error in stream: {e}")
    finally:
        stream.stop_stream()
        stream.close()
        p.terminate()

real_time_detection()

2025-04-05 22:34:54,366 - INFO - Starting real-time detection (10 seconds for demo)...




2025-04-05 22:34:54,703 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 65.57
2025-04-05 22:34:54,827 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 27.66
2025-04-05 22:34:54,964 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 24.49
2025-04-05 22:34:55,108 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 64.52
2025-04-05 22:34:55,247 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 23.16
2025-04-05 22:34:55,384 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 26.39
2025-04-05 22:34:55,524 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 70.46
2025-04-05 22:34:55,657 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 64.28
2025-04-05 22:34:55,800 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 78.86
2025-04-05 22:34:55,928 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 23.21
2025-04-05 22:34:56,074 - INFO - Deepfake Probability: 0.50, Retrieval Distance: 22.63
2025-04-05 22:34:56,214 - INFO - Deepfake P

## Step 9: Evaluation and Retraining
Evaluate the model and retrain with new samples.

In [18]:
# Cell 9: Evaluation and Retraining
logging.info(f"Validation real: {np.sum(val_labels == 0)}, Validation fake: {np.sum(val_labels == 1)}")
predictions = (detector.predict(val_data, verbose=0) > 0.5).astype(int)
logging.info(f"Accuracy: {accuracy_score(val_labels, predictions):.2f}")
logging.info(f"Precision: {precision_score(val_labels, predictions):.2f}")
logging.info(f"Recall: {recall_score(val_labels, predictions):.2f}")
logging.info(f"F1-Score: {f1_score(val_labels, predictions):.2f}")

# Test prediction on a known sample
test_sample = all_data[0][np.newaxis, ..., np.newaxis]
logging.info(f"Test prediction on first sample: {detector.predict(test_sample, verbose=0)[0][0]:.2f}")

retrain_model()

# Cleanup
conn.close()

2025-04-05 22:35:05,555 - INFO - Validation real: 40, Validation fake: 40
2025-04-05 22:35:05,941 - INFO - Accuracy: 0.50
2025-04-05 22:35:05,944 - INFO - Precision: 0.50
2025-04-05 22:35:05,946 - INFO - Recall: 1.00
2025-04-05 22:35:05,948 - INFO - F1-Score: 0.67
2025-04-05 22:35:06,014 - INFO - Test prediction on first sample: 0.50
