# Multi-Output MLP for Side-Channel Analysis

This notebook implements the MLP_MO (Multi-Output Multilayer Perceptron) architecture for non-profiled side-channel attack on AES-128.

**Goal:** Recover the AES-128 secret key from power traces by training a single neural network to predict all 256 possible key bytes simultaneously.

**Memory Constraint:** Max 25GB RAM - uses lazy loading from HDF5 file.


## Block 1: Setup & Configuration


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Constants
BATCH_SIZE = 1000
LEARNING_RATE = 0.001
NUM_EPOCHS = 30
NUM_BRANCHES = 256  # One branch for each key hypothesis (0-255)
TRACE_LENGTH = 700
DATASET_PATH = Path('dataset/ascadv2-extracted.h5')

print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Number of epochs: {NUM_EPOCHS}")


Using device: cuda
Batch size: 1000
Learning rate: 0.001
Number of epochs: 30


## Block 2: Verify H5 File Structure

First, let's inspect the HDF5 file to understand its structure.


In [2]:
# Verify H5 file structure
if DATASET_PATH.exists():
    with h5py.File(DATASET_PATH, 'r') as f:
        print("Top-level keys:", list(f.keys()))
        if 'traces' in f:
            print(f"Traces shape: {f['traces'].shape}")
            print(f"Traces dtype: {f['traces'].dtype}")
        if 'metadata' in f:
            print("Metadata keys:", list(f['metadata'].keys()))
            if 'plaintext' in f['metadata']:
                print(f"Plaintext shape: {f['metadata']['plaintext'].shape}")
                print(f"Plaintext dtype: {f['metadata']['plaintext'].dtype}")
            if 'key' in f['metadata']:
                print(f"Key shape: {f['metadata']['key'].shape}")
                print(f"Key value: {f['metadata']['key'][:]}")
        # Check for alternative key names
        if 'inputs' in f:
            print("Found 'inputs' key")
            print("Inputs keys:", list(f['inputs'].keys()) if hasattr(f['inputs'], 'keys') else "Not a group")
else:
    print(f"Warning: Dataset file not found at {DATASET_PATH}")
    print("Please ensure the HDF5 file is placed in the dataset/ directory")


Top-level keys: ['Attack_traces', 'Profiling_traces']


## Block 3: Memory-Efficient Data Loader

Implement a custom Dataset class that uses lazy loading to avoid loading the entire dataset into memory.


In [3]:
class ASCADDataset(Dataset):
    """
    Memory-efficient dataset for ASCAD power traces.
    Uses lazy loading - only reads data from disk when requested.
    """
    def __init__(self, h5_path, trace_key='traces', plaintext_key='metadata/plaintext'):
        """
        Args:
            h5_path: Path to the HDF5 file
            trace_key: Key for traces in HDF5 file (default: 'traces')
            plaintext_key: Key for plaintexts in HDF5 file (default: 'metadata/plaintext')
        """
        self.h5_path = h5_path
        self.trace_key = trace_key
        self.plaintext_key = plaintext_key
        
        # Open file to get length (but keep it open for lazy loading)
        with h5py.File(h5_path, 'r') as f:
            # Try different possible key structures
            if trace_key in f:
                self.length = len(f[trace_key])
            elif 'inputs' in f and 'traces' in f['inputs']:
                self.trace_key = 'inputs/traces'
                self.length = len(f['inputs']['traces'])
            else:
                raise KeyError(f"Could not find traces. Available keys: {list(f.keys())}")
            
            # Get plaintext key
            if plaintext_key in f:
                pass  # Key exists
            elif 'inputs' in f and 'plaintext' in f['inputs']:
                self.plaintext_key = 'inputs/plaintext'
            elif 'metadata' in f and 'plaintext' in f['metadata']:
                self.plaintext_key = 'metadata/plaintext'
            else:
                # Try to find any plaintext-like key
                with h5py.File(h5_path, 'r') as check_f:
                    def find_key(name, obj):
                        if 'plaintext' in name.lower():
                            print(f"Found potential plaintext key: {name}")
                    check_f.visititems(find_key)
                raise KeyError(f"Could not find plaintext. Checked: {plaintext_key}")
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        """
        Read a single trace and plaintext from disk.
        Returns:
            trace: torch.Tensor of shape (700,) dtype float32
            plaintext: torch.Tensor scalar dtype long
        """
        # Open file for each access (h5py handles this efficiently)
        with h5py.File(self.h5_path, 'r') as f:
            # Read trace - shape (700,)
            trace = f[self.trace_key][idx]
            # Read plaintext - scalar
            plaintext = f[self.plaintext_key][idx]
        
        # Convert to torch tensors with appropriate dtypes
        trace_tensor = torch.from_numpy(trace).float()  # (700,) float32
        plaintext_tensor = torch.tensor(plaintext, dtype=torch.long)  # scalar long
        
        return trace_tensor, plaintext_tensor

