## Install and Import Dependencies

In [None]:
!pip install torch torchaudio transformers datasets scikit-learn soundfile torchvision
!pip install -U ray
!pip install -U "flwr[simulation]==1.15.2"
!pip install audiomentations optuna opacus ipython

In [None]:
import os
import glob
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import logging
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, DataLoader, Subset
import flwr as fl
from transformers import HubertModel, HubertConfig
import ray

## Setup and Values

In [None]:
class Config:
    def __init__(self):
        self.dirs = {
            # "train": "/kaggle/input/imbalanceddataset/real90",
            # "train": "/kaggle/input/imbalanceddataset/fake92", #use it for fake92 
            # "train": "/kaggle/input/fakes-and-reals/audio_train/audio_train",# use it for balanced
            "test": "/kaggle/input/fakes-and-reals/audio_test/audio_test",
        }
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_clients = 4
        self.num_rounds = 20
        self.epochs_per_round = 1
        self.batch_size = 8
        self.sample_rate = 16000
        self.max_length = 16000
        self.label_mapping = {"real": 0, "fake": 1, "REAL":0, "FAKE":1}
        self.unfreeze_layers = [-1]  # Last transformer layer
        self.base_lr = 1e-5
        self.classifier_lr_multiplier = 10
        self.lr_decay = 0.95
        self.min_lr = 1e-7

config = Config()

In [None]:
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift
import torchaudio

class AudioAugmenter:
    def __init__(self, sample_rate=16000):
        self.augment = Compose([
            AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.005, p=0.2),
            TimeStretch(min_rate=0.95, max_rate=1.05, p=0.2),  # More conservative stretching
            PitchShift(min_semitones=-1, max_semitones=1, p=0.2),  # Reduced semitone range
            Shift(min_shift=-0.1, max_shift=0.1, p=0.2),
        ])
        self.sample_rate = sample_rate
        
    def __call__(self, waveform):
        """Process and augment waveform while maintaining proper dimensions"""
        # Convert to numpy and ensure proper shape
        np_waveform = waveform.numpy()
        
        # Handle different channel configurations
        if np_waveform.ndim == 2:  # [channels, time]
            # Convert multi-channel to mono by averaging
            np_waveform = np.mean(np_waveform, axis=0)
        elif np_waveform.ndim == 1:  # [time]
            pass  # Already mono
        else:
            raise ValueError(f"Unexpected waveform shape: {np_waveform.shape}")
        
        # Apply augmentations
        augmented = self.augment(
            samples=np_waveform,
            sample_rate=self.sample_rate
        )
        
        # Convert back to tensor with proper dimensions [1, time]
        return torch.from_numpy(augmented).unsqueeze(0).float()

## Loading Dataset

> **Audiomentations library used to impart noise and other impairments to our audio samples. To train on the original audio samples comment out the parts in the below code where its commented "#Impairments".**

