# Digital Audio Forensics using AI - Detecting Deepfake Audio

## Envirtonment Setup

In [None]:
#Following libraries are needed for thie notebook

!pip install torch
!pip install torchaudio
!pip install librosa
!pip install numpy
!pip install matplotlib
!pip install tqdm
!pip install IPython
!pip install torchvision


## Importing Libraries

In [None]:
import numpy as np
import os
import librosa
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import IPython
import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split, ConcatDataset
from torchvision import transforms
from torchaudio import transforms as AT
from torch.nn import functional as F
from torch import flatten
from torch import nn
import torch.optim as optim
import torchaudio
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, roc_curve

## Dataset used:

**The following data set has been used:**

[The Fake-or-Real (FoR) Dataset](https://www.kaggle.com/datasets/mohammedabdeldayem/the-fake-or-real-dataset)

A private dataset has also been used that was made using ElevenLabs text to speech model.


## Dataset creation class:

In [None]:
"""Create a DataSet object from the Fake or Real Dataset. Return the MFCC of audio file along with
the corresponding label. 0:Fake and 1:Real"""

n_fft = 2048
win_length = None
hop_length = 512
n_mels = 256

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None, target_sample_rate=16000,target_time_length = 972):
        super().__init__()
        self.root_dir = root_dir
        self.transform = transform
        self.target_sample_rate = target_sample_rate
        self.target_time_length = target_time_length
        self.classes = os.listdir(root_dir)
        self.file_paths = []
        self.labels = []

        #class_idx: {0:fake, 1:real}
        for class_idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for file_name in tqdm(os.listdir(class_dir)):
                file_path = os.path.join(class_dir, file_name)
                self.file_paths.append(file_path)
                self.labels.append(class_idx)

    def __len__(self):
        return len(self.file_paths)

    def mix_down_if_necessary(self, signal): #to convert from stereo to mono
        if signal.shape[0] > 1:
            signal = torch.mean(signal, dim = 0, keepdims = True)
        return signal

    def __getitem__(self, idx):
        try:
            file_path = self.file_paths[idx]
            label = self.labels[idx]

            # Load audio file using torchaudio
            waveform, sample_rate = torchaudio.load(file_path)
            waveform = self.mix_down_if_necessary(waveform)
            if sample_rate != self.target_sample_rate:
                resampler_transform = AT.Resample(sample_rate,self.target_sample_rate)
                waveform = resampler_transform(waveform)
                sample_rate = self.target_sample_rate

            # Apply transformation to get MFCC
            if self.transform:
                mfcc = self.transform(waveform)
                if mfcc.shape[2] != self.target_time_length:
                    mfcc = torch.nn.functional.pad(mfcc, (0, self.target_time_length - mfcc.shape[2], 0, 0), mode='constant')
                else:
                    mfcc = mfcc[:, :self.target_time_length]

            return mfcc,label

        except Exception as e:
            return self.__getitem__(idx + 1)


#mfcc transform that gives us the MFCC tensor
mfcc_transform = AT.MFCC(
    sample_rate=16000,
    n_mfcc=40,
    melkwargs={
        "n_fft": n_fft,
        "n_mels": n_mels,
        "hop_length": hop_length,
        "mel_scale": "htk",
    },
)


## Model Class:

