In [None]:
! pip install pycryptodome
import os, math, random, time
import numpy as np
from tqdm import tqdm
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from Crypto.Cipher import AES, DES
from Crypto.Util.Padding import pad
from torch.cuda.amp import GradScaler
import matplotlib.pyplot as plt

Collecting pycryptodome
  Downloading pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Downloading pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m32.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pycryptodome
Successfully installed pycryptodome-3.23.0


### Settings

In [None]:
# settings
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
samples_per_class = 50000
n_classes = 4
batch_size = 256
num_workers = 4
hist_bins = 64
dropout = 0.2

# parameters
max_len = 512             # Increased from 256 to 512
d_model = 256
nlayers = 6               # Increased from 4 to 6
nhead = 8

#Using Semi-Static Data Generation
num_parts = 4
epochs_per_part = 25
num_epochs = num_parts * epochs_per_part

#Patience settings
patience = 10
scheduler_patience = 4

### Cipher Implementations

In [None]:
# cipher implementations
def random_plaintext(min_len=128, max_len=256):
    return os.urandom(random.randint(min_len, max_len))

def random_key_for(cipher_name):
    if cipher_name == 'AES': return os.urandom(16)
    if cipher_name == 'DES': return os.urandom(8)
    if cipher_name == 'Speck': return os.urandom(8)
    if cipher_name == 'Vigenere': return bytes(random.randint(1,255) for _ in range(random.randint(3,12)))

def aes_encrypt(pt, key):
    cipher = AES.new(key[:16].ljust(16, b'\x00'), AES.MODE_CBC)
    return cipher.encrypt(pad(pt, AES.block_size))

def des_encrypt(pt, key):
    cipher = DES.new(key[:8].ljust(8, b'\x00'), DES.MODE_CBC)
    return cipher.encrypt(pad(pt, DES.block_size))

def speck_encrypt_block(plain_block, key_words, rounds=22, word_bits=16):
    mask = (1 << word_bits) - 1
    x, y = plain_block
    for k in key_words:
        x = (((x >> 8) | ((x << (word_bits - 8)) & mask)) + y) & mask
        x ^= k
        y = (((y << 3) & mask) | (y >> (word_bits - 3))) ^ x
    return x, y

