In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from difflib import SequenceMatcher

class ImprovedCryptoTrainer:
    """Enhanced training with focus on security"""

    def __init__(self, vocab_size=98, embed_dim=128, alice_layers=4,
                 bob_layers=4, eve_layers=6, max_len=128,
                 device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.max_len = max_len

        from v3_part1 import StringProcessor, KeyManager
        from v3_part2_alice import AliceEncryptor
        from v3_part3_bob_eve import BobDecryptor, EveAttacker

        self.processor = StringProcessor(max_length=max_len)
        self.key_manager = KeyManager(device=device)

        self.alice_wrapper = AliceEncryptor(vocab_size, embed_dim, alice_layers, max_len, device)
        self.bob_wrapper = BobDecryptor(vocab_size, embed_dim, bob_layers, max_len, device)
        self.eve_wrapper = EveAttacker(vocab_size, embed_dim, eve_layers, max_len, device)

        self.alice = self.alice_wrapper.alice
        self.bob = self.bob_wrapper.bob
        self.eve = self.eve_wrapper.eve

        # Optimizers with better settings
        self.opt_alice = optim.AdamW(self.alice.parameters(), lr=0.001, weight_decay=0.01)
        self.opt_bob = optim.AdamW(self.bob.parameters(), lr=0.001, weight_decay=0.01)
        self.opt_eve = optim.Adam(self.eve.parameters(), lr=0.0003)

        # Schedulers
        self.sched_alice = optim.lr_scheduler.CosineAnnealingWarmRestarts(self.opt_alice, T_0=20, T_mult=2)
        self.sched_bob = optim.lr_scheduler.CosineAnnealingWarmRestarts(self.opt_bob, T_0=20, T_mult=2)
        self.sched_eve = optim.lr_scheduler.ReduceLROnPlateau(self.opt_eve, patience=10, factor=0.5)

        # Loss
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=0)

        self.current_stage = 1

    def train_stage1_reconstruction(self, messages, num_epochs=150):
        """Stage 1: Perfect reconstruction"""
        print("\n" + "="*70)
        print("STAGE 1: TRAINING ALICE+BOB FOR RECONSTRUCTION")
        print("="*70)

        history = {'loss': [], 'accuracy': []}
        best_acc = 0
        patience_counter = 0

        for epoch in range(num_epochs):
            batch = np.random.choice(messages, size=min(8, len(messages)), replace=False).tolist()
            tokens = self.processor.batch_encode(batch).to(self.device)
            keys = self.key_manager.generate_keys_for_batch(len(batch), tokens.size(1))

            self.opt_alice.zero_grad()
            self.opt_bob.zero_grad()

            self.alice.train()
            self.bob.train()

            encrypted = self.alice(tokens, keys['key_tensors'])
            logits = self.bob(encrypted, keys['key_tensors'])

            loss = self.ce_loss(logits.reshape(-1, logits.size(-1)), tokens.reshape(-1))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.alice.parameters(), 0.5)
            torch.nn.utils.clip_grad_norm_(self.bob.parameters(), 0.5)

            self.opt_alice.step()
            self.opt_bob.step()
            self.sched_alice.step()
            self.sched_bob.step()

            with torch.no_grad():
                pred_tokens = torch.argmax(logits, dim=-1)
                mask = tokens != 0
                correct = ((pred_tokens == tokens) & mask).sum().item()
                total = mask.sum().item()
                accuracy = correct / total if total > 0 else 0

            history['loss'].append(loss.item())
            history['accuracy'].append(accuracy)

            if epoch % 10 == 0:
                print(f"Epoch {epoch:3d} | Loss: {loss.item():.4f} | Accuracy: {accuracy*100:.1f}%")

            # Early stopping
            if accuracy > best_acc:
                best_acc = accuracy
                patience_counter = 0
            else:
                patience_counter += 1

            if best_acc > 0.95 and patience_counter > 20:
                print(f"Early stopping at epoch {epoch}")
                break

        final_acc = np.mean(history['accuracy'][-10:])
        print(f"\nStage 1 Complete! Final accuracy: {final_acc*100:.1f}%")

        if final_acc > 0.90:
            print("SUCCESS: Bob can decrypt accurately!")
            self.current_stage = 2
        else:
            print("WARNING: Bob accuracy < 90%")

        return history

    def train_stage2_adversary(self, messages, num_epochs=50):
        """Stage 2: Train Eve"""
        print("\n" + "="*70)
        print("STAGE 2: TRAINING EVE (ADVERSARY)")
        print("="*70)

        for param in self.alice.parameters():
            param.requires_grad = False
        for param in self.bob.parameters():
            param.requires_grad = False

        history = {'loss': [], 'accuracy': []}

        for epoch in range(num_epochs):
            batch = np.random.choice(messages, size=min(6, len(messages)), replace=False).tolist()
            tokens = self.processor.batch_encode(batch).to(self.device)
            keys = self.key_manager.generate_keys_for_batch(len(batch), tokens.size(1))

            with torch.no_grad():
                encrypted = self.alice(tokens, keys['key_tensors'])

            self.opt_eve.zero_grad()
            self.eve.train()

            eve_logits = self.eve(encrypted)
            loss = self.ce_loss(eve_logits.reshape(-1, eve_logits.size(-1)), tokens.reshape(-1))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.eve.parameters(), 1.0)
            self.opt_eve.step()

            with torch.no_grad():
                pred_tokens = torch.argmax(eve_logits, dim=-1)
                mask = tokens != 0
                correct = ((pred_tokens == tokens) & mask).sum().item()
                total = mask.sum().item()
                accuracy = correct / total if total > 0 else 0

            history['loss'].append(loss.item())
            history['accuracy'].append(accuracy)

            if epoch % 10 == 0:
                print(f"Epoch {epoch:3d} | Loss: {loss.item():.4f} | Eve Accuracy: {accuracy*100:.1f}%")

            self.sched_eve.step(loss)

        for param in self.alice.parameters():
            param.requires_grad = True
        for param in self.bob.parameters():
            param.requires_grad = True

        print(f"\nStage 2 Complete!")
        self.current_stage = 3

        return history

    def train_stage3_adversarial(self, messages, num_epochs=100):
        """Stage 3: Adversarial fine-tuning with security focus"""
        print("\n" + "="*70)
        print("STAGE 3: ADVERSARIAL FINE-TUNING")
        print("="*70)

        history = {
            'bob_loss': [], 'eve_loss': [],
            'bob_acc': [], 'eve_acc': [], 'ratio': []
        }

        for epoch in range(num_epochs):
            batch = np.random.choice(messages, size=min(6, len(messages)), replace=False).tolist()
            tokens = self.processor.batch_encode(batch).to(self.device)
            keys = self.key_manager.generate_keys_for_batch(len(batch), tokens.size(1))

            # === Train Alice+Bob (reconstruction) ===
            self.opt_alice.zero_grad()
            self.opt_bob.zero_grad()

            self.alice.train()
            self.bob.train()

            encrypted = self.alice(tokens, keys['key_tensors'])
            bob_logits = self.bob(encrypted, keys['key_tensors'])
            bob_loss = self.ce_loss(bob_logits.reshape(-1, bob_logits.size(-1)), tokens.reshape(-1))

            bob_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.alice.parameters(), 0.5)
            torch.nn.utils.clip_grad_norm_(self.bob.parameters(), 0.5)
            self.opt_alice.step()
            self.opt_bob.step()

            # === Train Eve (attack) ===
            self.opt_eve.zero_grad()
            self.eve.train()

            with torch.no_grad():
                encrypted_for_eve = self.alice(tokens, keys['key_tensors'])

            eve_logits = self.eve(encrypted_for_eve)
            eve_loss = self.ce_loss(eve_logits.reshape(-1, eve_logits.size(-1)), tokens.reshape(-1))

            eve_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.eve.parameters(), 1.0)
            self.opt_eve.step()

            # === Adversarial training for Alice (fool Eve) ===
            self.opt_alice.zero_grad()

            encrypted_adv = self.alice(tokens, keys['key_tensors'])
            eve_logits_adv = self.eve(encrypted_adv)

            # Alice wants to MAXIMIZE Eve's loss (make Eve confused)
            adv_loss = -self.ce_loss(eve_logits_adv.reshape(-1, eve_logits_adv.size(-1)), tokens.reshape(-1))

            # Weight adversarial loss
            (adv_loss * 0.3).backward()
            torch.nn.utils.clip_grad_norm_(self.alice.parameters(), 0.5)
            self.opt_alice.step()

            # Calculate metrics
            with torch.no_grad():
                bob_pred = torch.argmax(bob_logits, dim=-1)
                eve_pred = torch.argmax(eve_logits, dim=-1)
                mask = tokens != 0

                bob_acc = ((bob_pred == tokens) & mask).sum().item() / mask.sum().item()
                eve_acc = ((eve_pred == tokens) & mask).sum().item() / mask.sum().item()
                ratio = eve_loss.item() / (bob_loss.item() + 1e-8)

            history['bob_loss'].append(bob_loss.item())
            history['eve_loss'].append(eve_loss.item())
            history['bob_acc'].append(bob_acc)
            history['eve_acc'].append(eve_acc)
            history['ratio'].append(ratio)

            if epoch % 10 == 0:
                print(f"Epoch {epoch:3d} | Bob: {bob_loss.item():.3f} ({bob_acc*100:.1f}%) | "
                      f"Eve: {eve_loss.item():.3f} ({eve_acc*100:.1f}%) | Ratio: {ratio:.2f}x")

            self.sched_alice.step()
            self.sched_bob.step()
            self.sched_eve.step(eve_loss)

        print(f"\nStage 3 Complete!")
        return history

    def full_training(self, messages, stage1_epochs=150, stage2_epochs=50, stage3_epochs=100):
        """Run all 3 stages"""
        print("\n" + "="*70)
        print("FULL STAGED TRAINING")
        print(f"Dataset: {len(messages)} messages")
        print("="*70)

        h1 = self.train_stage1_reconstruction(messages, stage1_epochs)
        h2 = self.train_stage2_adversary(messages, stage2_epochs)
        h3 = self.train_stage3_adversarial(messages, stage3_epochs)

        return {'stage1': h1, 'stage2': h2, 'stage3': h3}