In [None]:
class AudioDataset(Dataset):
    def __init__(self, root_dir, config, augment=False):
        self.config = config
        self.file_list = []
        self.labels = []
        self.augment = augment #Impairments
        self._load_data(root_dir)
        self.augmenter = AudioAugmenter(sample_rate=config.sample_rate) if augment else None
        
    def _load_data(self, root_dir):
        """Improved data loading with better error handling"""
        for label_name, label in config.label_mapping.items():
            folder = os.path.join(root_dir, label_name)
            if not os.path.exists(folder):
                print(f"Warning: Missing directory: {folder}")
                continue
                
            files = glob.glob(os.path.join(folder, "*.*"))
            print(f"Found {len(files)} files in {folder}")
            
            for file in files:
                if self._is_valid_audio(file):
                    self.file_list.append(file)
                    self.labels.append(label)
                else:
                    print(f"Warning: Skipping invalid file: {file}")


        # Shuffle with seed for reproducibility
        random.seed(42)
        combined = list(zip(self.file_list, self.labels))
        random.shuffle(combined)
        self.file_list, self.labels = zip(*combined) if combined else ([], [])

    def _is_valid_audio(self, file_path):
        """Enhanced validation with detailed logging"""
        try:
            # Check file size
            if os.path.getsize(file_path) == 0:
                print(f"Empty file: {file_path}")
                return False
                
            # Try loading the file
            waveform, sr = torchaudio.load(file_path)
            if waveform.nelement() == 0:
                return False
            if waveform.shape[0] not in [1, 2]:  # Mono or stereo
                return False
            if waveform.shape[1] < 100:  # Minimum 100 samples
                print(f"Short audio: {file_path} ({waveform.shape[1]} samples)")
                return False
            return True
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            return False

    def __getitem__(self, idx):
        try:
            waveform, sr = torchaudio.load(self.file_list[idx])
            
            # Resample if necessary
            if sr != config.sample_rate:
                resampler = torchaudio.transforms.Resample(sr, config.sample_rate)
                waveform = resampler(waveform)

            if self.augment and self.augmenter: #Impairments
                waveform = self.augmenter(waveform) #Impairments

            # Convert to mono and process
            waveform = self._process_waveform(waveform)
            label = self.labels[idx]

            return waveform.squeeze(0), label
        except Exception as e:
            print(f"Error processing {self.file_list[idx]}: {str(e)}")
            return torch.zeros((1, config.max_length)), 0  # Return dummy data

    def _process_waveform(self, waveform):
        """Guarantee 2D output [1, max_length] regardless of input"""
        # Convert to 2D if needed
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)  # [1, time]
        elif waveform.dim() > 2:
            waveform = waveform.view(-1, waveform.size(-1))  # Flatten to 2D
        
        # Convert to mono
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Trim/pad to exact length
        if waveform.size(1) > self.config.max_length:
            waveform = waveform[:, :self.config.max_length]
        else:
            pad_amount = self.config.max_length - waveform.size(1)
            waveform = torch.nn.functional.pad(waveform, (0, pad_amount))
            
        return waveform  # Guaranteed [1, max_length]

    def __len__(self):
        """Returns the total number of samples in the dataset"""
        return len(self.file_list)

## HuBERT Classification

In [None]:
class HuBERTClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Load pre-trained HuBERT
        self.hubert = HubertModel.from_pretrained("facebook/hubert-base-ls960")
        self._freeze_layers()
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, len(config.label_mapping))
        )

    def _freeze_layers(self):
        """Improved layer freezing with logging"""
        total_layers = len(self.hubert.encoder.layers)
        print(f"Total HuBERT layers: {total_layers}")
        
        for i, layer in enumerate(self.hubert.encoder.layers):
            if i not in self.config.unfreeze_layers:
                for param in layer.parameters():
                    param.requires_grad = False
            else:
                print(f"Unfreezing layer {i}")

    def forward(self, input_values):
        """Handle [batch, channels, time] input"""
        # Convert to HuBERT-compatible 2D [batch, time]
        if input_values.dim() == 3:
            input_values = input_values.squeeze(1)  # Remove channel dim
            
        outputs = self.hubert(input_values)
        pooled_output = outputs.last_hidden_state.mean(dim=1)
        return self.classifier(pooled_output)

In [None]:
def collate_fn(batch):
    """Robust collate function handling various audio dimensions"""
    # 1. Filter invalid entries
    batch = [b for b in batch if b is not None]
    
    if not batch:
        return torch.zeros((0, 1, config.max_length)), torch.zeros(0, dtype=torch.long)
    
    # 2. Separate components
    waveforms, labels = zip(*batch)
    
    waveforms = torch.stack(waveforms)  # [batch, 1, max_length]
    labels = torch.tensor(labels, dtype=torch.long)
    
    return waveforms, labels

## Metrics Saving

In [None]:
import json
from pathlib import Path
import matplotlib.pyplot as plt

