In [2]:
"""
MINIMAL FIX VERSION
Strategy: Keep the original working system, just tweak key sensitivity
Don't break what's already working!
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from difflib import SequenceMatcher
import hashlib

# ============ String Processor (UNCHANGED) ============
class StringProcessor:
    def __init__(self, max_len=64):
        self.max_len = max_len
        chars = ' abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.,!?-'
        self.vocab = {c: i for i, c in enumerate(chars)}
        self.vocab['<PAD>'] = len(self.vocab)
        self.vocab['<END>'] = len(self.vocab)
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.vocab_size = len(self.vocab)

    def encode(self, text):
        ids = [self.vocab.get(c, 0) for c in text[:self.max_len-1]]
        ids.append(self.vocab['<END>'])
        while len(ids) < self.max_len:
            ids.append(self.vocab['<PAD>'])
        return torch.LongTensor(ids)

    def decode(self, ids):
        chars = []
        for i in ids:
            c = self.inv_vocab.get(int(i), '')
            if c == '<END>':
                break
            if c != '<PAD>':
                chars.append(c)
        return ''.join(chars)

    def batch_encode(self, texts):
        return torch.stack([self.encode(t) for t in texts])


# ============ Autoencoder (UNCHANGED) ============
class CryptoAutoencoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab_size-2)

        self.encoder = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, vocab_size)
        )

    def forward(self, tokens, return_embeddings=False):
        emb = self.embedding(tokens)
        enc = self.encoder(emb)

        if return_embeddings:
            return enc

        logits = self.decoder(enc)
        return logits


# ============ Key Generator (UNCHANGED) ============
class KeyGenerator:
    @staticmethod
    def generate_from_message(message, key_size=128):
        msg_hash = hashlib.sha256(message.encode()).hexdigest()
        seed = int(msg_hash[:8], 16)
        np.random.seed(seed)
        key = torch.FloatTensor(np.random.randn(key_size))
        return key

    @staticmethod
    def generate_random(key_size=128):
        return torch.randn(key_size)


# ============ TWEAKED: Slightly Stronger Key Layer ============
class KeyDependentEncryption(nn.Module):
    """MINIMAL CHANGE: Just add one more nonlinearity"""
    def __init__(self, embed_dim=128):
        super().__init__()
        # Original transform + one extra layer
        self.key_transform = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),
            nn.Tanh(),
            nn.Linear(embed_dim * 2, embed_dim),
            nn.Tanh()  # Extra nonlinearity
        )

        self.encrypt_mix = nn.Sequential(
            nn.Linear(embed_dim * 2, embed_dim),
            nn.Tanh()
        )

    def encrypt(self, embeddings, key):
        if len(key.shape) == 1:
            key = key.unsqueeze(0).expand(embeddings.size(0), -1)

        key_features = self.key_transform(key)
        key_expanded = key_features.unsqueeze(1).expand(-1, embeddings.size(1), -1)

        combined = torch.cat([embeddings, key_expanded], dim=-1)
        encrypted = self.encrypt_mix(combined)

        # SLIGHT TWEAK: Multiply by 2 for stronger key effect
        encrypted = encrypted * (1 + key_expanded * 2)

        return encrypted

    def decrypt(self, encrypted, key):
        if len(key.shape) == 1:
            key = key.unsqueeze(0).expand(encrypted.size(0), -1)

        key_features = self.key_transform(key)
        key_expanded = key_features.unsqueeze(1).expand(-1, encrypted.size(1), -1)

        # Reverse with same factor
        decrypted = encrypted / (1 + key_expanded * 2 + 1e-8)

        return decrypted


# ============ Eve (UNCHANGED) ============
class EveAttacker(nn.Module):
    def __init__(self, vocab_size, embed_dim=128):
        super().__init__()
        self.attack_network = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),
            nn.LayerNorm(embed_dim * 2),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(embed_dim * 2, embed_dim * 2),
            nn.LayerNorm(embed_dim * 2),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(embed_dim * 2, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, vocab_size)
        )

    def forward(self, encrypted_embeddings):
        return self.attack_network(encrypted_embeddings)


# ============ CRITICAL FIX: Gentler Phase 2 Training ============
class NeuralCryptoSystem:
    def __init__(self, vocab_size, embed_dim=128, device='cuda'):
        self.device = device
        self.processor = StringProcessor()

        self.autoencoder = CryptoAutoencoder(vocab_size, embed_dim).to(device)
        self.crypto_layer = KeyDependentEncryption(embed_dim).to(device)
        self.eve = EveAttacker(vocab_size, embed_dim).to(device)

        self.opt_main = optim.Adam(
            list(self.autoencoder.parameters()) + list(self.crypto_layer.parameters()),
            lr=0.001
        )
        self.opt_eve = optim.Adam(self.eve.parameters(), lr=0.0005)

        self.criterion = nn.CrossEntropyLoss()

    def train_phase1_reconstruction(self, messages, epochs=200):
        """UNCHANGED - This works perfectly"""
        print("\n" + "="*70)
        print("PHASE 1: RECONSTRUCTION TRAINING (NO ENCRYPTION)")
        print("="*70)

        for epoch in range(epochs):
            total_loss = 0
            total_acc = 0

            for msg in messages:
                tokens = self.processor.encode(msg).unsqueeze(0).to(self.device)

                self.opt_main.zero_grad()
                logits = self.autoencoder(tokens)

                loss = self.criterion(logits.view(-1, logits.size(-1)), tokens.view(-1))
                loss.backward()
                self.opt_main.step()

                pred = torch.argmax(logits, dim=-1)
                acc = (pred == tokens).float().mean().item()

                total_loss += loss.item()
                total_acc += acc

            if epoch % 20 == 0:
                avg_acc = total_acc / len(messages)
                print(f"Epoch {epoch:3d} | Loss: {total_loss/len(messages):.4f} | Acc: {avg_acc*100:.1f}%")

        final_acc = total_acc / len(messages)
        print(f"\n✓ Phase 1 Complete! Accuracy: {final_acc*100:.1f}%")
        return final_acc > 0.95

    def train_phase2_with_encryption(self, messages, epochs=100):
        """
        CRITICAL FIX: Gentle training that preserves Phase 1 learning
        - Keep original training intensity
        - Just add small adversarial component
        """
        print("\n" + "="*70)
        print("PHASE 2: GENTLE ENCRYPTION TRAINING")
        print("="*70)

        for epoch in range(epochs):
            batch_msgs = np.random.choice(messages, min(8, len(messages)), replace=True)
            tokens = self.processor.batch_encode(batch_msgs).to(self.device)

            keys = torch.stack([KeyGenerator.generate_random(128) for _ in batch_msgs]).to(self.device)

            # === Train Alice+Bob (ORIGINAL intensity) ===
            self.opt_main.zero_grad()

            embeddings = self.autoencoder(tokens, return_embeddings=True)
            encrypted = self.crypto_layer.encrypt(embeddings, keys)
            decrypted = self.crypto_layer.decrypt(encrypted, keys)
            logits = self.autoencoder.decoder(decrypted)

            loss_reconstruction = self.criterion(logits.view(-1, logits.size(-1)), tokens.view(-1))
            loss_reconstruction.backward()
            torch.nn.utils.clip_grad_norm_(self.autoencoder.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(self.crypto_layer.parameters(), 1.0)
            self.opt_main.step()

            # === Train Eve (ORIGINAL intensity) ===
            self.opt_eve.zero_grad()

            with torch.no_grad():
                embeddings = self.autoencoder(tokens, return_embeddings=True)
                encrypted = self.crypto_layer.encrypt(embeddings, keys)

            eve_logits = self.eve(encrypted)
            loss_eve = self.criterion(eve_logits.view(-1, eve_logits.size(-1)), tokens.view(-1))
            loss_eve.backward()
            self.opt_eve.step()

            # === MINIMAL Adversarial (only after epoch 30) ===
            if epoch > 30:  # Wait until Bob is stable!
                self.opt_main.zero_grad()

                embeddings = self.autoencoder(tokens, return_embeddings=True)
                encrypted = self.crypto_layer.encrypt(embeddings, keys)
                eve_attack = self.eve(encrypted)

                loss_adversarial = -self.criterion(eve_attack.view(-1, eve_attack.size(-1)), tokens.view(-1))
                (loss_adversarial * 0.1).backward()  # VERY SMALL weight!
                self.opt_main.step()

            if epoch % 10 == 0:
                bob_pred = torch.argmax(logits, dim=-1)
                eve_pred = torch.argmax(eve_logits, dim=-1)

                bob_acc = (bob_pred == tokens).float().mean().item()
                eve_acc = (eve_pred == tokens).float().mean().item()
                ratio = loss_eve.item() / (loss_reconstruction.item() + 1e-8)

                print(f"Epoch {epoch:3d} | Bob: {loss_reconstruction.item():.3f} ({bob_acc*100:.1f}%) | "
                      f"Eve: {loss_eve.item():.3f} ({eve_acc*100:.1f}%) | Ratio: {ratio:.2f}x")

        print("\n✓ Phase 2 Complete!")

    def encrypt_message(self, message, key=None):
        if key is None:
            key = KeyGenerator.generate_from_message(message, 128)

        tokens = self.processor.encode(message).unsqueeze(0).to(self.device)
        key = key.to(self.device)

        with torch.no_grad():
            embeddings = self.autoencoder(tokens, return_embeddings=True)
            encrypted = self.crypto_layer.encrypt(embeddings, key)

        return encrypted, key

    def decrypt_message(self, encrypted, key):
        key = key.to(self.device)

        with torch.no_grad():
            decrypted = self.crypto_layer.decrypt(encrypted, key)
            logits = self.autoencoder.decoder(decrypted)
            tokens = torch.argmax(logits, dim=-1)
            message = self.processor.decode(tokens[0])

        return message

    def eve_attack(self, encrypted):
        with torch.no_grad():
            logits = self.eve(encrypted)
            tokens = torch.argmax(logits, dim=-1)
            message = self.processor.decode(tokens[0])

        return message

    def evaluate(self, test_messages):
        print("\n" + "="*70)
        print("FINAL EVALUATION")
        print("="*70)

        bob_sims = []
        eve_sims = []
        key_sens = []

        for msg in test_messages[:10]:
            encrypted, correct_key = self.encrypt_message(msg)

            bob_msg = self.decrypt_message(encrypted, correct_key)
            eve_msg = self.eve_attack(encrypted)

            wrong_sims = []
            for _ in range(3):
                wrong_key = KeyGenerator.generate_random(128)
                wrong_msg = self.decrypt_message(encrypted, wrong_key)
                wrong_sims.append(SequenceMatcher(None, msg, wrong_msg).ratio())

            bob_sim = SequenceMatcher(None, msg, bob_msg).ratio()
            eve_sim = SequenceMatcher(None, msg, eve_msg).ratio()
            avg_wrong = np.mean(wrong_sims)

            bob_sims.append(bob_sim)
            eve_sims.append(eve_sim)
            key_sens.append(1 - avg_wrong)

            print(f"\nOriginal:  '{msg}'")
            print(f"Bob:       '{bob_msg}' ({bob_sim*100:.1f}%)")
            print(f"Eve:       '{eve_msg}' ({eve_sim*100:.1f}%)")
            print(f"Wrong key: Avg {avg_wrong*100:.1f}% similarity")

        avg_bob = np.mean(bob_sims)
        avg_eve = np.mean(eve_sims)
        avg_key_sens = np.mean(key_sens)
        security_ratio = avg_bob / max(avg_eve, 0.01)

        print("\n" + "="*70)
        print("FINAL METRICS")
        print("="*70)
        print(f"Bob Similarity:    {avg_bob*100:.1f}% {'✓' if avg_bob > 0.90 else '✗'}")
        print(f"Eve Similarity:    {avg_eve*100:.1f}% {'✓' if avg_eve < 0.30 else '⚠️'}")
        print(f"Key Sensitivity:   {avg_key_sens*100:.1f}% {'✓' if avg_key_sens > 0.50 else '⚠️'}")
        print(f"Security Ratio:    {security_ratio:.2f}x {'✓' if security_ratio > 3.0 else '⚠️'}")

        if avg_bob > 0.90:
            print("\n✓ Bob: EXCELLENT decryption!")
        if avg_eve < 0.20:
            print("✓ Eve: CANNOT break encryption!")
        if avg_key_sens > 0.60:
            print("✓ Keys: Good sensitivity!")

        if avg_bob > 0.90 and security_ratio > 3:
            print("\n🎉 SUCCESS! System works well!")

        print("="*70)


# ============ Dataset ============
LARGE_DATASET = [
    "Hello World!", "This is a test.", "Secret message here.",
    "Encryption works!", "Neural crypto system.", "Testing ABC 123.",
    "Quick brown fox.", "The lazy dog jumps.",
    "Machine learning is powerful.", "Deep neural networks.", "Artificial intelligence evolves.",
    "Natural language processing.", "Computer vision tasks.", "Reinforcement learning agent.",
    "Gradient descent optimizer.", "Backpropagation algorithm.", "Model accuracy improves.",
    "Training loss decreases.", "Validation metrics good.", "Test results excellent.",
    "Good morning everyone.", "How are you today?", "See you tomorrow.",
    "Thank you very much.", "Great job well done.", "Nice work keep going.",
    "Data science project.", "Python programming fun.", "Code quality matters.",
    "Documentation complete.", "Production ready now.", "System performance optimal.",
    "Hi there!", "Goodbye!", "Yes indeed.", "No problem.", "Of course.",
    "Absolutely right.", "Definitely true.", "Maybe later.", "Not now.", "Soon enough.",
    "The sun rises early.", "Birds sing beautifully.", "Rivers flow downstream.",
    "Mountains stand tall.", "Oceans are deep.", "Stars shine bright.",
    "Music sounds wonderful.", "Books tell stories.", "Art inspires people.",
    "Science explains nature.", "Math solves problems.", "History teaches lessons.",
]


# ============ Main ============
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    print(f"Dataset: {len(LARGE_DATASET)} messages\n")

    processor = StringProcessor()
    system = NeuralCryptoSystem(processor.vocab_size, embed_dim=128, device=device)

    success = system.train_phase1_reconstruction(LARGE_DATASET, epochs=200)

    if success:
        system.train_phase2_with_encryption(LARGE_DATASET, epochs=100)
        system.evaluate(LARGE_DATASET)

Device: cuda
Dataset: 54 messages


PHASE 1: RECONSTRUCTION TRAINING (NO ENCRYPTION)
Epoch   0 | Loss: 0.6427 | Acc: 89.6%
Epoch  20 | Loss: 0.0003 | Acc: 100.0%
Epoch  40 | Loss: 0.0001 | Acc: 100.0%
Epoch  60 | Loss: 0.0000 | Acc: 100.0%
Epoch  80 | Loss: 0.0000 | Acc: 100.0%
Epoch 100 | Loss: 0.0000 | Acc: 100.0%
Epoch 120 | Loss: 0.0000 | Acc: 100.0%
Epoch 140 | Loss: 0.0000 | Acc: 100.0%
Epoch 160 | Loss: 0.0000 | Acc: 100.0%
Epoch 180 | Loss: 0.0000 | Acc: 100.0%

✓ Phase 1 Complete! Accuracy: 100.0%

PHASE 2: GENTLE ENCRYPTION TRAINING
Epoch   0 | Bob: 3.323 (53.5%) | Eve: 3.973 (11.1%) | Ratio: 1.20x
Epoch  10 | Bob: 0.020 (99.8%) | Eve: 2.016 (71.3%) | Ratio: 99.43x
Epoch  20 | Bob: 0.011 (99.8%) | Eve: 1.599 (75.4%) | Ratio: 146.77x
Epoch  30 | Bob: 0.001 (100.0%) | Eve: 1.112 (88.9%) | Ratio: 919.27x
Epoch  40 | Bob: 0.015 (99.6%) | Eve: 1.170 (80.7%) | Ratio: 75.55x
Epoch  50 | Bob: 0.014 (99.8%) | Eve: 1.440 (69.5%) | Ratio: 102.75x
Epoch  60 | Bob: 0.005 (100.0%) | Eve: 1

In [3]:
"""
NEURAL CRYPTO SYSTEM WITH LARGE DATASET SUPPORT
Supports multiple open-source datasets from Hugging Face
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from difflib import SequenceMatcher
import hashlib

