In [12]:
import os
import time
from datetime import datetime
import logging
import soundfile as sf
from multiprocessing import Pool
import uuid

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def validate_wav(file_path):
    """Validate if a WAV file is readable and non-empty."""
    try:
        data, sr = sf.read(file_path)
        if len(data) == 0:
            raise ValueError("Empty audio file")
        return file_path, True
    except Exception as e:
        logging.error(f"Invalid file {file_path}: {str(e)}")
        return file_path, False

def remove_corrupted_files(directory):
    """Remove corrupted WAV files from the directory and return count of valid files."""
    try:
        # Get list of WAV files
        files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.wav')]
        if not files:
            logging.info(f"No WAV files found in {directory}")
            return 0

        # Validate files in parallel
        with Pool() as pool:
            results = pool.map(validate_wav, files)

        # Separate valid and corrupted files
        valid_files = [f for f, valid in results if valid]
        corrupted_files = [f for f, valid in results if not valid]

        # Remove corrupted files
        for corrupted_file in corrupted_files:
            try:
                os.remove(corrupted_file)
                logging.info(f"Removed corrupted file: {corrupted_file}")
            except OSError as e:
                logging.error(f"Error removing {corrupted_file}: {str(e)}")

        return len(valid_files)

    except Exception as e:
        logging.error(f"Error processing directory {directory}: {str(e)}")
        return 0

def rename_files_with_timestamp(directory):
    """Rename all WAV files in the directory with unique timestamp-based names."""
    try:
        # Get list of WAV files
        files = [f for f in os.listdir(directory) if f.endswith('.wav')]
        if not files:
            logging.info(f"No WAV files to rename in {directory}")
            return

        # Rename files
        for idx, filename in enumerate(files):
            # Generate unique timestamp-based name with UUID for additional uniqueness
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
            unique_id = str(uuid.uuid4())[:8]  # Short UUID segment
            new_filename = f"audio_{timestamp}_{unique_id}.wav"

            # Define old and new file paths
            old_file = os.path.join(directory, filename)
            new_file = os.path.join(directory, new_filename)

            # Rename file
            try:
                os.rename(old_file, new_file)
                logging.info(f"Renamed {filename} to {new_filename}")
                # Small delay to ensure unique timestamps
                time.sleep(0.001)
            except OSError as e:
                logging.error(f"Error renaming {filename}: {str(e)}")
                continue

    except Exception as e:
        logging.error(f"Error processing directory {directory}: {str(e)}")

def process_directories(directories):
    """Process each directory: remove corrupted files, rename valid files, and count valid files."""
    valid_file_counts = {}
    
    for directory in directories:
        logging.info(f"Processing directory: {directory}")
        
        # Step 1: Remove corrupted files and get count of valid files
        valid_count = remove_corrupted_files(directory)
        valid_file_counts[directory] = valid_count
        
        # Step 2: Rename valid files
        rename_files_with_timestamp(directory)
        
    # Print valid file counts
    logging.info("Valid file counts per directory:")
    for directory, count in valid_file_counts.items():
        logging.info(f"{directory}: {count} valid files")

# if __name__ == "__main__":
#     directories = [
#         "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/real",
#         "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/fake",
#         "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/real",
#         "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/fake",
#         "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/real",
#         "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/fake",
#     ]
    
#     process_directories(directories)

In [13]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np

class AudioDeepfakeDataset(Dataset):
    def __init__(self, data_dirs, sample_rate=16000, max_length=4.0):
        """
        Custom Dataset for audio deepfake detection.
        
        Args:
            data_dirs (list): List of directories containing audio files (real and fake).
            sample_rate (int): Target sample rate (16000 Hz for wav2vec2).
            max_length (float): Maximum audio length in seconds (4.0 seconds).
        """
        self.data_dirs = data_dirs
        self.sample_rate = sample_rate
        self.max_length = max_length
        self.max_samples = int(max_length * sample_rate)
        
        # Collect all audio files and their labels
        self.audio_files = []
        self.labels = []
        
        for data_dir in data_dirs:
            data_dir = Path(data_dir)
            # Label: 0 for real, 1 for fake
            label = 0 if 'real' in data_dir.name.lower() else 1
            for audio_file in data_dir.glob('*.wav'):
                self.audio_files.append(str(audio_file))
                self.labels.append(label)
        
        assert len(self.audio_files) > 0, "No audio files found in the provided directories."
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        # Load audio file
        audio_path = self.audio_files[idx]
        label = self.labels[idx]
        
        # Load waveform using torchaudio
        waveform, orig_sample_rate = torchaudio.load(audio_path)
        
        # Resample to 16kHz if needed
        if orig_sample_rate != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_sample_rate, self.sample_rate)
            waveform = resampler(waveform)
        
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Normalize waveform (zero mean, unit variance)
        waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
        
        # Pad or truncate to max_length
        num_samples = waveform.shape[1]
        if num_samples > self.max_samples:
            waveform = waveform[:, :self.max_samples]
        elif num_samples < self.max_samples:
            padding = torch.zeros(1, self.max_samples - num_samples)
            waveform = torch.cat([waveform, padding], dim=1)
        
        return waveform.squeeze(0), label

