Import dependencies

In [None]:
import torch as t
import transformers, pickle

Global GPU

In [None]:
output_device = t.device('cpu')
model_run_device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')

Set up trainers. Adjust batch size based on VRAM availability.

In [None]:
def trainer_inator(file_name, tokens):
    model = transformers.T5ForConditionalGeneration.from_pretrained('t5-base')
    training_args = transformers.TrainingArguments(
        output_dir=f'./checkpoints/full-fine-tuning-{file_name}-t5-base',
        eval_strategy="epoch",
        learning_rate=1e-4,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=3,
        weight_decay=0.01,
        logging_dir="./logs",
        log_level="info",
        save_total_limit=1,
        overwrite_output_dir=True,
        disable_tqdm=False,
        use_cpu=False,
        fp16=True
    )
    trainer = transformers.Trainer(
        model=model,
        args=training_args,
        train_dataset=tokens['train'],
        eval_dataset=tokens['test']
    )
    trainer.model.to(model_run_device)
    return(trainer)

Load each of the datasets (CNN, SAMSum, Mixed) from preprocessing and call the trainer inator for each. 

In [None]:
with open(f'./preprocessing/cnn_tokens.pickle', "rb") as file:
    cnn_tokens = pickle.load(file)
cnn_FFT_trainer = trainer_inator("cnn", cnn_tokens)

In [None]:
with open(f'./preprocessing/samsum_tokens.pickle', "rb") as file:
    samsum_tokens = pickle.load(file)
samsum_FFT_trainer = trainer_inator("samsum", samsum_tokens)

In [None]:
with open(f'./preprocessing/mixed_tokens.pickle', "rb") as file:
    mixed_tokens = pickle.load(file)
mixed_FFT_trainer = trainer_inator("mixed", mixed_tokens)

Train the models!!!

In [None]:
cnn_FFT_trainer.train()
with open(f"./models/cnn_FFT_trainer.pickle", "wb") as file:
    pickle.dump(cnn_FFT_trainer, file)

In [None]:
samsum_FFT_trainer.train()
with open(f"./models/samsum_FFT_trainer.pickle", "wb") as file:
    pickle.dump(samsum_FFT_trainer, file)

In [None]:
mixed_FFT_trainer.train()
with open(f"./models/mixed_FFT_trainer.pickle", "wb") as file:
    pickle.dump(mixed_FFT_trainer, file)

Also save them.