In [None]:
import os
import tqdm
import glob
import torch
import random
import librosa
import torchaudio
import numpy as np
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchaudio import transforms
from collections import namedtuple
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader, Dataset,TensorDataset
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

In [None]:
from transformers import (
                            AutoModelForAudioClassification,
                            EarlyStoppingCallback,
                            AutoFeatureExtractor,
                            TrainingArguments,
                            Trainer
)



from transformers import Wav2Vec2Processor, AutoProcessor
from transformers import AutoProcessor, Wav2Vec2Model,Wav2Vec2FeatureExtractor

In [None]:
# Get cpu, gpu or mps device for training.
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {DEVICE} device")


Data exploration

In [None]:
Info = namedtuple("Info", ["length", "sample_rate", "channels"])
def get_audio_info(path: str):
    info = torchaudio.info(path)
    if hasattr(info, 'num_frames'):
        return Info(info.num_frames, info.sample_rate, info.num_channels)
    else:
        siginfo = info[0]
        return Info(siginfo.length // siginfo.channels, siginfo.rate, siginfo.channels)


def get_total_dataset_length(base_dir: str) -> None:
    """
    Gets information related to the length of the audio data
    """
    length = []
    srs = []
    channels = []

    file_paths = glob.glob(os.path.join(base_dir, '**', '*.mp3'), recursive=True)
    fails = 0
    fail_paths = []
    for file_path in tqdm.tqdm(file_paths):
        try:
          audio_info = get_audio_info(file_path)
          srs.append(audio_info[1])
          channels.append(audio_info[2])
          length.append(audio_info[0]/audio_info[1])
        except Exception as e:
          fails+=1
          fail_paths.append(file_path)


    print()
    print("-"*50)
    print(f"Min audio length (in seconds): {min(length)} | Max audio length (in seconds): {max(length)}")
    print(f"Mean audio length (in seconds): {np.mean(length)} | Median: {np.median(length)} | Std: {np.std(length)}")
    print(f"Total amount of data (in minutes): {np.sum(length)/60}")
    print('-'*50)
    print(f"Different sample rates in audios in the dataset: {set(srs)}")
    print(f"Different number of channels in the audios in the dataset: {set(channels)}")
    print(f"Falhas na contagem {fails}/{len(file_paths)}")
    for fail_audio in fail_paths:
      print(fail_audio)

In [None]:
#get_total_dataset_length("../audios/")

DATA

In [None]:
audio,sr = torchaudio.load('../audios/train/fake/ff79bf6b-f3e5-480d-b7c7-a36c0a1ddf46.mp3')

In [None]:
class ExtractFeatures():
    def __init__(self, model_path):
        self.model_path = model_path
        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model = Wav2Vec2Model.from_pretrained(model_path).to(DEVICE)
    
    def w2v_extract_features(self,speech_array, target_sample_rate):
        input_tensor = self.processor(speech_array, sampling_rate=target_sample_rate, return_tensors='pt', padding=True).input_values
        input_tensor = torch.squeeze(input_tensor,dim=0).to(DEVICE)
        with torch.no_grad():
            embeddings = self.model(input_tensor).last_hidden_state
        return embeddings

In [None]:
class AudioDataset(Dataset):
    def __init__(self, data_dir, is_label, TARGET_SAMPLE_RATE = 16000, mode = 'mel_spec', feature_extractor = None):
        #is label = True para audios com label, False para audios sem label, isso garante que mudanças no pipeline de extração de features sejam para ambos os conjuntos
        self.data_dir = data_dir
        self.classes = ["real", "fake"]
        self.audio_files = []
        self.labels = []
        self.is_label = is_label
        self.TARGET_SAMPLE_RATE =TARGET_SAMPLE_RATE
        self.mode = mode
        self.feature_extractor = feature_extractor
        if self.is_label:
            for class_idx, class_name in enumerate(self.classes):
                class_dir = os.path.join(data_dir, class_name)
                for file in os.listdir(class_dir):
                    if file.endswith(".mp3"):
                        self.audio_files.append(os.path.join(class_dir, file))
                        self.labels.append(class_idx)
        else:            
            for file in os.listdir(self.data_dir):
                if file.endswith(".mp3"):
                    self.audio_files.append(os.path.join(self.data_dir,file))

        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=TARGET_SAMPLE_RATE, n_fft=1024, hop_length=512, n_mels=64
        )

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_file = self.audio_files[idx]
        if self.is_label:
            label = self.labels[idx]

        # Load audio
        audio, sr = torchaudio.load(audio_file)
        # Convert to mono
        if audio.shape[0] > 1:
            audio = torch.mean(audio, dim=0).unsqueeze(0)

        if sr != self.TARGET_SAMPLE_RATE:
            audio = torchaudio.transforms.Resample(sr, self.TARGET_SAMPLE_RATE)(audio)

        # Pad or truncate the audio to a fixed length
        fixed_length = (
            self.TARGET_SAMPLE_RATE * 4
        )  # Adjust this value based on your requirements
        if audio.shape[1] < fixed_length:
            audio = torch.nn.functional.pad(audio, (0, fixed_length - audio.shape[1]))
        else:
            audio = audio[:, :fixed_length]
        
        if self.mode == 'mel_spec':
            audio = self.mel_spectrogram(audio)
            if self.is_label:
                return audio, label
            else:
                #import for test generating
                return audio, os.path.basename(audio_file)
                
        elif self.mode == 'w2v_feature':
            audio = self.feature_extractor.w2v_extract_features(audio, self.TARGET_SAMPLE_RATE)
            if self.is_label:
                return audio, label
            else:
                #import for test generating
                return audio, os.path.basename(audio_file)
        else:
            print("This feature extraction wasnt implemented")