# Test the dataset
if DATASET_PATH.exists():
    try:
        dataset = ASCADDataset(DATASET_PATH)
        print(f"Dataset length: {len(dataset)}")
        
        # Test loading a single sample
        trace, plaintext = dataset[0]
        print(f"Sample trace shape: {trace.shape}, dtype: {trace.dtype}")
        print(f"Sample plaintext: {plaintext.item()}, dtype: {plaintext.dtype}")
        
        # Create DataLoader
        dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
        
        # Test a batch
        traces_batch, plaintexts_batch = next(iter(dataloader))
        print(f"\nBatch traces shape: {traces_batch.shape}")
        print(f"Batch plaintexts shape: {plaintexts_batch.shape}")
        print(f"Batch traces dtype: {traces_batch.dtype}")
        print(f"Batch plaintexts dtype: {plaintexts_batch.dtype}")
    except Exception as e:
        print(f"Error creating dataset: {e}")
        print("You may need to adjust the key names based on your HDF5 file structure")
else:
    print("Dataset file not found. Please add the HDF5 file to continue.")


Error creating dataset: "Could not find traces. Available keys: ['Attack_traces', 'Profiling_traces']"
You may need to adjust the key names based on your HDF5 file structure


## Block 4: AES S-box Definition

Define the AES S-box lookup table for label generation.


In [4]:
# AES S-box (Substitution Box) - 256 byte lookup table
AES_SBOX = np.array([
    0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
    0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
    0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
    0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
    0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
    0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
    0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
    0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
    0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
    0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
    0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
    0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
    0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
    0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
    0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
    0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16
], dtype=np.uint8)

print(f"AES S-box shape: {AES_SBOX.shape}")
print(f"AES S-box dtype: {AES_SBOX.dtype}")
print(f"First 16 values: {AES_SBOX[:16]}")


AES S-box shape: (256,)
AES S-box dtype: uint8
First 16 values: [ 99 124 119 123 242 107 111 197  48   1 103  43 254 215 171 118]


## Block 5: Label Generation Function

Generate labels for all 256 key hypotheses on-the-fly during training.


In [5]:
def get_all_labels(plaintexts, sbox):
    """
    Generate labels for all 256 key hypotheses.
    
    For each key hypothesis k (0-255), compute:
    label = LSB(Sbox[plaintext XOR k])
    
    Args:
        plaintexts: (batch_size,) tensor of uint8 plaintext bytes
        sbox: (256,) numpy array containing AES S-box lookup table
    
    Returns:
        labels: (batch_size, 256) tensor of labels (0 or 1) - LSB of Sbox(plaintext XOR k)
    """
    batch_size = plaintexts.shape[0]
    labels = torch.zeros(batch_size, 256, dtype=torch.long, device=plaintexts.device)
    
    # Convert plaintexts to numpy for efficient indexing
    plaintexts_np = plaintexts.cpu().numpy() if isinstance(plaintexts, torch.Tensor) else plaintexts
    
    # For each key hypothesis k
    for k in range(256):
        # Compute: intermediate = Sbox[plaintext XOR k]
        intermediate = sbox[plaintexts_np ^ k]
        # Extract LSB: label = intermediate & 1
        labels[:, k] = torch.tensor(intermediate & 1, dtype=torch.long, device=plaintexts.device)
    
    return labels

# Test label generation
if DATASET_PATH.exists():
    try:
        test_plaintexts = torch.tensor([0x00, 0x01, 0xFF], dtype=torch.long)
        test_labels = get_all_labels(test_plaintexts, AES_SBOX)
        print(f"Test plaintexts: {[hex(p.item()) for p in test_plaintexts]}")
        print(f"Test labels shape: {test_labels.shape}")
        print(f"Labels for first plaintext (0x00), first 8 key hypotheses: {test_labels[0, :8].tolist()}")
        print(f"Labels for first plaintext (0x00), key hypothesis 0: {test_labels[0, 0].item()}")
        print(f"Expected: LSB(Sbox[0x00 XOR 0x00]) = LSB(Sbox[0x00]) = LSB(0x63) = 1")
        print(f"Got: {test_labels[0, 0].item()}")
    except Exception as e:
        print(f"Error testing label generation: {e}")


