In [2]:
import os
import torch
import torchaudio
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.cuda.amp import GradScaler, autocast

class AudioDeepfakeDataset(Dataset):
    def __init__(self, data_dirs, sample_rate=16000, max_length=4.0):
        self.data_dirs = data_dirs
        self.sample_rate = sample_rate
        self.max_length = max_length
        self.max_samples = int(max_length * sample_rate)
        
        self.audio_files = []
        self.labels = []
        
        for data_dir in data_dirs:
            data_dir = Path(data_dir)
            label = 0 if 'real' in data_dir.name.lower() else 1
            for audio_file in data_dir.glob('*.wav'):
                self.audio_files.append(str(audio_file))
                self.labels.append(label)
        
        assert len(self.audio_files) > 0, "No audio files found in the provided directories."
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        label = self.labels[idx]
        
        waveform, orig_sample_rate = torchaudio.load(audio_path)
        
        if orig_sample_rate != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_sample_rate, self.sample_rate)
            waveform = resampler(waveform)
        
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
        
        num_samples = waveform.shape[1]
        if num_samples > self.max_samples:
            waveform = waveform[:, :self.max_samples]
        elif num_samples < self.max_samples:
            padding = torch.zeros(1, self.max_samples - num_samples)
            waveform = torch.cat([waveform, padding], dim=1)
        
        return waveform.squeeze(0), label

def collate_fn(batch):
    waveforms, labels = zip(*batch)
    waveforms = torch.stack([wf for wf in waveforms])
    labels = torch.tensor(labels, dtype=torch.long)
    return waveforms, labels

def get_dataloaders(train_dirs, val_dirs, test_dirs, batch_size=16, num_workers=8):
    train_dataset = AudioDeepfakeDataset(train_dirs)
    val_dataset = AudioDeepfakeDataset(val_dirs)
    test_dataset = AudioDeepfakeDataset(test_dirs)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

class AudioDeepfakeModel(nn.Module):
    def __init__(self, model_name="facebook/wav2vec2-base", num_labels=2):
        super(AudioDeepfakeModel, self).__init__()
        self.wav2vec2 = Wav2Vec2ForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels
        )
        self.wav2vec2.wav2vec2.feature_extractor.eval()
        for param in self.wav2vec2.wav2vec2.feature_extractor.parameters():
            param.requires_grad = False
    
    def forward(self, input_values, labels=None):
        outputs = self.wav2vec2(input_values, labels=labels)
        return outputs

def compute_metrics(labels, preds):
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='binary')
    return {"accuracy": accuracy, "f1": f1}

def train_model(model, train_loader, val_loader, output_dir, num_epochs=15, patience=5, accum_steps=2):
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=500)
    scaler = GradScaler()
    
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    best_f1 = 0
    best_val_loss = float('inf')
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        
        optimizer.zero_grad()
        for i, batch in enumerate(train_pbar):
            waveforms, labels = batch
            waveforms, labels = waveforms.to(device), labels.to(device)
            
            with autocast():
                outputs = model(waveforms, labels=labels)
                loss = outputs.loss / accum_steps
            
            scaler.scale(loss).backward()
            
            if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
            
            train_loss += loss.item() * accum_steps
            preds = outputs.logits.argmax(dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)
            
            train_pbar.set_postfix({
                "loss": f"{train_loss/train_total:.4f}",
                "acc": f"{train_correct/train_total:.4f}"
            })
        
        train_loss /= len(train_loader)
        train_accuracy = train_correct / train_total
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        val_preds, val_labels = [], []
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
        
        with torch.no_grad():
            for batch in val_pbar:
                waveforms, labels = batch
                waveforms, labels = waveforms.to(device), labels.to(device)
                
                with autocast():
                    outputs = model(waveforms, labels=labels)
                    loss = outputs.loss
                
                val_loss += loss.item()
                preds = outputs.logits.argmax(dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
                
                val_pbar.set_postfix({
                    "loss": f"{val_loss/val_total:.4f}",
                    "acc": f"{val_correct/val_total:.4f}"
                })
        
        val_loss /= len(val_loader)
        val_accuracy = val_correct / val_total
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        
        metrics = compute_metrics(val_labels, val_preds)
        val_f1 = metrics["f1"]
        
        if epoch == 0 or val_f1 > best_f1:
            best_f1 = val_f1
            model.wav2vec2.save_pretrained(os.path.join(output_dir, "best_model"))
            processor.save_pretrained(os.path.join(output_dir, "best_model"))
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}, Val F1: {val_f1:.4f}")
    
    return train_losses, train_accuracies, val_losses, val_accuracies