class DiskMetricsCollector:
    def __init__(self, base_dir="/kaggle/working/metrics"):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(parents=True, exist_ok=True)
        self.metrics_file = self.base_dir / "all_metrics.json"
        
        # Initialize empty metrics structure if file doesn't exist
        if not self.metrics_file.exists():
            with open(self.metrics_file, "w") as f:
                json.dump({}, f)
    
    def add_metrics(self, round_num, client_id, metrics):
        """Append metrics to consolidated JSON file"""
        # Load existing data
        with open(self.metrics_file, "r") as f:
            all_metrics = json.load(f)
        
        # Create round entry if not exists
        round_key = f"round_{round_num}"
        if round_key not in all_metrics:
            all_metrics[round_key] = {}
        
        # Add client metrics
        client_key = f"client_{client_id}"
        all_metrics[round_key][client_key] = {
            "loss": metrics["loss"],
            "accuracy": metrics["accuracy"]
        }
        
        # Save back to file
        with open(self.metrics_file, "w") as f:
            json.dump(all_metrics, f, indent=2)

    def plot_round(self, round_num):
        """Plot metrics directly from JSON"""
        with open(self.metrics_file, "r") as f:
            all_metrics = json.load(f)
        
        round_key = f"round_{round_num}"
        if round_key not in all_metrics:
            print(f"No metrics for round {round_num}")
            return
        
        plt.figure(figsize=(12, 6))
        
        for client_key, metrics in all_metrics[round_key].items():
            client_id = client_key.split("_")[1]
            epochs = range(1, len(metrics["loss"]) + 1)
            
            plt.subplot(1, 2, 1)
            plt.plot(epochs, metrics["loss"], label=f'Client {client_id}')
            plt.title(f'Round {round_num} - Loss')
            plt.xlabel('Epoch')
            
            plt.subplot(1, 2, 2)
            plt.plot(epochs, metrics["accuracy"], label=f'Client {client_id}')
            plt.title(f'Round {round_num} - Accuracy')
            plt.xlabel('Epoch')
        
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.close()

## Federated Setup

