# Import and Install Dependancies

In [None]:
!pip install torch torchaudio transformers datasets scikit-learn soundfile torchvision audiomentations

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchaudio
from transformers import HubertModel
import os
import glob
import random
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


# Declare Values

In [None]:
class Config:
    def __init__(self):
        self.datasets = {
            # "fake92": "/kaggle/input/imbalanceddataset/fake92",
            # "real90": "/kaggle/input/imbalanceddataset/real90", 
            # "balanced": "/kaggle/input/fakes-and-reals/audio_train/audio_train"
        }
        self.test_path = "/kaggle/input/fakes-and-reals/audio_test/audio_test"
        
        # Hardware settings
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_workers = 4
        self.pin_memory = True
        
        # Training parameters
        self.num_epochs = 20
        self.batch_size = 8
        self.train_val_split = 0.8
        
        # Audio processing
        self.sample_rate = 16000
        self.max_length = 16000
        self.label_mapping = {"real": 0, "fake": 1}  # Simplified mapping
        
        # Model configuration
        self.unfreeze_layers = [-1]  # Last layer only
        self.use_augmentation = True
        
        # Optimization parameters
        self.base_lr = 1e-5
        self.classifier_lr_multiplier = 5
        self.lr_decay_per_epoch = 0.95
        self.min_lr = 1e-7
        self.huber_weight_decay = 1e-5
        self.classifier_weight_decay = 1e-4
        self.gradient_clip = 1.0

config = Config()


In [None]:
results = {
    'metrics': {},
    'curves': {}
}

# Prepare Audio Dataset

In [None]:
class AudioDataset(Dataset):
    def __init__(self, root_dir, config, augment=False):
        self.config = config
        self.file_list = []
        self.labels = []
        self.augment = augment
        self._load_data(root_dir)

    def _is_valid_audio(self, file_path):
        """Enhanced validation with detailed logging"""
        try:
            # Check file size
            if os.path.getsize(file_path) == 0:
                print(f"Empty file: {file_path}")
                return False
                
            # Try loading the file
            waveform, sr = torchaudio.load(file_path)
            if waveform.nelement() == 0:
                return False
            if waveform.shape[0] not in [1, 2]:  # Mono or stereo
                return False
            if waveform.shape[1] < 100:  # Minimum 100 samples
                print(f"Short audio: {file_path} ({waveform.shape[1]} samples)")
                return False
            return True
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            return False

    def _load_data(self, root_dir):
        for label_name in self.config.label_mapping:
            label = self.config.label_mapping[label_name]
            folder = os.path.join(root_dir, label_name)
            
            if not os.path.exists(folder):
                print(f"Warning: Missing directory {folder}")
                continue
                
            files = glob.glob(os.path.join(folder, "*.*"))
            print(f"Found {len(files)} files in {folder}")
            
            for file in files:
                if self._is_valid_audio(file):
                    self.file_list.append(file)
                    self.labels.append(label)
                else:
                    print(f"Warning: Skipping invalid file: {file}")
            

        # Shuffle dataset
        random.seed(42)
        combined = list(zip(self.file_list, self.labels))
        random.shuffle(combined)
        self.file_list, self.labels = zip(*combined) if combined else ([], [])

    def __getitem__(self, idx):
        try:
            waveform, sr = torchaudio.load(self.file_list[idx])
            
            # Resample if needed
            if sr != self.config.sample_rate:
                resampler = torchaudio.transforms.Resample(sr, self.config.sample_rate)
                waveform = resampler(waveform)

            # Process waveform
            waveform = self._process_waveform(waveform)
            return waveform, self.labels[idx]
            
        except Exception as e:
            print(f"Error loading {self.file_list[idx]}: {str(e)}")
            return torch.zeros((1, self.config.max_length)), 0

    def _process_waveform(self, waveform):
        # Convert to mono
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
            
        # Trim/pad to fixed length
        if waveform.shape[1] > self.config.max_length:
            waveform = waveform[:, :self.config.max_length]
        else:
            pad_amount = self.config.max_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, pad_amount))
            
        return waveform

    def __len__(self):
        return len(self.file_list)


In [None]:
def collate_fn(batch):
    waveforms, labels = [], []
    for wav, lbl in batch:
        waveforms.append(wav)
        labels.append(lbl)
    return torch.stack(waveforms), torch.tensor(labels)

# HuBERT Classifier

In [None]:
class HuBERTClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hubert = HubertModel.from_pretrained("facebook/hubert-base-ls960")
        self._freeze_layers()
        
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, len(config.label_mapping))
        )
    
    def _freeze_layers(self):
        for idx, layer in enumerate(self.hubert.encoder.layers):
            layer.requires_grad_(idx in config.unfreeze_layers)

    def forward(self, x):
        if x.dim() == 3:  # Handle channel dimension
            x = x.squeeze(1)
        outputs = self.hubert(x)
        features = outputs.last_hidden_state.mean(dim=1)
        return self.classifier(features)

