# Neural Cryptography

## **Competitive Training with Proper Balance**

This notebook implements a rebalanced system with proper competitive dynamics:

### The Rebalancing Fixes:
1. **Stronger Bob**: Larger architecture (512 hidden, 8 layers) for >0.8 accuracy
2. **Weaker Eve**: Smaller architecture (24 hidden, 1 layer) to hover around 0.5
3. **Alice Metrics**: Track Alice's entropy regulation and key dependency performance
4. **Enhanced Shocks**: More frequent, varied shock events for better dynamics

**EXPECTED RESULT: Eve → 0.5, Bob → >0.8, Alice optimized**

Currently the structure of the code isn't working properly. I need to balance them out for better learning. 

## 1. Setup and Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import random
import os
import json
from datetime import datetime
import math
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🎯 Using device: {device}")
if torch.cuda.is_available():
    print(f"🚀 GPU: {torch.cuda.get_device_name(0)}")
    torch.backends.cudnn.benchmark = True
    
print("\nREBALANCED CRYPTO System Loaded!")

## 2. Rebalanced Training Configuration

In [None]:
class RebalancedConfig:
    MESSAGE_LENGTH = 32
    KEY_LENGTH = 32
    BATCH_SIZE = 128
    
    ALICE_HIDDEN_SIZE = 256
    BOB_HIDDEN_SIZE = 512          
    EVE_HIDDEN_SIZE = 24           
    ALICE_NUM_LAYERS = 6
    BOB_NUM_LAYERS = 8             
    EVE_NUM_LAYERS = 1             #eve is way shallower yet it outperforms by a huge margin. this really needs work

    DROPOUT = 0.4
    
    LEARNING_RATE_ALICE = 0.005      
    LEARNING_RATE_BOB = 0.005        
    LEARNING_RATE_EVE = 0.0005       
    NUM_EPOCHS = 2000 #I might just increase these to be way more
    EVE_TRAINING_RATIO = 5           # training Eve less frequently
    
    # these are very agressive as of now. 
    RECONSTRUCTION_WEIGHT = 3.0      
    ENTROPY_WEIGHT = 10.0            # need to keep entropy high
    ADVERSARIAL_WEIGHT = 2.0         
    
    # more noise?
    NOISE_STRENGTH = 0.3
    
    #curriculum and crypto ruegularization
    TARGET_ENTROPY_INITIAL = 0.85     
    TARGET_ENTROPY_MID = 0.92
    TARGET_ENTROPY_LATE = 0.98
    CURRICULUM_MILESTONES = [100, 300] 
    
    KEY_DEPENDENCY_MARGIN = 0.35 #increase the req might help  
    KEY_DEPENDENCY_WEIGHT = 3.0       
    WRONG_KEY_CONFUSION_WEIGHT = 2.0  
    ENTROPY_REG_WEIGHT = 5.0          

config = RebalancedConfig()
print(" REBALANCED TRAINING Configuration:")
print(f"  Loss Weights: ENTROPY_REG={config.ENTROPY_REG_WEIGHT}, Adv={config.ADVERSARIAL_WEIGHT}, KeyDep={config.KEY_DEPENDENCY_WEIGHT}")
print(f"  Eve Config: Size={config.EVE_HIDDEN_SIZE}, LR={config.LEARNING_RATE_EVE}, Ratio={config.EVE_TRAINING_RATIO} (WEAKENED)")
print(f"  Bob Config: Size={config.BOB_HIDDEN_SIZE}, Layers={config.BOB_NUM_LAYERS} (STRENGTHENED)")
print(f"  Key Sensitivity Target: {config.KEY_DEPENDENCY_MARGIN} (higher avalanche effect)")

## 3. Rebalanced Network Architectures

