In [18]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_curve
import soundfile as sf
from tqdm import tqdm
import matplotlib.pyplot as plt
import librosa
import random
from torch.distributions import Categorical
import wandb
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

set_seed()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Feature extraction functions
def extract_mfcc(audio, sr=16000, n_mfcc=20):
    """Extract MFCC features from audio"""
    mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=n_mfcc)
    delta = librosa.feature.delta(mfcc)
    delta2 = librosa.feature.delta(mfcc, order=2)
    features = np.concatenate([mfcc, delta, delta2], axis=0)
    return features

def extract_spec(audio, sr=16000, n_fft=512, hop_length=256):
    """Extract log mel-spectrogram features from audio"""
    mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=80)
    log_mel_spec = librosa.power_to_db(mel_spec)
    return log_mel_spec

def extract_cqt(audio, sr=16000, hop_length=256):
    """Extract Constant-Q Transform features from audio"""
    cqt = librosa.cqt(y=audio, sr=sr, hop_length=hop_length)
    return np.abs(cqt)

# ASVSpoof Dataset Class
class ASVSpoofDataset(Dataset):
    def __init__(self, root_dir, protocol_file, feature_type='mfcc', max_len=None, is_train=True):
        """
        Args:
            root_dir (string): Directory with all the audio files.
            protocol_file (string): Path to the protocol file.
            feature_type (string): Type of features to extract ('mfcc', 'spec', 'cqt').
            max_len (int): Maximum length of features sequence.
            is_train (bool): Whether this is for training or testing.
        """
        self.root_dir = root_dir
        self.feature_type = feature_type
        self.max_len = max_len
        self.is_train = is_train
        
        # Read protocol file
        self.data = []
        
        print(f"Reading protocol file: {protocol_file}")
        try:
            with open(protocol_file, 'r') as f:
                lines = f.readlines()
                
                # Use tqdm for loading progress
                for line in tqdm(lines, desc=f"Loading {'training' if is_train else 'evaluation'} protocol"):
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        speaker_id = parts[0]
                        file_id = parts[1]
                        label_text = parts[4]
                        label = 0 if label_text == 'bonafide' else 1  # 0 for bonafide, 1 for spoof
                        self.data.append((file_id, label))
            
            # Count number of bonafide and spoof samples
            bonafide_count = sum(1 for _, label in self.data if label == 0)
            spoof_count = sum(1 for _, label in self.data if label == 1)
            
            print(f"Dataset loaded: {len(self.data)} samples ({bonafide_count} bonafide, {spoof_count} spoof)")
            
            if is_train:
                # Subsample for faster NAS
                if len(self.data) > 5000:
                    print(f"Subsampling training data for faster NAS...")
                    np.random.shuffle(self.data)
                    # Keep balanced class distribution
                    bonafide_samples = [item for item in self.data if item[1] == 0][:2500]
                    spoof_samples = [item for item in self.data if item[1] == 1][:2500]
                    self.data = bonafide_samples + spoof_samples
                    np.random.shuffle(self.data)
                    print(f"Subsampled to {len(self.data)} samples")
        
        except Exception as e:
            print(f"Error loading protocol file: {e}")
            self.data = []
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        file_id, label = self.data[idx]
        audio_path = os.path.join(self.root_dir, f"{file_id}.flac")
        
        try:
            audio, sr = sf.read(audio_path)
            
            # Feature extraction
            if self.feature_type == 'mfcc':
                features = extract_mfcc(audio, sr)
            elif self.feature_type == 'spec':
                features = extract_spec(audio, sr)
            elif self.feature_type == 'cqt':
                features = extract_cqt(audio, sr)
            else:
                raise ValueError(f"Unknown feature type: {self.feature_type}")
            
            # Normalize features
            features = (features - np.mean(features)) / (np.std(features) + 1e-8)
            
            # Handle sequence length
            seq_len = features.shape[1]
            if self.max_len is not None:
                if seq_len > self.max_len:
                    start = np.random.randint(0, seq_len - self.max_len) if self.is_train else 0
                    features = features[:, start:start+self.max_len]
                elif seq_len < self.max_len:
                    # Pad with zeros
                    pad_width = ((0, 0), (0, self.max_len - seq_len))
                    features = np.pad(features, pad_width, mode='constant')
            
            return torch.FloatTensor(features), torch.LongTensor([label])[0]
            
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            # Return a dummy sample in case of error
            dummy_features = np.zeros((60, 100 if self.max_len is None else self.max_len))
            return torch.FloatTensor(dummy_features), torch.LongTensor([label])[0]

Using device: cuda


In [19]:
# Neural Architecture Search operations

# Base operation class
class Operation(nn.Module):
    """Base class for all operations in the search space"""
    def __init__(self, channels, stride=1):
        super(Operation, self).__init__()
        self.channels = channels
        self.stride = stride
    
    def forward(self, x):
        raise NotImplementedError

# Convolutional Block
class ConvBlock(Operation):
    def __init__(self, channels, kernel_size, stride=1):
        super(ConvBlock, self).__init__(channels, stride)
        self.conv = nn.Conv1d(channels, channels, kernel_size, stride=stride, padding=kernel_size//2)
        self.bn = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

# LSTM Block
class LSTM(Operation):
    def __init__(self, channels, stride=1):
        super(LSTM, self).__init__(channels, stride)
        self.lstm = nn.LSTM(channels, channels, batch_first=True)
        self.input_proj = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1)
    
    def forward(self, x):
        # x shape: [B, C, T]
        batch_size, channels, seq_len = x.size()
        x = x.permute(0, 2, 1)  # [B, C, T] -> [B, T, C]
        
        # LSTM with batch_first=True
        x, _ = self.lstm(x)
        
        # Return to original dimension ordering
        x = x.permute(0, 2, 1)  # [B, T, C] -> [B, C, T]
        
        return x

# Dilated Convolution
class Dilated(Operation):
    def __init__(self, channels, stride=1):
        super(Dilated, self).__init__(channels, stride)
        self.conv = nn.Conv1d(channels, channels, 3, stride=stride, padding=2, dilation=2)
        self.bn = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

# Skip Connection
class SkipConnect(Operation):
    def __init__(self, channels, stride=1):
        super(SkipConnect, self).__init__(channels, stride)
    
    def forward(self, x):
        return x

# Self-Attention Block
class Attention(Operation):
    def __init__(self, channels, stride=1):
        super(Attention, self).__init__(channels, stride)
        self.query = nn.Conv1d(channels, channels, 1)
        self.key = nn.Conv1d(channels, channels, 1)
        self.value = nn.Conv1d(channels, channels, 1)
        self.scale = torch.sqrt(torch.FloatTensor([channels])).to(device)
    
    def forward(self, x):
        # x shape: [B, C, T]
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        
        # Reshape for attention
        batch_size, C, T = q.size()
        q = q.permute(0, 2, 1)  # [B, T, C]
        k = k.permute(0, 2, 1)  # [B, T, C]
        v = v.permute(0, 2, 1)  # [B, T, C]
        
        # Self-attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        attention = F.softmax(scores, dim=-1)
        context = torch.matmul(attention, v)
        
        # Reshape back
        return context.permute(0, 2, 1)  # [B, C, T]

# Separable Convolution
class SeparableConv(Operation):
    def __init__(self, channels, stride=1):
        super(SeparableConv, self).__init__(channels, stride)
        self.depthwise = nn.Conv1d(channels, channels, 3, stride=stride, padding=1, groups=channels)
        self.pointwise = nn.Conv1d(channels, channels, 1)
        self.bn = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.bn(self.pointwise(self.depthwise(x))))