# ============ String Processor ============
class StringProcessor:
    def __init__(self, max_len=64):
        self.max_len = max_len
        chars = ' abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.,!?-\'"'
        self.vocab = {c: i for i, c in enumerate(chars)}
        self.vocab['<PAD>'] = len(self.vocab)
        self.vocab['<END>'] = len(self.vocab)
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.vocab_size = len(self.vocab)

    def encode(self, text):
        ids = [self.vocab.get(c, 0) for c in text[:self.max_len-1]]
        ids.append(self.vocab['<END>'])
        while len(ids) < self.max_len:
            ids.append(self.vocab['<PAD>'])
        return torch.LongTensor(ids)

    def decode(self, ids):
        chars = []
        for i in ids:
            c = self.inv_vocab.get(int(i), '')
            if c == '<END>':
                break
            if c != '<PAD>':
                chars.append(c)
        return ''.join(chars)

    def batch_encode(self, texts):
        return torch.stack([self.encode(t) for t in texts])


# ============ Dataset Loader ============
class DatasetLoader:
    """Load various open-source datasets"""

    @staticmethod
    def load_dataset(dataset_name='imdb', max_samples=10000, max_len=64):
        """
        Load dataset from various sources

        Available datasets:
        - 'imdb': Movie reviews (Hugging Face)
        - 'ag_news': News articles (Hugging Face)
        - 'yelp': Restaurant reviews (Hugging Face)
        - 'sst2': Sentiment analysis (Hugging Face)
        - 'tweets': Twitter sentiment (Hugging Face)
        - 'wikitext': Wikipedia articles (Hugging Face)
        - 'news': News headlines (Hugging Face)
        """
        print(f"\nLoading dataset: {dataset_name}")
        print(f"Max samples: {max_samples}")

        try:
            from datasets import load_dataset as hf_load_dataset

            if dataset_name == 'imdb':
                # Movie reviews - balanced positive/negative
                dataset = hf_load_dataset('imdb', split='train')
                texts = [item['text'][:max_len] for item in dataset.select(range(min(max_samples, len(dataset))))]
                print(f"✓ Loaded {len(texts)} movie reviews from IMDB")

            elif dataset_name == 'ag_news':
                # News articles with categories
                dataset = hf_load_dataset('ag_news', split='train')
                texts = [item['text'][:max_len] for item in dataset.select(range(min(max_samples, len(dataset))))]
                print(f"✓ Loaded {len(texts)} news articles from AG News")

            elif dataset_name == 'yelp':
                # Restaurant reviews
                dataset = hf_load_dataset('yelp_review_full', split='train')
                texts = [item['text'][:max_len] for item in dataset.select(range(min(max_samples, len(dataset))))]
                print(f"✓ Loaded {len(texts)} restaurant reviews from Yelp")

            elif dataset_name == 'sst2':
                # Stanford Sentiment Treebank
                dataset = hf_load_dataset('glue', 'sst2', split='train')
                texts = [item['sentence'][:max_len] for item in dataset.select(range(min(max_samples, len(dataset))))]
                print(f"✓ Loaded {len(texts)} sentences from SST-2")

            elif dataset_name == 'tweets':
                # Twitter sentiment
                dataset = hf_load_dataset('tweet_eval', 'sentiment', split='train')
                texts = [item['text'][:max_len] for item in dataset.select(range(min(max_samples, len(dataset))))]
                print(f"✓ Loaded {len(texts)} tweets")

            elif dataset_name == 'wikitext':
                # Wikipedia articles
                dataset = hf_load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
                # Filter out empty lines and split into sentences
                texts = []
                for item in dataset:
                    text = item['text'].strip()
                    if text and len(text) > 10:
                        # Split into sentences
                        sentences = text.split('. ')
                        for sent in sentences:
                            if 10 <= len(sent) <= max_len:
                                texts.append(sent)
                                if len(texts) >= max_samples:
                                    break
                    if len(texts) >= max_samples:
                        break
                print(f"✓ Loaded {len(texts)} Wikipedia sentences")

            elif dataset_name == 'news':
                # News headlines
                dataset = hf_load_dataset('Fraser/news-category-dataset', split='train')
                texts = [item['headline'][:max_len] for item in dataset.select(range(min(max_samples, len(dataset))))]
                print(f"✓ Loaded {len(texts)} news headlines")

            else:
                print(f"⚠️ Unknown dataset: {dataset_name}")
                return DatasetLoader.get_default_dataset()

            # Clean and filter texts
            cleaned_texts = []
            for text in texts:
                text = text.strip()
                if 5 <= len(text) <= max_len:  # Reasonable length
                    cleaned_texts.append(text)

            print(f"✓ After cleaning: {len(cleaned_texts)} valid texts")
            return np.array(cleaned_texts)

        except ImportError:
            print("⚠️ Hugging Face datasets not installed!")
            print("Installing: pip install datasets")
            print("\nUsing default dataset instead...\n")
            return DatasetLoader.get_default_dataset()

        except Exception as e:
            print(f"⚠️ Error loading dataset: {e}")
            print("Using default dataset instead...\n")
            return DatasetLoader.get_default_dataset()

    @staticmethod
    def get_default_dataset():
        """Fallback to default dataset if loading fails"""
        return np.array([
            "Hello World!", "This is a test.", "Secret message here.",
            "Encryption works!", "Neural crypto system.", "Testing ABC 123.",
            "Quick brown fox.", "The lazy dog jumps.",
            "Machine learning is powerful.", "Deep neural networks.",
            "Artificial intelligence evolves.", "Natural language processing.",
            "Computer vision tasks.", "Reinforcement learning agent.",
            "Gradient descent optimizer.", "Backpropagation algorithm.",
            "Model accuracy improves.", "Training loss decreases.",
            "Validation metrics good.", "Test results excellent.",
            "Good morning everyone.", "How are you today?",
            "See you tomorrow.", "Thank you very much.",
            "Great job well done.", "Nice work keep going.",
            "Data science project.", "Python programming fun.",
            "Code quality matters.", "Documentation complete.",
            "Production ready now.", "System performance optimal.",
        ])

    @staticmethod
    def download_text_file(url, max_samples=10000, max_len=64):
        """Download text file from URL and extract sentences"""
        try:
            import requests
            print(f"\nDownloading from: {url}")
            response = requests.get(url)
            text = response.text

            # Split into lines/sentences
            lines = text.split('\n')
            texts = []
            for line in lines:
                line = line.strip()
                if 5 <= len(line) <= max_len:
                    texts.append(line)
                    if len(texts) >= max_samples:
                        break

            print(f"✓ Loaded {len(texts)} lines from URL")
            return np.array(texts)

        except Exception as e:
            print(f"⚠️ Error downloading: {e}")
            return DatasetLoader.get_default_dataset()


