In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import string
import random
import math
from tqdm.notebook import tqdm

print("‚úÖ Setting up configuration...")

VOCAB = string.printable
VOCAB_SIZE = len(VOCAB)
CHAR_TO_IDX = {ch: i for i, ch in enumerate(VOCAB)}
IDX_TO_CHAR = {i: ch for i, ch in enumerate(VOCAB)}

MAX_LEN = 256
D_MODEL = 512
NHEAD = 8
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
DIM_FEEDFORWARD = 2048
DROPOUT = 0.1

LEARNING_RATE = 1e-4
BATCH_SIZE = 32
EPOCHS = 10
STEPS_PER_EPOCH = 500
MARGIN = 20.0
MODEL_SAVE_PATH = "energy_transformer_vigenere.pth"


INFERENCE_STEPS = 250
INFERENCE_LR = 0.1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚ñ∂Ô∏è Using device: {device}")
if not torch.cuda.is_available():
    print("‚ö†Ô∏è WARNING: GPU not found. Training will be extremely slow.")

‚úÖ Setting up configuration...
‚ñ∂Ô∏è Using device: cuda


In [None]:
print("‚úÖ Defining data generation functions and model architecture...")

def vigenere_encrypt(plaintext, key):
    ciphertext = []
    for p_char, k_char in zip(plaintext, key):
        p_idx = CHAR_TO_IDX.get(p_char, 0)
        k_idx = CHAR_TO_IDX.get(k_char, 0)
        c_idx = (p_idx + k_idx) % VOCAB_SIZE
        ciphertext.append(IDX_TO_CHAR.get(c_idx, ''))
    return "".join(ciphertext)

def vigenere_decrypt(ciphertext, key):
    plaintext = []
    for c_char, k_char in zip(ciphertext, key):
        c_idx = CHAR_TO_IDX.get(c_char, 0)
        k_idx = CHAR_TO_IDX.get(k_char, 0)
        p_idx = (c_idx - k_idx + VOCAB_SIZE) % VOCAB_SIZE
        plaintext.append(IDX_TO_CHAR.get(p_idx, ''))
    return "".join(plaintext)

def generate_random_text(length):
    return "".join(random.choice(VOCAB) for _ in range(length))

def corrupt_plaintext(plaintext, corruption_rate=0.2):
    pt_list = list(plaintext)
    num_corruptions = int(len(pt_list) * corruption_rate)
    indices_to_corrupt = random.sample(range(len(pt_list)), num_corruptions)
    for idx in indices_to_corrupt:
        pt_list[idx] = random.choice(VOCAB)
    return "".join(pt_list)

def generate_training_sample():
    length = random.randint(64, MAX_LEN)
    plaintext = generate_random_text(length)
    key = generate_random_text(length)
    ciphertext = vigenere_encrypt(plaintext, key)

    neg_plaintext_corrupted = corrupt_plaintext(plaintext)
    wrong_key = generate_random_text(length)
    while wrong_key == key:
        wrong_key = generate_random_text(length)
    neg_plaintext_wrong_key = vigenere_decrypt(ciphertext, wrong_key)

    return {
        "ciphertext": ciphertext,
        "positive_plaintext": plaintext,
        "negative_plaintext_corrupted": neg_plaintext_corrupted,
        "negative_plaintext_wrong_key": neg_plaintext_wrong_key
    }

def get_batch(batch_size):
    while True:
        yield [generate_training_sample() for _ in range(batch_size)]

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=DROPOUT, max_len=MAX_LEN):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class EnergyTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.d_model = D_MODEL
        self.embedding = nn.Embedding(VOCAB_SIZE, D_MODEL)
        self.pos_encoder = PositionalEncoding(D_MODEL)
        self.transformer = nn.Transformer(
            d_model=D_MODEL, nhead=NHEAD, num_encoder_layers=NUM_ENCODER_LAYERS,
            num_decoder_layers=NUM_DECODER_LAYERS, dim_feedforward=DIM_FEEDFORWARD,
            dropout=DROPOUT, batch_first=True
        )
        self.energy_head = nn.Linear(D_MODEL, 1)

    def forward(self, ciphertext_tokens, plaintext_tokens):
        ct_embed = self.pos_encoder(self.embedding(ciphertext_tokens) * math.sqrt(self.d_model))
        pt_embed = self.pos_encoder(self.embedding(plaintext_tokens) * math.sqrt(self.d_model))
        output = self.transformer(src=ct_embed, tgt=pt_embed)
        pooled_output = output.mean(dim=1)
        return self.energy_head(pooled_output).squeeze(-1)

print("‚ñ∂Ô∏è Model and data functions are ready.")

‚úÖ Defining data generation functions and model architecture...
‚ñ∂Ô∏è Model and data functions are ready.


In [None]:
print("‚úÖ Preparing for training...")

model = EnergyTransformer().to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
data_loader = get_batch(BATCH_SIZE)