class CryptoEvaluator:
    @staticmethod
    def evaluate(trainer, test_messages):
        print("\n" + "="*70)
        print("FINAL EVALUATION")
        print("="*70)

        results = {'bob_sim': [], 'eve_sim': [], 'examples': []}

        for i, original in enumerate(test_messages[:10]):
            tokens = trainer.processor.encode(original).unsqueeze(0).to(trainer.device)
            keys = trainer.key_manager.generate_keys_for_batch(1, tokens.size(1))

            with torch.no_grad():
                encrypted = trainer.alice(tokens, keys['key_tensors'])
                bob_logits = trainer.bob(encrypted, keys['key_tensors'])
                eve_logits = trainer.eve(encrypted)

                bob_tokens = torch.argmax(bob_logits, dim=-1)
                eve_tokens = torch.argmax(eve_logits, dim=-1)

                bob_msg = trainer.processor.decode(bob_tokens[0])
                eve_msg = trainer.processor.decode(eve_tokens[0])

            bob_sim = SequenceMatcher(None, original, bob_msg).ratio()
            eve_sim = SequenceMatcher(None, original, eve_msg).ratio()

            results['bob_sim'].append(bob_sim)
            results['eve_sim'].append(eve_sim)
            results['examples'].append((original, bob_msg, eve_msg))

            if i < 5:
                print(f"\nExample {i+1}:")
                print(f"  Original: '{original}'")
                print(f"  Bob:      '{bob_msg}' ({bob_sim*100:.1f}%)")
                print(f"  Eve:      '{eve_msg}' ({eve_sim*100:.1f}%)")

        avg_bob = np.mean(results['bob_sim'])
        avg_eve = np.mean(results['eve_sim'])

        print("\n" + "="*70)
        print(f"Bob Similarity: {avg_bob*100:.1f}% {'✓' if avg_bob > 0.9 else '✗'}")
        print(f"Eve Similarity: {avg_eve*100:.1f}% {'✓' if avg_eve < 0.3 else '✗'}")
        print(f"Security Ratio: {avg_bob/max(avg_eve,0.01):.2f}x")
        print("="*70)

        return results