# Squeeze-and-Excitation Block
class SqueezeExcitation(Operation):
    def __init__(self, channels, stride=1, reduction=16):
        super(SqueezeExcitation, self).__init__(channels, stride)
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Conv1d(channels, max(channels // reduction, 1), kernel_size=1)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv1d(max(channels // reduction, 1), channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        # x shape: [B, C, T]
        scale = self.avg_pool(x)
        scale = self.fc1(scale)
        scale = self.relu(scale)
        scale = self.fc2(scale)
        scale = self.sigmoid(scale)
        return x * scale

# Frequency-Aware Convolution (Audio-specific)
class FrequencyAwareConv(Operation):
    def __init__(self, channels, stride=1, bands=4):
        super(FrequencyAwareConv, self).__init__(channels, stride)
        # Ensure band_size is at least 1
        self.band_size = max(channels // bands, 1)
        self.bands = min(bands, channels)
        
        # Create different kernel sizes for frequency bands
        self.convs = nn.ModuleList([
            nn.Conv1d(self.band_size, self.band_size, 3 + i*2, padding=(3+i*2)//2, stride=stride)
            for i in range(self.bands)
        ])
        
        self.bn = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        # Split along channel dimension into bands
        split_sizes = [self.band_size] * self.bands
        remaining = self.channels - (self.band_size * self.bands)
        if remaining > 0:
            split_sizes[-1] += remaining
            
        x_bands = torch.split(x, split_sizes, dim=1)
        
        # Process each band separately
        out_bands = []
        for i, band in enumerate(x_bands):
            if i < self.bands:
                out_bands.append(self.convs[i](band))
        
        # Concatenate results
        out = torch.cat(out_bands, dim=1)
        
        return self.relu(self.bn(out))

# Gated Convolution
class GatedConv(Operation):
    def __init__(self, channels, stride=1):
        super(GatedConv, self).__init__(channels, stride)
        self.conv_features = nn.Conv1d(channels, channels, 3, stride=stride, padding=1)
        self.conv_gate = nn.Conv1d(channels, channels, 3, stride=stride, padding=1)
        self.bn = nn.BatchNorm1d(channels)
    
    def forward(self, x):
        features = self.conv_features(x)
        gate = torch.sigmoid(self.conv_gate(x))
        return self.bn(features * gate)

# Mixed Operation (Weighted sum of operations)
class MixedOp(nn.Module):
    def __init__(self, channels, stride=1):
        super(MixedOp, self).__init__()
        self.ops = nn.ModuleList([
            # Original operations
            ConvBlock(channels, 3, stride),
            ConvBlock(channels, 5, stride),
            LSTM(channels, stride),
            Dilated(channels, stride),
            SkipConnect(channels, stride),
            Attention(channels, stride),
            # New operations
            SeparableConv(channels, stride),
            SqueezeExcitation(channels, stride),
            FrequencyAwareConv(channels, stride),
            GatedConv(channels, stride)
        ])
    
    def forward(self, x, weights):
        """Forward pass with operation weights"""
        return sum(w * op(x) for w, op in zip(weights, self.ops))

# Cell structure
class Cell(nn.Module):
    def __init__(self, channels, num_nodes=4):
        super(Cell, self).__init__()
        self.channels = channels
        self.num_nodes = num_nodes
        
        # For each node, create edges from all previous nodes
        self.edges = nn.ModuleList()
        for i in range(num_nodes):
            for j in range(i+1):  # connections from input and previous nodes
                self.edges.append(MixedOp(channels))
        
        # Output projection
        self.project = nn.Conv1d(channels * num_nodes, channels, 1)
    
    def forward(self, x, weights):
        """
        Forward pass through the cell
        Args:
            x: Input tensor [B, C, T]
            weights: List of weight tensors for each edge
        """
        states = [x]
        offset = 0
        
        # Process each node
        for i in range(self.num_nodes):
            # Gather inputs from previous nodes
            node_inputs = []
            for j in range(i+1):
                edge_output = self.edges[offset + j](states[j], weights[offset + j])
                node_inputs.append(edge_output)
            
            node_input = sum(node_inputs)
            offset += i+1
            states.append(node_input)
        
        # Concatenate all intermediate nodes
        cat_states = torch.cat(states[1:], dim=1)
        
        return self.project(cat_states)

In [20]:
# Complete model with hybrid PPO-DARTS support
class DeepfakeDetectionModel(nn.Module):
    def __init__(self, input_channels, num_cells=3, num_nodes=4, num_ops=10):
        super(DeepfakeDetectionModel, self).__init__()
        self.input_channels = input_channels
        self.num_cells = num_cells
        self.num_nodes = num_nodes
        self.num_ops = num_ops
        
        # Initial projection
        self.stem = nn.Sequential(
            nn.Conv1d(input_channels, 64, 3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )
        
        # Cells
        self.cells = nn.ModuleList()
        for i in range(num_cells):
            self.cells.append(Cell(64, num_nodes))
        
        # Classification head
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(64, 2)  # Binary classification
        
        # Initialize architectural parameters (alphas) for DARTS
        self._initialize_alphas()
        
        # Calculate total number of weights needed for PPO
        edges_per_cell = sum(range(1, num_nodes+1))
        self.total_weights = num_cells * edges_per_cell * num_ops
    
    def _initialize_alphas(self):
        """Initialize architectural parameters for DARTS"""
        edges_per_cell = sum(range(1, self.num_nodes+1))
        total_edges = self.num_cells * edges_per_cell
        # Create parameter tensor for alphas
        self._alphas = nn.Parameter(torch.zeros(total_edges, self.num_ops))
        # Initialize with small random values
        nn.init.normal_(self._alphas, mean=0, std=0.001)
    
    def alphas(self):
        """Return architectural parameters for optimizer"""
        return [self._alphas]  # Wrapped in list for optimizer compatibility
    
    def weights(self):
        """Return model weights excluding alphas"""
        return [p for n, p in self.named_parameters() if '_alphas' not in n]
    
    def forward(self, x, architecture_weights=None, discrete=False):
        """
        Forward pass with multiple modes:
        - PPO mode: Using external architecture_weights
        - DARTS mode: Using internal alphas with continuous relaxation
        - Evaluation mode: Using internal alphas with discrete operations
        """
        # Input shape handling
        if x.shape[1] == self.input_channels:
            # Input is already [B, C, T]
            pass
        else:
            # Input is [B, T, C], convert to [B, C, T]
            x = x.permute(0, 2, 1)
        
        # Process input
        x = self.stem(x)
        
        # Determine which weights to use
        edges_per_cell = sum(range(1, self.num_nodes+1))
        
        if architecture_weights is not None:
            # PPO mode: use external weights
            edge_weights = []
            for i in range(len(architecture_weights) // self.num_ops):
                start_idx = i * self.num_ops
                end_idx = start_idx + self.num_ops
                # Apply softmax to get probability distribution
                edge_weights.append(F.softmax(architecture_weights[start_idx:end_idx], dim=0))
        else:
            # DARTS mode: use internal alphas
            if discrete:
                # Convert to discrete (one-hot) for evaluation
                max_indices = torch.argmax(self._alphas, dim=1)
                edge_weights = []
                for j, idx in enumerate(max_indices):
                    weights = torch.zeros_like(self._alphas[j])
                    weights[idx] = 1.0
                    edge_weights.append(weights)
            else:
                # Use softmax for continuous relaxation
                edge_weights = [F.softmax(self._alphas[i], dim=0) for i in range(self._alphas.size(0))]
        
        # Process cells
        offset = 0
        for i, cell in enumerate(self.cells):
            cell_weights = edge_weights[offset:offset + edges_per_cell]
            offset += edges_per_cell
            x = cell(x, cell_weights)
        
        # Classification
        x = self.pool(x).squeeze(-1)
        x = self.classifier(x)
        
        return x

In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import networkx as nx
from matplotlib.patches import FancyArrowPatch

def visualize_architecture(architecture, num_cells=3, num_nodes=4, num_ops=10, save_path='architecture_visualization.png', 
                           save_to_wandb=False, title="Hybrid PPO-DARTS Architecture"):
    """
    Visualize the architecture discovered by the hybrid PPO-DARTS approach.
    
    Args:
        architecture: Tensor containing the selected operations (for PPO) or operation weights (for DARTS)
        num_cells: Number of cells in the architecture
        num_nodes: Number of intermediate nodes in each cell
        num_ops: Number of possible operations for each edge
        save_path: Path to save the visualization
        save_to_wandb: Whether to save the visualization to wandb
        title: Title of the visualization
    
    Returns:
        Path to the saved visualization
    """
    # Define operation names for visualization
    operation_names = [
        'Conv 3x3', 'Conv 5x5', 'LSTM', 'Dilated Conv', 'Skip Connect',
        'Self Attention', 'Separable Conv', 'Squeeze-Excitation', 'Frequency-Aware', 'Gated Conv'
    ]
    
    # Define colors for different operations
    colors = list(mcolors.TABLEAU_COLORS.values())
    
    # Calculate edges per cell
    edges_per_cell = sum(range(1, num_nodes+1))
    total_edges = num_cells * edges_per_cell
    
    # Check if architecture is from PPO (1D tensor of indices) or DARTS (2D tensor of weights)
    is_ppo = architecture.dim() == 1
    
    # Create a figure with multiple subplots (one per cell)
    fig, axes = plt.subplots(1, num_cells, figsize=(6*num_cells, 6), constrained_layout=True)
    if num_cells == 1:
        axes = [axes]
    
    plt.suptitle(title, fontsize=20, y=1.05)
    
    # For each cell
    for cell_idx in range(num_cells):
        ax = axes[cell_idx]
        G = nx.DiGraph()
        
        # Label for cell type
        cell_type = "Normal Cell" if cell_idx % 2 == 0 else "Expand Cell"
        ax.set_title(f"Cell {cell_idx+1}: {cell_type}", fontsize=16)
        
        # Add nodes to the graph
        for i in range(num_nodes + 2):  # +2 for input and output nodes
            if i == 0 or i == 1:
                G.add_node(i, label=f"Input {i+1}")
            elif i == num_nodes + 1:
                G.add_node(i, label="Output")
            else:
                G.add_node(i, label=f"Node {i}")
        
        # Add edges to the graph based on the architecture
        edge_offset = cell_idx * edges_per_cell
        edge_count = 0
        
        for i in range(2, num_nodes + 2):  # For each intermediate node
            for j in range(i):  # For all previous nodes
                edge_idx = edge_offset + edge_count
                
                if is_ppo:
                    # For PPO, the architecture contains operation indices
                    if edge_idx < len(architecture):
                        op_idx = int(architecture[edge_idx].item())
                        op_name = operation_names[op_idx]
                        G.add_edge(j, i, label=op_name, color=colors[op_idx])
                else:
                    # For DARTS, the architecture contains operation weights
                    if edge_idx < architecture.size(0):
                        op_idx = torch.argmax(architecture[edge_idx]).item()
                        op_name = operation_names[op_idx]
                        G.add_edge(j, i, label=op_name, color=colors[op_idx])
                
                edge_count += 1
        
        # Position nodes in a hierarchical layout
        pos = {}
        pos[0] = np.array([-1, 0.5])
        pos[1] = np.array([-1, -0.5])
        
        # Position intermediate nodes in a line
        for i in range(2, num_nodes + 2):
            level = (i - 1) / (num_nodes + 1)
            pos[i] = np.array([level*2 - 1, 0])
        
        # Adjust output node position
        pos[num_nodes + 1] = np.array([1, 0])
        
        # Draw nodes
        for n in G.nodes:
            nx.draw_networkx_nodes(G, pos, nodelist=[n], node_size=1200, 
                                  node_color='lightblue', alpha=0.8, ax=ax)
        
        # Draw node labels
        nx.draw_networkx_labels(G, pos, labels=nx.get_node_attributes(G, 'label'), 
                               font_size=10, font_family='sans-serif', ax=ax)
        
        # Draw edges with custom arrows
        for u, v, data in G.edges(data=True):
            color = data.get('color', 'gray')
            label = data.get('label', '')
            
            # Create a curved arrow
            arrow = FancyArrowPatch(pos[u], pos[v], connectionstyle="arc3,rad=0.2",
                                   arrowstyle="-|>", color=color, lw=1.5, alpha=0.8)
            ax.add_patch(arrow)
            
            # Add edge label (operation name)
            # Calculate label position (midpoint of the curved edge with slight offset)
            x = (pos[u][0] + pos[v][0]) / 2
            y = (pos[u][1] + pos[v][1]) / 2
            offset = 0.1 if pos[u][1] < pos[v][1] else -0.1
            ax.text(x, y + offset, label, fontsize=8, ha='center', va='center', 
                   bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1))
        
        # Remove axis ticks and frame
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
    
    # Save figure
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    # Log to wandb if requested
    if save_to_wandb:
        try:
            import wandb
            if wandb.run is not None:
                wandb.log({"architecture_visualization": wandb.Image(save_path)})
        except ImportError:
            print("Warning: wandb not installed, skipping wandb logging")
    
    plt.close()
    return save_path


def visualize_darts_architecture(architecture_weights, num_cells=3, num_nodes=4, num_ops=10, 
                                save_path='darts_architecture.png', save_to_wandb=False):
    """
    Visualize the DARTS architecture represented by architecture weights.
    
    Args:
        architecture_weights: Tensor of shape [num_edges, num_ops] containing operation weights
        num_cells: Number of cells in the architecture
        num_nodes: Number of nodes in each cell
        num_ops: Number of operations
    """
    return visualize_architecture(
        architecture_weights, 
        num_cells, 
        num_nodes, 
        num_ops, 
        save_path, 
        save_to_wandb,
        title="DARTS Architecture (Discrete)"
    )


def visualize_ppo_architecture(architecture_indices, num_cells=3, num_nodes=4, num_ops=10, 
                              save_path='ppo_architecture.png', save_to_wandb=False):
    """
    Visualize the PPO architecture represented by operation indices.
    
    Args:
        architecture_indices: Tensor containing operation indices for each edge
        num_cells: Number of cells in the architecture
        num_nodes: Number of nodes in each cell
        num_ops: Number of operations
    """
    return visualize_architecture(
        architecture_indices, 
        num_cells, 
        num_nodes, 
        num_ops, 
        save_path, 
        save_to_wandb,
        title="PPO-Generated Architecture"
    )


def compare_architectures(ppo_architecture, darts_architecture, num_cells=3, num_nodes=4, num_ops=10,
                         save_path='architecture_comparison.png', save_to_wandb=False):
    """
    Create a visualization comparing PPO and DARTS architectures side by side.
    
    Args:
        ppo_architecture: Architecture tensor from PPO
        darts_architecture: Architecture tensor from DARTS
        num_cells: Number of cells
        num_nodes: Number of nodes per cell
        num_ops: Number of operations
    """
    # Define operation names
    operation_names = [
        'Conv 3x3', 'Conv 5x5', 'LSTM', 'Dilated Conv', 'Skip Connect',
        'Self Attention', 'Separable Conv', 'Squeeze-Excitation', 'Frequency-Aware', 'Gated Conv'
    ]
    
    # Create a figure with a grid of subplots
    fig, axes = plt.subplots(2, num_cells, figsize=(6*num_cells, 12), constrained_layout=True)
    plt.suptitle("PPO vs DARTS Architecture Comparison", fontsize=24, y=1.05)
    
    # Top row for PPO
    for cell_idx in range(num_cells):
        axes[0, cell_idx].set_title(f"PPO Cell {cell_idx+1}", fontsize=16)
    
    # Bottom row for DARTS
    for cell_idx in range(num_cells):
        axes[1, cell_idx].set_title(f"DARTS Cell {cell_idx+1}", fontsize=16)
    
    # Save paths for individual visualizations
    ppo_save_path = 'ppo_temp.png'
    darts_save_path = 'darts_temp.png'
    
    # Generate the individual visualizations
    visualize_ppo_architecture(ppo_architecture, num_cells, num_nodes, num_ops, ppo_save_path)
    visualize_darts_architecture(darts_architecture, num_cells, num_nodes, num_ops, darts_save_path)
    
    # Create a combined visualization
    from PIL import Image
    ppo_img = Image.open(ppo_save_path)
    darts_img = Image.open(darts_save_path)
    
    # Create a new combined image
    combined_width = max(ppo_img.width, darts_img.width)
    combined_height = ppo_img.height + darts_img.height + 50  # Extra space for title
    combined_img = Image.new('RGB', (combined_width, combined_height), color='white')
    
    # Add title
    from PIL import ImageDraw, ImageFont
    draw = ImageDraw.Draw(combined_img)
    try:
        font = ImageFont.truetype("arial.ttf", 36)
    except:
        font = ImageFont.load_default()
    
    draw.text((combined_width//2 - 200, 10), "Architecture Comparison", fill="black", font=font)
    
    # Paste the individual images
    combined_img.paste(ppo_img, (0, 50))
    combined_img.paste(darts_img, (0, 50 + ppo_img.height))
    
    # Save the combined image
    combined_img.save(save_path)
    
    # Clean up temporary files
    os.remove(ppo_save_path)
    os.remove(darts_save_path)
    
    # Log to wandb if requested
    if save_to_wandb:
        try:
            import wandb
            if wandb.run is not None:
                wandb.log({"architecture_comparison": wandb.Image(save_path)})
        except ImportError:
            print("Warning: wandb not installed, skipping wandb logging")
    
    return save_path


def analyze_architecture_statistics(best_architecture, is_ppo=True, num_cells=3, num_nodes=4, num_ops=10):
    """
    Analyze the statistics of the discovered architecture.
    
    Args:
        best_architecture: The architecture tensor (PPO indices or DARTS weights)
        is_ppo: Whether the architecture is from PPO (True) or DARTS (False)
        num_cells: Number of cells
        num_nodes: Number of nodes per cell
        num_ops: Number of operations
    
    Returns:
        Dictionary of statistics about the architecture
    """
    operation_names = [
        'Conv 3x3', 'Conv 5x5', 'LSTM', 'Dilated Conv', 'Skip Connect',
        'Self Attention', 'Separable Conv', 'Squeeze-Excitation', 'Frequency-Aware', 'Gated Conv'
    ]
    
    # Calculate edges per cell
    edges_per_cell = sum(range(1, num_nodes+1))
    total_edges = num_cells * edges_per_cell
    
    # Initialize operation counts
    op_counts = {op: 0 for op in operation_names}
    
    # Count operations by cell
    op_counts_by_cell = []
    for cell_idx in range(num_cells):
        cell_op_counts = {op: 0 for op in operation_names}
        edge_offset = cell_idx * edges_per_cell
        
        for edge_idx in range(edges_per_cell):
            global_edge_idx = edge_offset + edge_idx
            
            if is_ppo:
                # For PPO architecture (indices)
                if global_edge_idx < len(best_architecture):
                    op_idx = int(best_architecture[global_edge_idx].item())
                    op_name = operation_names[op_idx]
                    op_counts[op_name] += 1
                    cell_op_counts[op_name] += 1
            else:
                # For DARTS architecture (weights)
                if global_edge_idx < best_architecture.size(0):
                    op_idx = torch.argmax(best_architecture[global_edge_idx]).item()
                    op_name = operation_names[op_idx]
                    op_counts[op_name] += 1
                    cell_op_counts[op_name] += 1
        
        op_counts_by_cell.append(cell_op_counts)
    
    # Calculate percentages
    total_ops = sum(op_counts.values())
    op_percentages = {op: count/total_ops*100 for op, count in op_counts.items()}
    
    # Find most and least common operations
    most_common_op = max(op_counts.items(), key=lambda x: x[1])[0]
    least_common_op = min(op_counts.items(), key=lambda x: x[1])[0]
    
    # Analyze patterns
    patterns = {}
    patterns["most_common_op"] = most_common_op
    patterns["most_common_percentage"] = op_percentages[most_common_op]
    patterns["least_common_op"] = least_common_op
    patterns["least_common_percentage"] = op_percentages[least_common_op]
    
    # Check for cell-specific patterns
    patterns["cell_specific_patterns"] = []
    for i, cell_counts in enumerate(op_counts_by_cell):
        cell_total = sum(cell_counts.values())
        if cell_total > 0:
            cell_most_common = max(cell_counts.items(), key=lambda x: x[1])[0]
            cell_percent = cell_counts[cell_most_common] / cell_total * 100
            if cell_percent > 40:  # If an operation dominates a cell
                patterns["cell_specific_patterns"].append(
                    f"Cell {i+1} ({['Normal', 'Expand'][i%2]}) uses {cell_most_common} for {cell_percent:.1f}% of connections"
                )
    
    # Return combined statistics
    statistics = {
        "operation_counts": op_counts,
        "operation_percentages": op_percentages,
        "patterns": patterns,
        "by_cell": op_counts_by_cell
    }
    
    return statistics


def plot_architecture_statistics(statistics, save_path='architecture_stats.png', save_to_wandb=False):
    """
    Create visualizations of architecture statistics.
    
    Args:
        statistics: Output from analyze_architecture_statistics
        save_path: Path to save the visualization
        save_to_wandb: Whether to log to wandb
    """
    # Create a figure with subplots
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    # Plot operation counts
    op_counts = statistics["operation_counts"]
    sorted_ops = sorted(op_counts.items(), key=lambda x: x[1], reverse=True)
    ops, counts = zip(*sorted_ops)
    
    axes[0].bar(ops, counts, color='skyblue')
    axes[0].set_title('Operation Counts', fontsize=16)
    axes[0].set_ylabel('Count')
    axes[0].tick_params(axis='x', rotation=45)
    
    # Annotate bars with counts
    for i, v in enumerate(counts):
        axes[0].text(i, v + 0.1, str(v), ha='center')
    
    # Plot operation percentages by cell
    cell_data = statistics["by_cell"]
    cell_names = [f"Cell {i+1}\n({'Normal' if i%2==0 else 'Expand'})" for i in range(len(cell_data))]
    
    # Get all operations used
    all_ops = set()
    for cell in cell_data:
        for op, count in cell.items():
            if count > 0:
                all_ops.add(op)
    
    # Create a grouped bar chart
    x = np.arange(len(cell_names))
    width = 0.8 / len(all_ops)
    
    # Sort operations by overall frequency
    sorted_all_ops = sorted(all_ops, key=lambda op: -statistics["operation_counts"].get(op, 0))
    
    for i, op in enumerate(sorted_all_ops):
        values = [cell.get(op, 0) for cell in cell_data]
        offset = i * width - (len(all_ops) - 1) * width / 2
        axes[1].bar(x + offset, values, width, label=op)
    
    axes[1].set_title('Operations by Cell Type', fontsize=16)
    axes[1].set_xticks(x)
    axes[1].set_xticklabels(cell_names)
    axes[1].set_ylabel('Count')
    axes[1].legend(loc='upper right', bbox_to_anchor=(1.15, 1))
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    # Log to wandb if requested
    if save_to_wandb:
        try:
            import wandb
            if wandb.run is not None:
                wandb.log({"architecture_statistics": wandb.Image(save_path)})
        except ImportError:
            print("Warning: wandb not installed, skipping wandb logging")
    
    plt.close()
    return save_path


# # Example usage:
# if __name__ == "__main__":
#     # Example of using these functions with your model:
    
#     # 1. Load your best architecture from checkpoint
#     checkpoint_path = 'best_hybrid_model.pth'
    
#     if os.path.exists(checkpoint_path):
#         checkpoint = torch.load(checkpoint_path)
        
#         # Check if it's a PPO or DARTS architecture
#         if 'ppo_architecture' in checkpoint:
#             # PPO architecture
#             architecture = checkpoint['ppo_architecture']
#             is_ppo = True
#             best_mode = 'PPO'
#         elif 'darts_alphas' in checkpoint:
#             # DARTS architecture
#             architecture = checkpoint['darts_alphas']
#             is_ppo = False
#             best_mode = 'DARTS'
#         else:
#             # Example random architecture for demonstration
#             print("No architecture found in checkpoint, generating random example")
#             architecture = torch.randint(0, 10, (30,))
#             is_ppo = True
#             best_mode = 'Random'
            
#         # Visualize the architecture
#         if is_ppo:
#             save_path = visualize_ppo_architecture(architecture, save_to_wandb=True)
#         else:
#             save_path = visualize_darts_architecture(architecture, save_to_wandb=True)
            
#         print(f"Architecture visualization saved to {save_path}")
        
#         # Analyze architecture statistics
#         stats = analyze_architecture_statistics(architecture, is_ppo=is_ppo)
#         stats_path = plot_architecture_statistics(stats, save_to_wandb=True)
        
#         print(f"Architecture statistics saved to {stats_path}")
        
#         # Print key findings
#         patterns = stats["patterns"]
#         print(f"\nArchitecture Analysis ({best_mode} architecture):")
#         print(f"Most common operation: {patterns['most_common_op']} ({patterns['most_common_percentage']:.1f}%)")
#         print(f"Least common operation: {patterns['least_common_op']} ({patterns['least_common_percentage']:.1f}%)")
        
#         if patterns["cell_specific_patterns"]:
#             print("\nCell-specific patterns:")
#             for pattern in patterns["cell_specific_patterns"]:
#                 print(f"- {pattern}")
#     else:
#         print(f"Checkpoint file {checkpoint_path} not found.")
#         print("To visualize your architecture, run this script after training or modify the path.")

In [21]:
# PPO Controller for architecture search
class PPOController(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(PPOController, self).__init__()
        
        # Print dimensions for debugging
        print(f"Initializing PPO controller with state_dim={state_dim}, action_dim={action_dim}")
        
        # Actor network (policy)
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, action_dim)
        )
        
        # Critic network (value function)
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, state):
        # Returns action probabilities and estimated value
        action_probs = F.softmax(self.actor(state), dim=-1)
        value = self.critic(state)
        return action_probs, value
    
    def act(self, state):
        action_probs, _ = self.forward(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.detach(), log_prob.detach()

# PPO Training function
def train_ppo(controller, optimizer, memories, clip_ratio=0.2, epochs=10, entropy_coef=0.01):
    """Train the PPO controller on collected experiences"""
    # Unpack memories
    states = torch.cat([m['state'] for m in memories])
    actions = torch.cat([m['action'] for m in memories])
    old_log_probs = torch.cat([m['log_prob'] for m in memories])
    rewards = torch.cat([m['reward'] for m in memories])
    
    # Normalize rewards for stable training
    if rewards.std() > 0:
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
    
    # Store metrics for logging
    metrics = {
        'actor_loss': 0,
        'critic_loss': 0,
        'entropy_loss': 0
    }
    
    # Train for multiple epochs
    for _ in range(epochs):
        # Evaluate current policy
        log_probs = []
        values = []
        entropy = []
        
        for i in range(len(states)):
            state_i = states[i:i+1]
            action_i = actions[i]
            
            # Get action probabilities and value
            action_probs, value = controller(state_i)
            
            # Create categorical distribution
            dist = Categorical(action_probs)
            
            # Get log probability and entropy
            log_prob = dist.log_prob(action_i)
            entropy_i = dist.entropy()
            
            log_probs.append(log_prob)
            values.append(value.squeeze())
            entropy.append(entropy_i)
        
        # Stack results
        log_probs = torch.stack(log_probs)
        values = torch.stack(values)
        entropy = torch.stack(entropy)
        
        # Compute ratio and surrogate loss
        ratio = torch.exp(log_probs - old_log_probs)
        surr1 = ratio * rewards
        surr2 = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * rewards
        
        # PPO losses
        actor_loss = -torch.min(surr1, surr2).mean()
        critic_loss = F.mse_loss(values, rewards)
        entropy_loss = -entropy.mean()
        
        # Total loss
        loss = actor_loss + 0.5 * critic_loss - entropy_coef * entropy_loss
        
        # Update controller
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update metrics
        metrics['actor_loss'] += actor_loss.item() / epochs
        metrics['critic_loss'] += critic_loss.item() / epochs
        metrics['entropy_loss'] += entropy_loss.item() / epochs
    
    return metrics['actor_loss'], metrics['critic_loss'], metrics['entropy_loss']

In [None]:
# Evaluation helper function
def evaluate_architecture(model, val_loader, device, architecture_weights=None, discrete=False):
    """Evaluate the performance of an architecture"""
    model.eval()
    all_targets = []
    all_scores = []
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            try:
                inputs, targets = inputs.to(device), targets.to(device)
                if torch.isnan(inputs).any():
                    inputs = torch.nan_to_num(inputs, nan=0.0)
                
                # Use the appropriate mode
                if architecture_weights is not None:
                    # PPO mode with external weights
                    outputs = model(inputs, architecture_weights)
                else:
                    # DARTS mode with internal alphas
                    outputs = model(inputs, discrete=discrete)
                
                scores = F.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                all_targets.extend(targets.cpu().numpy())
                all_scores.extend(scores)
                
            except Exception as e:
                print(f"Error in evaluation: {e}")
                continue
    
    # Calculate EER with robust error handling
    try:
        if len(all_targets) > 0 and len(all_scores) > 0:
            unique_targets = np.unique(all_targets)
            if len(unique_targets) >= 2:
                fpr, tpr, thresholds = roc_curve(all_targets, all_scores, pos_label=1)
                fnr = 1 - tpr
                idx = np.nanargmin(np.absolute(fnr - fpr))
                eer = (fpr[idx] + fnr[idx]) / 2
            else:
                eer = 0.5
        else:
            eer = 0.5
    except Exception as e:
        print(f"Error calculating EER: {e}")
        eer = 0.5
    
    return eer

# Hybrid PPO-DARTS search function
def search_architecture_hybrid(train_loader, val_loader, device, input_channels=60, num_cells=3, 
                              num_nodes=4, num_ops=10, epochs=30, ppo_updates=5, 
                              project_name="deepfake-nas-hybrid"):
    """Hybrid approach combining PPO for exploration with DARTS for optimization"""
    # Initialize wandb
    wandb.init(project=project_name, name=f"Hybrid_PPO_DARTS_cells{num_cells}_nodes{num_nodes}")
    
    # Log hyperparameters
    config = {
        "input_channels": input_channels,
        "num_cells": num_cells,
        "num_nodes": num_nodes,
        "num_ops": num_ops,
        "epochs": epochs,
        "w_lr": 0.001,        # Weight learning rate
        "alpha_lr": 0.0003,   # Architecture parameter learning rate
        "ppo_lr": 0.0005,     # PPO controller learning rate
        "ppo_updates": ppo_updates,
        "exploration_ratio": 0.3,  # Ratio of epochs to use PPO exploration
        "visualization_enabled": True  # Enable visualization
    }
    wandb.config.update(config)
    
    # Initialize model with expanded operation set
    model = DeepfakeDetectionModel(input_channels, num_cells, num_nodes, num_ops).to(device)
    
    # Calculate edges for PPO controller
    edges_per_cell = sum(range(1, num_nodes+1))
    total_edges = num_cells * edges_per_cell
    
    # Initialize PPO controller for exploration
    state_dim = 1  # Single value for validation performance
    action_dim = num_ops  # Number of operations per edge
    controller = PPOController(state_dim, action_dim).to(device)
    controller_optimizer = optim.Adam(controller.parameters(), lr=config["ppo_lr"])
    
    # Setup optimizers for DARTS
    w_optimizer = optim.Adam(model.weights(), lr=config["w_lr"], weight_decay=3e-4)
    w_scheduler = optim.lr_scheduler.CosineAnnealingLR(w_optimizer, epochs)
    alpha_optimizer = optim.Adam(model.alphas(), lr=config["alpha_lr"], betas=(0.5, 0.999), weight_decay=1e-3)
    
    # Metrics tracking
    best_val_eer = 1.0
    best_architecture = None
    best_mode = None
    
    with tqdm(total=epochs, desc="Hybrid Search Progress", position=0, leave=True) as epoch_pbar:
        for epoch in range(epochs):
            # Determine exploration mode for this epoch
            # More exploration in early stages, more exploitation later
            use_ppo = (random.random() < config["exploration_ratio"] * (1 - epoch/epochs))
            
            # PPO exploration phase
            if use_ppo:
                model.train()
                train_loss = 0.0
                batch_count = 0
                
                # For PPO
                memories = []
                current_architecture = []
                



                # Sample architecture using PPO
                for i in range(total_edges):
                    # Use current validation EER as state
                    state = torch.FloatTensor([min(best_val_eer, 0.5) * 2]).to(device)
                    
                    # Sample architecture weights for this edge
                    for j in range(num_ops):
                        action, log_prob = controller.act(state)
                        current_architecture.append(action.item())
                        
                        # Store experience for PPO
                        memories.append({
                            'state': state.clone(),
                            'action': action.unsqueeze(0),
                            'log_prob': log_prob.unsqueeze(0),
                            'reward': torch.zeros(1).to(device)  # Updated later
                        })
                
                # Convert architecture to tensor for PPO mode
                architecture_weights = torch.FloatTensor(current_architecture).to(device)
                
                # Train model with PPO-generated architecture
                for inputs, targets in train_loader:
                    batch_count += 1
                    if batch_count % 10 == 0:
                        print(f"\rPPO Training batch {batch_count}/{len(train_loader)}", end="")
                    
                    inputs, targets = inputs.to(device), targets.to(device)
                    if torch.isnan(inputs).any():
                        inputs = torch.nan_to_num(inputs, nan=0.0)
                    
                    # Update weights
                    w_optimizer.zero_grad()
                    outputs = model(inputs, architecture_weights)  # Use PPO architecture
                    loss = F.cross_entropy(outputs, targets)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.weights(), max_norm=1.0)
                    w_optimizer.step()
                    
                    train_loss += loss.item()
                
                print()  # Line break after training
                
                # Evaluate architecture from PPO
                val_eer = evaluate_architecture(model, val_loader, device, architecture_weights)
                
                # Update PPO controller based on performance
                reward = best_val_eer - val_eer if val_eer < best_val_eer else 0
                for memory in memories:
                    memory['reward'] = torch.FloatTensor([reward]).to(device)
                
                # Update best architecture if improved
                if val_eer < best_val_eer:
                    best_val_eer = val_eer
                    best_architecture = architecture_weights.clone()
                    best_mode = 'ppo'
                    print(f"\nNew best architecture found via PPO! EER: {best_val_eer:.4f}")
                
                # Update PPO controller
                if epoch % config["ppo_updates"] == 0 and memories:
                    actor_loss, critic_loss, entropy_loss = train_ppo(
                        controller, controller_optimizer, memories)
                    
                    # Log PPO metrics
                    wandb.log({
                        "actor_loss": actor_loss,
                        "critic_loss": critic_loss,
                        "entropy_loss": entropy_loss
                    })
            
            # DARTS optimization phase
            else:
                model.train()
                train_loss = 0.0
                batch_count = 0
                
                # Phase 1: Train model weights using DARTS approach
                for inputs, targets in train_loader:
                    batch_count += 1
                    if batch_count % 10 == 0:
                        print(f"\rDARTS Weight Training batch {batch_count}/{len(train_loader)}", end="")
                    
                    inputs, targets = inputs.to(device), targets.to(device)
                    if torch.isnan(inputs).any():
                        inputs = torch.nan_to_num(inputs, nan=0.0)
                    
                    # Update weights with internal alphas
                    w_optimizer.zero_grad()
                    outputs = model(inputs)  # Use alphas without external weights
                    loss = F.cross_entropy(outputs, targets)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.weights(), max_norm=1.0)
                    w_optimizer.step()
                    
                    train_loss += loss.item()
                
                print()  # Line break after training
                
                # Phase 2: Update architecture parameters on validation set
                model.train()  # Keep in train mode for alpha updates
                val_batch_count = 0
                
                for inputs, targets in val_loader:
                    # Use a subset of validation data
                    if random.random() > 0.2:  # Sample ~20% for alpha updates
                        continue
                        
                    val_batch_count += 1
                    if val_batch_count > 50:  # Limit validation batches for speed
                        break
                    
                    inputs, targets = inputs.to(device), targets.to(device)
                    if torch.isnan(inputs).any():
                        inputs = torch.nan_to_num(inputs, nan=0.0)
                    
                    # Update alphas
                    alpha_optimizer.zero_grad()
                    outputs = model(inputs)  # Use internal alphas
                    loss = F.cross_entropy(outputs, targets)
                    loss.backward()
                    alpha_optimizer.step()
                
                # Evaluate DARTS architecture
                val_eer = evaluate_architecture(model, val_loader, device, discrete=True)
                
                # Update best architecture if improved
                if val_eer < best_val_eer:
                    best_val_eer = val_eer
                    # Save alphas as best architecture
                    best_architecture = model._alphas.detach().clone()
                    best_mode = 'darts'
                    print(f"\nNew best architecture found via DARTS! EER: {best_val_eer:.4f}")
            
            # Update learning rate for weights
            w_scheduler.step()
            
            # Update epoch progress bar
            epoch_pbar.update(1)
            epoch_pbar.set_postfix({
                "Mode": "PPO" if use_ppo else "DARTS",
                "Val EER": f"{val_eer:.4f}",
                "Best EER": f"{best_val_eer:.4f}"
            })
            
            # Log metrics to wandb
            avg_train_loss = train_loss / max(batch_count, 1)
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "val_eer": val_eer,
                "best_val_eer": best_val_eer,
                "mode": "PPO" if use_ppo else "DARTS",
                "learning_rate": w_optimizer.param_groups[0]['lr']
            })
            
            # Save checkpoint for the best model
            if val_eer <= best_val_eer:
                checkpoint_path = 'best_hybrid_model.pth'
                if use_ppo:
                    # Save PPO-generated architecture
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'ppo_architecture': best_architecture,
                        'eer': best_val_eer,
                        'epoch': epoch + 1,
                        'mode': 'PPO'
                    }, checkpoint_path)
                    # Add visualization during training
                    if epoch % 5 == 0:  # Only create visualizations periodically to save time
                        vis_path = visualize_ppo_architecture(
                        best_architecture, 
                        num_cells=num_cells,
                        num_nodes=num_nodes,
                        num_ops=num_ops,
                        save_path=f"ppo_arch_epoch_{epoch}.png",
                        save_to_wandb=True
                        )
                else:
                    # Save DARTS-generated architecture
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'darts_alphas': model._alphas,
                        'eer': best_val_eer,
                        'epoch': epoch + 1,
                        'mode': 'DARTS'
                    }, checkpoint_path)
                    # Add visualization during training
                    if epoch % 5 == 0:  # Only create visualizations periodically
                        vis_path = visualize_darts_architecture(
                        model._alphas,
                        num_cells=num_cells,
                        num_nodes=num_nodes, 
                        num_ops=num_ops,
                        save_path=f"darts_arch_epoch_{epoch}.png",
                        save_to_wandb=True
                        )
                
                # Log best model to wandb
                wandb.save(checkpoint_path)
    
    # Finish wandb run
    wandb.finish()
    
    # Return the best architecture (either from PPO or DARTS)
    final_model = model
    final_architecture = best_architecture
    
    return final_model, final_architecture, best_val_eer

