In [None]:
import os
import torch
import torchaudio
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

class AudioDeepfakeDataset(Dataset):
    def __init__(self, data_dirs, sample_rate=16000, max_length=4.0):
        self.data_dirs = data_dirs
        self.sample_rate = sample_rate
        self.max_length = max_length
        self.max_samples = int(max_length * sample_rate)
        
        self.audio_files = []
        self.labels = []
        
        for data_dir in data_dirs:
            data_dir = Path(data_dir)
            label = 0 if 'real' in data_dir.name.lower() else 1
            for audio_file in data_dir.glob('*.wav'):
                self.audio_files.append(str(audio_file))
                self.labels.append(label)
        
        assert len(self.audio_files) > 0, "No audio files found in the provided directories."
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        label = self.labels[idx]
        
        waveform, orig_sample_rate = torchaudio.load(audio_path)
        
        if orig_sample_rate != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_sample_rate, self.sample_rate)
            waveform = resampler(waveform)
        
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
        
        num_samples = waveform.shape[1]
        if num_samples > self.max_samples:
            waveform = waveform[:, :self.max_samples]
        elif num_samples < self.max_samples:
            padding = torch.zeros(1, self.max_samples - num_samples)
            waveform = torch.cat([waveform, padding], dim=1)
        
        return waveform.squeeze(0), label

def collate_fn(batch):
    waveforms, labels = zip(*batch)
    waveforms = torch.stack([wf for wf in waveforms])
    labels = torch.tensor(labels, dtype=torch.long)
    return waveforms, labels

def get_test_dataloader(test_dirs, batch_size=16, num_workers=8):
    test_dataset = AudioDeepfakeDataset(test_dirs)
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )
    return test_loader

def evaluate_model(model, test_loader, processor, output_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    
    test_loss, test_correct, test_total = 0, 0, 0
    test_preds, test_labels = [], []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating on Test Data"):
            waveforms, labels = batch
            waveforms, labels = waveforms.to(device), labels.to(device)
            
            outputs = model(waveforms, labels=labels)
            loss = outputs.loss
            
            test_loss += loss.item()
            preds = outputs.logits.argmax(dim=-1)
            test_correct += (preds == labels).sum().item()
            test_total += labels.size(0)
            test_preds.extend(preds.cpu().numpy())
            test_labels.extend(labels.cpu().numpy())
    
    test_loss /= len(test_loader)
    test_accuracy = test_correct / test_total
    
    # Compute confusion matrix
    cm = confusion_matrix(test_labels, test_preds)
    
    # Compute classification report
    class_report = classification_report(test_labels, test_preds, target_names=["Real", "Fake"])
    
    # Save classification report
    with open(os.path.join(output_dir, "classification_report.txt"), "w") as f:
        f.write(class_report)
    
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
    print("\nClassification Report:\n", class_report)
    
    return cm, test_labels, test_preds, test_loss, test_accuracy

def plot_confusion_matrix(cm, output_dir, class_names=["Real", "Fake"]):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "confusion_matrix.png"))
    plt.close()

def main():
    # Define test data directories
    test_dirs = [
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/real",
        "/teamspace/studios/this_studio/audio_detect/dataset/split_data/test/fake"
    ]
    
    # Output directory
    output_dir = "saved_model"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load saved model and processor
    model = Wav2Vec2ForSequenceClassification.from_pretrained(os.path.join(output_dir, "best_model"))
    processor = Wav2Vec2Processor.from_pretrained(os.path.join(output_dir, "best_model"))
    
    # Get test data loader
    test_loader = get_test_dataloader(test_dirs, batch_size=16, num_workers=8)
    
    # Evaluate model
    cm, test_labels, test_preds, test_loss, test_accuracy = evaluate_model(model, test_loader, processor, output_dir)
    
    # Plot and save confusion matrix
    plot_confusion_matrix(cm, output_dir)



In [None]:
if __name__ == "__main__":
    main()

In [None]:
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification

def preprocess_audio(audio_path, sample_rate=16000, max_length=4.0):
    # Load audio
    waveform, orig_sample_rate = torchaudio.load(audio_path)
    
    # Resample to 16kHz if needed
    if orig_sample_rate != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_sample_rate, sample_rate)
        waveform = resampler(waveform)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Normalize (zero mean, unit variance)
    waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-6)
    
    # Trim or pad to 4 seconds
    max_samples = int(max_length * sample_rate)
    num_samples = waveform.shape[1]
    if num_samples > max_samples:
        waveform = waveform[:, :max_samples]
    elif num_samples < max_samples:
        padding = torch.zeros(1, max_samples - num_samples)
        waveform = torch.cat([waveform, padding], dim=1)
    
    return waveform.squeeze(0)

def predict_audio(model, processor, audio_path, device):
    # Preprocess audio
    waveform = preprocess_audio(audio_path)
    
    # Move to device and add batch dimension
    waveform = waveform.unsqueeze(0).to(device)
    
    # Get model prediction
    model.eval()
    with torch.no_grad():
        outputs = model(waveform)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)
        predicted_label = logits.argmax(dim=-1).item()
        confidence = probabilities[0, predicted_label].item()
    
    return predicted_label, confidence

def main():
    # Paths
    model_path = "saved_model/best_model"
    audio_path = "path/to/your/audio_clip.wav"  # Replace with your audio file path
    
    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model and processor
    model = Wav2Vec2ForSequenceClassification.from_pretrained(model_path).to(device)
    processor = Wav2Vec2Processor.from_pretrained(model_path)
    
    # Predict
    label, confidence = predict_audio(model, processor, audio_path, device)
    
    # Output result
    class_name = "Real" if label == 0 else "Fake"
    print(f"Prediction: {class_name}")
    print(f"Confidence: {confidence:.4f}")



In [None]:
if __name__ == "__main__":
    main()