# Enigma with a constant setting on german news dataset

In [None]:
# import src to path
import sys
import os

sys.path.append("./enigma-transformed/src")
sys.path.append("./src")
sys.path.append("../src")
sys.path.append("../../src")

if __name__ == "__main__":
    # try get SLURM JOB ID
    try:
        job_id = os.environ["SLURM_JOB_ID"]
    except:
        job_id = "debug"
    logdir = f"logs/slurm_{job_id}"
    os.makedirs(logdir, exist_ok=True)


# ## Setup and hyperparameters

In [None]:
from utils import calculate_batch_size
dataset_size = 100000
dataset_min_len = 200
dataset_max_len = 200
dataset_exclude_len = 50 # don't train and eval on sentences shorter than this
seed = 39  # reproducible
evaluate_on_test = True 
device = 'cuda:0'
train_epochs = 40
lr = 2e-3
warmup_ratio = .2

tartget_batch_size = 160
batch_size, grad_acc_steps = calculate_batch_size(tartget_batch_size, dataset_max_len)




## Data

In [None]:
# 0. (optional) get data and preprocess it
import os
import src.utils
from src.preprocessing import preprocess_file

data_path = 'news.2012.de.shuffled.deduped'
if not os.path.exists(data_path):
    utils.download_newscrawl(2012,'de')
    # preprocess_file(data_path)

In [None]:
import ByT5Dataset
import torch.utils.data
from preprocessing import load_dataset, preprocess_text

dataset = load_dataset(dataset_size, dataset_min_len, dataset_max_len, data_path, seed, dataset_exclude_len)
dataset = [preprocess_text(text) for text in dataset]
generator1 = torch.Generator().manual_seed(seed)
train_ex, dev_ex, test_ex = torch.utils.data.random_split(
    dataset,
    [round(0.8 * dataset_size), round(0.1 * dataset_size), round(0.1 * dataset_size)],
    generator=generator1,
)
train = ByT5Dataset.(train_ex, max_length=dataset_max_len)
dev = ByT5Dataset.(dev_ex, max_length=dataset_max_len)
test = ByT5Dataset.(test_ex, max_length=dataset_max_len)




## Model architecture

In [None]:


# We want a T5 architecutre but severely reduced in size
from transformers import ByT5Tokenizer, AutoModelForSeq2SeqLM

tokenizer = ByT5Tokenizer()
model = AutoModelForSeq2SeqLM.from_pretrained("google/byt5-small")




## Training setup

In [None]:


from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import (
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
training_args = Seq2SeqTrainingArguments(
    output_dir=logdir + "/output",
    evaluation_strategy="epoch",
    num_train_epochs=train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    # accumulate gradients to simulate higher batch size
    gradient_accumulation_steps=grad_acc_steps,
    save_total_limit=0,
    predict_with_generate=True,
    push_to_hub=False,
    logging_dir=logdir,
    learning_rate=lr,
    warmup_ratio=warmup_ratio,
    save_steps=10000,
)




## Training

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=dev,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()
trainer.save_model(logdir + "/model")


## Evaluation

In [None]:

if evaluate_on_test:
    pass
else:
    test = dev

In [None]:
from utils import levensthein_distance, print_avg_median_mode_error
from transformers import pipeline, logging
logging.set_verbosity(logging.ERROR)


error_counts = []
translate = pipeline("translation", model=model, tokenizer=tokenizer, device=device)
for index in range(len(test)):
    generated = translate(test[index]["input_text"], max_length=(dataset_max_len+1)*2)[0]["translation_text"]
    error_counts.append(levensthein_distance(generated, test[index]["output_text"]))
    if error_counts[-1] > 0:
        print(f"Example {index}, error count {error_counts[-1]}")
        print("In :", test[index]["input_text"])
        print("Gen:", generated)
        expected = test[index]["output_text"]
        print("Exp:", expected)
    else:
        print(f"Example {index} OK")
    print("-----------------------")

print(f"{error_counts=}")
print_avg_median_mode_error(error_counts)