In [None]:
#nueral network model class
class ShallowCNN(nn.Module):
    def __init__(self, in_features = 1, out_dim=1, **kwargs):
        super(ShallowCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_features, 32, kernel_size=4, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 48, kernel_size=5, stride=1, padding=1)
        self.conv3 = nn.Conv2d(48, 64, kernel_size=4, stride=1, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=(2, 4), stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(15104, 128)
        self.fc2 = nn.Linear(128, out_dim)

    def forward(self, x: torch.Tensor):
        x = x.unsqueeze(1)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Creating Dataset objects:

In [None]:
dataset_fake_or_real_for_2sec_testing = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-2sec/for-2seconds/testing',transform=mfcc_transform)
dataset_fake_or_real_for_2sec_training = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-2sec/for-2seconds/training',transform=mfcc_transform)
dataset_fake_or_real_for_2sec_validation = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-2sec/for-2seconds/validation',transform=mfcc_transform)

dataset_fake_or_real_for_norm_testing = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-norm/for-norm/testing',transform=mfcc_transform)
dataset_fake_or_real_for_norm_training = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-norm/for-norm/training',transform=mfcc_transform)
dataset_fake_or_real_for_norm_validation = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-norm/for-norm/validation',transform=mfcc_transform)

dataset_fake_or_real_for_original_testing = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-original/for-original/testing',transform=mfcc_transform)
dataset_fake_or_real_for_original_training = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-original/for-original/training',transform=mfcc_transform)
dataset_fake_or_real_for_original_validation = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-original/for-original/validation',transform=mfcc_transform)

dataset_fake_or_real_for_rerec_testing = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-rerec/for-rerecorded/testing',transform=mfcc_transform)
dataset_fake_or_real_for_rerec_training = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-rerec/for-rerecorded/training',transform=mfcc_transform)
dataset_fake_or_real_for_rerec_validation = CustomDataset(root_dir = '/kaggle/input/the-fake-or-real-dataset/for-rerec/for-rerecorded/validation',transform=mfcc_transform)


dataset_private_train = CustomDataset(root_dir = '/kaggle/input/dataset/trainData',transform=mfcc_transform)
dataset_private_test = CustomDataset(root_dir = '/kaggle/input/dataset/testData',transform=mfcc_transform)

prototype_phase_data = CustomDataset(root_dir = '/kaggle/input/privatedata/Prototype Assignment', transform = mfcc_transform)

dataset_train = ConcatDataset([
    dataset_fake_or_real_for_2sec_training,
    dataset_fake_or_real_for_norm_training,
    dataset_fake_or_real_for_original_training,
    dataset_fake_or_real_for_rerec_training,
    dataset_private_train
])

dataset_val = ConcatDataset([
    dataset_fake_or_real_for_2sec_validation,
    dataset_fake_or_real_for_norm_validation,
    dataset_fake_or_real_for_original_validation,
    dataset_fake_or_real_for_rerec_validation,

])

dataset_test = ConcatDataset([
    dataset_fake_or_real_for_2sec_testing,
    dataset_fake_or_real_for_norm_testing,
    dataset_fake_or_real_for_original_testing,
    dataset_fake_or_real_for_rerec_testing,
    dataset_private_test

])

## Creating iterable Dataloader object from datasets:

In [None]:
train_loader = torch.utils.data.DataLoader(dataset_train, shuffle = True, batch_size = 32, num_workers=4)
val_loader = torch.utils.data.DataLoader(dataset_val, shuffle = True, batch_size = 32,num_workers=4)
test_loader = torch.utils.data.DataLoader(dataset_test, shuffle = True, batch_size = 32,num_workers=4)

## Graphs of different Audio features:

In [None]:
#function to plot audio waveform
def plot_waveform(waveform, label, sr):
    waveform = waveform.detach().numpy()
    plt.figure(figsize= (6,2))
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.plot(waveform)
    plt.title('Real Audio Waveform' if label == 1 else 'Fake Audio Waveform')
    plt.show()

#fucntion to plot Mel Spectrogram of Audio file
def plot_spectrogram(waveform,label, sr):
    waveform = waveform.detach().numpy()
    n_fft = 2048
    hop_length = 512
    win_length = 2048
    n_mels = 128

    spectrogram = librosa.feature.melspectrogram(y=waveform, sr=sr, n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mels=n_mels)
    log_spectrogram = librosa.power_to_db(spectrogram)

    plt.figure(figsize=(10, 4))
    plt.imshow(log_spectrogram, cmap='viridis', aspect='auto', origin='lower')
    plt.colorbar(format='%+2.0f dB')
    plt.title('Real Audio Spectrogram' if label == 1 else 'Fake Audio Spectrogram')
    plt.xlabel('Time')
    plt.ylabel('Frequency')
    plt.show()

#function to plot MFCC of audio file
def plot_mfcc(waveform,label,sr):
    #mfcc transform that gives us the MFCC tensor
    n_fft = 2048
    win_length = None
    hop_length = 512
    n_mels = 256
    mfcc_transform = AT.MFCC(
        sample_rate=16000,
        n_mfcc=40,
        melkwargs={
            "n_fft": n_fft,
            "n_mels": n_mels,
            "hop_length": hop_length,
            "mel_scale": "htk",
        },
    )
    mfcc = mfcc_transform(waveform)
    if mfcc.shape[1] != 972:
        mfcc = torch.nn.functional.pad(mfcc, (0, 972 - mfcc.shape[1], 0, 0), mode='constant')
    else:
        mfcc = mfcc[:, :972]
    mfcc = mfcc.detach().numpy()
    log_mfcc = librosa.power_to_db(mfcc)

    plt.figure(figsize=(10, 4))
    plt.imshow(log_mfcc, cmap='viridis', aspect='auto', origin='lower')
    plt.colorbar(format='%+2.0f dB')
    plt.title('Real Audio MFCC' if label == 1 else 'Fake Audio MFCC')
    plt.xlabel('Time')
    plt.ylabel('MFCC')
    plt.show()

def plot_graphs(file_path,label): #0 for fake and 1 for real
    waveform, sample_rate = torchaudio.load(file_path)
    #to convert stereo to mono
    if waveform.shape[0] > 1:
        waveform = torch.mean(signal, dim = 0, keepdims = True)

    #calling the ploting functions
    plot_waveform(waveform[0],label,16000)
    plot_spectrogram(waveform[0],label,16000)
    plot_mfcc(waveform[0],label,16000)

file_path = "/kaggle/input/privatedata/Prototype Assignment/fake/Fake_1.wav"
plot_graphs(file_path,0)

## Training setup:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#clear cache befor starting the training
torch.cuda.empty_cache()
model = ShallowCNN(in_features=1, out_dim=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

num_epochs = 5

In [None]:
# Training and validation loop
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    loop = tqdm(enumerate(train_loader), total=len(train_loader))
    for batch_idx, (batch_mfccs, batch_labels) in loop:
        loop.set_description(f'Epoch {epoch + 1} / {num_epochs}')
        batch_mfccs, batch_labels = batch_mfccs.to(device), batch_labels.unsqueeze(1).type(torch.float32).to(device)

        # Zero the gradients
        optimizer.zero_grad()
        batch_mfccs = batch_mfccs.squeeze(1)

        # Forward pass
        outputs = model(batch_mfccs)

        # Compute loss
        loss = criterion(outputs, batch_labels)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

        # Accumulate training loss
        train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (batch_mfccs, batch_labels) in enumerate(val_loader):
            batch_mfccs, batch_labels = batch_mfccs.to(device), batch_labels.unsqueeze(1).type(torch.float32).to(device)

            # Forward pass
            batch_mfccs = batch_mfccs.squeeze(1)
            outputs = model(batch_mfccs)

            # Compute loss
            loss = criterion(outputs, batch_labels)

            # Accumulate validation loss
            val_loss += loss.item()
            batch_pred = (torch.sigmoid(outputs) + 0.5).int()
            # Compute accuracy
            total += batch_labels.size(0)
            correct += (batch_pred == batch_labels).sum().item()

    # Print epoch statistics
    train_loss /= len(train_loader)
    val_loss /= len(val_loader)
    val_accuracy = correct / total
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2%}")

print("Training finished.")



## Saving the model weights in a file:

In [None]:
model_path = 'FinalDeliverableModel.pth'
torch.save(model.state_dict(), model_path)
print(f"Model saved to '{model_path}'.")

## Testing the model:

In [None]:
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for batch_idx, (batch_mfccs, batch_labels) in tqdm(enumerate(test_loader)):
        batch_mfccs, batch_labels = batch_mfccs.to(device), batch_labels.unsqueeze(1).type(torch.float32).to(device)

        # Forward pass
        batch_mfccs = batch_mfccs.squeeze(1)
        outputs = model(batch_mfccs)

        # Compute loss
        loss = criterion(outputs, batch_labels)

        # Accumulate validation loss
        val_loss += loss.item()
        batch_pred = (torch.sigmoid(outputs) + 0.5).int()
        # Compute accuracy
        total += batch_labels.size(0)
        correct += (batch_pred == batch_labels).sum().item()
print(f"Loss is {val_loss}")
print(correct)
print(total)
print(f"Accuracy is {correct/total}")

## Testing the model on prototype phase data:

In [None]:
prototype_phase_test_loader = torch.utils.data.DataLoader(prototype_phase_data, batch_size = 1)

def compute_eer(y_true, y_scores):
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    fnr = 1 - tpr
    eer_threshold = thresholds[np.nanargmin(np.absolute((fnr - fpr)))]
    eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
    return eer, eer_threshold

def inference_function(prototype_phase_test_loader):
    model.eval()
    all_labels = []
    all_predictions = []
    all_scores = []

    with torch.no_grad():
        for batch_idx, (batch_mfccs, batch_labels) in tqdm(enumerate(prototype_phase_test_loader)):
            batch_mfccs, batch_labels = batch_mfccs.to(device), batch_labels.unsqueeze(1).type(torch.float32).to(device)

            # Forward pass
            batch_mfccs = batch_mfccs.squeeze(1)
            outputs = model(batch_mfccs)
#             print(outputs.item())
            scores = torch.sigmoid(outputs).cpu().numpy()
            predicted = (torch.sigmoid(outputs) + 0.5).int()
    
            all_labels.extend(batch_labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
            all_scores.extend(scores)

  
            #print(f"For {batch_idx} we have the label {batch_labels.item()} and the predicted output is {predicted.item()}")
    
    #conver lists to numpy arrays
    all_labels = np.array(all_labels)
    all_predictions = np.array(all_predictions)
    all_scores = np.array(all_scores)
    print(len(all_scores))
    print(len(all_labels))
    if len(all_labels) == 0 or len(all_scores) == 0:
        raise ValueError("No labels or scores were collected. Check the data loading and model inference steps.")


    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions)
    recall = recall_score(all_labels, all_predictions)
    f1 = f1_score(all_labels, all_predictions)
    roc_auc = roc_auc_score(all_labels, all_predictions)
    conf_matrix = confusion_matrix(all_labels, all_predictions)
    eer, eer_threshold = compute_eer(all_labels, all_scores)

    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1 Score: {f1:.4f}')
    print(f'ROC AUC: {roc_auc:.4f}')
    print(f'EER: {eer:.4f}')
    print('Confusion Matrix:')
    print(conf_matrix)
    
    #plot Confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=['Fake', 'Real'], yticklabels=['Fake', 'Real'])
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.show()
    
    # Plot ROC curve
    fpr, tpr, _ = roc_curve(all_labels, all_scores)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, marker='.')
    plt.plot([0, 1], [0, 1], linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.show()
    
    precision, recall, _ = precision_recall_curve(all_labels, all_scores)
    average_precision = average_precision_score(all_labels, all_scores)

    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, marker='.')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Precision-Recall Curve (AP = {average_precision:.2f})')
    plt.show()
