##### Prerequisites

In [None]:
%%capture 

!pip install torch==1.12.1+cu113
!pip install transformers==4.21.0
!pip install datasets==2.9.0
!pip install wandb==0.13.10

#### Imports 

In [None]:
from transformers import AutoModelForCausalLM
from transformers import TrainingArguments
from transformers import AutoTokenizer
from datasets import load_from_disk
from transformers import Trainer
import transformers 
import datasets 
import logging
import torch
import wandb
import os

In [None]:
torch.cuda.empty_cache()

##### Setup logging

In [None]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [None]:
logger.info(f'[Using transformers version: {transformers.__version__}]')
logger.info(f'[Using datasets version: {datasets.__version__}]')
logger.info(f'[Using torch version: {torch.__version__}]')
logger.info(f'[Using wandb version: {wandb.__version__}]')

##### Setup wandb logging

In [None]:
!wandb login 8489739d838b89d2f424147f354f9db40517c1c9

In [None]:
path = os.path.abspath('01-finetune.ipynb')
os.environ['WANDB_NOTEBOOK_NAME'] = path

#### Load dataset

In [None]:
%%time 

dataset = load_from_disk('./../01-prepare/data/tokenized')
logger.info(dataset)

In [None]:
def custom_data_collator(batch):
    # batch size for data collation = per_device_train_batch_size * number of GPUs
    input_ids = torch.stack([torch.LongTensor(example['input_ids']) for example in batch])
    attention_mask = torch.stack([torch.LongTensor(example['token_type_ids']) for example in batch])
    labels = torch.stack([torch.LongTensor(example['labels']) for example in batch])
    return {'input_ids': input_ids, 'token_type_ids': attention_mask, 'labels': labels}

#### Load GPT-Neo Tokenizer 

In [None]:
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
logger.info(tokenizer)

In [None]:
special_tokens = {
    'bos_token': '<|startoftext|>',
    'additional_special_tokens': ['<|speaker-1|>', '<|speaker-2|>', '<|pad|>', '<|mask|>']
}

In [None]:
_ = tokenizer.add_special_tokens(special_tokens)
vocab = tokenizer.get_vocab()

In [None]:
logger.info(tokenizer)

#### Load GPT-Neo model

In [None]:
%%time

model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M')
model.resize_token_embeddings(len(vocab))
device = torch.device('cuda')
model.to(device)
logger.info(next(model.parameters()).device)

#### Setup training config

In [None]:
TRAIN_EPOCHS = 2
TRAIN_BATCH_SIZE = 4
EVAL_BATCH_SIZE = 4
LOGGING_STEPS = 64
SAVE_STEPS = 10240  # Reduce it to a smaler value like 512 if you want to save checkpoints
SAVE_TOTAL_LIMIT = 2

In [None]:
training_args = TrainingArguments(output_dir='./model', 
                                  overwrite_output_dir=True, 
                                  num_train_epochs=TRAIN_EPOCHS,  
                                  optim='adamw_torch', 
                                  save_strategy='steps', 
                                  evaluation_strategy='epoch',
                                  per_device_train_batch_size=TRAIN_BATCH_SIZE, 
                                  per_device_eval_batch_size=EVAL_BATCH_SIZE, 
                                  warmup_steps=10, 
                                  weight_decay=0.1,
                                  logging_steps=LOGGING_STEPS,
                                  save_steps=SAVE_STEPS, 
                                  save_total_limit=SAVE_TOTAL_LIMIT,
                                  report_to='wandb',
                                  logging_dir='logs')

#### Train

In [None]:
trainer = Trainer(model=model, 
                  args=training_args, 
                  train_dataset=dataset['train'], 
                  eval_dataset=dataset['validation'], 
                  data_collator=custom_data_collator)

In [None]:
%%time 

trainer.train()

#### Save model 