In [None]:
class StochasticAlice(nn.Module):
    def __init__(self, message_length, key_length, hidden_size, num_layers, dropout):
        super(StochasticAlice, self).__init__()
        
        layers = [nn.Linear(message_length + key_length, hidden_size), nn.GELU()]
        for _ in range(num_layers):
            layers.extend([
                nn.Linear(hidden_size, hidden_size), 
                nn.GELU(), 
                nn.Dropout(dropout)
            ])
        layers.append(nn.Linear(hidden_size, message_length))
        self.network = nn.Sequential(*layers)
        
    def forward(self, plaintext, key):
        combined = torch.cat([plaintext, key], dim=1)
        logits = self.network(combined)
        probs = 0.5 + 0.5 * torch.tanh(logits)
        return probs

class RobustBob(nn.Module):
    def __init__(self, message_length, key_length, hidden_size, num_layers, dropout):
        super(RobustBob, self).__init__()
        
        layers = [nn.Linear(message_length + key_length, hidden_size), nn.GELU()]
        for _ in range(num_layers):
            layers.extend([
                nn.Linear(hidden_size, hidden_size), 
                nn.GELU(), 
                nn.Dropout(dropout)
            ])
        layers.append(nn.Linear(hidden_size, message_length))
        self.network = nn.Sequential(*layers)
        
    def forward(self, ciphertext, key):
        combined = torch.cat([ciphertext, key], dim=1)
        logits = self.network(combined)
        decrypted_probs = torch.sigmoid(logits)
        return decrypted_probs

class WeakEve(nn.Module):
    #it's super and supposed to hover around 0.5... need to fix this 
    def __init__(self, message_length, hidden_size, num_layers, dropout):
        super(WeakEve, self).__init__()
        
        layers = [nn.Linear(message_length, hidden_size), nn.ReLU()]
        for _ in range(num_layers):
            layers.extend([nn.Linear(hidden_size, hidden_size), nn.ReLU()])
        layers.extend([nn.Dropout(dropout), nn.Linear(hidden_size, 1)])
        self.network = nn.Sequential(*layers)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.05)  # Even weaker
    
    def forward(self, ciphertext):
        return self.network(ciphertext)

print("Rebalanced architectures defined: Strong Bob, Weak Eve!")

## 4. Crypto-Regularized Loss Functions

In [None]:
def reconstruction_loss(decrypted, original):
    return F.l1_loss(decrypted, original)

def adversarial_loss(eve_scores, is_real):
    target = torch.ones_like(eve_scores) if is_real else torch.zeros_like(eve_scores)
    return F.binary_cross_entropy_with_logits(eve_scores, target)

def bernoulli_entropy_from_probs(probs):
    eps = 1e-6
    p = probs.clamp(eps, 1 - eps)
    return - (p * torch.log2(p) + (1 - p) * torch.log2(1 - p))

def entropy_reg_loss(probs, target_entropy):
    H = bernoulli_entropy_from_probs(probs)
    return F.mse_loss(H, torch.full_like(H, target_entropy))

def key_dependency_loss(alice, messages, keys, margin, device):
    perm = torch.randperm(messages.size(0), device=device)
    keys_shuf = keys[perm]
    probs_orig = alice(messages, keys)
    probs_shuf = alice(messages, keys_shuf)
    dist = torch.mean(torch.abs(probs_orig - probs_shuf))
    return F.relu(margin - dist)

def wrong_key_confusion_loss(bob, ciphertext, keys, device):
    perm = torch.randperm(ciphertext.size(0), device=device)
    wrong_keys = keys[perm]
    dec_wrong = bob(ciphertext, wrong_keys)
    return F.mse_loss(dec_wrong, torch.full_like(dec_wrong, 0.5))

def calculate_entropy(data_tensor):
    p1 = data_tensor.mean().item()
    p0 = 1 - p1
    if p0 == 0 or p1 == 0:
        return 0.0
    return - (p0 * math.log2(p0) + p1 * math.log2(p1))

def generate_random_data(batch_size, message_length, key_length, device):
    messages = torch.randint(0, 2, size=(batch_size, message_length), device=device).float()
    keys = torch.randint(0, 2, size=(batch_size, key_length), device=device).float()
    return messages, keys

print("Crypto-regularized loss functions ready!")

## 5. Rebalanced Training Loop with Alice Metrics