inference_function(test_loader)

## Inference on a single file:

In [None]:
'''For inference on a single file make sure the weights files is in the same directory as the notebook.
 And run the first two cells to import libraries and the Model class cell to define the model'''
# from google.colab import files


model = ShallowCNN(in_features=1, out_dim=1).to(device)
state_dict = torch.load('/kaggle/input/fyp_final_deliverable_model_weights/pytorch/model.pth/1/FinalDeliverableModel_bc200401244.pth')
model.load_state_dict(state_dict)
print("model loaded Successfuly!")
n_fft = 2048
win_length = None
hop_length = 512
n_mels = 256

#mfcc transform that gives us the MFCC tensor
mfcc_transform = AT.MFCC(
    sample_rate=16000,
    n_mfcc=40,
    melkwargs={
        "n_fft": n_fft,
        "n_mels": n_mels,
        "hop_length": hop_length,
        "mel_scale": "htk",
    },
)

#Function to preprocess audio data:
def preprocess_audio(file_path,transform = mfcc_transform, target_time_length = 972):
    waveform, sample_rate = torchaudio.load(file_path)
    waveform = mix_down_if_necessary(waveform)
    if sample_rate != 16000:
        resampler_transform = AT.Resample(sample_rate,16000)
        waveform = resampler_transform(waveform)
        sample_rate = 16000

    # Apply MFCC transformation
    if transform:
        mfcc = transform(waveform)
        if mfcc.shape[2] != target_time_length:
            mfcc = torch.nn.functional.pad(mfcc, (0, target_time_length - mfcc.shape[2], 0, 0), mode='constant')
        else:
            mfcc = mfcc[:, :target_time_length]

    return mfcc

def mix_down_if_necessary(signal): #converting from stereo to mono
    if signal.shape[0] > 1:
        signal = torch.mean(signal, dim = 0, keepdims = True)
    return signal

# Prompt user to upload audio file. NOTE: Below code is only for google colab.
# print("Upload an audio file:")
# uploaded = files.upload()
# file_path = next(iter(uploaded.keys()))
file_path = '/kaggle/input/privatedata/Prototype Assignment/fake/Fake_9.wav'
# Preprocess the provided file
mfccs_tensor = preprocess_audio(file_path)
mfccs_tensor = mfccs_tensor.cuda()

# Perform inference
model.eval()
with torch.no_grad():
    output = model(mfccs_tensor)
    print(output.item())

    predicted = (torch.sigmoid(output) + 0.5).int()

# Display prediction
if predicted.item() == 1:
    print(f"The model predicts that file is real.")
else:
    print(f"The model predicts that file is AI-generated.")