def speck_encrypt(pt, key, block_size=4):
    out = bytearray()
    kw = []
    for i in range(0, len(key), 2):
        w = (key[i] << 8) | (key[i+1] if i+1 < len(key) else 0)
        kw.append(w & 0xFFFF)
    pad_len = (-len(pt)) % block_size
    pt = pt + bytes([0]*pad_len)
    for i in range(0, len(pt), block_size):
        a = pt[i:i+block_size//2]; b = pt[i+block_size//2:i+block_size]
        x = int.from_bytes(a, 'big'); y = int.from_bytes(b, 'big')
        x2, y2 = speck_encrypt_block((x, y), kw)
        out.extend(x2.to_bytes(block_size//2, 'big')); out.extend(y2.to_bytes(block_size//2, 'big'))
    return bytes(out)

def vigenere_encrypt(pt, key):
    return bytes((b + key[i % len(key)]) & 0xFF for i, b in enumerate(pt))

CIPHER_NAMES = ['AES','DES','Speck','Vigenere']

def generate_example_by_name(name):
    pt = random_plaintext()
    key = random_key_for(name)
    if name == 'AES': return aes_encrypt(pt, key)
    elif name == 'DES': return des_encrypt(pt, key)
    elif name == 'Speck': return speck_encrypt(pt, key, block_size=4)
    elif name == 'Vigenere': return vigenere_encrypt(pt, key)
    else: raise ValueError(name)

### Dataset Generation and Dataloaders

In [None]:
# dataset generation
def build_dataset_part(samples_per_class, max_len, hist_bins, filename):
    if os.path.exists(filename):
        print(f"Loading existing dataset from {filename}")
        data = torch.load(filename)
        return data["X"], data["y"], data["H"]

    X_list, y_list, H_list = [], [], []
    for label, name in enumerate(CIPHER_NAMES):
        print(f"Generating {samples_per_class} samples for {name} for file {filename}...")
        for _ in tqdm(range(samples_per_class)):
            ct = generate_example_by_name(name)
            if len(ct) > max_len:
                ct = ct[:max_len]
            else:
                ct = ct + bytes([0]*(max_len - len(ct)))
            arr = np.frombuffer(ct, dtype=np.uint8)
            #histogram feature to helpwith identifying Vignere
            factor = 256 // hist_bins
            hist = np.zeros(hist_bins, dtype=np.float32)
            for b in range(hist_bins):
                start, end = b*factor, (b+1)*factor
                hist[b] = np.sum((arr >= start) & (arr < end))
            hist /= (hist.sum() + 1e-12)
            X_list.append(arr)
            y_list.append(label)
            H_list.append(hist)

    X = torch.tensor(np.stack(X_list), dtype=torch.long)
    y = torch.tensor(np.array(y_list), dtype=torch.long)
    H = torch.tensor(np.stack(H_list), dtype=torch.float32)
    torch.save({"X": X, "y": y, "H": H}, filename)
    print("Saved dataset to", filename)
    return X, y, H

# PyTorch Dataset and Dataloader Creation
class CipherDataset(Dataset):
    def __init__(self, X, y, H):
        self.X, self.y, self.H = X, y, H
    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.y[idx], self.H[idx]

def create_dataloaders(X, y, H, batch_size, num_workers):
    dataset = CipherDataset(X, y, H)
    n_total = len(dataset)
    n_train = int(0.8 * n_total)
    n_val = int(0.1 * n_total)
    n_test = n_total - n_train - n_val

    train_ds, val_ds, test_ds = torch.utils.data.random_split(dataset, [n_train, n_val, n_test])
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_dl, val_dl, test_dl

### Model Definition

In [None]:
#model
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class EnergyTransformer(nn.Module):
    def __init__(self, n_tokens=256, d_model=256, nhead=8, nlayers=6, dim_feedforward=512, n_classes=4, label_emb_dim=64, hist_dim=64, dropout=0.2):
        super().__init__()
        self.embed = nn.Embedding(n_tokens, d_model)
        self.cls_token = nn.Parameter(torch.randn(1,1,d_model))
        self.pos = PositionalEncoding(d_model, max_len+1)
        layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, activation='gelu', dropout=dropout, batch_first=True)
        self.trans = nn.TransformerEncoder(layer, nlayers)
        self.label_emb = nn.Embedding(n_classes, label_emb_dim)
        input_dim = d_model + label_emb_dim + hist_dim
        self.energy_net = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )
    def forward(self, x, hist):
        B = x.size(0)
        x = self.embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = self.pos(x)
        x = self.trans(x)
        cls_rep = x[:, 0, :]
        labels = self.label_emb(torch.arange(self.label_emb.num_embeddings, device=cls_rep.device))
        rep_exp = cls_rep.unsqueeze(1).expand(-1, labels.size(0), -1)
        lab_exp = labels.unsqueeze(0).expand(B, -1, -1)
        hist_exp = hist.unsqueeze(1).expand(-1, labels.size(0), -1)
        cat = torch.cat([rep_exp, lab_exp, hist_exp], dim=-1)
        energies = self.energy_net(cat).squeeze(-1)
        return energies
    def predict_logits(self, x, hist):
        return -self.forward(x, hist)

### Training Helpers

In [None]:
# Training helpers
def train_epoch(model, dl, opt, scaler, device):
    model.train()
    total, correct, loss_sum = 0, 0, 0
    for xb, yb, hist in tqdm(dl, desc="Training"):
        xb, yb, hist = xb.to(device), yb.to(device), hist.to(device)
        opt.zero_grad()
        with torch.cuda.amp.autocast():
            logits = model.predict_logits(xb, hist)
            loss = F.cross_entropy(logits, yb)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        loss_sum += loss.item() * xb.size(0)
        correct += (logits.argmax(1) == yb).sum().item()
        total += xb.size(0)
    return loss_sum/total, correct/total

@torch.no_grad()
def eval_model(model, dl, device):
    model.eval()
    total, correct, loss_sum = 0, 0, 0
    for xb, yb, hist in dl:
        xb, yb, hist = xb.to(device), yb.to(device), hist.to(device)
        logits = model.predict_logits(xb, hist)
        loss = F.cross_entropy(logits, yb)
        loss_sum += loss.item() * xb.size(0)
        correct += (logits.argmax(1) == yb).sum().item()
        total += xb.size(0)
    return loss_sum/total, correct/total

### Main Training Execution

In [None]:
# Main Training Execution
model = EnergyTransformer(
    n_tokens=257, # 256 for bytes + 1 for padding if you use it
    d_model=d_model, nhead=nhead, nlayers=nlayers,
    dim_feedforward=512, n_classes=n_classes, label_emb_dim=64,
    hist_dim=hist_bins, dropout=dropout
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='max', factor=0.5, patience=scheduler_patience, verbose=True)
scaler = GradScaler()

best_val = 0
patience_counter = 0
train_losses, val_accs = [], []
global_epoch = 0
stop_training = False

for part in range(1, num_parts + 1):
    if stop_training:
        break

    print(f"\n{'='*20} TRAINING PART {part}/{num_parts} {'='*20}")

    dataset_file = f"cipher_dataset_{max_len}len_part_{part}.pt"
    X, y, H = build_dataset_part(samples_per_class, max_len, hist_bins, dataset_file)
    train_dl, val_dl, test_dl = create_dataloaders(X, y, H, batch_size, num_workers)

    for epoch in range(1, epochs_per_part + 1):
        global_epoch += 1
        print(f"\n--- Part {part}, Epoch {epoch}/{epochs_per_part} (Global Epoch {global_epoch}) ---")

        tr_loss, tr_acc = train_epoch(model, train_dl, opt, scaler, device)
        val_loss, val_acc = eval_model(model, val_dl, device)
        scheduler.step(val_acc)

        train_losses.append(tr_loss)
        val_accs.append(val_acc)

        print(f"Global Epoch {global_epoch:02d} | Train Loss {tr_loss:.4f} | Train Acc {tr_acc:.3f} | Val Acc {val_acc:.3f} | Best {best_val:.3f}")

        if val_acc > best_val:
            best_val = val_acc
            torch.save(model.state_dict(), "best_energy_transformer.pt")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after {global_epoch} epochs.")
                stop_training = True
                break

### Final Test and Plotting

In [None]:
# -----------------------
# Final Test + Plot
# -----------------------
print("\n--- Final Evaluation on Test Set ---")
model.load_state_dict(torch.load("best_energy_transformer.pt"))
test_loss, test_acc = eval_model(model, test_dl, device)
print(f"Test Accuracy: {test_acc:.4f} | Test Loss: {test_loss:.4f}")

# Plotting Training Loss
fig1, ax1 = plt.subplots(figsize=(10,6))
ax1.set_xlabel("Global Epoch")
ax1.set_ylabel("Train Loss", color='tab:red')
ax1.plot(range(1, global_epoch + 1), train_losses, color='tab:red', label='Train Loss')
ax1.tick_params(axis='y', labelcolor='tab:red')
ax1.grid(True, which='both', linestyle='--', linewidth=0.5)
ax1.set_title("Training Loss over Global Epochs")
fig1.tight_layout()
plt.show()

# Plotting Validation Accuracy
fig2, ax2 = plt.subplots(figsize=(10,6))
ax2.set_xlabel("Global Epoch")
ax2.set_ylabel("Val Accuracy", color='tab:blue')
ax2.plot(range(1, global_epoch + 1), val_accs, color='tab:blue', label='Val Accuracy')
ax2.tick_params(axis='y', labelcolor='tab:blue')
ax2.grid(True, which='both', linestyle='--', linewidth=0.5)
ax2.set_title("Validation Accuracy over Global Epochs")
fig2.tight_layout()
plt.show()