In [2]:
if __name__ == "__main__":
    # Extended dataset
    messages = np.array([
        "Hello World!",
        "This is a test.",
        "Secret message here.",
        "Encryption works!",
        "Neural crypto system.",
        "Testing ABC 123.",
        "Quick brown fox.",
        "The lazy dog jumps.",
        "Secure communication.",
        "Privacy matters most.",
        "Data protection now.",
        "Hidden information."
    ])

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trainer = ImprovedCryptoTrainer(
        vocab_size=98,
        embed_dim=128,
        alice_layers=4,
        bob_layers=4,
        eve_layers=6,
        max_len=64,
        device=device
    )

    # Train with better settings
    history = trainer.full_training(
        messages,
        stage1_epochs=150,
        stage2_epochs=50,
        stage3_epochs=100
    )

    # Evaluate
    results = CryptoEvaluator.evaluate(trainer, messages)

    print("\n✓ Training complete!")
    print(f"\nKey Improvements:")
    print("  - 8 chaotic keys (1-DEC + tent maps)")
    print("  - 4 key-mixing layers in Alice")
    print("  - 6-layer transformer for Eve")
    print("  - Adversarial training to fool Eve")
    print("  - Better optimization (AdamW + cosine annealing)")


FULL STAGED TRAINING
Dataset: 12 messages

