In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random, string, math, os, re
import pandas as pd
from tqdm import tqdm
from typing import Tuple

#configuration

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CHARS = string.ascii_uppercase
PAD_LABEL = 26
A2I = {c: i for i, c in enumerate(CHARS)}
I2A = {i: c for i, c in enumerate(CHARS)}
I2A[PAD_LABEL] = "_"
VOCAB_SIZE = len(CHARS)
K_MAX = 16  # max key length

# Model and training config
EPOCHS = 100
BATCH_SIZE = 128
LR = 5e-4
NUM_ENCODER_BLOCKS = 6
EMBED_DIM = 256
NHEAD = 8
HIDDEN_DIM = 512
N_SAMPLES = 10000 # Training samples
N_TEST_CASES = 500 # Evaluation samples
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Device: {DEVICE}")

# ============================================================
# ENGLISH CORPUS LOADING
# ============================================================
def get_english_corpus():
    try:
        import requests
        print("üìñ Downloading English corpus...")
        url = "https://www.gutenberg.org/cache/epub/1661/pg1661.txt"
        text = requests.get(url, timeout=20).text.upper()
        text = re.sub(r'[^A-Z ]', ' ', text)
        words = text.split()
        return ' '.join(words[:500000])
    except Exception as e:
        print("‚ö†Ô∏è Corpus download failed, using built-in fallback.")
        words = ['HELLO', 'WORLD', 'TEXT', 'MODEL', 'ENCRYPTION', 'TRAINING',
                 'VIGENERE', 'CIPHER', 'KEY', 'MESSAGE', 'DECRYPTION', 'ENGLISH']
        return ' '.join(words * 1000)

ENGLISH_CORPUS = get_english_corpus()
ENGLISH_WORDS = ENGLISH_CORPUS.split()

def random_text(n=50):
    """Generate random English-like plaintext from real words"""
    return ' '.join(random.choices(ENGLISH_WORDS, k=n))

# ============================================================
# VIGEN√àRE ENCRYPTION & DATA GENERATION
# ============================================================
def vigenere_encrypt(plaintext, key):
    ciphertext = ""
    key_indices = [A2I[k] for k in key]
    j = 0
    for ch in plaintext:
        if ch in A2I:
            shift = key_indices[j % len(key)]
            new_idx = (A2I[ch] + shift) % 26
            ciphertext += I2A[new_idx]
            j += 1
        else:
            ciphertext += ch
    return ciphertext

def generate_dataset(n_samples):
    data = []
    for _ in range(n_samples):
        plain = random_text(random.randint(10, 30))
        k_len = random.randint(2, K_MAX)
        key = ''.join(random.choices(CHARS, k=k_len))
        cipher = vigenere_encrypt(plain, key)
        data.append((cipher, key))
    return data

In [None]:
# ============================================================
# DATASET AND DATALOADER UTILITIES
# ============================================================
class VigenereDataset(Dataset):
    def __init__(self, pairs, max_len=500):
        self.pairs = pairs
        self.max_len = max_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        cipher, key = self.pairs[idx]
        c_enc = [A2I.get(ch, PAD_LABEL) for ch in cipher[:self.max_len]]
        k_enc = [A2I.get(ch, PAD_LABEL) for ch in key] + [PAD_LABEL]*(K_MAX-len(key))
        k_len = len(key)
        return torch.tensor(c_enc), torch.tensor(k_enc), torch.tensor(k_len)

def collate_fn(batch):
    ciphers, keys, key_lens = zip(*batch)
    max_len = max(len(c) for c in ciphers)
    ciphers = [F.pad(c, (0, max_len - len(c)), value=PAD_LABEL) for c in ciphers]
    return torch.stack(ciphers), torch.stack(keys), torch.tensor(key_lens)

# ============================================================
# MODEL DEFINITION
# ============================================================
class KeyRecoveryModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_blocks, nhead):
        super().__init__()
        self.embed = nn.Embedding(vocab_size+1, embed_dim, padding_idx=PAD_LABEL)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=nhead, dim_feedforward=hidden_dim, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_blocks)
        self.fc_len = nn.Linear(embed_dim, K_MAX+1)
        self.fc_key = nn.Linear(embed_dim, 27)

    def forward(self, x):
        x = self.embed(x)
        enc = self.encoder(x)
        pooled = enc.mean(dim=1)
        len_pred = self.fc_len(pooled)
        key_preds = self.fc_key(pooled.unsqueeze(1).repeat(1, K_MAX, 1))
        return len_pred, key_preds