DataLoader

In [None]:
class Data:
   
    def __init__(self, batch_size,dataset_train,dataset_test, do_split):
        self.modes = ['train','test']
        self.dataloaders = {}
        self.batch_size = batch_size
        self.do_split = do_split
        if self.do_split:
            self.modes = ['train','validation','test']
            generator = torch.Generator().manual_seed(42)
            train_size = int(len(dataset_train.audio_files)*0.8)
            val_size = int(len(dataset_train.audio_files)-train_size)
            train_set, val_set = random_split(dataset_train, [train_size, val_size], generator=generator)

            self.dataloaders['train'] = train_set
            self.dataloaders['validation'] = val_set
        else:
            self.dataloaders['train'] = dataset_train
            
        self.dataloaders['test'] = dataset_test
    

    def get_loader(self, mode):
        if mode == 'train':
            return  DataLoader(self.dataloaders[mode], batch_size=self.batch_size, shuffle=True)
        else:
            return  DataLoader(self.dataloaders[mode], batch_size=self.batch_size, shuffle=False)


Evaluator

In [None]:
class Evaluator:
    
    def __init__(self):
    
        self.loss_fn = nn.BCELoss()
    def get_loss(self, y, y_hat):
        return self.loss_fn(y_hat, y)

Model

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, input_size):
        super(NeuralNetwork, self).__init__()
        self.input_size = input_size
        self.conv_layer = nn.Sequential(
            nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.flatten = nn.Flatten()
        dim  = self.flatten(torch.zeros(self.input_size).unsqueeze(0)).shape[1]
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(dim, 128),
            nn.Linear(128,32),
            nn.Linear(32,4),
            nn.Linear(4,1)
        )

    def forward(self, x):
        #x = self.conv_layer(x)
        x = self.flatten(x)
        x = self.linear_relu_stack(x)
        x = torch.sigmoid(x)
        return x


Learner

In [None]:
class Learner:
    def __init__(self, input_size):
        self.model = NeuralNetwork(input_size=input_size)
        self.model.to(DEVICE)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.nweights = 0
        for name,weights in self.model.named_parameters():
            if 'bias' not in name:
                self.nweights = self.nweights + weights.numel()

    def predict(self, x):
        return self.model(x)

    def update(self, loss):
        # Backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    
    def l1_loss(self):
        L1_term = torch.tensor(0., requires_grad=True)
        
        for name, weights in self.model.named_parameters():
            if 'bias' not in name:
                weights_sum = torch.sum(torch.abs(weights))
                L1_term = L1_term + weights_sum
        L1_term = L1_term / self.nweights
        return L1_term

Metrics e best models

