In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import time
from IPython.display import clear_output
from torch.utils.data import TensorDataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
random.seed(42)

device

device(type='cuda')

# **Utils**

In [2]:
def encrypt(text, shift, alphabet):
    new_alphabet = alphabet[shift % len(alphabet):] + alphabet[:shift % len(alphabet)]
    result = ''
    for char in text:
        if char in alphabet:
            result += new_alphabet[alphabet.index(char)]
        else:
            result += char
    return result

def decrypt(text, shift, alphabet):
    return encrypt(text, -shift, alphabet)

def generate_data(alphabet, max_len=10, shift=3):
    length = random.randint(5, max_len)
    text = ''.join(random.choice(alphabet) for _ in range(length))
    encrypted = encrypt(text, shift, alphabet=alphabet)
    return {'text': text, 'encrypted': encrypted, 'length': length}

def text_to_one_hot(text, alphabet):
    seq_len = len(text)
    one_hot = np.zeros((seq_len, len(alphabet)), dtype=np.float32)
    for i, char in enumerate(text):
        if char in alphabet:
            one_hot[i, alphabet.index(char)] = 1
    return torch.tensor(one_hot)

def dataset_to_tensors(dataset, alphabet):
    max_len = max(pair['length'] for pair in dataset)
    X = torch.zeros((len(dataset), max_len, len(alphabet)), dtype=torch.float32)
    y = torch.zeros((len(dataset), max_len, len(alphabet)), dtype=torch.float32)
    
    for i, pair in enumerate(dataset):
        X[i, :len(pair['encrypted'])] = text_to_one_hot(pair['encrypted'], alphabet)
        y[i, :len(pair['text'])] = text_to_one_hot(pair['text'], alphabet)
    
    return X, y

class CaesarNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CaesarNet, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size)
    
    def forward(self, sentence):
        o, _ = self.rnn(sentence)
        return self.linear(o)

    def train_model(self, train_loader, alphabet_size, lr, epochs=10):
        self.to(device)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=1e-4)
        start = time.time()
        
        for epoch in range(epochs):
            self.train()
            train_loss, train_acc, train_iter_num = 0., 0., 0.
            
            for X_batch, y_batch in train_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)
                
                optimizer.zero_grad()
                output = self(X_batch)
                y_indices = torch.argmax(y_batch, dim=-1)
                
                loss = criterion(output.view(-1, alphabet_size), y_indices.view(-1))
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                batch_acc = (output.argmax(dim=-1) == y_indices).float().mean().item()
                train_acc += batch_acc
                train_iter_num += 1
            
            clear_output(wait=True)
            print(
                f"Epoch: {epoch+1}, loss: {train_loss/train_iter_num:.4f}, acc: "
                f"{train_acc/train_iter_num:.4f}, "
                f"{time.time() - start:.2f} sec."
            )
        
        return self

    def predict(self, text, alphabet):
        self.eval()
        best_shift = 0
        best_score = -1.
        best_pred_text = ""
        
        for shift in range(len(alphabet)):
            encrypted = encrypt(text, shift, alphabet)
            X_test = text_to_one_hot(encrypted, alphabet)[None, :]
            with torch.no_grad():
                X_test = X_test.to(device)
                pred = self(X_test)
            pred_text = ''.join(alphabet[torch.argmax(char).item()] for char in pred[0][:len(text)])
            
            matches = sum(1 for c1, c2 in zip(text, pred_text) if c1 == c2)
            score = matches / len(text) if len(text) > 0 else 0.
            
            if score > best_score:
                best_score = score
                best_shift = shift
                best_pred_text = pred_text
        
        encrypted = encrypt(text, best_shift, alphabet)
        return encrypted, best_pred_text, best_shift

# **Creation Alphabet**

In [3]:
alphabet = 'abcdefghijklmnopqrstuvwxyz' + 'abcdefghijklmnopqrstuvwxyz'.upper() + 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя' + 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'.upper() + r'0123456789,. ?|/\[]()";:'

# **Variables**

In [4]:
alphabet_size = len(alphabet)
input_size = output_size = alphabet_size
hidden_size = 128
batch_size = 32
learning_rate = .001
num_epochs = 10
shift = random.randint(-3555456654575, 3555456654575) % alphabet_size

In [5]:
shift

88

# **Creation initial dataset**

In [6]:
dataset_init = [generate_data(alphabet, max_len=10, shift=shift) for _ in range(10000)]
X, y = dataset_to_tensors(dataset_init, alphabet)

In [7]:
dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# **Creation Model**

In [8]:
model = CaesarNet(input_size, hidden_size, output_size)
model

CaesarNet(
  (rnn): RNN(142, 128, batch_first=True)
  (linear): Linear(in_features=128, out_features=142, bias=True)
)

# **Training model**