Test plaintexts: ['0x0', '0x1', '0xff']
Test labels shape: torch.Size([3, 256])
Labels for first plaintext (0x00), first 8 key hypotheses: [1, 0, 1, 1, 0, 1, 1, 1]
Labels for first plaintext (0x00), key hypothesis 0: 1
Expected: LSB(Sbox[0x00 XOR 0x00]) = LSB(Sbox[0x00]) = LSB(0x63) = 1
Got: 1


## Block 6: Multi-Output MLP Model Architecture

Implement the model with a shared layer and 256 independent branches.


In [6]:
class MultiOutputMLP(nn.Module):
    """
    Multi-Output MLP for Side-Channel Analysis.
    
    Architecture:
    - Shared Layer: Linear(700, 200) → ReLU
    - 256 Branches: Each branch takes shared output and applies:
      Linear(200, 20) → ReLU → Linear(20, 10) → ReLU → Linear(10, 2)
    """
    def __init__(self, input_dim=700, shared_hidden=200, branch_hidden1=20, branch_hidden2=10, num_branches=256):
        super(MultiOutputMLP, self).__init__()
        
        self.input_dim = input_dim
        self.num_branches = num_branches
        
        # Shared layer: processes input traces
        self.shared_layer = nn.Sequential(
            nn.Linear(input_dim, shared_hidden),
            nn.ReLU()
        )
        
        # 256 independent branches (one for each key hypothesis)
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Linear(shared_hidden, branch_hidden1),  # Branch hidden layer 1
                nn.ReLU(),
                nn.Linear(branch_hidden1, branch_hidden2),  # Branch hidden layer 2
                nn.ReLU(),
                nn.Linear(branch_hidden2, 2)  # Output 2 classes (LSB 0 or 1)
            ) for _ in range(num_branches)
        ])
        
        # Initialize weights using He uniform initialization
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize weights using He uniform initialization."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        """
        Forward pass.
        
        Args:
            x: Input tensor of shape (batch_size, 700)
        
        Returns:
            output: Tensor of shape (batch_size, 256, 2)
        """
        batch_size = x.shape[0]
        
        # Shared layer: (batch_size, 700) → (batch_size, 200)
        shared_output = self.shared_layer(x)
        
        # Process through each branch and stack outputs
        branch_outputs = []
        for branch in self.branches:
            # Each branch: (batch_size, 200) → (batch_size, 2)
            branch_out = branch(shared_output)
            branch_outputs.append(branch_out)
        
        # Stack: (batch_size, 256, 2)
        output = torch.stack(branch_outputs, dim=1)
        
        return output

# Test model
model = MultiOutputMLP(input_dim=TRACE_LENGTH, num_branches=NUM_BRANCHES).to(device)
print(f"Model created on device: {device}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
if DATASET_PATH.exists():
    try:
        test_input = torch.randn(BATCH_SIZE, TRACE_LENGTH).to(device)
        test_output = model(test_input)
        print(f"\nTest input shape: {test_input.shape}")
        print(f"Test output shape: {test_output.shape}")
        print(f"Expected output shape: ({BATCH_SIZE}, {NUM_BRANCHES}, 2)")
        assert test_output.shape == (BATCH_SIZE, NUM_BRANCHES, 2), f"Output shape mismatch: {test_output.shape}"
        print("✓ Model forward pass test passed!")
    except Exception as e:
        print(f"Error testing model: {e}")
else:
    print("Dataset not found - skipping forward pass test")


Model created on device: cuda
Total parameters: 1,228,712
Trainable parameters: 1,228,712

Test input shape: torch.Size([1000, 700])
Test output shape: torch.Size([1000, 256, 2])
Expected output shape: (1000, 256, 2)
✓ Model forward pass test passed!


## Block 7: Training Loop

Train the model with multi-loss computation across all 256 branches.


In [7]:
def train_model(model, dataloader, num_epochs, device, correct_key=None):
    """
    Training loop for Multi-Output MLP.
    
    Args:
        model: MultiOutputMLP model
        dataloader: DataLoader for training data
        num_epochs: Number of training epochs
        device: torch device
        correct_key: Correct key byte value (0-255) for accuracy tracking. If None, will try to load from dataset.
    
    Returns:
        correct_key_acc_history: List of correct key accuracies per epoch
        wrong_key_acc_history: List of best wrong key accuracies per epoch
    """
    # Initialize optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()
    
    # Try to get correct key from dataset if not provided
    if correct_key is None and DATASET_PATH.exists():
        try:
            with h5py.File(DATASET_PATH, 'r') as f:
                if 'metadata' in f and 'key' in f['metadata']:
                    key_array = f['metadata']['key'][:]
                    # Assuming we're attacking the first key byte
                    correct_key = int(key_array[0]) if len(key_array) > 0 else None
                    print(f"Found correct key in dataset: {key_array}")
                    print(f"Using first byte as correct_key: {correct_key}")
                else:
                    print("Warning: Could not find correct key in dataset. Accuracy tracking will be limited.")
        except Exception as e:
            print(f"Warning: Could not load correct key from dataset: {e}")
    
    # History for plotting
    correct_key_acc_history = []
    wrong_key_acc_history = []
    loss_history = []
    
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        epoch_correct_key_correct = 0
        epoch_correct_key_total = 0
        epoch_wrong_key_correct = 0
        epoch_wrong_key_total = 0
        
        num_batches = 0
        
        for batch_idx, (traces, plaintexts) in enumerate(dataloader):
            # Move to device
            traces = traces.to(device)  # (batch_size, 700)
            plaintexts = plaintexts.to(device)  # (batch_size,)
            
            # Get actual batch size (may vary on last batch)
            actual_batch_size = traces.shape[0]
            
            # Forward pass
            output = model(traces)  # (batch_size, 256, 2)
            
            # Generate labels for all 256 key hypotheses
            labels = get_all_labels(plaintexts, AES_SBOX)  # (batch_size, 256)
            
            # Compute loss: sum of CrossEntropyLoss across all 256 branches
            total_loss = 0
            for k in range(256):
                branch_output = output[:, k, :]  # (batch_size, 2)
                branch_labels = labels[:, k]  # (batch_size,)
                total_loss += criterion(branch_output, branch_labels)
            
            # Backward pass
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            # Track loss
            epoch_loss += total_loss.item()
            
            # Calculate accuracy
            # Get predictions: argmax over the 2 classes for each branch
            preds = output.argmax(dim=2)  # (batch_size, 256)
            
            # Calculate accuracy for correct key (if known)
            if correct_key is not None:
                correct_key_preds = preds[:, correct_key]  # (batch_size,)
                correct_key_labels = labels[:, correct_key]  # (batch_size,)
                correct_key_correct = (correct_key_preds == correct_key_labels).sum().item()
                epoch_correct_key_correct += correct_key_correct
                epoch_correct_key_total += actual_batch_size
            
            # Calculate accuracy for wrong keys (all keys except correct one)
            if correct_key is not None:
                wrong_key_mask = torch.ones(256, dtype=torch.bool, device=device)
                wrong_key_mask[correct_key] = False
                wrong_key_preds = preds[:, wrong_key_mask]  # (batch_size, 255)
                wrong_key_labels = labels[:, wrong_key_mask]  # (batch_size, 255)
                wrong_key_correct = (wrong_key_preds == wrong_key_labels).sum().item()
                epoch_wrong_key_correct += wrong_key_correct
                epoch_wrong_key_total += wrong_key_preds.numel()
            else:
                # If correct key unknown, track all keys
                all_correct = (preds == labels).sum().item()
                epoch_wrong_key_correct += all_correct
                epoch_wrong_key_total += preds.numel()
            
            num_batches += 1
            
            # Print progress every 10 batches
            if (batch_idx + 1) % 10 == 0:
                batch_loss = total_loss.item()
                if correct_key is not None:
                    batch_correct_acc = correct_key_correct / actual_batch_size
                    batch_wrong_acc = wrong_key_correct / (255 * actual_batch_size)
                    print(f"  Batch {batch_idx + 1}/{len(dataloader)}: Loss={batch_loss:.4f}, "
                          f"Correct Key Acc={batch_correct_acc:.4f}, Wrong Key Acc={batch_wrong_acc:.4f}")
        
        # Calculate epoch averages
        avg_loss = epoch_loss / num_batches
        loss_history.append(avg_loss)
        
        if correct_key is not None:
            correct_key_acc = epoch_correct_key_correct / epoch_correct_key_total if epoch_correct_key_total > 0 else 0.0
            wrong_key_acc = epoch_wrong_key_correct / epoch_wrong_key_total if epoch_wrong_key_total > 0 else 0.0
            correct_key_acc_history.append(correct_key_acc)
            wrong_key_acc_history.append(wrong_key_acc)
            
            print(f"Epoch {epoch + 1}/{num_epochs}: "
                  f"Loss={avg_loss:.4f}, "
                  f"Correct Key Acc={correct_key_acc:.4f}, "
                  f"Wrong Key Acc={wrong_key_acc:.4f}")
        else:
            all_acc = epoch_wrong_key_correct / epoch_wrong_key_total if epoch_wrong_key_total > 0 else 0.0
            wrong_key_acc_history.append(all_acc)
            print(f"Epoch {epoch + 1}/{num_epochs}: Loss={avg_loss:.4f}, All Key Acc={all_acc:.4f}")
    
    return correct_key_acc_history, wrong_key_acc_history, loss_history

print("Training function defined. Ready to train!")


Training function defined. Ready to train!


In [8]:
# Initialize model
model = MultiOutputMLP(input_dim=TRACE_LENGTH, num_branches=NUM_BRANCHES).to(device)

# Create dataset and dataloader
if DATASET_PATH.exists():
    dataset = ASCADDataset(DATASET_PATH)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    
    # Get correct key for tracking (optional)
    correct_key = None
    try:
        with h5py.File(DATASET_PATH, 'r') as f:
            if 'metadata' in f and 'key' in f['metadata']:
                key_array = f['metadata']['key'][:]
                correct_key = int(key_array[0]) if len(key_array) > 0 else None
                print(f"Correct key (first byte): {correct_key}")
    except:
        pass
    
    # Train model
    print("Starting training...")
    correct_key_acc_history, wrong_key_acc_history, loss_history = train_model(
        model, dataloader, NUM_EPOCHS, device, correct_key=correct_key
    )
    
    print("\nTraining completed!")
else:
    print("Dataset file not found. Please add the HDF5 file to the dataset/ directory.")
    correct_key_acc_history = []
    wrong_key_acc_history = []
    loss_history = []


KeyError: "Could not find traces. Available keys: ['Attack_traces', 'Profiling_traces']"

## Block 9: Visualization

Plot the attack success rate: correct key accuracy vs wrong key accuracy.


In [None]:
# Plot training curves
if len(correct_key_acc_history) > 0 and len(wrong_key_acc_history) > 0:
    plt.figure(figsize=(12, 5))
    
    # Plot 1: Accuracy comparison
    plt.subplot(1, 2, 1)
    epochs = range(1, len(correct_key_acc_history) + 1)
    plt.plot(epochs, correct_key_acc_history, 'r-', label='Correct Key Accuracy', linewidth=2)
    plt.plot(epochs, wrong_key_acc_history, 'b-', label='Wrong Key Accuracy', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Attack Success Rate: Correct Key vs Wrong Keys')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim([0.0, 1.0])
    
    # Add horizontal line at 0.5 (random guessing)
    plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Random (50%)')
    
    # Plot 2: Loss
    plt.subplot(1, 2, 2)
    plt.plot(epochs, loss_history, 'g-', label='Training Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print final statistics
    print(f"\nFinal Statistics:")
    print(f"Correct Key Accuracy: {correct_key_acc_history[-1]:.4f} ({correct_key_acc_history[-1]*100:.2f}%)")
    print(f"Wrong Key Accuracy: {wrong_key_acc_history[-1]:.4f} ({wrong_key_acc_history[-1]*100:.2f}%)")
    print(f"Separation: {correct_key_acc_history[-1] - wrong_key_acc_history[-1]:.4f}")
    
    if correct_key_acc_history[-1] > 0.63 and wrong_key_acc_history[-1] < 0.55:
        print("✓ Attack successful! Correct key accuracy significantly higher than wrong keys.")
    elif correct_key_acc_history[-1] > wrong_key_acc_history[-1] + 0.1:
        print("✓ Attack showing progress! Correct key accuracy is separating from wrong keys.")
    else:
        print("⚠ Attack may need more training or hyperparameter tuning.")
else:
    print("No training history available. Please run training first.")


No training history available. Please run training first.