# Load Dataset

In [None]:
test_dataset = AudioDataset(config.test_path, config)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                         collate_fn=collate_fn, pin_memory=config.pin_memory)

In [None]:
# Add to top imports
import csv
from datetime import datetime
import json

class DiskMetricWriter:
    def __init__(self):
        self.output_dir = "/kaggle/working/metrics"
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        os.makedirs(self.output_dir, exist_ok=True)
        
    def _get_path(self, dataset_name, metric_type):
        return f"{self.output_dir}/{dataset_name}_{self.timestamp}_{metric_type}.csv"

    def write_epoch_metrics(self, dataset_name, epoch, train_loss, val_loss, train_acc, val_acc):
        path = self._get_path(dataset_name, "training")
        write_header = not os.path.exists(path)
        
        with open(path, 'a', newline='') as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(['epoch', 'train_loss', 'val_loss', 'train_acc', 'val_acc'])
            writer.writerow([epoch, train_loss, val_loss, train_acc, val_acc])

    def write_final_metrics(self, dataset_name, test_loss, test_acc, auc_score):
        path = self._get_path(dataset_name, "final")
        with open(path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['metric', 'value'])
            writer.writerow(['test_loss', test_loss])
            writer.writerow(['test_accuracy', test_acc])
            writer.writerow(['auc', auc_score])

    def write_confusion_matrix(self, dataset_name, cm, classes):
        path = self._get_path(dataset_name, "confusion")
        with open(path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([''] + list(classes))
            for i, row in enumerate(cm):
                writer.writerow([classes[i]] + list(row))

    def write_classification_report(self, dataset_name, report):
        path = self._get_path(dataset_name, "classification").replace('.csv', '.txt')
        with open(path, 'w') as f:
            f.write(report)

    
    def save_test_predictions(self, dataset_name, probs, labels):
        probs_path = f"{self.output_dir}/{dataset_name}_{self.timestamp}_probs.npy"
        labels_path = f"{self.output_dir}/{dataset_name}_{self.timestamp}_labels.npy"
        np.save(probs_path, probs)
        np.save(labels_path, labels)


# Metrics Saving and Visualisation

In [None]:
class TestVisualizer:
    def __init__(self):
        self.output_dir = "/kaggle/working/test_metrics"
        os.makedirs(self.output_dir, exist_ok=True)
        self.metrics = []

    def add_metrics(self, dataset_name, test_metrics, fpr, tpr, auc_score, cm):
        self.metrics.append({
            'dataset': dataset_name,
            'accuracy': test_metrics['accuracy'],
            'loss': test_metrics['loss'],
            'auc': auc_score,
            'fpr': fpr,
            'tpr': tpr,
            'cm': cm
        })

    def save_metrics_to_csv(self):
        df = pd.DataFrame([{
            'dataset': m['dataset'],
            'accuracy': m['accuracy'],
            'loss': m['loss'],
            'auc': m['auc']
        } for m in self.metrics])
        df.to_csv(f"{self.output_dir}/test_metrics_summary.csv", index=False)

    def plot_all(self):
        plt.figure(figsize=(20, 15))
        
        # 1. AUC-ROC Curves
        plt.subplot(2, 2, 1)
        for m in self.metrics:
            plt.plot(m['fpr'], m['tpr'], label=f"{m['dataset']} (AUC = {m['auc']:.2f})")
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curves Comparison')
        plt.legend()
        
        # 2. Accuracy Comparison
        plt.subplot(2, 2, 2)
        accuracies = [m['accuracy'] for m in self.metrics]
        plt.bar(range(len(self.metrics)), accuracies)
        plt.xticks(range(len(self.metrics)), [m['dataset'] for m in self.metrics])
        plt.ylim(0, 1)
        plt.ylabel('Accuracy')
        plt.title('Test Accuracy Comparison')
        
        # 3. Loss Comparison
        plt.subplot(2, 2, 3)
        losses = [m['loss'] for m in self.metrics]
        plt.bar(range(len(self.metrics)), losses)
        plt.xticks(range(len(self.metrics)), [m['dataset'] for m in self.metrics])
        plt.ylabel('Loss')
        plt.title('Test Loss Comparison')
        
        # 4. Confusion Matrices
        plt.subplot(2, 2, 4)
        for idx, m in enumerate(self.metrics):
            plt.subplot(2, 2, 4)
            sns.heatmap(m['cm'], annot=True, fmt='d', 
                        xticklabels=['Real', 'Fake'], 
                        yticklabels=['Real', 'Fake'])
            plt.title(f'Confusion Matrix - {m["dataset"]}')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.savefig(f"{self.output_dir}/confusion_matrix_{m['dataset']}.png")
            plt.clf()
        
        plt.tight_layout()
        plt.savefig(f"{self.output_dir}/test_performance_summary.png")
        plt.close()

# Initialize visualizer
test_visualizer = TestVisualizer()


In [None]:
metric_writer = DiskMetricWriter()

# Evaluation

In [None]:
def evaluate_model(model, loader, device, criterion=None):
    """Enhanced evaluation function with all metrics"""
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []
    total_loss = 0.0
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            
            # Calculate loss if criterion provided
            if criterion:
                loss = criterion(outputs, labels)
                total_loss += loss.item()
            
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    metrics = {
        'labels': np.array(all_labels),
        'preds': np.array(all_preds),
        'probs': np.array(all_probs)
    }
    
    if criterion:
        metrics['loss'] = total_loss / len(loader)
        metrics['accuracy'] = np.mean(metrics['labels'] == metrics['preds'])
    
    return metrics


# Training

In [None]:
# Training and evaluation loop for each dataset
for dataset_name, train_path in config.datasets.items():
    print(f"\n{'='*40}")
    print(f"Training on {dataset_name}")
    print(f"{'='*40}")
    
    # Initialize fresh model
    model = HuBERTClassifier(config).to(config.device)
    criterion = nn.CrossEntropyLoss()  # Fixed typo
    optimizer = optim.Adam([
        {'params': model.hubert.parameters(), 'lr': config.base_lr},
        {'params': model.classifier.parameters(), 
         'lr': config.base_lr * config.classifier_lr_multiplier}
    ])
    
    # Create datasets
    train_dataset = AudioDataset(train_path, config, augment=True)
    train_size = int(config.train_val_split * len(train_dataset))
    train_dataset, val_dataset = random_split(train_dataset, [train_size, len(train_dataset)-train_size])
    
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                            shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size,
                          collate_fn=collate_fn)
    
    best_val_acc = 0.0
    
    # Single epoch loop
    for epoch in range(config.num_epochs):
        model.train()
        epoch_loss = 0.0
        correct = 0
        total = 0
        
        # Training
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            inputs, labels = inputs.to(config.device), labels.to(config.device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            optimizer.step()
            
            epoch_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(config.device), labels.to(config.device)
                outputs = model(inputs)
                val_loss += criterion(outputs, labels).item()
                
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        # Calculate metrics
        train_loss = epoch_loss / len(train_loader)
        train_acc = correct / total
        val_loss = val_loss / len(val_loader)
        val_acc = val_correct / val_total
        
        # Write to disk immediately
        metric_writer.write_epoch_metrics(
            dataset_name=dataset_name,
            epoch=epoch+1,
            train_loss=train_loss,
            val_loss=val_loss,
            train_acc=train_acc,
            val_acc=val_acc
        )
        
        # Update learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = max(param_group['lr'] * config.lr_decay_per_epoch, 
                                  config.min_lr)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f"best_model_{dataset_name}.pth")
        
        print(f"Epoch {epoch+1}/{config.num_epochs}")
        print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.2%}")
        print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.2%}\n")
    
    # Final test evaluation
    test_metrics = evaluate_model(model, test_loader, config.device, criterion)


    metric_writer.save_test_predictions(
        dataset_name=dataset_name,
        probs=test_metrics['probs'],
        labels=test_metrics['labels']
    )
    
    # Final test evaluation
    test_metrics = evaluate_model(model, test_loader, config.device, criterion)
    
    # Calculate metrics
    fpr, tpr, _ = roc_curve(test_metrics['labels'], test_metrics['probs'][:, 1])
    auc_score = auc(fpr, tpr)
    cm = confusion_matrix(test_metrics['labels'], np.argmax(test_metrics['probs'], axis=1))
    
    # Store metrics for visualization
    test_visualizer.add_metrics(dataset_name, test_metrics, fpr, tpr, auc_score, cm)
    
    metric_writer.write_final_metrics(
        dataset_name=dataset_name,
        test_loss=test_metrics['loss'],
        test_acc=test_metrics['accuracy'],
        auc_score=auc_score
    )
    
    # Cleanup
    del model, criterion, optimizer, train_dataset, val_dataset
    torch.cuda.empty_cache()