In [None]:
class Metrics():
    #FOR TRAIN AND VALIDATION ONLY
    def __init__(self):
        self.metrics_save = {}
        self.best_models_weigths = {}
        self.metrics_names  = ['acuracy','recall','precision','f1-score']
        
    def calc_metrics(self,preds,labels,mode,loss, model_weigths=None, show=False):
        acc = accuracy_score(y_pred=preds, y_true=labels)
        recall = recall_score(y_pred=preds, y_true=labels)
        precision = precision_score(y_pred=preds, y_true=labels)
        f1 = f1_score(y_pred=preds, y_true=labels,average='binary')
        metrics_values = [acc,recall,precision,f1]


        ###################LOSS#############################################
        if f'{mode}_loss' not in self.metrics_save.keys():
            self.metrics_save[f'{mode}_loss'] = [loss]
        else:
            self.metrics_save[f'{mode}_loss'].append(loss)
        ################################################################
        
        if show:
            print(f"{mode} -  Acuracy: {acc:.2f} - Recall {recall:.2f} - Precision {precision:.2f} - F1-Score {f1:.2f} - Loss {loss:.2f}")


        #############Metrics##################################################
        
        for metric_name, metric_value in zip(self.metrics_names, metrics_values):

            ###Add metrics 
            if f'{mode}_{metric_name}' not in self.metrics_save.keys():
                self.metrics_save[f'{mode}_{metric_name}'] = [metric_value]
            else:
                self.metrics_save[f'{mode}_{metric_name}'].append(metric_value)
            
            ###Sava best metrics and respective weigths 
            if mode == 'train':
                if f'{mode}_best_{metric_name}' not in self.metrics_save.keys():
                    self.metrics_save[f'{mode}_best_{metric_name}'] = metric_value
                    self.best_models_weigths[f'{mode}_best_{metric_name}'] = model_weigths
                elif metric_value > self.metrics_save[f'{mode}_best_{metric_name}'] :
                    self.best_models_weigths[f'{mode}_best_{metric_name}'] = model_weigths
    
        ################################################################


    def plot_metrics(self):
        fig, axs = plt.subplots(2, 2, figsize=(10, 10))
        metrics = self.metrics_names.copy()
        modes = ['train','validation']
        for i,metric_name in enumerate(metrics):
            metric_train  = self.metrics_save[f'{modes[0]}_{metric_name}']
            metric_validation  = self.metrics_save[f'{modes[1]}_{metric_name}']

            ii = i % 2
            jj = i // 2
            axs[jj,ii].plot(metric_train, label=modes[0])
            axs[jj,ii].plot(metric_validation, label=modes[1])

            axs[jj,ii].set_title(f'{metric_name}: {modes[0]} x {modes[1]}')
            axs[jj,ii].set_xlabel('Epoch')
            axs[jj,ii].set_ylabel(f'{metric_name}')
            axs[jj,ii].legend()

        plt.tight_layout()

        # Mostrar o plot
        plt.show()

    def plot_loss(self):
        modes = ['train','validation']
        metric_name = 'loss'
        metric_train  = np.exp(self.metrics_save[f'{modes[0]}_{metric_name}'])
        metric_validation  = np.exp(self.metrics_save[f'{modes[1]}_{metric_name}'])
        fig, axs = plt.subplots(1, 1, figsize=(12, 6))
        axs.set_title(f'Log scale {metric_name}: {modes[0]} x {modes[1]}')
        axs.plot(metric_train, label=modes[0])
        axs.plot(metric_validation, label=modes[1])
        axs.set_xlabel('Epoch')
        axs.set_ylabel(f'{metric_name}')
        axs.legend()
        plt.show()

    def get_best_model(self, metric):
        for key in self.best_models_weigths.keys():
                if metric in key:
                    return self.best_models_weigths[key]
    
    def save_best_model(self,all_metrics, metric='F1-Score'):
        if all_metrics:
            print("Saving all models")
            for key in self.best_models_weigths.keys():
                torch.save(f'{self.best_models_weigths[key]}.pt', key)
                print(f"Save model at: {key}.pt")

        else:
            print(f"Saving best model for {metric}")
            for key in self.best_models_weigths.keys():
                if metric in key:
                    torch.save(f'{self.best_models_weigths[key]}.pt', key)
                    print(f"Save model at: {key}.pt")
                    break


