In [12]:
!pip3 install -U transformers



In [23]:
import torch
print(torch.cuda.is_available())

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import torchvision
import torchvision.transforms as transforms
from torchvision import models

import numpy as np
import matplotlib.pyplot as plt
import librosa
import os
from tqdm import tqdm
from transformers import Wav2Vec2Model, Wav2Vec2Processor
import soundfile as sf

True


# Load audio

In [6]:
class WaveformDataset(Dataset):
    def __init__(self, loader_type):
        
        root = os.getcwd()
        data_root = None
        real_folder = None
        fake_folder = None
        
        # get the folder
        if loader_type == "train":
            data_root = os.path.join(root, 'data/for-rerecorded/training')
            real_folder = os.path.join(data_root, 'real')
            fake_folder = os.path.join(data_root, 'fake')
        elif loader_type == "validation":
            data_root = os.path.join(root, 'data/for-rerecorded/validation')
            real_folder = os.path.join(data_root, 'real')
            fake_folder = os.path.join(data_root, 'fake')
        elif loader_type == "test":
            data_root = os.path.join(root, 'data/for-rerecorded/testing')
            real_folder = os.path.join(data_root, 'real')
            fake_folder = os.path.join(data_root, 'fake')
        elif loader_type == "ITWFull":
            data_root = os.path.join(root, 'data/release_in_the_wild/')
            real_folder = os.path.join(data_root, 'real')
            fake_folder = os.path.join(data_root, 'fake')
        else:
            # Should never occur.
            pass
        
        self.real_files = []
        self.fake_files = []
        
        # get real example filenames
        suffix = f".wav"
        for filename in os.listdir(real_folder):
            # check if correct suffix and exists as a file
            if filename.endswith(suffix) and os.path.isfile(os.path.join(real_folder, filename)):
                this_filepath = os.path.join(real_folder, filename)
                self.real_files.append(this_filepath)
                
        print(f"Real examples for raw waveform {loader_type}: {len(self.real_files)}")
        
        # get fake example filenames
        suffix = f".wav"
        for filename in os.listdir(fake_folder):
            # check if correct suffix and exists as a file
            if filename.endswith(suffix) and os.path.isfile(os.path.join(fake_folder, filename)):
                this_filepath = os.path.join(fake_folder, filename)
                self.fake_files.append(this_filepath)
                
        print(f"Fake examples for raw waveform {loader_type}: {len(self.fake_files)}")
            
        # load the raw waveform data
        #
        # References the dataloader from the RawNet implementation
        
        self.data = []
        
        
        for real_file in self.real_files:
            X, sr = sf.read(real_file)
            X = X.astype(np.float64)
            X = X.reshape(1, -1)
            
            fixed_length = 2 * sr
            if X.shape[1] < fixed_length:
                X[:, :fixed_length]
            else:
                pad_width = fixed_length - X.shape[1]
                X = np.pad(X, ((0, 0), (0, pad_width)), mode='constant')
                
            # append
            X = X.astype(np.float32)
            self.data.append((X, 1))
            
        for fake_file in self.fake_files:
            X, sr = sf.read(real_file)
            X = X.astype(np.float64)
            X = X.reshape(1, -1)

            fixed_length = 2 * sr
            if X.shape[1] < fixed_length:
                X[:, :fixed_length]
            else:
                pad_width = fixed_length - X.shape[1]
                X = np.pad(X, ((0, 0), (0, pad_width)), mode='constant')
                
            # append
            X = X.astype(np.float32)
            self.data.append((X, 0))
            
            
        
    def __len__(self):
        return len(self.data)
            
    def __getitem__(self, idx):
        # return the data and the label
        return self.data[idx]

# Define model