In [None]:
import gc
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, train_data, test_data, config, client_id,metrics_collector):
        self.client_id = client_id
        self.model = model.to(config.device)
        self.config = config
        self.server_round = 0
        self.train_loader = DataLoader(
            train_data,
            batch_size=config.batch_size,
            shuffle=True,
            collate_fn=collate_fn,
            drop_last=True
        )
        self.test_loader = DataLoader(
            test_data,
            batch_size=config.batch_size,
            collate_fn=collate_fn
        )
        self.optimizer = self._create_optimizer()
        self.criterion = nn.CrossEntropyLoss()
        self.metrics_collector = metrics_collector

    
    # def get_parameters(self, config):
    #     """Proper parameter serialization using state_dict"""
    #     return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    
    # def set_parameters(self, parameters):
    #     """Proper parameter deserialization using state_dict"""
    #     params_dict = zip(self.model.state_dict().keys(), parameters)
    #     state_dict = {k: torch.tensor(v) for k, v in params_dict}
    #     self.model.load_state_dict(state_dict, strict=True)
    
    
    def _create_optimizer(self):
        """Optimizer with learning rate decay and stability features"""
        decay_factor = max(
            self.config.lr_decay ** self.server_round,
            self.config.min_lr / self.config.base_lr
        )
        
        params = [
            {
                "params": self.model.hubert.parameters(),
                "lr": self.config.base_lr * decay_factor,
                "weight_decay": 1e-5
            },
            {
                "params": self.model.classifier.parameters(),
                "lr": (self.config.base_lr * self.config.classifier_lr_multiplier) * decay_factor,
                "weight_decay": 1e-4
            }
        ]
        
        return optim.Adam(
            params,
            eps=1e-7,
            amsgrad=True  # Improved convergence stability
        )
    
    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    
    def set_parameters(self, parameters):
        if any(np.isnan(p).any() for p in parameters):
            raise ValueError(f"Client {self.client_id} received NaN parameters")
            
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {
            k: torch.tensor(v).to(self.config.device)
            for k, v in params_dict
        }
        self.model.load_state_dict(state_dict, strict=True)

        # if not parameters:
        #     raise ValueError("Received empty parameters")
        
        # try:
        #     # Convert numpy arrays to tensors
        #     params_dict = zip(self.model.state_dict().keys(), parameters)
        #     state_dict = {
        #         k: torch.tensor(v).to(self.config.device)
        #         for k, v in params_dict
        #     }
            
        #     # Strict loading with informative errors
        #     self.model.load_state_dict(state_dict, strict=True)
            
        # except RuntimeError as e:
        #     # Wrap PyTorch errors for Flower compatibility
        #     raise fl.common.parameter.ParametersError(str(e)) from e

    
    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.server_round = config.get("server_round", 1)
        self.optimizer = self._create_optimizer()  # Update optimizer for current round
        self.set_parameters(parameters)
        self.model.train()
    
        # Track per-epoch metrics
        epoch_losses = []
        epoch_accuracies = []
        epoch_metrics = {'loss': [], 'accuracy': []}
    
        try:
            if len(self.train_loader.dataset) == 0:
                print(f"Client has no training data!")
                return (
                    self.get_parameters({}),
                    0,
                    {"loss": 0.0, "accuracy": 0.0}
                )
    
            for epoch in range(self.config.epochs_per_round):
                epoch_loss = 0.0
                correct = 0
                total = 0
                progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}")
    
                for batch_idx, (data, targets) in enumerate(progress_bar):
                    data, targets = data.to(self.config.device), targets.to(self.config.device)
                    
                    self.optimizer.zero_grad()
                    outputs = self.model(data)

                    if torch.isnan(outputs).any():
                        print(f"NaN outputs detected, skipping batch")
                        continue
                    loss = self.criterion(outputs, targets)

                    if torch.isnan(loss):
                        print(f"NaN loss detected, resetting parameters")
                        self.set_parameters(parameters)  # Reset to server parameters
                        return self.get_parameters({}), 0, {}
                        
                    loss.backward()

                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        max_norm=1.0,
                        norm_type=2.0
                    )
                    
                    self.optimizer.step()
    
                    epoch_loss += loss.item()
                    preds = outputs.argmax(dim=1)
                    correct += (preds == targets).sum().item()
                    total += targets.size(0)
    
                    progress_bar.set_postfix(loss=loss.item())
                    
                    del outputs, loss
                    torch.cuda.empty_cache()
                    gc.collect()
    
                # Calculate metrics properly
                if len(self.train_loader) > 0 and total > 0:
                    avg_loss = epoch_loss / len(self.train_loader)
                    accuracy = correct / total
                else:  # Handle empty/corrupted data cases
                    avg_loss = 0.0
                    accuracy = 0.0
    
                epoch_losses.append(avg_loss)
                epoch_accuracies.append(accuracy)
                epoch_metrics['loss'].append(avg_loss)
                epoch_metrics['accuracy'].append(accuracy)  # Fixed variable name
    
        except Exception as e:
            print(f"Training error: {str(e)}")
            return self.get_parameters({}), 0, {}
    
        # Store metrics
        self.metrics_collector.add_metrics(
            self.server_round,  # Use the server_round from config
            self.client_id,
            {'loss': epoch_losses, 'accuracy': epoch_accuracies}
        )
    
        # Calculate averages
        avg_loss = sum(epoch_losses)/len(epoch_losses) if epoch_losses else 0.0
        avg_accuracy = sum(epoch_accuracies)/len(epoch_accuracies) if epoch_accuracies else 0.0
    
        return (
            self.get_parameters({}), 
            len(self.train_loader.dataset),
            {"loss": float(avg_loss), "accuracy": float(avg_accuracy)}
        )

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        
        total_loss = 0
        correct = 0
        with torch.no_grad():
            for data, targets in self.test_loader:
                data, targets = data.to(self.config.device), targets.to(self.config.device)
                outputs = self.model(data)
                total_loss += self.criterion(outputs, targets).item()
                preds = outputs.argmax(dim=1)
                correct += (preds == targets).sum().item()

        accuracy = correct / len(self.test_loader.dataset)
        avg_loss = total_loss / len(self.test_loader)
        return (
            float(avg_loss), 
            len(self.test_loader.dataset), 
            {
                "loss":float(avg_loss),
                "accuracy": float(accuracy)
            })