def collate_fn(batch):
    """
    Custom collate function to handle variable-length audio in batches.
    
    Args:
        batch: List of (waveform, label) tuples.
    
    Returns:
        waveforms: Padded waveforms as a tensor.
        labels: Tensor of labels.
    """
    waveforms, labels = zip(*batch)
    
    # Stack waveforms and labels
    waveforms = torch.stack([wf for wf in waveforms])
    labels = torch.tensor(labels, dtype=torch.long)
    
    return waveforms, labels

def get_dataloaders(train_dirs, val_dirs, test_dirs, batch_size=16, num_workers=4):
    """
    Create dataloaders for train, validation, and test sets.
    
    Args:
        train_dirs (list): Directories for training data.
        val_dirs (list): Directories for validation data.
        test_dirs (list): Directories for test data.
        batch_size (int): Batch size for dataloaders.
        num_workers (int): Number of workers for data loading.
    
    Returns:
        train_loader, val_loader, test_loader: PyTorch DataLoaders.
    """
    # Initialize datasets
    train_dataset = AudioDeepfakeDataset(train_dirs)
    val_dataset = AudioDeepfakeDataset(val_dirs)
    test_dataset = AudioDeepfakeDataset(test_dirs)
    
    # Initialize dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

# Example usage
if __name__ == "__main__":
    train_dirs = [
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/real",
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/fake"
    ]
    val_dirs = [
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/real",
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/fake"
    ]
    test_dirs = [
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/real",
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/fake"
    ]
    
    # Get dataloaders
    train_loader, val_loader, test_loader = get_dataloaders(train_dirs, val_dirs, test_dirs, batch_size=16)
    
    # Test a batch
    for waveforms, labels in train_loader:
        print(f"Waveforms shape: {waveforms.shape}")  # Expected: [batch_size, 64000]
        print(f"Labels shape: {labels.shape}")        # Expected: [batch_size]
        break



Waveforms shape: torch.Size([16, 64000])
Labels shape: torch.Size([16])


In [None]:
import os
import random
from IPython.display import Audio, display

def display_random_audio(directory):
    """Display one random WAV file from the given directory."""
    try:
        # Get list of WAV files
        files = [f for f in os.listdir(directory) if f.endswith('.wav')]
        if not files:
            print(f"No WAV files found in {directory}")
            return
        
        # Select random file
        random_file = random.choice(files)
        file_path = os.path.join(directory, random_file)
        
        # Display audio
        print(f"Playing: {file_path}")
        display(Audio(file_path))
        
    except Exception as e:
        print(f"Error playing audio from {directory}: {str(e)}")

if __name__ == "__main__":
    directories = [
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/real",
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/fake",
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/real",
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/fake",
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/real",
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/fake"
    ]
    
    # Display one random audio file from each directory
    for directory in directories:
        display_random_audio(directory)

In [20]:
import os
import torch
import torchaudio
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from pathlib import Path
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
import matplotlib.pyplot as plt

# Section 1: Define Model Class
class AudioDeepfakeModel(nn.Module):
    def __init__(self, model_name="facebook/wav2vec2-base", num_labels=2):
        super(AudioDeepfakeModel, self).__init__()
        self.wav2vec2 = Wav2Vec2ForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels
        )
        self.wav2vec2.wav2vec2.feature_extractor.eval()
        for param in self.wav2vec2.wav2vec2.feature_extractor.parameters():
            param.requires_grad = False
    
    def forward(self, input_values, labels=None):
        outputs = self.wav2vec2(input_values, labels=labels)
        return outputs

# Section 2: Define Metrics Function
def compute_metrics(labels, preds):
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='binary')
    return {"accuracy": accuracy, "f1": f1}

# Section 3: Setup DataLoaders
train_dirs = [
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/real",
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/train/fake"
]
val_dirs = [
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/real",
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/val/fake"
]
test_dirs = [
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/real",
    "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/fake"
]

train_loader, val_loader, test_loader = get_dataloaders(
    train_dirs,
    val_dirs,
    test_dirs,
    batch_size=8,
    num_workers=4
)