In [None]:
alice = StochasticAlice(
    config.MESSAGE_LENGTH, config.KEY_LENGTH, config.ALICE_HIDDEN_SIZE, 
    config.ALICE_NUM_LAYERS, config.DROPOUT
).to(device)

bob = RobustBob(
    config.MESSAGE_LENGTH, config.KEY_LENGTH, config.BOB_HIDDEN_SIZE, 
    config.BOB_NUM_LAYERS, config.DROPOUT
).to(device)

eve = WeakEve(
    config.MESSAGE_LENGTH, config.EVE_HIDDEN_SIZE, 
    config.EVE_NUM_LAYERS, config.DROPOUT
).to(device)

alice_bob_optimizer = optim.Adam(list(alice.parameters()) + list(bob.parameters()), lr=config.LEARNING_RATE_ALICE, weight_decay=1e-4)
eve_optimizer = optim.Adam(eve.parameters(), lr=config.LEARNING_RATE_EVE, weight_decay=1e-4)

alice_bob_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(alice_bob_optimizer, T_0=200, T_mult=2)
eve_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(eve_optimizer, T_0=150, T_mult=2)

history = defaultdict(list)

def train_rebalanced(num_epochs):
    print("STARTING REBALANCED TRAINING")
    print(f"Target: Eve→0.5, Bob→>0.8, Alice optimized. Enhanced shocks every 50 epochs.")
    print("-"*80)
    
    def get_curriculum(epoch):
        shock_event = (epoch % 50 == 0 and epoch > 0)
        
        if shock_event:
            shock_type = (epoch // 50) % 3
            if shock_type == 0:
                return 0.6, 1    
            elif shock_type == 1:
                return 0.99, 1   
            else:
                return 0.8, 1    
        
        if epoch < config.CURRICULUM_MILESTONES[0]:
            return config.TARGET_ENTROPY_INITIAL, 5
        elif epoch < config.CURRICULUM_MILESTONES[1]:
            return config.TARGET_ENTROPY_MID, 3
        else:
            return config.TARGET_ENTROPY_LATE, 2
    
    for epoch in range(num_epochs):
        target_H, eve_ratio = get_curriculum(epoch)
        
        alice_bob_optimizer.zero_grad()
        
        messages, keys = generate_random_data(config.BATCH_SIZE, config.MESSAGE_LENGTH, config.KEY_LENGTH, device)
        
        cipher_probs = alice(messages, keys)
        ciphertext = torch.bernoulli(cipher_probs)
        decrypted_probs = bob(ciphertext.detach(), keys)
        eve_scores = eve(ciphertext.detach())
        
        loss_recon = reconstruction_loss(decrypted_probs, messages)
        loss_entropy_reg = entropy_reg_loss(cipher_probs, target_H)
        loss_keydep = key_dependency_loss(alice, messages, keys, config.KEY_DEPENDENCY_MARGIN, device)
        loss_wrongkey = wrong_key_confusion_loss(bob, ciphertext.detach(), keys, device)
        loss_adv = adversarial_loss(eve_scores, is_real=True)
        
        total_loss = (config.RECONSTRUCTION_WEIGHT * loss_recon + 
                      config.ENTROPY_REG_WEIGHT * loss_entropy_reg + 
                      config.KEY_DEPENDENCY_WEIGHT * loss_keydep + 
                      config.WRONG_KEY_CONFUSION_WEIGHT * loss_wrongkey + 
                      config.ADVERSARIAL_WEIGHT * loss_adv)
        
        total_loss.backward()
        alice_bob_optimizer.step()
        alice_bob_scheduler.step()
        
        shock_event = (epoch % 50 == 0 and epoch > 0)
        if epoch % eve_ratio == 0:
            eve_optimizer.zero_grad()
            with torch.no_grad():
                messages_eve, keys_eve = generate_random_data(config.BATCH_SIZE, config.MESSAGE_LENGTH, config.KEY_LENGTH, device)
                probs_eve = alice(messages_eve, keys_eve)
                ciphertext_eve = torch.bernoulli(probs_eve)
                random_bits = torch.randint(0, 2, size=ciphertext_eve.shape, device=device).float()
            
            real_scores = eve(ciphertext_eve)
            fake_scores = eve(random_bits)
            loss_real = adversarial_loss(real_scores, True)
            loss_fake = adversarial_loss(fake_scores, False)
            eve_loss = (loss_real + loss_fake) / 2
            
            if shock_event:
                shock_type = (epoch // 50) % 3
                if shock_type == 0:
                    eve_loss = eve_loss * 5.0  
                elif shock_type == 1:
                    eve_loss = eve_loss * 0.1  
                else:
                    eve_loss = eve_loss * 2.0  
            
            eve_loss.backward()
            eve_optimizer.step()
            eve_scheduler.step()
        else:
            eve_loss = torch.tensor(0.0)
        
        if epoch % 25 == 0:
            with torch.no_grad():
                eval_messages, eval_keys = generate_random_data(1000, config.MESSAGE_LENGTH, config.KEY_LENGTH, device)
                eval_probs = alice(eval_messages, eval_keys)
                eval_cipher = torch.bernoulli(eval_probs)
                eval_decrypted = bob(eval_cipher, eval_keys)
                eval_random = torch.randint(0, 2, size=eval_cipher.shape, device=device).float()
                
                bob_accuracy = 1 - reconstruction_loss((eval_decrypted > 0.5).float(), eval_messages).item()
                entropy = calculate_entropy(eval_cipher)
                
                real_preds = (torch.sigmoid(eve(eval_cipher)) > 0.5).float()
                fake_preds = (torch.sigmoid(eve(eval_random)) < 0.5).float()
                eve_accuracy = (real_preds.mean().item() + fake_preds.mean().item()) / 2
                
                alice_entropy_loss_val = loss_entropy_reg.item()
                alice_keydep_loss_val = loss_keydep.item()
                alice_consistency = torch.std(eval_probs, dim=0).mean().item()
                
                perm = torch.randperm(eval_messages.size(0), device=device)
                eval_keys_shuf = eval_keys[perm]
                eval_probs_shuf = alice(eval_messages, eval_keys_shuf)
                key_sensitivity = torch.mean(torch.abs(eval_probs - eval_probs_shuf)).item()
                
                history['bob_accuracy'].append(bob_accuracy)
                history['entropy'].append(entropy)
                history['eve_accuracy'].append(eve_accuracy)
                history['alice_entropy_loss'].append(alice_entropy_loss_val)
                history['alice_keydep_loss'].append(alice_keydep_loss_val)
                history['alice_consistency'].append(alice_consistency)
                history['key_sensitivity'].append(key_sensitivity)
                history['total_loss'].append(total_loss.item())
                history['recon_loss'].append(loss_recon.item())
                history['entropy_reg_loss'].append(loss_entropy_reg.item())
                history['keydep_loss'].append(loss_keydep.item())
                history['wrongkey_loss'].append(loss_wrongkey.item())
                history['adv_loss'].append(loss_adv.item())
                history['epoch'].append(epoch)
                
                alice_lr = alice_bob_scheduler.get_last_lr()[0]
                eve_lr = eve_scheduler.get_last_lr()[0]
                shock_event_log = (epoch % 50 == 0 and epoch > 0)
                shock_indicator = ""
                if shock_event_log:
                    shock_type = (epoch // 50) % 3
                    shock_names = ['LOW-H', 'HIGH-H', 'MID-H']
                    shock_indicator = f" [SHOCK-{shock_names[shock_type]}!]"
                
                print(f"Epoch {epoch:4d} | H={entropy:.4f} | Bob={bob_accuracy:.4f} | Eve={eve_accuracy:.4f} | KeySens={key_sensitivity:.3f}{shock_indicator}")
                print(f"    Alice: H-Loss={alice_entropy_loss_val:.4f} KeyDep-Loss={alice_keydep_loss_val:.4f} | LRs: A/B={alice_lr:.5f} E={eve_lr:.5f}")
                print(f"    Target-H={target_H:.3f} | EveRatio={eve_ratio} | Consistency={alice_consistency:.3f}")
                print()
    
    return history

print("Rebalanced training loop ready with Alice metrics!")

## 6. Launch Training & Enhanced Visualization

In [None]:
history = train_rebalanced(config.NUM_EPOCHS)

fig, axes = plt.subplots(2, 3, figsize=(20, 10))
fig.suptitle('Rebalanced Training Results - Strong Bob, Weak Eve, Optimized Alice', fontsize=16)

ax = axes[0, 0]
ax.plot(history['epoch'], history['entropy'], color='red', linewidth=3)
ax.axhline(y=1.0, color='red', linestyle='--', label='Perfect (1.0)')
ax.set_title('Ciphertext Entropy', fontweight='bold')
ax.set_xlabel('Epoch')
ax.set_ylabel('Entropy')
ax.set_ylim(0, 1.1)
ax.grid(True)
ax.legend()

ax = axes[0, 1]
ax.plot(history['epoch'], history['bob_accuracy'], color='green', linewidth=2)
ax.axhline(y=0.8, color='green', linestyle='--', label='Target (0.8)')
ax.set_title('Strong Bob Accuracy')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(0, 1.1)
ax.grid(True)
ax.legend()

ax = axes[0, 2]
ax.plot(history['epoch'], history['eve_accuracy'], color='orange', linewidth=2)
ax.axhline(y=0.5, color='orange', linestyle='--', label='Target (0.5)')
ax.set_title('Weak Eve Accuracy')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(0, 1.1)
ax.grid(True)
ax.legend()

ax = axes[1, 0]
ax.plot(history['epoch'], history['key_sensitivity'], color='purple', linewidth=2)
ax.axhline(y=config.KEY_DEPENDENCY_MARGIN, color='purple', linestyle='--', label=f'Target: {config.KEY_DEPENDENCY_MARGIN}')
ax.set_title('Key Sensitivity (Avalanche)')
ax.set_xlabel('Epoch')
ax.set_ylabel('Key Sensitivity')
ax.grid(True)
ax.legend()

ax = axes[1, 1]
ax.plot(history['epoch'], history['alice_entropy_loss'], color='blue', linewidth=2)
ax.set_title('Alice Entropy Regulation')
ax.set_xlabel('Epoch')
ax.set_ylabel('Entropy Loss')
ax.set_yscale('log')
ax.grid(True)

ax = axes[1, 2]
ax.plot(history['epoch'], history['alice_keydep_loss'], color='darkblue', linewidth=2)
ax.set_title('Alice Key Dependency')
ax.set_xlabel('Epoch')
ax.set_ylabel('Key Dep Loss')
ax.set_yscale('log')
ax.grid(True)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

ax = axes[0]
ax.plot(history['epoch'], history['recon_loss'], label='Reconstruction', alpha=0.8)
ax.plot(history['epoch'], history['entropy_reg_loss'], label='Entropy Reg', linewidth=2, alpha=0.8)
ax.plot(history['epoch'], history['keydep_loss'], label='Key Dependency', alpha=0.8)
ax.plot(history['epoch'], history['wrongkey_loss'], label='Wrong Key', alpha=0.8)
ax.plot(history['epoch'], history['adv_loss'], label='Adversarial', alpha=0.8)
ax.set_title('Loss Components')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_yscale('log')
ax.grid(True)
ax.legend()

ax = axes[1]
ax.plot(history['epoch'], history['alice_consistency'], color='cyan', linewidth=2)
ax.set_title('Alice Output Consistency')
ax.set_xlabel('Epoch')
ax.set_ylabel('Std Dev of Probs')
ax.grid(True)

ax = axes[2]
ax.plot(history['epoch'], history['bob_accuracy'], label='Bob (Target >0.8)', color='green', linewidth=2)
ax.plot(history['epoch'], history['eve_accuracy'], label='Eve (Target ~0.5)', color='orange', linewidth=2)
ax.plot(history['epoch'], history['entropy'], label='Entropy (Target ~1.0)', color='red', linewidth=2)
ax.axhline(y=0.8, color='green', linestyle='--', alpha=0.5)
ax.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5)
ax.axhline(y=1.0, color='red', linestyle='--', alpha=0.5)
ax.set_title('Performance Summary')
ax.set_xlabel('Epoch')
ax.set_ylabel('Performance')
ax.set_ylim(0, 1.1)
ax.grid(True)
ax.legend()

plt.tight_layout()
plt.show()

## 7. Final Rebalanced Analysis

In [None]:
print("\n" + "="*80)
print("FINAL REBALANCED TRAINING EVALUATION")
print("="*80)

with torch.no_grad():
    final_messages, final_keys = generate_random_data(5000, config.MESSAGE_LENGTH, config.KEY_LENGTH, device)
    final_probs = alice(final_messages, final_keys)
    final_cipher = torch.bernoulli(final_probs)
    final_decrypted = bob(final_cipher, final_keys)
    final_random = torch.randint(0, 2, size=final_cipher.shape, device=device).float()
    
    final_bob_acc = 1 - reconstruction_loss((final_decrypted > 0.5).float(), final_messages).item()
    final_entropy = calculate_entropy(final_cipher)
    
    final_real_preds = (torch.sigmoid(eve(final_cipher)) > 0.5).float()
    final_fake_preds = (torch.sigmoid(eve(final_random)) < 0.5).float()
    final_eve_acc = (final_real_preds.mean().item() + final_fake_preds.mean().item()) / 2
    
    perm = torch.randperm(final_messages.size(0), device=device)
    final_keys_shuf = final_keys[perm]
    final_probs_shuf = alice(final_messages, final_keys_shuf)
    final_key_sensitivity = torch.mean(torch.abs(final_probs - final_probs_shuf)).item()
    
    final_alice_consistency = torch.std(final_probs, dim=0).mean().item()

print(f"FINAL REBALANCED PERFORMANCE:")
print(f"  **Ciphertext Entropy**: {final_entropy:.4f} (Target: ~1.0)")
print(f"  **Bob's Accuracy**: {final_bob_acc:.4f} (Target: >0.8) {'✅ SUCCESS' if final_bob_acc > 0.8 else '❌ NEEDS WORK'}")
print(f"  **Eve's Accuracy**: {final_eve_acc:.4f} (Target: ~0.5) {'✅ SUCCESS' if abs(final_eve_acc - 0.5) < 0.1 else '❌ TOO STRONG'}")
print(f"  **Key Sensitivity**: {final_key_sensitivity:.4f} (Target: >{config.KEY_DEPENDENCY_MARGIN}) {'✅ SUCCESS' if final_key_sensitivity > config.KEY_DEPENDENCY_MARGIN else '❌ TOO LOW'}")
print(f"  **Alice Consistency**: {final_alice_consistency:.4f}")

print("\nREBALANCING VERDICT:")
success_count = 0
if final_entropy > 0.95:
    success_count += 1
if final_bob_acc > 0.8:
    success_count += 1
if abs(final_eve_acc - 0.5) < 0.1:
    success_count += 1
if final_key_sensitivity > config.KEY_DEPENDENCY_MARGIN:
    success_count += 1

if success_count >= 3:
    print("  **REBALANCING SUCCESS!** The system now has proper competitive dynamics!")
elif success_count >= 2:
    print("  **Good Progress.** Most targets achieved, fine-tuning may be needed.")
else:
    print("  **Needs More Rebalancing.** Consider further architectural or hyperparameter adjustments.")

print("\nKey Insights:")
print(f"   • Bob (stronger): {config.BOB_HIDDEN_SIZE} hidden, {config.BOB_NUM_LAYERS} layers")
print(f"   • Eve (weaker): {config.EVE_HIDDEN_SIZE} hidden, {config.EVE_NUM_LAYERS} layer, LR={config.LEARNING_RATE_EVE}")
print(f"   • Enhanced shocks every 50 epochs with varied intensity")
print(f"   • Alice metrics now tracked for crypto optimization")