In [None]:
class CustomFedAvg(fl.server.strategy.FedAvg):
    def __init__(self, metrics_collector, initial_parameters=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.metrics_collector = metrics_collector
        self.current_parameters = initial_parameters  # Initialize with provided parameters
        self.initial_parameters = initial_parameters  # Store initial parameters

    def aggregate_fit(self, server_round, results, failures):
        # Call parent aggregation first
        aggregated = super().aggregate_fit(server_round, results, failures)
        
        if aggregated is not None:
            # Store the aggregated parameters
            self.current_parameters = aggregated[0]
        elif self.current_parameters is None:
            # Fallback to initial parameters if no aggregation happened
            self.current_parameters = self.initial_parameters
            
        return aggregated

    def get_parameters(self, config):
        # Return current parameters or initial ones if none exist
        if self.current_parameters is not None:
            return self.current_parameters
        return self.initial_parameters

## Visualisations

In [None]:
from sklearn.metrics import confusion_matrix, roc_curve, auc

# Add visualization functions before main()
def plot_training_metrics(history):
    """Plot training metrics from Flower history"""
    if not history.metrics_distributed:
        print("No metrics available in history")
        return
    
    plt.figure(figsize=(12, 5))
    
    # Training Loss
    plt.subplot(1, 2, 1)
    if "train_loss" in history.metrics_distributed:
        losses = [metric[1] for metric in history.metrics_distributed["train_loss"]]
        plt.plot(losses, marker='o', label='Training Loss')
    if "test_loss" in history.metrics_distributed:
        losses = [metric[1] for metric in history.metrics_distributed["test_loss"]]
        plt.plot(losses, marker='o', label='Test Loss')
    plt.title('Loss per Round')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.legend()
    
    # Accuracy
    plt.subplot(1, 2, 2)
    if "train_accuracy" in history.metrics_distributed:
        accs = [metric[1] for metric in history.metrics_distributed["train_accuracy"]]
        plt.plot(accs, marker='o', label='Training Accuracy')
    if "test_accuracy" in history.metrics_distributed:
        accs = [metric[1] for metric in history.metrics_distributed["test_accuracy"]]
        plt.plot(accs, marker='o', label='Test Accuracy')
    plt.title('Accuracy per Round')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()


In [None]:
def central_evaluation(model, test_loader, device):
    """Robust central evaluation with error handling"""
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for batch in test_loader:
            try:
                # Handle different batch formats
                if isinstance(batch, (list, tuple)):
                    data, targets = batch[0], batch[1]
                else:  # Handle single-tensor batches
                    data, targets = batch, None

                if data is None:
                    continue

                # Move to device
                inputs = data.to(device)
                targets = targets.to(device) if targets is not None else None
                
                # Forward pass
                outputs = model(inputs)
                probs = torch.softmax(outputs, dim=1)
                preds = torch.argmax(outputs, dim=1)
                
                # Only store if targets exist
                if targets is not None:
                    all_labels.extend(targets.cpu().numpy())
                    all_preds.extend(preds.cpu().numpy())
                    all_probs.extend(probs.cpu().numpy())
                    
            except Exception as e:
                print(f"Error processing batch: {str(e)}")
                continue
    
    return all_labels, all_preds, np.array(all_probs)

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def plot_confusion_matrix(y_true, y_pred, class_names, normalize=False, figsize=(8, 6)):
    """
    Plot a confusion matrix with enhanced visualization
    
    Parameters:
    y_true (array): True labels
    y_pred (array): Predicted labels
    class_names (list): List of class names
    normalize (bool): Whether to normalize the matrix
    figsize (tuple): Figure size
    """
    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Normalize if requested
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = '.2%'
        title = 'Normalized Confusion Matrix'
    else:
        fmt = 'd'
        title = 'Confusion Matrix'
    
    # Create figure
    plt.figure(figsize=figsize)
    
    # Create heatmap
    heatmap = sns.heatmap(
        cm,
        annot=True,
        fmt=fmt,
        cmap='Blues',
        xticklabels=class_names,
        yticklabels=class_names,
        cbar=False,
        linewidths=0.5,
        annot_kws={'size': 12}
    )
    
    # Add labels and title
    plt.title(title, fontsize=14, pad=20)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    
    # Adjust tick labels
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(rotation=0, fontsize=10)
    
    # Add colorbar
    plt.colorbar(heatmap.collections[0]).ax.set_ylabel('Counts' if not normalize else 'Percentage', 
                                                    rotation=270, labelpad=15)
    
    plt.tight_layout()
    plt.show()
    
    return cm

In [None]:
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc, RocCurveDisplay


def plot_roc_curve(y_true, y_probs, class_names):
    # Handle binary vs multi-class cases
    n_classes = len(class_names)
    
    if n_classes == 2:
        # Binary classification - use positive class probabilities
        fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1])  # Use second column
        roc_auc = auc(fpr, tpr)
        
        plt.figure()
        plt.plot(fpr, tpr, color='darkorange', lw=2,
                 label=f'ROC curve (AUC = {roc_auc:.2f})')
    else:
        # Multi-class: One-vs-Rest
        y_true_bin = label_binarize(y_true, classes=np.arange(n_classes))
        
        plt.figure()
        for i in range(n_classes):
            fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, lw=2,
                     label=f'{class_names[i]} (AUC = {roc_auc:.2f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend(loc="lower right")
    plt.title('ROC Curve' + (' (Binary)' if n_classes == 2 else ' (OvR)'))
    plt.show()

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# import torchaudio