In [None]:
# ============================================================
# TRAINING FUNCTION
# ============================================================
def train_model():
    print(f"\n--- Starting Training on {DEVICE} ({N_SAMPLES} samples, {EPOCHS} epochs) ---")
    dataset = generate_dataset(N_SAMPLES)
    loader = DataLoader(VigenereDataset(dataset), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

    model = KeyRecoveryModel(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, NUM_ENCODER_BLOCKS, NHEAD).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=LR, total_steps=len(loader)*EPOCHS)
    ce_loss = nn.CrossEntropyLoss()

    for epoch in range(1, EPOCHS+1):
        model.train()
        total_loss = 0
        for x, yk, yl in tqdm(loader, desc=f"Epoch {epoch}/{EPOCHS}"):
            x, yk, yl = x.to(DEVICE), yk.to(DEVICE), yl.to(DEVICE)
            opt.zero_grad()

            pred_len, pred_key = model(x)

            loss_len = ce_loss(pred_len, yl)
            loss_key = ce_loss(pred_key.view(-1, 27), yk.view(-1))
            loss = loss_len + loss_key

            loss.backward()
            opt.step()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"Epoch {epoch}/{EPOCHS} | Avg Loss={avg_loss:.4f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/model_epoch_{epoch}.pth")

    print("‚úÖ Training Complete")
    return model

# Execute Training
trained_model = train_model()

In [None]:
# ============================================================
# INFERENCE AND EVALUATION FUNCTIONS
# ============================================================
def predict_key(model, cipher):
    """Predicts a single key from a cipher string."""
    model.eval()
    with torch.no_grad():
        x = torch.tensor([A2I.get(ch, PAD_LABEL) for ch in cipher]).unsqueeze(0).to(DEVICE)
        pred_len, pred_key = model(x)
        key_len = pred_len.argmax(dim=1).item()
        key_chars = pred_key.squeeze(0).argmax(dim=1).tolist()[:key_len]
        predicted_key = ''.join(I2A[int(i)] for i in key_chars if int(i) != PAD_LABEL)
    return predicted_key

def evaluate_and_generate_csv(model, n_samples=N_TEST_CASES):
    """
    Evaluates the trained model on a new test dataset and generates a CSV.
    """
    print(f"\n--- Evaluation on {n_samples} Test Cases ---")

    # Generate Test Data and DataLoader
    test_dataset_pairs = generate_dataset(n_samples)
    test_loader = DataLoader(VigenereDataset(test_dataset_pairs), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

    model.eval()
    total_samples = 0
    correct_key_len_predictions = 0
    correct_full_key_predictions = 0
    results = []

    with torch.no_grad():
        for x, yk, yl in tqdm(test_loader, desc="Evaluating"):
            x, yk, yl = x.to(DEVICE), yk.to(DEVICE), yl.to(DEVICE)
            batch_size = x.size(0)
            total_samples += batch_size

            pred_len, pred_key = model(x)
            predicted_len_indices = pred_len.argmax(dim=1).cpu().tolist()
            predicted_key_indices = pred_key.argmax(dim=2).cpu().tolist()

            true_key_indices = yk.cpu().tolist()
            true_len_indices = yl.cpu().tolist()

            for i in range(batch_size):
                true_len = true_len_indices[i]

                # --- Key Length Accuracy ---
                pred_len_i = predicted_len_indices[i]
                len_is_correct = (pred_len_i == true_len)
                correct_key_len_predictions += len_is_correct

                # --- Key Content and Full Key Recovery ---
                pred_key_str = ''.join(I2A[idx] for idx in predicted_key_indices[i][:pred_len_i])
                true_key_str = ''.join(I2A[idx] for idx in true_key_indices[i][:true_len])

                full_key_is_correct = len_is_correct and (pred_key_str == true_key_str)
                correct_full_key_predictions += full_key_is_correct

                # Store result row
                results.append({
                    'ciphertext': ''.join(I2A[idx] for idx in x[i].cpu().tolist() if idx != PAD_LABEL),
                    'predicted_key_length': pred_len_i,
                    'correct_key_length': true_len,
                    'predicted_key': pred_key_str,
                    'correct_key': true_key_str,
                    # Optional: Add boolean columns for easy filtering
                    'length_match': len_is_correct,
                    'key_match': full_key_is_correct
                })

    # Calculate final metrics
    len_accuracy = correct_key_len_predictions / total_samples
    key_accuracy = correct_full_key_predictions / total_samples

    # Save to CSV
    df = pd.DataFrame(results)
    filename = "vigenere_model_evaluation_results.csv"
    df.to_csv(filename, index=False)

    print("\n‚úÖ Evaluation and CSV Generation Complete")
    print("---------------------------------")
    print(f"Results saved to: **{filename}**")
    print(f"Total Test Samples: {total_samples}")
    print(f"Key Length Accuracy: {len_accuracy * 100:.2f}%")
    print(f"Full Key Recovery Accuracy: {key_accuracy * 100:.2f}%")
    print("---------------------------------")

    return df

# --- EXECUTION ---

# 1. Evaluate the model and generate CSV
evaluation_df = evaluate_and_generate_csv(trained_model)

# 2. Example Test
test_cipher = "KCCPKBGUFDPHQTYAVINRRTMVGRKDNBVFDETDGIL"
key_guess = predict_key(trained_model, test_cipher)
print(f"\nüîë Prediction for Example Cipher: {test_cipher}")
print(f"Predicted key: {key_guess}")