In [23]:
def evaluate_model(model, architecture, test_loader, device, log_to_wandb=True):
    """Evaluate the model with the best architecture"""
    model.eval()
    all_targets = []
    all_scores = []
    test_loss = 0.0
    correct = 0
    total = 0
    
    # Determine if architecture is from PPO or DARTS
    is_ppo_arch = architecture.dim() == 1
    
    # Create a progress bar for evaluation
    eval_pbar = tqdm(test_loader, desc="Evaluating")
    
    with torch.no_grad():
        for inputs, targets in eval_pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Choose correct evaluation mode
            if is_ppo_arch:
                # PPO architecture
                outputs = model(inputs, architecture)
            else:
                # DARTS architecture - set model's alphas and use discrete mode
                model._alphas.data = architecture.data
                outputs = model(inputs, discrete=True)
            
            # Compute loss
            loss = F.cross_entropy(outputs, targets)
            test_loss += loss.item()
            
            # Compute accuracy
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Get scores for EER calculation
            scores = F.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            all_targets.extend(targets.cpu().numpy())
            all_scores.extend(scores)
            
            # Update progress bar
            eval_pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "accuracy": f"{100.0 * correct / total:.2f}%"
            })
    
    # Calculate test accuracy
    test_accuracy = 100.0 * correct / total
    avg_test_loss = test_loss / len(test_loader)
    
    # Calculate EER
    try:
        fpr, tpr, thresholds = roc_curve(all_targets, all_scores, pos_label=1)
        fnr = 1 - tpr
        idx = np.nanargmin(np.absolute(fnr - fpr))
        eer = (fpr[idx] + fnr[idx]) / 2
        eer_threshold = thresholds[idx]
    except Exception as e:
        print(f"Error calculating EER: {e}")
        eer = 0.5
        eer_threshold = 0.5
    
    # Plot ROC curve
    plt.figure(figsize=(10, 8))
    plt.plot(fpr, tpr, label=f'ROC Curve (EER = {eer:.4f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve for Deepfake Detection')
    plt.legend(loc="lower right")
    
    # Save ROC curve
    roc_curve_path = 'roc_curve.png'
    plt.savefig(roc_curve_path)
    
    # Print results
    print(f"\nTest Results - Loss: {avg_test_loss:.4f}, Accuracy: {test_accuracy:.2f}%, EER: {eer:.4f}")
    
    # Log to wandb if requested
    if log_to_wandb and wandb.run is not None:
        wandb.log({
            "test_loss": avg_test_loss,
            "test_accuracy": test_accuracy,
            "test_eer": eer,
            "eer_threshold": eer_threshold,
            "roc_curve": wandb.Image(roc_curve_path)
        })
        
        # Log confusion matrix
        cm = np.zeros((2, 2))
        for i in range(len(all_targets)):
            pred_class = 1 if all_scores[i] > eer_threshold else 0
            cm[all_targets[i]][pred_class] += 1
        
        # Normalize confusion matrix
        cm_norm = cm / np.maximum(cm.sum(axis=1, keepdims=True), 1e-8)
        
        # Plot confusion matrix
        plt.figure(figsize=(8, 6))
        plt.imshow(cm_norm, cmap='Blues')
        plt.colorbar()
        plt.title('Normalized Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.xticks([0, 1], ['Bonafide', 'Spoof'])
        plt.yticks([0, 1], ['Bonafide', 'Spoof'])
        
        # Add text annotations
        for i in range(2):
            for j in range(2):
                plt.text(j, i, f'{cm[i, j]:.0f}\n({cm_norm[i, j]:.2f})', 
                         ha='center', va='center', 
                         color='white' if cm_norm[i, j] > 0.5 else 'black')
        
        cm_path = 'confusion_matrix.png'
        plt.savefig(cm_path)
        wandb.log({"confusion_matrix": wandb.Image(cm_path)})
    
    return eer, eer_threshold

In [None]:
# ===========================
# Architecture Visualization
# ===========================
print("Visualizing discovered architectures...")

# Add the visualization code here
# [Insert the entire visualization code I provided]

# Then add the usage code:
if wandb.run is not None:
    print("Generating visualizations for best architecture...")
    
    # Load best architecture from checkpoint
    checkpoint_path = 'best_hybrid_model.pth'
    
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        
        # Determine architecture type
        if 'ppo_architecture' in checkpoint:
            architecture = checkpoint['ppo_architecture']
            is_ppo = True
            best_mode = 'PPO'
        elif 'darts_alphas' in checkpoint:
            architecture = checkpoint['darts_alphas']
            is_ppo = False
            best_mode = 'DARTS'
        
        # Create visualizations
        if is_ppo:
            save_path = visualize_ppo_architecture(architecture, save_to_wandb=True)
        else:
            save_path = visualize_darts_architecture(architecture, save_to_wandb=True)
            
        # Analyze architecture statistics
        stats = analyze_architecture_statistics(architecture, is_ppo=is_ppo)
        plot_architecture_statistics(stats, save_to_wandb=True)
        
        # Print key findings
        patterns = stats["patterns"]
        print(f"Architecture Analysis ({best_mode}):")
        print(f"Most common operation: {patterns['most_common_op']} ({patterns['most_common_percentage']:.1f}%)")
        
        if patterns["cell_specific_patterns"]:
            print("Cell-specific patterns detected!")
    else:
        print(f"Checkpoint file {checkpoint_path} not found.")

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_curve
import soundfile as sf
from tqdm import tqdm
import matplotlib.pyplot as plt
import librosa
import random
from torch.distributions import Categorical
import wandb
import warnings
warnings.filterwarnings('ignore')

def setup_progress_monitoring():
    """Setup enhanced progress monitoring"""
    import threading
    import time
    import psutil
    import os
    
    def monitor_resources():
        process = psutil.Process(os.getpid())
        start_time = time.time()
        while True:
            try:
                # CPU usage
                cpu_percent = process.cpu_percent(interval=1)
                # Memory usage
                memory_info = process.memory_info()
                memory_mb = memory_info.rss / (1024 * 1024)
                # GPU memory if available
                gpu_memory_mb = 0
                try:
                    if torch.cuda.is_available():
                        gpu_memory_mb = torch.cuda.memory_allocated() / (1024 * 1024)
                except:
                    pass
                
                elapsed = time.time() - start_time
                hours, remainder = divmod(elapsed, 3600)
                minutes, seconds = divmod(remainder, 60)
                
                print(f"\r[{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}] "
                      f"CPU: {cpu_percent:.1f}% | RAM: {memory_mb:.1f} MB | "
                      f"GPU: {gpu_memory_mb:.1f} MB", end="", flush=True)
                
                time.sleep(5)  # Update every 5 seconds
            except:
                break
    
    # Start monitoring in a background thread
    monitor_thread = threading.Thread(target=monitor_resources, daemon=True)
    monitor_thread.start()
    print("Resource monitoring started...")

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

# Main function follows
def main():
    # Start the resource monitoring first
    setup_progress_monitoring()
    
    # Then set the random seed
    set_seed()
    
    # Import time for experiment naming
    import time
    import os
    
    # Initialize wandb for the entire experiment
    experiment_name = f"ASVspoof2019_NAS_{int(time.time())}"
    
    print("Starting Deepfake Audio Detection with NAS")
    print("=" * 80)
    
    # Paths and parameters (based on the provided dataset structure)
    base_dir = "/kaggle/input/asvspoof-dataset-2019"
    data_dir_train = os.path.join(base_dir, "LA", "ASVspoof2019_LA_train", "flac")
    data_dir_dev = os.path.join(base_dir, "LA", "ASVspoof2019_LA_dev", "flac")
    data_dir_eval = os.path.join(base_dir, "LA", "ASVspoof2019_LA_eval", "flac")
    
    train_protocol = os.path.join(base_dir, "LA", "ASVspoof2019_LA_cm_protocols", "ASVspoof2019.LA.cm.train.trn.txt")
    dev_protocol = os.path.join(base_dir, "LA", "ASVspoof2019_LA_cm_protocols", "ASVspoof2019.LA.cm.dev.trl.txt")
    eval_protocol = os.path.join(base_dir, "LA", "ASVspoof2019_LA_cm_protocols", "ASVspoof2019.LA.cm.eval.trl.txt")
    
    # Log dataset information
    print(f"Train data directory: {data_dir_train}")
    print(f"Train protocol file: {train_protocol}")
    print(f"Dev data directory: {data_dir_dev}")
    print(f"Dev protocol file: {dev_protocol}")
    print(f"Eval data directory: {data_dir_eval}")
    print(f"Eval protocol file: {eval_protocol}")
    
    # Experiment configuration
    feature_type = 'mfcc'  # Options: 'mfcc', 'spec', 'cqt'
    max_seq_len = 400
    batch_size_train = 32
    batch_size_eval = 64
    num_workers = 4
    
    # Select search method: 'hybrid' for PPO+DARTS
    search_method = 'hybrid'

    # Initialize wandb with your API key (replace with your actual key)
    wandb.login(key="")
    
    wandb.init(
        project="ASVspoof2019-NAS",
        name=f"{experiment_name}_{search_method}",
        config={
            "feature_type": feature_type,
            "max_sequence_length": max_seq_len,
            "batch_size_train": batch_size_train,
            "batch_size_eval": batch_size_eval,
            "num_workers": num_workers,
            "dataset": "ASVspoof2019 LA",
            "search_method": search_method
        }
    )
    
    # Create datasets
    print("Creating datasets...")
    
    print("Loading training dataset...")
    train_dataset = ASVSpoofDataset(
        root_dir=data_dir_train,
        protocol_file=train_protocol,
        feature_type=feature_type,
        max_len=max_seq_len,
        is_train=True
    )
    
    print("Loading validation dataset...")
    dev_dataset = ASVSpoofDataset(
        root_dir=data_dir_dev,
        protocol_file=dev_protocol,
        feature_type=feature_type,
        max_len=max_seq_len,
        is_train=False
    )
    
    print("Loading evaluation dataset...")
    eval_dataset = ASVSpoofDataset(
        root_dir=data_dir_eval,
        protocol_file=eval_protocol,
        feature_type=feature_type,
        max_len=max_seq_len,
        is_train=False
    )
    
    # Log dataset sizes
    print(f"Training dataset size: {len(train_dataset)} samples")
    print(f"Validation dataset size: {len(dev_dataset)} samples")
    print(f"Evaluation dataset size: {len(eval_dataset)} samples")
    wandb.log({
        "train_dataset_size": len(train_dataset),
        "val_dataset_size": len(dev_dataset),
        "eval_dataset_size": len(eval_dataset)
    })
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size_train, 
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    dev_loader = DataLoader(
        dev_dataset, 
        batch_size=batch_size_eval, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    eval_loader = DataLoader(
        eval_dataset, 
        batch_size=batch_size_eval, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    # Architecture search parameters
    nas_config = {
        "input_channels": 60,  # 20 MFCCs x 3 (static, delta, delta-delta)
        "num_cells": 3,
        "num_nodes": 4,
        "num_ops": 10,  # Expanded to 10 operations
        "epochs": 30,
        "ppo_updates": 5,
        "project_name": "ASVspoof2019-NAS"
    }
    
    # Log NAS configuration
    print("\nNeural Architecture Search Configuration:")
    for key, value in nas_config.items():
        print(f"  {key}: {value}")
    
    # Create output directory for results
    output_dir = f"/kaggle/working/results_{experiment_name}_{search_method}"
    os.makedirs(output_dir, exist_ok=True)
    print(f"\nResults will be saved to {output_dir}")
    
    # Perform architecture search using hybrid PPO-DARTS approach
    print(f"\nStarting Neural Architecture Search using {search_method.upper()} method...")
    
    model, best_architecture, best_val_eer = search_architecture_hybrid(
        train_loader=train_loader,
        val_loader=dev_loader,
        device=device,
        input_channels=nas_config["input_channels"],
        num_cells=nas_config["num_cells"],
        num_nodes=nas_config["num_nodes"],
        num_ops=nas_config["num_ops"],
        epochs=nas_config["epochs"],
        ppo_updates=nas_config["ppo_updates"],
        project_name=nas_config["project_name"] + "-hybrid"
    )
    
    # Save best architecture
    torch.save(best_architecture, os.path.join(output_dir, "best_architecture.pt"))
    
    # Initialize a new wandb run for final evaluation
    wandb.finish()  # Finish the NAS run
    wandb.init(
        project="ASVspoof2019-NAS",
        name=f"{experiment_name}_{search_method}_final_evaluation",
        config={
            "feature_type": feature_type,
            "best_val_eer": best_val_eer,
            "search_method": search_method
        }
    )
    
    # Evaluate on the evaluation set
    print("\nPerforming final evaluation on test set...")
    test_eer, eer_threshold = evaluate_model(model, best_architecture, eval_loader, device)
    
    # Log final metrics
    wandb.log({
        "final_test_eer": test_eer,
        "eer_threshold": eer_threshold
    })
    
    # Visualize the architecture with annotations
    print("\nVisualizing the best architecture...")
    # Check if it's a PPO or DARTS architecture
    is_ppo_arch = best_architecture.dim() == 1
    
    if is_ppo_arch:
        fig_path = visualize_architecture(best_architecture, 
                                     num_cells=nas_config["num_cells"], 
                                     num_nodes=nas_config["num_nodes"], 
                                     num_ops=nas_config["num_ops"],
                                     save_to_wandb=True)
    else:
        fig_path = visualize_darts_architecture(best_architecture, 
                                          num_cells=nas_config["num_cells"], 
                                          num_nodes=nas_config["num_nodes"], 
                                          num_ops=nas_config["num_ops"],
                                          save_to_wandb=True)
    
    # Save architecture visualization
    import shutil
    shutil.copy(fig_path, os.path.join(output_dir, "architecture_visualization.png"))
    
    # Create a summary report
    summary = {
        "experiment_name": experiment_name,
        "search_method": search_method,
        "feature_type": feature_type,
        "best_validation_eer": best_val_eer,
        "test_eer": test_eer,
        "eer_threshold": eer_threshold,
        "model_architecture": {
            "num_cells": nas_config["num_cells"],
            "num_nodes": nas_config["num_nodes"],
            "num_operations": nas_config["num_ops"]
        }
    }
    
    # Save summary as JSON
    import json
    with open(os.path.join(output_dir, "summary.json"), "w") as f:
        json.dump(summary, f, indent=4)
    
    # Also save as text for easy reading
    with open(os.path.join(output_dir, "summary.txt"), "w") as f:
        f.write("ASVspoof 2019 Deepfake Detection Summary\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Experiment name: {experiment_name}\n")
        f.write(f"Search method: {search_method.upper()}\n")
        f.write(f"Feature type: {feature_type}\n\n")
        f.write("Performance metrics:\n")
        f.write(f"  Best validation EER: {best_val_eer:.4f}\n")
        f.write(f"  Test EER: {test_eer:.4f}\n")
        f.write(f"  EER threshold: {eer_threshold:.4f}\n\n")
        f.write("Model architecture:\n")
        f.write(f"  Number of cells: {nas_config['num_cells']}\n")
        f.write(f"  Number of nodes per cell: {nas_config['num_nodes']}\n")
        f.write(f"  Number of operations: {nas_config['num_ops']}\n")
    
    # Log summary to wandb
    wandb.save(os.path.join(output_dir, "summary.txt"))
    wandb.save(os.path.join(output_dir, "summary.json"))
    
    print("\nExperiment completed!")
    print(f"Final Test EER: {test_eer:.4f}, EER Threshold: {eer_threshold:.4f}")
    print(f"All results saved to {output_dir}")
    print("=" * 80)
    
    # Finish wandb
    wandb.finish()

if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrnparikh[0m ([33mrnparikh-carnegie-mellon-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Resource monitoring started...
Starting Deepfake Audio Detection with NAS
Train data directory: /kaggle/input/asvspoof-dataset-2019/LA/ASVspoof2019_LA_train/flac
Train protocol file: /kaggle/input/asvspoof-dataset-2019/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt
Dev data directory: /kaggle/input/asvspoof-dataset-2019/LA/ASVspoof2019_LA_dev/flac
Dev protocol file: /kaggle/input/asvspoof-dataset-2019/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt
Eval data directory: /kaggle/input/asvspoof-dataset-2019/LA/ASVspoof2019_LA_eval/flac
Eval protocol file: /kaggle/input/asvspoof-dataset-2019/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[00:02:37] CPU: 1.0% | RAM: 638.5 MB | GPU: 0.0 MBB

[00:00:07] CPU: 42.0% | RAM: 639.7 MB | GPU: 0.0 MBCreating datasets...
Loading training dataset...
Reading protocol file: /kaggle/input/asvspoof-dataset-2019/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt


Loading training protocol: 100%|██████████| 25380/25380 [00:00<00:00, 1475357.03it/s]


Dataset loaded: 25380 samples (2580 bonafide, 22800 spoof)
Subsampling training data for faster NAS...
Subsampled to 5000 samples
Loading validation dataset...
Reading protocol file: /kaggle/input/asvspoof-dataset-2019/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt


Loading evaluation protocol: 100%|██████████| 24844/24844 [00:00<00:00, 1630494.74it/s]


Dataset loaded: 24844 samples (2548 bonafide, 22296 spoof)
Loading evaluation dataset...
Reading protocol file: /kaggle/input/asvspoof-dataset-2019/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt


Loading evaluation protocol: 100%|██████████| 71237/71237 [00:00<00:00, 1495518.46it/s]


Dataset loaded: 71237 samples (7355 bonafide, 63882 spoof)
Training dataset size: 5000 samples
Validation dataset size: 24844 samples
Evaluation dataset size: 71237 samples

Neural Architecture Search Configuration:
  input_channels: 60
  num_cells: 3
  num_nodes: 4
  num_ops: 10
  epochs: 30
  ppo_updates: 5
  project_name: ASVspoof2019-NAS

Results will be saved to /kaggle/working/results_ASVspoof2019_NAS_1745021052_hybrid

Starting Neural Architecture Search using HYBRID method...
Initializing PPO controller with state_dim=1, action_dim=10
[00:03:13] CPU: 83.7% | RAM: 881.6 MB | GPU: 15.9 MB

Hybrid Search Progress:   0%|          | 0/30 [00:00<?, ?it/s]

[00:04:13] CPU: 101.0% | RAM: 1431.5 MB | GPU: 1156.5 MB