
# Vignere cipher (all possible settings, length 3) on news dataset

In [1]:
# half precision and optimized
# import src to path
import sys
import os
import src.ciphers as ciphers
import src.ByT5Dataset as ByT5Dataset
import src.utils as utils
import src.evaluation as evaluation
import src.preprocessing as preprocessing


if __name__ == "__main__":
    # try get SLURM JOB ID
    try:
        job_id = os.environ["SLURM_JOB_ID"]
    except:
        job_id = "debug"
    logdir = f"logs/slurm_{job_id}"
    os.makedirs(logdir, exist_ok=True)


# ## Setup and hyperparameters

In [None]:
from src.utils import calculate_batch_size

dataset_size = 100000
dataset_min_len = 200
dataset_max_len = 200
seed = 39  # reproducible
evaluate_on_test = True
device = "cuda:0"
train_epochs = 40
lr = 2e-3
warmup_ratio = 0.2

tartget_batch_size = 192
batch_size, grad_acc_steps = calculate_batch_size(tartget_batch_size, dataset_max_len)
# batch_size, grad_acc_steps = 16, 12
print(f"batch_size: {batch_size}, grad_acc_steps: {grad_acc_steps}")
# 100k dataset 40 epochs


## Data

In [None]:
# 0. (optional) get data and preprocess it
import os
import src.utils
from src.preprocessing import preprocess_file

data_path = "news.2012.de.shuffled.deduped"
if not os.path.exists(data_path):
    utils.download_newscrawl(2012, "de")
    preprocess_file(data_path)

In [None]:
import src.ByT5Dataset
import torch.utils.data
from src.preprocessing import load_dataset, preprocess_text

dataset = load_dataset(dataset_size, dataset_min_len, dataset_max_len, data_path, seed)
dataset = [preprocess_text(text) for text in dataset]
generator1 = torch.Generator().manual_seed(seed)
train_ex, dev_ex, test_ex = torch.utils.data.random_split(
    dataset,  # type: ignore
    [round(0.8 * dataset_size), round(0.1 * dataset_size), round(0.1 * dataset_size)],
    generator=generator1,
)
dataset_class = ByT5Dataset.ByT5NoisyVignere3Dataset

train = dataset_class(train_ex, max_length=dataset_max_len)
dev = dataset_class(dev_ex, max_length=dataset_max_len)
test = dataset_class(test_ex, max_length=dataset_max_len)


## Model architecture

In [None]:
# We want a T5 architecutre but severely reduced in size
from transformers import ByT5Tokenizer, AutoModelForSeq2SeqLM
import torch
tokenizer = ByT5Tokenizer()
model = AutoModelForSeq2SeqLM.from_pretrained("google/byt5-small", 
# torch_dtype=torch.bfloat16
)
# model = model.half() 
# model.to(device)


## Training setup

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import (
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
from accelerate import Accelerator

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
training_args = Seq2SeqTrainingArguments(
    output_dir=logdir + "/output",
    evaluation_strategy="epoch",
    num_train_epochs=train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    # accumulate gradients to simulate higher batch size
    gradient_accumulation_steps=grad_acc_steps,
    # save_total_limit=0,
    predict_with_generate=True,
    push_to_hub=False,
    logging_dir=logdir,
    learning_rate=lr,
    warmup_ratio=warmup_ratio,
    save_steps=500,
    # fp16=True,
)


# ## Training

In [None]:
import wandb

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=dev,
    data_collator=data_collator,
    tokenizer=tokenizer,
)
import os

os.environ["WANDB_LOG_MODEL"] = "checkpoint"
wandb.init(project="vignere3_noisy_news_de")
print('training time')
trainer.train()
trainer.save_model(logdir + "/model")

wandb.finish()