In [1]:
import pandas as pd
from transformers import GPT2Tokenizer, BertTokenizer
import sentencepiece as spm
import torch
import torch.nn as nn
import torch.optim as optim
from tweaked_model2 import Model, get_preprocessed_example

In [2]:
train_df = pd.read_csv('datasets/wmt14_translate_de-en_train.csv',lineterminator='\n')

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
num_added_tokens = tokenizer.add_tokens(['[START]', '[END]'], special_tokens=True)
vocab_size = tokenizer.vocab_size + num_added_tokens

# sp = spm.SentencePieceProcessor() # gpt2 seems to do better than this
# sp.load('saved/sp/combined_tokenizer.model')

# vocab_size = sp.vocab_size()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = Model(embed_dim=512,
              num_blocks=6, 
              num_heads=8, 
              ff_dim=2048, 
              dropout_rate=0.1,
              vocab_size=vocab_size).to(device)

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)

# model.load_state_dict(torch.load('saved/tweaked_model/13/model2.pth'))
model.train()

# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-6, verbose=True)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6, verbose=True)

num_epochs = 1
scheduler_interval = 1000
best_loss_interval = 10000
losses = []
lr_decrease_steps = []
best_loss = 100000000

for epoch in range(num_epochs):
    for i in range(train_df.shape[0]):
        optimizer.zero_grad()

        src, tgt = get_preprocessed_example(i=i, train_df=train_df, sp=tokenizer, device=device)
        logits = model(src, tgt[:-1])

        loss = criterion(logits, tgt[1:])
        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # gradient clipping

        optimizer.step()

        losses.append(loss.item())
        if i % scheduler_interval == 0 and i != 0:
            avg_interval_loss = sum(losses[-scheduler_interval:]) / scheduler_interval
            print('step:', i)
            print('avg loss:', avg_interval_loss)
            print('\n')
            torch.save(model.state_dict(), 'saved/tweaked_working_model/x/model.pth') 
            
            max_length = max(len(losses), len(lr_decrease_steps))
            padded_losses = losses + [None] * (max_length - len(losses)) # THESE ARE EFFECTIVELY STORING HUNDREDS OF THOUSANDS
                                                                        # OF ELEMENTS EACH, CAN FIX BY STORING ALL IN 1 ROW OF DF
                                                                        # CAN MAKE THE DF STORE LIST OF ALL LOSSES IN THE FIRST ROW AND COLUMN
            padded_lr_decrease_steps = lr_decrease_steps + [None] * (max_length - len(lr_decrease_steps))
            padded_best_loss = [best_loss] + [None] * (max_length - len([best_loss]))

            previous_lr = optimizer.param_groups[0]['lr']
            scheduler.step(avg_interval_loss)
            current_lr = optimizer.param_groups[0]['lr']
            if current_lr < previous_lr: 
                lr_decrease_steps.append(i)

            df_dict = {'loss': padded_losses,
                       'lr_decrease_steps': padded_lr_decrease_steps,
                       'best_model_saves_avg_loss':padded_best_loss}
            df = pd.DataFrame(df_dict)
            df.to_csv('saved/tweaked_working_model/x/losses_and_lrs.csv', index=False)

        if i % best_loss_interval == 0 and i != 0 and sum(losses[-best_loss_interval:]) / best_loss_interval < best_loss:
            torch.save(model.state_dict(), 'saved/tweaked_working_model/x/best_model.pth') # based on past 10000 examples avg loss
            best_loss = sum(losses[-best_loss_interval:]) / best_loss_interval
            print('best avg loss over 10,000 steps:', best_loss)
        
    print('epoch', epoch, "complete")