def plot_metrics(train_losses, train_accuracies, val_losses, val_accuracies, output_dir):
    epochs = range(1, len(train_losses) + 1)
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Train Loss')
    plt.plot(epochs, val_losses, 'r-', label='Val Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, 'b-', label='Train Accuracy')
    plt.plot(epochs, val_accuracies, 'r-', label='Val Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'metrics_plot.png'))
    plt.close()

train_dirs = [
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/real",
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/fake"
]
val_dirs = [
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/real",
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/fake"
]
test_dirs = [
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/real",
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/fake"
]

train_loader, val_loader, test_loader = get_dataloaders(
    train_dirs,
    val_dirs,
    test_dirs,
    batch_size=32,
    num_workers=8
)



In [3]:
model = AudioDeepfakeModel(model_name="facebook/wav2vec2-base", num_labels=2)

output_dir = "saved_model"
os.makedirs(output_dir, exist_ok=True)

train_losses, train_accuracies, val_losses, val_accuracies = train_model(
    model, train_loader, val_loader, output_dir, num_epochs=15, patience=5, accum_steps=2
)

plot_metrics(train_losses, train_accuracies, val_losses, val_accuracies, output_dir)

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  scaler = GradScaler()
  with autocast():
Epoch 1/15 [Train]:   0%|          | 0/1001 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 800.00 MiB. GPU 0 has a total capacity of 14.58 GiB of which 31.62 MiB is free. Process 5551 has 2.92 GiB memory in use. Process 55485 has 10.70 GiB memory in use. Process 130195 has 940.00 MiB memory in use. Of the allocated memory 768.82 MiB is allocated by PyTorch, and 51.18 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [4]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torchaudio
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.amp import GradScaler, autocast

class AudioDeepfakeDataset(Dataset):
    def __init__(self, data_dirs, sample_rate=16000, max_length=4.0):
        self.data_dirs = data_dirs
        self.sample_rate = sample_rate
        self.max_length = max_length
        self.max_samples = int(max_length * sample_rate)
        
        self.audio_files = []
        self.labels = []
        
        for data_dir in data_dirs:
            data_dir = Path(data_dir)
            label = 0 if 'real' in data_dir.name.lower() else 1
            for audio_file in data_dir.glob('*.wav'):
                self.audio_files.append(str(audio_file))
                self.labels.append(label)
        
        assert len(self.audio_files) > 0, "No audio files found in the provided directories."
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        label = self.labels[idx]
        
        waveform, orig_sample_rate = torchaudio.load(audio_path)
        
        if orig_sample_rate != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_sample_rate, self.sample_rate)
            waveform = resampler(waveform)
        
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
        
        num_samples = waveform.shape[1]
        if num_samples > self.max_samples:
            waveform = waveform[:, :self.max_samples]
        elif num_samples < self.max_samples:
            padding = torch.zeros(1, self.max_samples - num_samples)
            waveform = torch.cat([waveform, padding], dim=1)
        
        return waveform.squeeze(0), label

def collate_fn(batch):
    waveforms, labels = zip(*batch)
    waveforms = torch.stack([wf for wf in waveforms])
    labels = torch.tensor(labels, dtype=torch.long)
    return waveforms, labels

def get_dataloaders(train_dirs, val_dirs, test_dirs, batch_size=16, num_workers=8):
    train_dataset = AudioDeepfakeDataset(train_dirs)
    val_dataset = AudioDeepfakeDataset(val_dirs)
    test_dataset = AudioDeepfakeDataset(test_dirs)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

class AudioDeepfakeModel(nn.Module):
    def __init__(self, model_name="facebook/wav2vec2-base", num_labels=2):
        super(AudioDeepfakeModel, self).__init__()
        self.wav2vec2 = Wav2Vec2ForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels
        )
        self.wav2vec2.gradient_checkpointing_enable()  # Enable gradient checkpointing
        self.wav2vec2.wav2vec2.feature_extractor.eval()
        for param in self.wav2vec2.wav2vec2.feature_extractor.parameters():
            param.requires_grad = False
    
    def forward(self, input_values, labels=None):
        outputs = self.wav2vec2(input_values, labels=labels)
        return outputs