def texts_to_tensors(texts):
    tensors = torch.zeros(len(texts), MAX_LEN, dtype=torch.long, device=device)
    for i, text in enumerate(texts):
        tokenized = [CHAR_TO_IDX.get(c, 0) for c in text]
        length = min(len(tokenized), MAX_LEN)
        tensors[i, :length] = torch.tensor(tokenized[:length], device=device)
    return tensors

model.train()
print("üöÄ Starting training...")

for epoch in range(EPOCHS):
    total_loss = 0.0
    progress_bar = tqdm(range(STEPS_PER_EPOCH), desc=f"Epoch {epoch+1}/{EPOCHS}")

    for step in progress_bar:
        batch = next(data_loader)

        ct = texts_to_tensors([item['ciphertext'] for item in batch])
        pt_pos = texts_to_tensors([item['positive_plaintext'] for item in batch])
        pt_neg1 = texts_to_tensors([item['negative_plaintext_corrupted'] for item in batch])
        pt_neg2 = texts_to_tensors([item['negative_plaintext_wrong_key'] for item in batch])

        optimizer.zero_grad()

        energy_positive = model(ct, pt_pos)
        energy_negative1 = model(ct, pt_neg1)
        energy_negative2 = model(ct, pt_neg2)

        loss_neg1 = torch.relu(MARGIN - energy_negative1)
        loss_neg2 = torch.relu(MARGIN - energy_negative2)
        loss = energy_positive.mean() + loss_neg1.mean() + loss_neg2.mean()

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}", avg_loss=f"{total_loss / (step + 1):.4f}")

    avg_loss = total_loss / STEPS_PER_EPOCH
    print(f"‚úÖ Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")

    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"üíæ Model saved to {MODEL_SAVE_PATH}")

print("üéâ Training complete!")

In [None]:
print("‚úÖ Preparing for inference...")

inference_model = EnergyTransformer().to(device)
try:
    inference_model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device))
    inference_model.eval()
    print("‚ñ∂Ô∏è Trained model loaded successfully.")
except FileNotFoundError:
    print(f"‚ùå Error: Model file not found at {MODEL_SAVE_PATH}. Please train the model first.")

def text_to_tensor(text):
    tokens = [CHAR_TO_IDX.get(c, 0) for c in text]
    return torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)

def embeddings_to_text(embeddings, model_embedding_layer):
    vocab_embeddings = model_embedding_layer.weight.detach()
    dist = torch.cdist(embeddings.squeeze(0).detach(), vocab_embeddings, p=2)
    closest_indices = torch.argmin(dist, dim=1)
    return "".join([IDX_TO_CHAR.get(idx.item(), '?') for idx in closest_indices])

def crack_cipher(ciphertext, model):
    print(f"\n--- Cracking Ciphertext ---")
    print(f"Input ({len(ciphertext)} chars): {ciphertext[:80]}...")

    ct_len = min(len(ciphertext), MAX_LEN)
    ciphertext = ciphertext[:ct_len]

    with torch.no_grad():
        ct_tensor = text_to_tensor(ciphertext)
        ct_embed = model.pos_encoder(model.embedding(ct_tensor) * math.sqrt(model.d_model))
        encoder_output = model.transformer.encoder(ct_embed)

    pt_guess_embed = torch.randn(1, ct_len, model.d_model, device=device, requires_grad=True)
    optimizer = torch.optim.AdamW([pt_guess_embed], lr=INFERENCE_LR)

    for step in tqdm(range(INFERENCE_STEPS), desc="üîç Searching for plaintext"):
        optimizer.zero_grad()
        pt_guess_with_pos = model.pos_encoder(pt_guess_embed)
        decoder_output = model.transformer.decoder(tgt=pt_guess_with_pos, memory=encoder_output)
        pooled_output = decoder_output.mean(dim=1)
        energy = model.energy_head(pooled_output)
        energy.backward()
        optimizer.step()

    final_plaintext = embeddings_to_text(pt_guess_embed, model.embedding)
    print(f"Result ({len(final_plaintext)} chars): {final_plaintext[:80]}...")
    return final_plaintext

if 'inference_model' in locals() and isinstance(inference_model, nn.Module):
    plaintext_secret = "This is a final project for my cybersecurity class. The goal is to identify ciphers. We are using an energy-based transformer model to see if it can learn the patterns of a Vigenere cipher."
    key_secret = generate_random_text(len(plaintext_secret))
    ciphertext_to_crack = vigenere_encrypt(plaintext_secret, key_secret)

    cracked_text = crack_cipher(ciphertext_to_crack, inference_model)

    print("\n--- üìä Evaluation ---")
    print(f"Original Plaintext : {plaintext_secret}")
    print(f"Cracked Plaintext  : {cracked_text}")

    correct_chars = sum(1 for a, b in zip(plaintext_secret, cracked_text) if a == b)
    accuracy = (correct_chars / len(plaintext_secret)) * 100
    print(f"Character-level Accuracy: {accuracy:.2f}%")