# def plot_comparison(raw, aug, sample_rate=16000):
#     plt.figure(figsize=(18, 12))
    
#     # Time-align waveforms
#     min_length = min(len(raw), len(aug))
#     raw = raw[:min_length]
#     aug = aug[:min_length]
    
#     # Create window function
#     window = torch.hann_window(1024)
    
#     # Waveform plots with difference overlay
#     plt.subplot(3, 1, 1)
#     plt.plot(raw.numpy().squeeze(), 'b', alpha=0.6, label='Original')
#     plt.plot(aug.numpy().squeeze(), 'r', alpha=0.4, label='Augmented')
#     plt.plot(np.abs(raw.numpy().squeeze() - aug.numpy().squeeze()),'k', label='Difference')
#     plt.xlabel("Time (s) (Samples)")
#     plt.ylabel("Amplitude")
#     plt.title("Waveform Comparison")
#     plt.legend()
    
#     # Spectral difference using correct parameters
#     def create_spectrogram(waveform):
#         return torchaudio.functional.spectrogram(
#             waveform=waveform.unsqueeze(0),
#             pad=0,
#             window=window,
#             n_fft=1024,
#             hop_length=256,
#             win_length=1024,
#             power=2.0,
#             normalized=False
#         ).squeeze().log2().numpy()
    
#     S_raw = create_spectrogram(raw)
#     S_aug = create_spectrogram(aug)
    
#     plt.subplot(3, 2, 3)
#     plt.imshow(S_raw, aspect='auto', cmap='viridis', origin='lower')
#     plt.xlabel("Time (s)")
#     plt.ylabel("Frequency (Hz)")
#     plt.title("Original Spectrogram")
#     plt.colorbar()
    
#     plt.subplot(3, 2, 4)
#     plt.imshow(S_aug, aspect='auto', cmap='viridis', origin='lower')
#     plt.xlabel("Time (s)")
#     plt.ylabel("Frequency (Hz)")
#     plt.title("Augmented Spectrogram")
#     plt.colorbar()
    
#     plt.subplot(3, 2, 5)
#     plt.imshow(S_aug - S_raw, aspect='auto', cmap='coolwarm', origin='lower', vmin=-1, vmax=1)
#     plt.xlabel("Time (s)")
#     plt.ylabel("Frequency (Hz)")
#     plt.title("Spectral Difference")
#     plt.colorbar()
    
#     plt.subplot(3, 2, 6)
#     plt.specgram(raw.numpy().squeeze(), Fs=sample_rate, cmap='plasma', NFFT=1024, noverlap=512)
#     plt.xlabel("Time (s)")
#     plt.ylabel("Frequency (Hz)")
#     plt.title("Original Spectrogram")
    
