In [2]:
import torch
import torch.utils.data
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Trainer, TrainingArguments, AutoConfig, T5Model, T5Tokenizer
import random

In [9]:
sumlen = 0
lines = 0
lineslongerthan256 = 0
with open ("news.2012.en.shuffled.deduped", "r") as myfile:
    for line in myfile:
        sumlen += len(line)
        lines += 1
        if len(line) > 200:
            lineslongerthan256 += 1
        # print(len(line))
        
print(sumlen)
print(sumlen/lines)
print(lineslongerthan256/lines)

1736552179
120.5360691179489
0.1335448846105712


In [17]:
import torch
from torch.utils.data import Dataset
from typing import List, Tuple
from transformers import PreTrainedTokenizerFast

class EnigmaDataset(Dataset):
    def __init__(self, data_file: str, encrypt_function, tokenizer: PreTrainedTokenizerFast, max_length: int = 512, max_size:int = 10000) -> None:
        self.data = []
        self.encrypt_function = encrypt_function
        self.tokenizer = tokenizer
        self.max_length = max_length

        with open(data_file, "r") as file:
            i = 0
            for line in file:
                text = line.strip()
                encrypted_text = self.encrypt_function(text)
                tokenized_encrypted_text = self.tokenizer.encode(encrypted_text, max_length=self.max_length, padding='max_length', truncation=True)
                tokenized_text = self.tokenizer.encode(text, max_length=self.max_length, padding='max_length', truncation=True)
                self.data.append((tokenized_encrypted_text, tokenized_text))
                i+=1
                if i>=max_size:
                    break

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
        return self.data[index]

In [18]:
from typing import List, Dict
from transformers.data.data_collator import DataCollator

class CustomDataCollator(object):
    def __call__(self, batch: List[Tuple[List[int], List[int]]]) -> Dict[str, torch.Tensor]:
        input_ids, labels = zip(*batch)
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        labels = torch.tensor(labels, dtype=torch.long)
        return {"input_ids": input_ids, "labels": labels}


In [13]:
from enigma.machine import EnigmaMachine

# set machine initial starting position
def encrypt_all_the_same(text):
    machine = EnigmaMachine.from_key_sheet(
       rotors='I II III',
       reflector='B',
       ring_settings=[0, 0, 0],
       plugboard_settings=None)
    start_display = 'ABC'
    machine.set_display(start_display)
    return f"{start_display}{machine.process_text(text)}"

In [19]:
from transformers import ByT5Tokenizer
data_file = "news.2012.en.shuffled.deduped"
tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
enigma_dataset = EnigmaDataset(data_file, encrypt_all_the_same, tokenizer, max_length=200)

# Access an example pair from the dataset
encrypted_text, original_text = enigma_dataset[0]

: 

In [None]:
from torch.utils.data import random_split
from transformers import T5ForConditionalGeneration, T5Config, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq

train_dataset, val_dataset, test = random_split(enigma_dataset, [int(len(enigma_dataset)*0.8), int(len(enigma_dataset)*0.1), int(len(enigma_dataset)*0.1)])

config = T5Config.from_pretrained("google/byt5-small")
config.tie_word_embeddings = False

model = T5ForConditionalGeneration(config)

# Create training arguments and data collator
training_args = Seq2SeqTrainingArguments(
    output_dir="byt5_output",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_steps=10_000,
    save_total_limit=2,
    evaluation_strategy="epoch",
)

data_collator = CustomDataCollator()

# Create trainer and train the model
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()



In [None]:
# evaluate on the test set
trainer.evaluate(test_dataset=test)

model = T5ForConditionalGeneration.from_pretrained("byt5_output/checkpoint-10000")
for i in range(10):
    tokenized_encrypted_text, tokenized_gold_label = test[i]
    input_ids = torch.tensor(tokenized_encrypted_text, dtype=torch.long).unsqueeze(0)

    with torch.no_grad():
        outputs = trainer.generate(input_ids)

    predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    gold_label = tokenizer.decode(tokenized_gold_label, skip_special_tokens=True)

    print(f"Example {i + 1}:")
    print(f"Predicted: {predicted_text}")
    print(f"Gold Label: {gold_label}")
    print("=" * 80)