In [9]:
model = model.train_model(train_loader, alphabet_size, learning_rate, num_epochs)

Epoch: 10, loss: 0.0201, acc: 1.0000, 4.46 sec.


# **Experiment 1: with bruteforce shift**

## **Testing Results**

In [10]:
text = "Привет мир, это я DJ 23"
encrypted, predicted, learning_shift = model.predict(text, alphabet)
print(f"Original: {text}")
print(f"Encrypted: {encrypted}")
print(f"Predicted: {predicted}")
print(f"Learning Shift: {learning_shift}")
print(f"Initial Shift: {shift}")

Original: Привет мир, это я DJ 23
Encrypted: VphadrчlhpхчCrnчEчЯ5чно
Predicted: Привет мир, это я DJ 23
Learning Shift: 88
Initial Shift: 88


In [11]:
text = dataset_init[9999]['text']
encrypted, predicted, learning_shift = model.predict(text, alphabet)
print(f"Original: {text}")
print(f"Encrypted: {encrypted}")
print(f"Predicted: {predicted}")
print(f"Learning Shift: {learning_shift}")
print(f"Initial Shift: {shift}")

Original: cV8бVG
Encrypted: Е[у:[2
Predicted: cV8бVG
Learning Shift: 88
Initial Shift: 88


# **Experiment 2: with trainable shift**

In [12]:
class CaesarNetTrainShift(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CaesarNetTrainShift, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size)
        self.shift = nn.Parameter(torch.tensor(float(random.randint(0, input_size - 1))))
        self.alphabet_size = input_size

    def forward(self, sentence):
        normalized_shift = torch.sigmoid(self.shift) * self.alphabet_size
        shift = torch.round(normalized_shift).long() % self.alphabet_size
        shifted_sentence = torch.roll(sentence, shifts=shift.item(), dims=-1)
        o, _ = self.rnn(shifted_sentence)
        return self.linear(o)

    def train_model(self, train_loader, alphabet_size, lr, epochs=10):
        self.to(device)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam([
        {'params': [self.shift], 'lr': .01},
        {'params': [p for p in self.parameters() if p is not self.shift], 'lr': lr}
    ], weight_decay=1e-4)
        start = time.time()
        
        for epoch in range(epochs):
            self.train()
            train_loss, train_acc, train_iter_num = 0., 0., 0.
            
            for X_batch, y_batch in train_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)
                
                optimizer.zero_grad()
                output = self(X_batch)
                y_indices = torch.argmax(y_batch, dim=-1)
                
                loss = criterion(output.view(-1, alphabet_size), y_indices.view(-1)) + .01 * torch.abs(self.shift)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                batch_acc = (output.argmax(dim=-1) == y_indices).float().mean().item()
                train_acc += batch_acc
                train_iter_num += 1
            
            clear_output(wait=True)
            print(
                f"Epoch: {epoch+1}, loss: {train_loss/train_iter_num:.4f}, acc: "
                f"{train_acc/train_iter_num:.4f},"
                f"{time.time() - start:.2f} sec."
            )
        
        return self

    def predict(self, text, alphabet):
        self.eval()

        normalized_shift = torch.sigmoid(self.shift) * len(alphabet)
        shift = int(round(normalized_shift.item())) % len(alphabet)
        encrypted = encrypt(text=text, shift=shift, alphabet=alphabet)
        X_test = text_to_one_hot(encrypted, alphabet)[None, :]
        with torch.no_grad():
            X_test = X_test.to(device)
            pred = self(X_test)
        pred_text = ''.join(alphabet[torch.argmax(char).item()] for char in pred[0])
        return encrypted, pred_text, shift

In [13]:
model1 = CaesarNetTrainShift(input_size, hidden_size, output_size)
model1

CaesarNetTrainShift(
  (rnn): RNN(142, 128, batch_first=True)
  (linear): Linear(in_features=128, out_features=142, bias=True)
)

In [14]:
model1 = model1.train_model(train_loader, alphabet_size, learning_rate, num_epochs)

Epoch: 10, loss: 0.0202, acc: 1.0000,5.10 sec.


In [15]:
model1.shift

Parameter containing:
tensor(0.0005, device='cuda:0', requires_grad=True)

In [16]:
text = "Привет мир, это я DJ 23"
encrypted, predicted, learning_shift = model1.predict(text, alphabet)
print(f"Original: {text}")
print(f"Encrypted: {encrypted}")
print(f"Predicted: {predicted}")
print(f"Learning Shift: {learning_shift}")
print(f"Initial Shift: {shift}")

Original: Привет мир, это я DJ 23
Encrypted: E;|7,aж]|;ежla)жnжОФжXY
Predicted: яаSLOвЫWSаЩЫмвYЫоЫmsЫСТ
Learning Shift: 71
Initial Shift: 88