In [None]:
# After all datasets are processed
test_visualizer.save_metrics_to_csv()
test_visualizer.plot_all()

In [None]:
def plot_comparative_curves(metric_dir="/kaggle/working/metrics"):
    import os
    import glob
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.metrics import roc_curve
    
    # Create directory if not exists
    os.makedirs(metric_dir, exist_ok=True)
    
    # Get all dataset names from config
    datasets = list(config.datasets.keys())
    
    plt.figure(figsize=(15, 10))
    
    # 1. Loss Comparison
    plt.subplot(2, 2, 1)
    for dataset in datasets:
        try:
            # Find latest training file for this dataset
            train_files = glob.glob(f"{metric_dir}/{dataset}_*_training.csv")
            if not train_files:
                print(f"No training files found for {dataset}, skipping")
                continue
                
            latest_file = max(train_files, key=os.path.getctime)
            df = pd.read_csv(latest_file)
            
            plt.plot(df['epoch'], df['train_loss'], label=f'{dataset} Train')
            plt.plot(df['epoch'], df['val_loss'], '--', label=f'{dataset} Val')
        except Exception as e:
            print(f"Error plotting {dataset} loss: {str(e)}")
            continue

    plt.title('Loss Comparison')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # 2. Accuracy Comparison
    plt.subplot(2, 2, 2)
    for dataset in datasets:
        try:
            train_files = glob.glob(f"{metric_dir}/{dataset}_*_training.csv")
            if not train_files:
                continue
                
            latest_file = max(train_files, key=os.path.getctime)
            df = pd.read_csv(latest_file)
            
            plt.plot(df['epoch'], df['train_acc'], label=f'{dataset} Train')
            plt.plot(df['epoch'], df['val_acc'], '--', label=f'{dataset} Val')
        except Exception as e:
            print(f"Error plotting {dataset} accuracy: {str(e)}")
            continue

    plt.title('Accuracy Comparison')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # 3. Test Metrics Comparison
    plt.subplot(2, 2, 3)
    metrics_list = []
    for dataset in datasets:
        try:
            test_files = glob.glob(f"{metric_dir}/{dataset}_*_final.csv")
            if not test_files:
                continue
                
            latest_test = max(test_files, key=os.path.getctime)
            df_test = pd.read_csv(latest_test)
            
            accuracy_value = df_test.loc[df_test['metric'] == 'test_accuracy', 'value'].values[0]
            loss_value = df_test.loc[df_test['metric'] == 'test_loss', 'value'].values[0]
            
            metrics_list.append({
                'dataset': dataset,
                'accuracy': accuracy_value,
                'loss': loss_value
            })
        except Exception as e:
            print(f"Error loading test metrics for {dataset}: {str(e)}")
            continue

    if metrics_list:
        metrics_df = pd.DataFrame(metrics_list)
        x = np.arange(len(metrics_df))
        width = 0.35
        
        plt.bar(x - width/2, metrics_df['accuracy'], width, label='Accuracy')
        plt.bar(x + width/2, metrics_df['loss'], width, label='Loss')
        plt.xticks(x, metrics_df['dataset'])
        plt.title('Test Set Performance')
        plt.legend()
    
    # 4. ROC Comparison
    plt.subplot(2, 2, 4)
    for dataset in datasets:
        try:
            # Get latest test files
            test_files = glob.glob(f"{metric_dir}/{dataset}_*_final.csv")
            prob_files = glob.glob(f"{metric_dir}/{dataset}_*_probs.npy")
            label_files = glob.glob(f"{metric_dir}/{dataset}_*_labels.npy")
            
            if not test_files or not prob_files or not label_files:
                continue
                
            # Load data
            latest_test = max(test_files, key=os.path.getctime)
            latest_probs = max(prob_files, key=os.path.getctime)
            latest_labels = max(label_files, key=os.path.getctime)
            
            df_test = pd.read_csv(latest_test)
            auc_score = df_test.loc[df_test['metric'] == 'auc', 'value'].values[0]
            probs = np.load(latest_probs)
            labels = np.load(latest_labels)
            
            fpr, tpr, _ = roc_curve(labels, probs[:, 1])
            plt.plot(fpr, tpr, label=f'{dataset} (AUC = {auc_score:.2f})')
        except Exception as e:
            print(f"Error plotting ROC for {dataset}: {str(e)}")
            continue

    plt.plot([0, 1], [0, 1], 'k--')
    plt.title('ROC Curve Comparison')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(f'{metric_dir}/comparison_plot.png')
    plt.close()
    print(f"Comparison plot saved to {metric_dir}/comparison_plot.png")


In [None]:
# Generate comparison plots
plot_comparative_curves()



# Download Files

In [None]:
# zip metrics
# zip test_metrics

# download link for best_model_{dataset_name}.pth and all metrics

In [None]:
!zip -r metrics.zip /kaggle/working/metrics

In [None]:
!zip -r test_metrics.zip /kaggle/working/test-metrics

In [None]:
from IPython.display import FileLink
FileLink(r'best_model_fake92.pth')

In [None]:
FileLink(r'best_model_real90.pth')

In [None]:
FileLink(r'best_model_balanced.pth')

In [None]:
FileLink(r'metrics.zip')

In [None]:
FileLink(r'test_metrics.zip')