# ============ Autoencoder ============
class CryptoAutoencoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab_size-2)

        self.encoder = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, vocab_size)
        )

    def forward(self, tokens, return_embeddings=False):
        emb = self.embedding(tokens)
        enc = self.encoder(emb)

        if return_embeddings:
            return enc

        logits = self.decoder(enc)
        return logits


# ============ Key Generator ============
class KeyGenerator:
    @staticmethod
    def generate_from_message(message, key_size=128):
        msg_hash = hashlib.sha256(message.encode()).hexdigest()
        seed = int(msg_hash[:8], 16)
        np.random.seed(seed)
        key = torch.FloatTensor(np.random.randn(key_size))
        return key

    @staticmethod
    def generate_random(key_size=128):
        return torch.randn(key_size)


# ============ Key-Dependent Encryption ============
class KeyDependentEncryption(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()
        self.key_transform = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),
            nn.Tanh(),
            nn.Linear(embed_dim * 2, embed_dim),
            nn.Tanh()
        )

        self.encrypt_mix = nn.Sequential(
            nn.Linear(embed_dim * 2, embed_dim),
            nn.Tanh()
        )

    def encrypt(self, embeddings, key):
        if len(key.shape) == 1:
            key = key.unsqueeze(0).expand(embeddings.size(0), -1)

        key_features = self.key_transform(key)
        key_expanded = key_features.unsqueeze(1).expand(-1, embeddings.size(1), -1)

        combined = torch.cat([embeddings, key_expanded], dim=-1)
        encrypted = self.encrypt_mix(combined)
        encrypted = encrypted * (1 + key_expanded * 2)

        return encrypted

    def decrypt(self, encrypted, key):
        if len(key.shape) == 1:
            key = key.unsqueeze(0).expand(encrypted.size(0), -1)

        key_features = self.key_transform(key)
        key_expanded = key_features.unsqueeze(1).expand(-1, encrypted.size(1), -1)
        decrypted = encrypted / (1 + key_expanded * 2 + 1e-8)

        return decrypted