STAGE 1: TRAINING ALICE+BOB FOR RECONSTRUCTION
Epoch   0 | Loss: 4.6324 | Accuracy: 0.0%
Epoch  10 | Loss: 2.8628 | Accuracy: 38.8%
Epoch  20 | Loss: 2.3237 | Accuracy: 56.2%
Epoch  30 | Loss: 1.3141 | Accuracy: 79.8%
Epoch  40 | Loss: 0.7859 | Accuracy: 82.8%
Epoch  50 | Loss: 0.6063 | Accuracy: 88.1%
Epoch  60 | Loss: 0.5052 | Accuracy: 89.7%
Epoch  70 | Loss: 0.3074 | Accuracy: 95.1%
Epoch  80 | Loss: 0.1452 | Accuracy: 99.4%
Epoch  90 | Loss: 0.0583 | Accuracy: 100.0%
Epoch 100 | Loss: 0.0481 | Accuracy: 100.0%
Early stopping at epoch 104

Stage 1 Complete! Final accuracy: 100.0%
SUCCESS: Bob can decrypt accurately!

STAGE 2: TRAINING EVE (ADVERSARY)
Epoch   0 | Loss: 4.6056 | Eve Accuracy: 0.0%


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  current = float(metrics)


Epoch  10 | Loss: 3.0916 | Eve Accuracy: 74.8%
Epoch  20 | Loss: 2.2587 | Eve Accuracy: 81.8%
Epoch  30 | Loss: 1.5987 | Eve Accuracy: 92.9%
Epoch  40 | Loss: 1.2161 | Eve Accuracy: 92.6%

Stage 2 Complete!

STAGE 3: ADVERSARIAL FINE-TUNING
Epoch   0 | Bob: 0.035 (100.0%) | Eve: 0.996 (89.8%) | Ratio: 28.07x
Epoch  10 | Bob: 0.042 (100.0%) | Eve: 0.763 (96.7%) | Ratio: 18.35x
Epoch  20 | Bob: 0.061 (100.0%) | Eve: 1.154 (83.8%) | Ratio: 19.06x
Epoch  30 | Bob: 0.053 (100.0%) | Eve: 0.828 (92.2%) | Ratio: 15.57x
Epoch  40 | Bob: 3.696 (9.0%) | Eve: 5.167 (4.1%) | Ratio: 1.40x
Epoch  50 | Bob: 3.377 (18.8%) | Eve: 4.463 (4.3%) | Ratio: 1.32x
Epoch  60 | Bob: 3.639 (17.7%) | Eve: 5.035 (0.0%) | Ratio: 1.38x
Epoch  70 | Bob: 1.927 (43.0%) | Eve: 5.530 (0.0%) | Ratio: 2.87x
Epoch  80 | Bob: 1.711 (50.4%) | Eve: 5.822 (0.0%) | Ratio: 3.40x
Epoch  90 | Bob: 1.203 (64.8%) | Eve: 5.866 (0.0%) | Ratio: 4.88x

Stage 3 Complete!

FINAL EVALUATION

Example 1:
  Original: 'Hello World!'
  Bob:      