def compute_metrics(labels, preds):
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='binary')
    return {"accuracy": accuracy, "f1": f1}

def train_model(model, train_loader, val_loader, output_dir, num_epochs=15, patience=5, accum_steps=4):
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=500)
    scaler = GradScaler('cuda')
    
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    best_f1 = 0
    best_val_loss = float('inf')
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        
        optimizer.zero_grad()
        for i, batch in enumerate(train_pbar):
            waveforms, labels = batch
            waveforms, labels = waveforms.to(device), labels.to(device)
            
            with autocast('cuda'):
                outputs = model(waveforms, labels=labels)
                loss = outputs.loss / accum_steps
            
            scaler.scale(loss).backward()
            
            if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
            
            train_loss += loss.item() * accum_steps
            preds = outputs.logits.argmax(dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)
            
            train_pbar.set_postfix({
                "loss": f"{train_loss/train_total:.4f}",
                "acc": f"{train_correct/train_total:.4f}"
            })
        
        train_loss /= len(train_loader)
        train_accuracy = train_correct / train_total
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        val_preds, val_labels = [], []
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
        
        with torch.no_grad():
            for batch in val_pbar:
                waveforms, labels = batch
                waveforms, labels = waveforms.to(device), labels.to(device)
                
                with autocast('cuda'):
                    outputs = model(waveforms, labels=labels)
                    loss = outputs.loss
                
                val_loss += loss.item()
                preds = outputs.logits.argmax(dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
                
                val_pbar.set_postfix({
                    "loss": f"{val_loss/val_total:.4f}",
                    "acc": f"{val_correct/val_total:.4f}"
                })
        
        val_loss /= len(val_loader)
        val_accuracy = val_correct / val_total
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        
        metrics = compute_metrics(val_labels, val_preds)
        val_f1 = metrics["f1"]
        
        if epoch == 0 or val_f1 > best_f1:
            best_f1 = val_f1
            model.wav2vec2.save_pretrained(os.path.join(output_dir, "best_model"))
            processor.save_pretrained(os.path.join(output_dir, "best_model"))
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}, Val F1: {val_f1:.4f}")
    
    return train_losses, train_accuracies, val_losses, val_accuracies

def plot_metrics(train_losses, train_accuracies, val_losses, val_accuracies, output_dir):
    epochs = range(1, len(train_losses) + 1)
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Train Loss')
    plt.plot(epochs, val_losses, 'r-', label='Val Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, 'b-', label='Train Accuracy')
    plt.plot(epochs, val_accuracies, 'r-', label='Val Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'metrics_plot.png'))
    plt.close()

train_dirs = [
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/real",
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/fake"
]
val_dirs = [
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/real",
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/fake"
]
test_dirs = [
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/real",
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/fake"
]

train_loader, val_loader, test_loader = get_dataloaders(
    train_dirs,
    val_dirs,
    test_dirs,
    batch_size=16,
    num_workers=8
)



In [5]:
model = AudioDeepfakeModel(model_name="facebook/wav2vec2-base", num_labels=2)

output_dir = "saved_model"
os.makedirs(output_dir, exist_ok=True)

train_losses, train_accuracies, val_losses, val_accuracies = train_model(
    model, train_loader, val_loader, output_dir, num_epochs=15, patience=5, accum_steps=4
)

plot_metrics(train_losses, train_accuracies, val_losses, val_accuracies, output_dir)

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.58 GiB of which 13.62 MiB is free. Process 5551 has 2.92 GiB memory in use. Process 55485 has 10.70 GiB memory in use. Process 130195 has 958.00 MiB memory in use. Of the allocated memory 808.05 MiB is allocated by PyTorch, and 29.95 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)