# Section 4: Training Function with Progress Bar and Metrics Tracking
def train_model(model, train_loader, val_loader, output_dir, num_epochs=3):
    try:
        processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        
        optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=500)
        
        train_losses, val_losses = [], []
        train_accuracies, val_accuracies = [], []
        best_f1 = 0
        
        for epoch in range(num_epochs):
            # Training
            model.train()
            train_loss, train_correct, train_total = 0, 0, 0
            train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
            
            for batch in train_pbar:
                try:
                    waveforms, labels = batch
                    waveforms, labels = waveforms.to(device), labels.to(device)
                    
                    optimizer.zero_grad()
                    outputs = model(waveforms, labels=labels)
                    loss = outputs.loss
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
                    
                    train_loss += loss.item()
                    preds = outputs.logits.argmax(dim=-1)
                    train_correct += (preds == labels).sum().item()
                    train_total += labels.size(0)
                    
                    train_pbar.set_postfix({
                        "loss": f"{train_loss/train_total:.4f}",
                        "acc": f"{train_correct/train_total:.4f}"
                    })
                except Exception as e:
                    print(f"Error in training batch: {e}")
                    continue
            
            train_loss /= len(train_loader)
            train_accuracy = train_correct / train_total
            train_losses.append(train_loss)
            train_accuracies.append(train_accuracy)
            
            # Validation
            model.eval()
            val_loss, val_correct, val_total = 0, 0, 0
            val_preds, val_labels = [], []
            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
            
            with torch.no_grad():
                for batch in val_pbar:
                    try:
                        waveforms, labels = batch
                        waveforms, labels = waveforms.to(device), labels.to(device)
                        
                        outputs = model(waveforms, labels=labels)
                        loss = outputs.loss
                        val_loss += loss.item()
                        preds = outputs.logits.argmax(dim=-1)
                        val_correct += (preds == labels).sum().item()
                        val_total += labels.size(0)
                        val_preds.extend(preds.cpu().numpy())
                        val_labels.extend(labels.cpu().numpy())
                        
                        val_pbar.set_postfix({
                            "loss": f"{val_loss/val_total:.4f}",
                            "acc": f"{val_correct/val_total:.4f}"
                        })
                    except Exception as e:
                        print(f"Error in validation batch: {e}")
                        continue
            
            val_loss /= len(val_loader)
            val_accuracy = val_correct / val_total
            val_losses.append(val_loss)
            val_accuracies.append(val_accuracy)
            
            # Compute F1-score
            metrics = compute_metrics(val_labels, val_preds)
            val_f1 = metrics["f1"]
            
            # Save best model
            if epoch == 0 or val_f1 > best_f1:
                best_f1 = val_f1
                model.save_pretrained(os.path.join(output_dir, "best_model"))
                processor.save_pretrained(os.path.join(output_dir, "best_model"))
            
            print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}, Val F1: {val_f1:.4f}")
        
        return train_losses, train_accuracies, val_losses, val_accuracies
    
    except Exception as e:
        print(f"Error in train_model: {e}")
        return None, None, None, None

# Section 5: Plotting Function
def plot_metrics(train_losses, train_accuracies, val_losses, val_accuracies, output_dir):
    epochs = range(1, len(train_losses) + 1)
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Train Loss')
    plt.plot(epochs, val_losses, 'r-', label='Val Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, 'b-', label='Train Accuracy')
    plt.plot(epochs, val_accuracies, 'r-', label='Val Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'metrics_plot.png'))
    plt.close()

# Section 6: Inference Function
def predict_audio_clip(model_path, audio_path, sample_rate=16000):
    processor = Wav2Vec2Processor.from_pretrained(model_path)
    model = AudioDeepfakeModel.from_pretrained(model_path)
    model.eval()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    waveform, orig_sample_rate = torchaudio.load(audio_path)
    if orig_sample_rate != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_sample_rate, sample_rate)
        waveform = resampler(waveform)
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
    
    max_samples = int(4.0 * sample_rate)
    if waveform.shape[1] > max_samples:
        waveform = waveform[:, :max_samples]
    elif waveform.shape[1] < max_samples:
        padding = torch.zeros(1, max_samples - waveform.shape[1])
        waveform = torch.cat([waveform, padding], dim=1)
    
    inputs = waveform.squeeze(0).to(device)
    with torch.no_grad():
        outputs = model(inputs.unsqueeze(0))
    logits = outputs.logits
    probs = torch.softmax(logits, dim=-1)
    prediction = torch.argmax(logits, dim=-1).item()
    
    return {
        "prediction": "Real" if prediction == 0 else "Fake",
        "prob_real": probs[0, 0].item(),
        "prob_fake": probs[0, 1].item()
    }



In [None]:
# Section 7: Initialize Model and Train
output_dir = "./deepfake_model_checkpoints"
os.makedirs(output_dir, exist_ok=True)

model = AudioDeepfakeModel(model_name="facebook/wav2vec2-base", num_labels=2)
train_losses, train_accuracies, val_losses, val_accuracies = train_model(
    model, train_loader, val_loader, output_dir, num_epochs=3
)


Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/3 [Train]:   0%|          | 10/4002 [03:01<20:08:15, 18.16s/it, loss=0.0869, acc=0.4875]


In [None]:
# Section 8: Plot Metrics
if train_losses:
    plot_metrics(train_losses, train_accuracies, val_losses, val_accuracies, output_dir)

# Section 9: Test Inference on Example Audio Clip
example_audio_path = "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/real/example.wav"
result = predict_audio_clip(os.path.join(output_dir, "best_model"), example_audio_path)
print("Inference Result:", result)