Trainer

In [None]:
class Trainer:
    def __init__(self, data: Data, learner: Learner, evaluator: Evaluator, metrics: Metrics):
        self.data = data
        self.learner = learner
        self.metrics = metrics
        self.evaluator = evaluator

    def one_epoch(self, mode):
        if mode == 'train':
            self.learner.model.train(True)
        else:
            self.learner.model.train(False)
    

        dataloader = self.data.get_loader(mode)
        preds = []
        labels = []
        epoch_loss = 0
        
        for (X, y) in tqdm.tqdm(dataloader):
            X, y = X.to(DEVICE), y.to(DEVICE).float().unsqueeze(1)

            y_hat = self.learner.predict(X)
            
            loss = self.evaluator.get_loss(y, y_hat)
               
            if mode == 'train':
                #L1_term  = self.learner.l1_loss()
                #loss  = loss - (L1_term * 0.1)
                self.learner.update(loss)

            epoch_loss += loss.item()

            labels.extend(y.int().tolist())
            preds.extend((y_hat > 0.5).int().tolist())
        
        epoch_loss /= len(dataloader)

        #preds,labels,mode,loss, model_weigths=None, show=False
        self.metrics.calc_metrics(preds=preds, labels=labels, mode=mode, loss=epoch_loss, model_weigths=self.learner.model.state_dict(), show=True)

    def test(self,mode,name_test, model_weigths=None):
        self.learner.model.load_state_dict(model_weigths)
        self.learner.model.train(False)
        dataloader = self.data.get_loader(mode)
        preds = []
        ids = []
        for (X, x_id) in tqdm.tqdm(dataloader):
            X = X.to(DEVICE)
            y_hat = self.learner.predict(X)
            ids.extend(x_id)
            preds.extend((y_hat).float().tolist())
        
        file_test_submtion =  open(f'{name_test}.csv','w')
        file_test_submtion.write('id,fake_prob\n')
        for idx,pred in zip(ids,preds):
            file_test_submtion.write(f"{idx},{pred[0]}\n")
        print(f"Test submission for {name_test} saved at {name_test}.csv")


    def run(self, n_epochs: int):
        print("Starting training")
        for t in range(n_epochs):
            print(f"Epoch {t+1}\n-------------------------------")
            self.one_epoch(mode='train')

            with torch.no_grad():
                self.one_epoch(mode='validation')
        print("Training done")
        
    def run_test(self,name_test):
        #Keep test at the end of training
        metric = 'f1-score'
        print(f"Generating test probs with {metric} best model:")
        with torch.no_grad():
            best_model = self.metrics.get_best_model(metric=metric)
            self.test(mode='test', name_test=name_test, model_weigths=best_model)


Run training

instancias

In [None]:
feature_extractor = ExtractFeatures(model_path='facebook/wav2vec2-base')

In [None]:
#datasets
audio_train = AudioDataset(data_dir='/home/gustavo/Projects/PAV/DEEPFAKE-COMPTETITION-PAV/audios/train', is_label=True, mode='w2v_feature', feature_extractor=feature_extractor)
audio_teste = AudioDataset(data_dir='/home/gustavo/Projects/PAV/DEEPFAKE-COMPTETITION-PAV/audios/test', is_label=False, mode='w2v_feature', feature_extractor=feature_extractor)


In [None]:
#dataloaders
data =Data(batch_size=50, dataset_train=audio_train, dataset_test=audio_teste, do_split=True)


In [None]:
#evaluator
evaluator = Evaluator()

In [None]:
#learner
learner = Learner(input_size=audio_train[0][0].shape)

In [None]:
metrics = Metrics()

Treino

In [None]:
trainer = Trainer(data=data, evaluator=evaluator, learner=learner, metrics=metrics)
trainer.run(n_epochs=30)
trainer.run_test(name_test='test_w2v_features_extractor_dropout')



In [None]:
metrics.plot_metrics()

In [None]:
metrics.plot_loss()

In [None]:
#save all metrics
metrics.save_best_model(all_metrics=True)