#     plt.tight_layout()
#     plt.show()


In [None]:
# def get_sample_waveform(dataset, index=0):
#     """Helper to get raw and augmented versions of same sample"""
#     # Create non-augmented version
#     raw_dataset = AudioDataset(
#         root_dir=dataset.config.dirs["train"],
#         config=dataset.config,
#         augment=False
#     )
    
#     # Get raw waveform
#     raw_waveform, label = raw_dataset[index]
    
#     # Get augmented waveform (create new dataset with augment=True)
#     aug_dataset = AudioDataset(
#         root_dir=dataset.config.dirs["train"],
#         config=dataset.config,
#         augment=True
#     )
#     aug_waveform, _ = aug_dataset[index]
    
#     return raw_waveform, aug_waveform


In [None]:
# from IPython.display import Audio, display

# def play_comparison(raw, aug, sr=16000):
#     print("Original:")
#     display(Audio(raw.numpy().squeeze(), rate=sr))
#     print("Augmented:")
#     display(Audio(aug.numpy().squeeze(), rate=sr))

In [None]:
ray.shutdown()
ray.init()
os.environ["RAY_memory_monitor_refresh_ms"]="0"
# os.environ["RAY_memory_usage_threshold"] = "0.98"


In [None]:
def weighted_avg(metrics, metric_name):
    """Helper function for metric aggregation"""
    values = []
    weights = []
    for num_examples, m in metrics:
        if metric_name in m:
            values.append(m[metric_name])
            weights.append(num_examples)
    return sum(v * w for v, w in zip(values, weights)) / sum(weights) if weights else 0.0


## Main function

In [None]:
from flwr.server.client_manager import SimpleClientManager  # Import missing class
from torch.utils.data import random_split