# ============ Eve Attacker ============
class EveAttacker(nn.Module):
    def __init__(self, vocab_size, embed_dim=128):
        super().__init__()
        self.attack_network = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),
            nn.LayerNorm(embed_dim * 2),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(embed_dim * 2, embed_dim * 2),
            nn.LayerNorm(embed_dim * 2),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(embed_dim * 2, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, vocab_size)
        )

    def forward(self, encrypted_embeddings):
        return self.attack_network(encrypted_embeddings)


# ============ Neural Crypto System ============
class NeuralCryptoSystem:
    def __init__(self, vocab_size, embed_dim=128, device='cuda'):
        self.device = device
        self.processor = StringProcessor()

        self.autoencoder = CryptoAutoencoder(vocab_size, embed_dim).to(device)
        self.crypto_layer = KeyDependentEncryption(embed_dim).to(device)
        self.eve = EveAttacker(vocab_size, embed_dim).to(device)

        self.opt_main = optim.Adam(
            list(self.autoencoder.parameters()) + list(self.crypto_layer.parameters()),
            lr=0.001
        )
        self.opt_eve = optim.Adam(self.eve.parameters(), lr=0.0005)

        self.criterion = nn.CrossEntropyLoss()

    def train_phase1_reconstruction(self, messages, epochs=200, batch_size=32):
        """Train with mini-batches for large datasets"""
        print("\n" + "="*70)
        print("PHASE 1: RECONSTRUCTION TRAINING (NO ENCRYPTION)")
        print("="*70)

        for epoch in range(epochs):
            # Shuffle data
            indices = np.random.permutation(len(messages))
            total_loss = 0
            total_acc = 0
            num_batches = 0

            # Mini-batch training
            for i in range(0, len(messages), batch_size):
                batch_indices = indices[i:i+batch_size]
                batch_msgs = messages[batch_indices]

                tokens = self.processor.batch_encode(batch_msgs).to(self.device)

                self.opt_main.zero_grad()
                logits = self.autoencoder(tokens)

                loss = self.criterion(logits.view(-1, logits.size(-1)), tokens.view(-1))
                loss.backward()
                self.opt_main.step()

                pred = torch.argmax(logits, dim=-1)
                acc = (pred == tokens).float().mean().item()

                total_loss += loss.item()
                total_acc += acc
                num_batches += 1

            if epoch % 20 == 0:
                avg_loss = total_loss / num_batches
                avg_acc = total_acc / num_batches
                print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | Acc: {avg_acc*100:.1f}%")

        final_acc = total_acc / num_batches
        print(f"\n✓ Phase 1 Complete! Accuracy: {final_acc*100:.1f}%")
        return final_acc > 0.95

    def train_phase2_with_encryption(self, messages, epochs=100, batch_size=8):
        """Phase 2 with large dataset support"""
        print("\n" + "="*70)
        print("PHASE 2: ENCRYPTION TRAINING")
        print("="*70)

        for epoch in range(epochs):
            # Sample random batch
            batch_indices = np.random.choice(len(messages), min(batch_size, len(messages)), replace=False)
            batch_msgs = messages[batch_indices]
            tokens = self.processor.batch_encode(batch_msgs).to(self.device)

            keys = torch.stack([KeyGenerator.generate_random(128) for _ in batch_msgs]).to(self.device)

            # Train Alice+Bob
            self.opt_main.zero_grad()
            embeddings = self.autoencoder(tokens, return_embeddings=True)
            encrypted = self.crypto_layer.encrypt(embeddings, keys)
            decrypted = self.crypto_layer.decrypt(encrypted, keys)
            logits = self.autoencoder.decoder(decrypted)

            loss_reconstruction = self.criterion(logits.view(-1, logits.size(-1)), tokens.view(-1))
            loss_reconstruction.backward()
            torch.nn.utils.clip_grad_norm_(self.autoencoder.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(self.crypto_layer.parameters(), 1.0)
            self.opt_main.step()

            # Train Eve
            self.opt_eve.zero_grad()
            with torch.no_grad():
                embeddings = self.autoencoder(tokens, return_embeddings=True)
                encrypted = self.crypto_layer.encrypt(embeddings, keys)

            eve_logits = self.eve(encrypted)
            loss_eve = self.criterion(eve_logits.view(-1, eve_logits.size(-1)), tokens.view(-1))
            loss_eve.backward()
            self.opt_eve.step()

            # Adversarial training
            if epoch > 30:
                self.opt_main.zero_grad()
                embeddings = self.autoencoder(tokens, return_embeddings=True)
                encrypted = self.crypto_layer.encrypt(embeddings, keys)
                eve_attack = self.eve(encrypted)

                loss_adversarial = -self.criterion(eve_attack.view(-1, eve_attack.size(-1)), tokens.view(-1))
                (loss_adversarial * 0.1).backward()
                self.opt_main.step()

            if epoch % 10 == 0:
                bob_pred = torch.argmax(logits, dim=-1)
                eve_pred = torch.argmax(eve_logits, dim=-1)

                bob_acc = (bob_pred == tokens).float().mean().item()
                eve_acc = (eve_pred == tokens).float().mean().item()
                ratio = loss_eve.item() / (loss_reconstruction.item() + 1e-8)

                print(f"Epoch {epoch:3d} | Bob: {loss_reconstruction.item():.3f} ({bob_acc*100:.1f}%) | "
                      f"Eve: {loss_eve.item():.3f} ({eve_acc*100:.1f}%) | Ratio: {ratio:.2f}x")

        print("\n✓ Phase 2 Complete!")

    def encrypt_message(self, message, key=None):
        if key is None:
            key = KeyGenerator.generate_from_message(message, 128)

        tokens = self.processor.encode(message).unsqueeze(0).to(self.device)
        key = key.to(self.device)

        with torch.no_grad():
            embeddings = self.autoencoder(tokens, return_embeddings=True)
            encrypted = self.crypto_layer.encrypt(embeddings, key)

        return encrypted, key

    def decrypt_message(self, encrypted, key):
        key = key.to(self.device)

        with torch.no_grad():
            decrypted = self.crypto_layer.decrypt(encrypted, key)
            logits = self.autoencoder.decoder(decrypted)
            tokens = torch.argmax(logits, dim=-1)
            message = self.processor.decode(tokens[0])

        return message

    def eve_attack(self, encrypted):
        with torch.no_grad():
            logits = self.eve(encrypted)
            tokens = torch.argmax(logits, dim=-1)
            message = self.processor.decode(tokens[0])

        return message

    def evaluate(self, test_messages):
        print("\n" + "="*70)
        print("FINAL EVALUATION")
        print("="*70)

        bob_sims = []
        eve_sims = []
        key_sens = []

        # Evaluate on random sample
        eval_msgs = np.random.choice(test_messages, min(10, len(test_messages)), replace=False)

        for msg in eval_msgs:
            encrypted, correct_key = self.encrypt_message(msg)

            bob_msg = self.decrypt_message(encrypted, correct_key)
            eve_msg = self.eve_attack(encrypted)

            wrong_sims = []
            for _ in range(3):
                wrong_key = KeyGenerator.generate_random(128)
                wrong_msg = self.decrypt_message(encrypted, wrong_key)
                wrong_sims.append(SequenceMatcher(None, msg, wrong_msg).ratio())

            bob_sim = SequenceMatcher(None, msg, bob_msg).ratio()
            eve_sim = SequenceMatcher(None, msg, eve_msg).ratio()
            avg_wrong = np.mean(wrong_sims)

            bob_sims.append(bob_sim)
            eve_sims.append(eve_sim)
            key_sens.append(1 - avg_wrong)

            print(f"\nOriginal:  '{msg[:50]}...'")
            print(f"Bob:       '{bob_msg[:50]}...' ({bob_sim*100:.1f}%)")
            print(f"Eve:       '{eve_msg[:50]}...' ({eve_sim*100:.1f}%)")
            print(f"Wrong key: Avg {avg_wrong*100:.1f}% similarity")

        avg_bob = np.mean(bob_sims)
        avg_eve = np.mean(eve_sims)
        avg_key_sens = np.mean(key_sens)
        security_ratio = avg_bob / max(avg_eve, 0.01)

        print("\n" + "="*70)
        print("FINAL METRICS")
        print("="*70)
        print(f"Bob Similarity:    {avg_bob*100:.1f}% {'✓' if avg_bob > 0.90 else '✗'}")
        print(f"Eve Similarity:    {avg_eve*100:.1f}% {'✓' if avg_eve < 0.30 else '⚠️'}")
        print(f"Key Sensitivity:   {avg_key_sens*100:.1f}% {'✓' if avg_key_sens > 0.50 else '⚠️'}")
        print(f"Security Ratio:    {security_ratio:.2f}x {'✓' if security_ratio > 3.0 else '⚠️'}")

        if avg_bob > 0.90:
            print("\n✓ Bob: EXCELLENT decryption!")
        if avg_eve < 0.20:
            print("✓ Eve: CANNOT break encryption!")
        if avg_key_sens > 0.60:
            print("✓ Keys: Good sensitivity!")

        if avg_bob > 0.90 and security_ratio > 3:
            print("\n🎉 SUCCESS! System works well!")

        print("="*70)


# ============ Main ============
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    # ============ CHOOSE YOUR DATASET ============
    # Option 1: Load from Hugging Face (recommended)
    DATASET = DatasetLoader.load_dataset(
        dataset_name='imdb',      # Options: 'imdb', 'ag_news', 'yelp', 'sst2', 'tweets', 'wikitext', 'news'
        max_samples=10000,         # Number of samples to load
        max_len=64                 # Maximum text length
    )

    # Option 2: Load from URL (uncomment to use)
    # DATASET = DatasetLoader.download_text_file(
    #     url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt',
    #     max_samples=10000,
    #     max_len=64
    # )

    # Option 3: Use default dataset
    # DATASET = DatasetLoader.get_default_dataset()

    print(f"\nFinal dataset size: {len(DATASET)} messages\n")

    # Initialize system
    processor = StringProcessor()
    system = NeuralCryptoSystem(processor.vocab_size, embed_dim=128, device=device)

    # Train
    success = system.train_phase1_reconstruction(DATASET, epochs=100, batch_size=32)

    if success:
        system.train_phase2_with_encryption(DATASET, epochs=100, batch_size=8)
        system.evaluate(DATASET)

Device: cuda

Loading dataset: imdb
Max samples: 10000


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

plain_text/unsupervised-00000-of-00001.p(…):   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

✓ Loaded 10000 movie reviews from IMDB
✓ After cleaning: 10000 valid texts

Final dataset size: 10000 messages


PHASE 1: RECONSTRUCTION TRAINING (NO ENCRYPTION)
Epoch   0 | Loss: 0.1110 | Acc: 98.7%
Epoch  20 | Loss: 0.0000 | Acc: 100.0%
Epoch  40 | Loss: 0.0000 | Acc: 100.0%
Epoch  60 | Loss: 0.0000 | Acc: 100.0%
Epoch  80 | Loss: 0.0000 | Acc: 100.0%

✓ Phase 1 Complete! Accuracy: 100.0%

PHASE 2: ENCRYPTION TRAINING
Epoch   0 | Bob: 12.963 (0.8%) | Eve: 4.416 (0.4%) | Ratio: 0.34x
Epoch  10 | Bob: 0.016 (99.8%) | Eve: 2.984 (59.8%) | Ratio: 189.57x
Epoch  20 | Bob: 0.011 (99.6%) | Eve: 2.180 (75.4%) | Ratio: 192.19x
Epoch  30 | Bob: 0.002 (100.0%) | Eve: 1.837 (74.8%) | Ratio: 1020.23x
Epoch  40 | Bob: 0.005 (99.8%) | Eve: 1.624 (72.9%) | Ratio: 351.04x
Epoch  50 | Bob: 0.026 (99.4%) | Eve: 1.519 (72.5%) | Ratio: 58.21x
Epoch  60 | Bob: 0.056 (98.4%) | Eve: 2.225 (48.2%) | Ratio: 39.64x
Epoch  70 | Bob: 0.033 (99.6%) | Eve: 2.616 (40.4%) | Ratio: 78.47x
Epoch  80 | Bob: 0.010 (99.8

In [7]:
!jupyter nbconvert --clear-output --ClearMetadataPreprocessor.enabled=True --inplace nn1.ipynb

This application is used to convert notebook files (*.ipynb)
        to various other formats.


Options
The options below are convenience aliases to configurable class-options,
as listed in the "Equivalent to" description-line of the aliases.
To see all configurable class-options for some <cmd>, use:
    <cmd> --help-all

--debug
    set log level to logging.DEBUG (maximize logging output)
    Equivalent to: [--Application.log_level=10]
--show-config
    Show the application's configuration (human-readable format)
    Equivalent to: [--Application.show_config=True]
--show-config-json
    Show the application's configuration (json format)
    Equivalent to: [--Application.show_config_json=True]
--generate-config
    generate default config file
    Equivalent to: [--JupyterApp.generate_config=True]
-y
    Answer yes to any questions instead of prompting.
    Equivalent to: [--JupyterApp.answer_yes=True]
--execute
    Execute the notebook prior to export.
    Equivalent to: [--ExecutePr