In [31]:
class Wav2VecClassifier(nn.Module):
    def __init__(self, hidden_dim=768):
        super().__init__()
        self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h", use_safetensors=False)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1)  
        )

    def forward(self, input_values, attention_mask):
        outputs = self.wav2vec(input_values=input_values, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state.mean(dim=1)  
        return self.classifier(pooled)

# Training Procedures

In [32]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

def collate_fn(batch):
    waveforms, labels = zip(*batch)
    waveforms = [torch.tensor(w[0], dtype=torch.float32) for w in waveforms]  
    waveforms = [w.squeeze(0) for w in waveforms] 
    
    inputs = processor(waveforms, sampling_rate=16000, return_tensors="pt", padding=True)
    labels = torch.tensor(labels, dtype=torch.float).unsqueeze(1)
    return inputs, labels
    

In [33]:
device = torch.device("cuda")

model = None
model = Wav2VecClassifier()
    
model.to(device)

epochs = 5
batch_size = 32
weight_decay = 5e-4
learning_rate = 0.0001
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay = weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs)

FoR_train_dataset = WaveformDataset("train")
FoR_val_dataset = WaveformDataset("validation")
FoR_test_dataset = WaveformDataset("test")
FoR_train_loader = DataLoader(FoR_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
FoR_val_loader = DataLoader(FoR_val_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
FoR_test_loader = DataLoader(FoR_test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

NameError: name 'init_empty_weights' is not defined

In [None]:
def compute_EER(model, loader):
    model.eval()
    all_scores = []
    all_labels = []
    
    with torch.no_grad():
        for data in loader:
            waveform, labels = data
            
            waveform = waveform.to(device)
            labels = labels.to(device)
            
            out = model(waveform)
            out = torch.sigmoid(out)

            
            all_scores.extend(out.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # use sklearn to compute this for us
    fpr, tpr, thresholds = roc_curve(all_labels, all_scores)
    
    # definition
    fnr = 1 - tpr

    # find closest threshold
    eer_thresh = np.nanargmin(np.abs(fpr-fnr))
    EER = (fpr[eer_thresh] + fnr[eer_thresh])/2
    
    return EER

In [None]:
def train(loader):
    model.train()
    training_loss = 0.0
    
    for data in loader:
        waveform, labels = data
        waveform = waveform.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # basic pytorch boilerplate
        out = model(waveform)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        
        training_loss += loss.item()
        
    training_loss = training_loss / len(loader)
    return training_loss

In [None]:
def validate(loader):
    model.eval()
    validation_loss = 0.0
    
    n_correct = 0
    n_total = 0
    
    with torch.no_grad():
        for data in loader:
            waveform, labels = data
            
            waveform = waveform.to(device)
            labels = labels.to(device)
            
            out = model(waveform)
            loss = criterion(out, labels)
            
            validation_loss += loss.item()
            
            # count correct predictions
            preds = None
            preds = (out > 0).long()
            
            n_correct = n_correct + (preds == labels).sum().item()
            n_total = n_total + labels.size(0)
            
    validation_loss = validation_loss / len(loader)
    accuracy = n_correct / n_total
    
    return validation_loss, accuracy
            

In [None]:
# reference paper uses patience = 5
patience = 5
best_validation_loss = 10000.0
fail_count = 0
epochs = 30

training_losses = []
val_losses = []
test_losses = []

for epoch in tqdm(range(epochs)):
    training_loss = train(FoR_train_loader)
    print(f"[Epoch {epoch}] Training Loss: {training_loss}")
    
    training_losses.append(training_loss)
    
    validation_loss, val_accuracy = validate(FoR_val_loader)    
    print(f"[Epoch {epoch}] Validation Loss: {validation_loss} Accuracy: {val_accuracy}")
    
    val_losses.append(validation_loss)
    
    test_loss, test_accuracy = validate(FoR_test_loader)
    print(f"[DEBUG Epoch {epoch}] Test Loss: {test_loss} Accuracy: {test_accuracy}")
    
    test_losses.append(test_loss)
    
    if validation_loss < best_validation_loss:
        best_validation_loss = validation_loss
        fail_count = 0
    else:
        # increment number of epochs of no improvement
        fail_count = fail_count + 1
        
    if fail_count >= patience:
        print(f"Triggering early breaking on epoch {epoch}")
        break
    
    scheduler.step()