def main():
    try:
        print("Starting federated learning setup")
        
        # Load datasets
        train_dataset = AudioDataset(config.dirs["train"], config, augment = True)
        train_size = int(1* len(train_dataset))
        train_subset, _ = random_split(
            train_dataset,
            [train_size, len(train_dataset) - train_size],
            generator=torch.Generator().manual_seed(42)
        )

        # if len(train_dataset) > 0:
        #     # Get first sample's raw and augmented versions
        #     raw_wave, aug_wave = get_sample_waveform(train_dataset, 2004)
        #     # Plot comparison
        #     plot_comparison(raw_wave, aug_wave)
        #     play_comparison(raw_wave, aug_wave, config.sample_rate)
        # else:
        #     print("Dataset is empty - check your data paths!")
            
        test_dataset = AudioDataset(config.dirs["test"], config)
        test_loader = DataLoader(test_dataset,
                                 batch_size=config.batch_size,
                                 collate_fn=collate_fn,
                                 shuffle=False)

        # Create metrics collector
        metrics_collector = DiskMetricsCollector()
        
        # Split into client partitions
        indices = list(range(len(train_subset)))
        chunk_size = len(indices) // config.num_clients
        
        # Create client datasets
        client_datasets = [
            Subset(train_subset, indices[i*chunk_size:(i+1)*chunk_size])
            for i in range(config.num_clients)
        ]

        print(f"Using {len(train_subset)}/{len(train_dataset)} training samples")
        print(f"Test samples: {len(test_dataset)}")

        # Define client creation function
        def client_fn(cid: str) -> FlowerClient:
            """Create client with fresh model instance"""
            client_id = int(cid)
            torch.cuda.empty_cache()
            model = HuBERTClassifier(config).to(config.device)  # Ensure model is on correct device
            numpy_client = FlowerClient(
                model=model,
                train_data=client_datasets[client_id],
                test_data=test_dataset,
                config=config,
                client_id=client_id,
                metrics_collector=metrics_collector
            )
            return numpy_client.to_client()

        # Initialize global model
        initial_model = HuBERTClassifier(config).to(config.device)
        initial_params = fl.common.ndarrays_to_parameters([
            val.cpu().numpy() for _, val in initial_model.state_dict().items()
        ])

        # Configure strategy with proper aggregation
        strategy = CustomFedAvg(
            metrics_collector=metrics_collector,
            min_fit_clients=config.num_clients,
            min_available_clients=config.num_clients,
            initial_parameters=initial_params,
            fit_metrics_aggregation_fn=lambda metrics: {
                "train_loss": weighted_avg(metrics, "loss"),
                "train_accuracy": weighted_avg(metrics, "accuracy")
            },
            evaluate_metrics_aggregation_fn=lambda metrics: {
                "test_loss": weighted_avg(metrics, "loss"),
                "test_accuracy": weighted_avg(metrics, "accuracy")
            },

        )

        print("Starting federated training")
        history = fl.simulation.start_simulation(
            client_fn=client_fn,
            num_clients=config.num_clients,
            config=fl.server.ServerConfig(num_rounds=config.num_rounds),
            strategy=strategy,
            client_resources={
                "num_cpus": 0.8,
                "num_gpus": 0.25 if torch.cuda.is_available() else 0
            },
        )

        print("\nGenerating visualizations...")
        for round_num in range(1, config.num_rounds+1):
            metrics_collector.plot_round(round_num)
            print("\nPlotting training metrics...")
            plot_training_metrics(history)

        print("\nPerforming central evaluation...")
        final_model = HuBERTClassifier(config).to(config.device)
        params_obj = strategy.get_parameters(None)
        
        if params_obj is None:
            print("Using initial parameters for evaluation")
            params_obj = initial_params
            raise ValueError("No parameters found in strategy")
        final_params = fl.common.parameters_to_ndarrays(params_obj)

        # Load parameters into the model
        sd = final_model.state_dict()
        param_names = [name for name, _ in final_model.named_parameters()]
        assert len(final_params) == len(param_names), \
            f"Parameter count mismatch: Model has {len(param_names)}, strategy supplied {len(final_params)}"
        
        for (name, _), array in zip(final_model.named_parameters(), final_params):
            sd[name] = torch.from_numpy(array).to(config.device)

        final_model.load_state_dict(sd, strict=True)
        final_model = final_model.to(config.device)

        # Device verification
        print("\n=== Device Verification ===")
        print(f"Model device: {next(final_model.parameters()).device}")
        sample_batch = next(iter(test_loader))[0]
        print(f"Data device: {sample_batch.device}")

        try:
            y_true, y_pred, y_probs = central_evaluation(final_model, test_loader, config.device)
        except Exception as e:
            print(f"Evaluation failed: {str(e)}")
            print("Model architecture:", final_model)
            print("Sample input shape:", sample_batch.shape)
            raise
        
        # Move test data to correct device
        for batch in test_loader:
            inputs, labels = batch
            inputs = inputs.to(config.device)
            labels = labels.to(config.device)

         # Handle binary probabilities
        print(f"\nProbability matrix shape: {y_probs.shape}")
        if y_probs.shape[1] == 1:
            y_probs = np.hstack([1 - y_probs, y_probs])

        class_names = ["Real", "Fake"]  # Define your classes
        plot_confusion_matrix(y_true, y_pred, class_names)
        plot_roc_curve(y_true, y_probs, class_names)

        
        from sklearn.metrics import classification_report
        import pandas as pd
        
        print("\n=== Class-wise Performance Metrics ===")
        report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
        report_df = pd.DataFrame(report).transpose()
        print("===full report===\n")
        print(report_df)
        print("\n===selected report===\n")
        print(report_df[["precision", "recall", "f1-score"]].round(2))

        return history

    except Exception as e:
        print(f"Critical failure: {str(e)}")
        raise


In [None]:
if __name__ == "__main__":
    history = main()  # Return history from main()
    # Additional analysis using history can happen here

In [None]:
# # Access metrics directly
# all_metrics = ray.get(metrics_collector.metrics.remote())

# # Plot specific round
# ray.get(metrics_collector.plot_round_metrics.remote(3))  # Plot round 3

# # Plot client 0's progress
# ray.get(metrics_collector.plot_client_progress.remote(0))
# # Load parameters from round 5

# params = torch.load("round_5_params.pth")

In [None]:
# ray.shutdown()

In [None]:
# m=DiskMetricsCollector()