<a href="https://colab.research.google.com/github/YashashGaurav/poetai/blob/master/experiments/PoetAI_345M_submission.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Instance Checks

In [None]:
! nvidia-smi

Tue Apr 26 04:38:28 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Load Dependencies

In [None]:
! pip install transformers[sentencepiece]
! pip install einops
! pip install python-Levenshtein
! pip install neptune-client
! pip install deep-phonemizer

Collecting transformers[sentencepiece]
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 34.5 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.49-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 76.3 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 91.7 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 92.8 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 8.7 MB/s 
[?25hCollecting sentencepiece!=0.1.92,>=0.1.91
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_6

Loading deep phonemizer's model as dependency

In [None]:
! curl https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt --output en_us_cmudict_ipa_forward.pt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 62.5M  100 62.5M    0     0  9513k      0  0:00:06  0:00:06 --:--:-- 14.3M


# Codebase

In [None]:
import numpy as np
import pandas as pd 

import random
import time
import datetime
import json
import os

import torch
from transformers import (GPT2Tokenizer, 
                          GPT2LMHeadModel, 
                          GPT2Config, 
                          AdamW, 
                          get_linear_schedule_with_warmup)

from torch.utils.data import (Dataset, 
                              random_split,
                              DataLoader,
                              RandomSampler,
                              SequentialSampler)

from einops import rearrange
import math
import torch.nn as nn
import pdb

from Levenshtein import distance as levenshtein_distance

import nltk
from functools import lru_cache
import itertools
from itertools import product as iterprod

import neptune.new as neptune

In [None]:
args = {
    "path_to_data_folder": '/content/Project/data/'
}

In [None]:
poem_stanza_df = pd.read_csv(os.path.join(args['path_to_data_folder'], 'limericks_clean_with_@and#.csv'), index_col=0)
poem_stanza_df = poem_stanza_df.fillna('')

In [None]:
poem_stanza_df.head(10)

Unnamed: 0,limerick
0,capn jack was washed over the side@\nhis crew ...
1,as a soup bisque is best when served hot@\nmad...
2,simply add to the grasp of a rhesus@\nthe anti...
3,abeds where you sleep in the night@\nunless yo...
4,a smiling young fellow from spain@\nfell aslee...
5,the man who becomes alcoholic@\nis not on a pe...
6,its in castles that monarchs reside@\nthick st...
7,configuration is called absolute@\nwhen a mole...
8,according to my recollection@\nthe buoy was me...
9,can you cure my addiction please doc@\ni drink...


In [None]:
RANDOM_SEED = 73
BATCH_SIZE = 1
MAX_LEN = 64

The below code for loading the dataset, creating dataloader, running GPT-2 model, and generating samples is referenced from Generating an Edgar Allan Poe-Styled Poem Using GPT-2 https://scottmduda.medium.com/generating-an-edgar-allen-poe-styled-poem-using-gpt-2-289801ded82c

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
tokenizer.model_max_length = MAX_LEN
tokenizer.add_tokens('\n')

special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>', 'pad_token': '<PAD>'}
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/718 [00:00<?, ?B/s]

In [None]:
len_poem = []
for poem in poem_stanza_df['limerick']:
  len_poem.append(len([char for char in poem]))
len_poem.sort(reverse=True)

In [None]:
poem_stanza_df_size_limit = poem_stanza_df[poem_stanza_df['limerick'].apply(lambda x: True if len(x) < 256 else False)]

In [None]:
poem_stanza_df_size_limit.reset_index(inplace=True)
poem_stanza_df_size_limit = poem_stanza_df_size_limit[['limerick']]

In [None]:
poem_stanza_df_size_limit.tail(10)

Unnamed: 0,limerick
82897,almug the very same thing@\nas algum two words...
82898,cutis vera its part of the skin@\nthat covers ...
82899,a prisoner locked in a cell@\nfor a pet has a ...
82900,in biblical studies id dabble@\nand thats wher...
82901,darwins theory some doctrine still mocks@\nman...
82902,the storys in front of our noses@\nin the bulr...
82903,understanding the bible is hard@\ntake the cas...
82904,diverticula making you sick you@\nmay need a r...
82905,un ballo in maschera what@\nis the opera about...
82906,i said joe daddy thinks that youre drony@\nand...


In [None]:
print(tokenizer.model_max_length)

64


In [None]:
vocab_size = tokenizer.vocab_size
vocab_size += 4
print(vocab_size)

50261


In [None]:
class PoemDataset(Dataset):
    
    def __init__(self, data, tokenizer, gpt2_type='gpt2', max_length=MAX_LEN):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []
        
        for i in data:
            encodings_dict = tokenizer('<BOS>' + i + '<EOS>',
                                     truncation=False,
                                     max_length=max_length,
                                     padding='max_length'
                                    )

            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))

    def __len__(self):
        
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        
        return self.input_ids[idx], self.attn_masks[idx]

In [None]:
poem_stanza_dataset = PoemDataset(poem_stanza_df_size_limit['limerick'].values, tokenizer, max_length=MAX_LEN)

In [None]:
def train_val_split(split, dataset):
    train_size = int(split * len(dataset))
    val_size = len(dataset) - train_size
    return train_size, val_size

In [None]:
poem_stanza_train_size, poem_stanza_val_size = train_val_split(1, poem_stanza_dataset)

# random split imported from troch.utils
poem_stanza_train_dataset, poem_stanza_val_dataset = random_split(poem_stanza_dataset, [poem_stanza_train_size, poem_stanza_val_size])

In [None]:
torch.cuda.manual_seed_all(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7f50a4842f50>

In [None]:
poem_stanza_train_dataloader = DataLoader(poem_stanza_train_dataset,
                              sampler=RandomSampler(poem_stanza_train_dataset),
                              batch_size=BATCH_SIZE)

poem_stanza_val_dataloader = DataLoader(poem_stanza_val_dataset,
                            sampler=SequentialSampler(poem_stanza_val_dataset),
                            batch_size=BATCH_SIZE)

In [None]:
# helper function for logging time
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

# create text generation seed prompt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

prompt = "<BOS>"
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
generated = generated.to(device)

The below code, which calculates context loss using self-attention LSTM is referenced from SP-GPT2: Semantics Improvement in Vietnamese
Poetry Generation https://github.com/fsoft-ailab/Poem-Generator (https://arxiv.org/pdf/2110.15723v1.pdf)

In [None]:
class ScaleDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None, e=1e-12):
        batch_size, head, length, d_tensor = k.size()

        score = torch.einsum("bhid,bhjd->bhij",q,k)
        score = score/math.sqrt(d_tensor)

        if mask is not None:
            score = score.masked_fill(mask == 0, -e)

        score = self.softmax(score)

        v = score @ v

        return v, score

class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model*n_head)
        self.w_k = nn.Linear(d_model, d_model*n_head)
        self.w_v = nn.Linear(d_model, d_model*n_head)
        self.w_concat = nn.Linear(d_model*n_head, d_model)

    def forward(self, x, mask=None):
        q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), (q, k, v))

        out, attention = self.attention(q, k, v, mask=mask)

        # 4. concat and pass to linear layer
        # out = self.concat(out)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.w_concat(out)

        return out

class SelfAttentionLstm(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers,n_head):
        super(SelfAttentionLstm, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.multi_attention = MultiHeadAttention(d_model=input_size,n_head=4)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

    def forward(self, x):
        x = self.multi_attention(x)
         
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        out, _ = self.lstm(x, (h0, c0))

        out = out[: ,-1, : ]
        return out

In [None]:
head_gpt = SelfAttentionLstm(input_size=1024,hidden_size=800, num_layers=2,n_head=4).to(device)

In [None]:
def get_idx_five_line(lm_logits):

    token = torch.argmax(lm_logits, dim= 2)
    token = token[0].tolist()
    index_token = [0]

    for i in range(1, len(token)):
        if (token[i] == 50259): #<EOS> token
          index_token.append(i)
          break
        elif ((token[i] == 50257 and token[i - 1] != 50257) or 
              (token[i] != 50257 and token[i - 1] == 50257)): #\n token
          index_token.append(i)

    token_final = []
    if (len(index_token) != 10):
        # print(len(index_token))
        pass
    else:
        for i in range(0, 10, 2):
            token_final.append([index_token[i], index_token[i + 1]])
    
    return token_final

In [None]:
def limerick_context_loss(lm_logits, embedding, default_loss):
    lm_logits = torch.unsqueeze(lm_logits,0)
    pair_list = get_idx_five_line(lm_logits)
    embedding = torch.unsqueeze(embedding,0)
    
    total_lost = 0
    if len(pair_list) != 5:
      return -1
    loss = nn.MSELoss().to(device)
    
    # array where,[0] is start index and [1] is end index of line
    one = pair_list[0] 
    two = pair_list[1]
    three = pair_list[2]
    four = pair_list[3]
    five = pair_list[4]


    # embedding[batch_size, token sequence, embedding]
    embed_one = head_gpt(embedding[:, one[0]: one[1], :])
    embed_two = head_gpt(embedding[:, two[0]: two[1], :])
    embed_three = head_gpt(embedding[:, three[0]: three[1], :])
    embed_four = head_gpt(embedding[:, four[0]: four[1], :])
    embed_five = head_gpt(embedding[:, five[0]: five[1], :])

    total_lost = loss(embed_one,embed_five) + loss(embed_two,embed_five) + loss(embed_three,embed_five) + loss(embed_four,embed_five)

    return total_lost   

Rhyming loss calculation

In [None]:
from dp.phonemizer import Phonemizer

phonemizer = Phonemizer.from_checkpoint('/content/en_us_cmudict_ipa_forward.pt')
phonemizer('Phonemizing an English text is imposimpable!', lang='en_us')

vowels = ['a', 'e', 'i', 'o', 'u']
    
#made changes to avoid two letter words ending with None
def break_word(word):
  for i, c in enumerate(word[::-1]):
    if c in vowels and c!= word[-1]:
      return word[len(word)-i-1:]
    if i == len(word)-1:
      return word

def rhyming_pair_loss(word_1, word_2):
    rhyme_pair_ldistance = levenshtein_distance(
        phonemizer(break_word(word_1), lang='en_us'), 
        phonemizer(break_word(word_2), lang='en_us')
    )
    rhyme_loss = 2*((1/(1+np.exp(-rhyme_pair_ldistance)))-0.5)
    return rhyme_loss

#added eps to avoid divide by zero error
def non_rhyming_pair_loss(word_1, word_2, eps = 1e-9):

    rhyme_pair_ldistance = levenshtein_distance(
        phonemizer(break_word(word_1), lang='en_us'), 
        phonemizer(break_word(word_2), lang='en_us')
    )
    non_rhyme_loss = 2*((1/(1+np.exp(-(1/(rhyme_pair_ldistance+eps)))))-0.5)
    return non_rhyme_loss

def get_line_last_token_id(lm_logits):

    token = torch.argmax(lm_logits, dim= 2)
    token = token[0].tolist()
    last_word_token_ids = []

    for i in range(1, len(token)):
        if (token[i] == 50259): #<EOS> token
          last_word_token_ids.append(token[i - 1])
          break
        elif (token[i] == 50257 and token[i - 1] != 50257): #\n token
          last_word_token_ids.append(token[i - 1])

    return last_word_token_ids

"""
Converted each pair's distance into an array and comparing now with a tensor of zeros
in MSE Loss. 
"""
def limerick_rhyme_loss(lm_logits, embedding, default_loss):
    lm_logits = torch.unsqueeze(lm_logits,0)
    line_last_token_id = get_line_last_token_id(lm_logits)

    loss = nn.MSELoss().to(device)

    # pdb.set_trace()
    # in case lines generated are not 5
    if len(line_last_token_id) != 5:
        return -1

    one = tokenizer.convert_ids_to_tokens(line_last_token_id[0])
    two = tokenizer.convert_ids_to_tokens(line_last_token_id[1])
    three = tokenizer.convert_ids_to_tokens(line_last_token_id[2])
    four = tokenizer.convert_ids_to_tokens(line_last_token_id[3])
    five = tokenizer.convert_ids_to_tokens(line_last_token_id[4])

    # Rhymes
    rhyming_pair_losses = np.array((rhyming_pair_loss(one, two),
                                    rhyming_pair_loss(one, five),
                                    rhyming_pair_loss(two, five),
                                    rhyming_pair_loss(three, four)))
    
    rhyming_pair_losses = torch.as_tensor(rhyming_pair_losses)
    rhyming_pair_target = np.zeros(4)
    rhyming_pair_target = torch.as_tensor(rhyming_pair_target)

    rhyme_loss = loss(rhyming_pair_losses, rhyming_pair_target)

    # Non-Rhyme
    non_rhyming_pair_losses = np.array((non_rhyming_pair_loss(one, three),
                                        non_rhyming_pair_loss(two, three),
                                        non_rhyming_pair_loss(one, four),
                                        non_rhyming_pair_loss(two, four),
                                        non_rhyming_pair_loss(three, five),
                                        non_rhyming_pair_loss(four, five)))
    
    non_rhyming_pair_losses = torch.as_tensor(non_rhyming_pair_losses)
    non_rhyming_pair_target = np.zeros(6)
    non_rhyming_pair_target = torch.as_tensor(non_rhyming_pair_target)

    non_rhyme_loss = loss(non_rhyming_pair_losses, non_rhyming_pair_target)

    total_loss = rhyme_loss + non_rhyme_loss

    return total_loss  

2022-04-26 04:40:46,976.976 DEBUG phonemizer:  Initializing phonemizer with model step 710000


Custom loss aggregation

In [None]:
def get_custom_loss(logits, embeddings, loss):

    context_loss = 0
    for i in range(logits.shape[0]):
        context_loss += limerick_context_loss(logits[i], embeddings[i], loss)
        if context_loss < 0:
            break
            
    rhyme_loss = 0
    for i in range(logits.shape[0]):
      rhyme_loss += limerick_rhyme_loss(logits[i], embeddings[i], loss)
    
    total_loss = loss
    
    if rhyme_loss != -1:
        total_loss += rhyme_loss
    if context_loss >= 0:
        total_loss += context_loss
    else:
        total_loss *= 2

    # pdb.set_trace()

    print(f'Total Loss: {total_loss} | Context Loss: {context_loss} | Rhyme Loss: {rhyme_loss} ')

    return total_loss

## Training artefacts defination

In [None]:
training_storage_path = '/content/Project/trainings/'
iteration_step_to_log_generation = 250 
iteration_step_to_log_checkpoint = 10000
total_iterations = 80000

# hyperparameters
learning_rate = 1e-4
eps = 1e-8
warmup_steps = 10000

In [None]:
# GPT Configuration
configuration = GPT2Config(
        vocab_size=len(tokenizer), 
        n_positions=MAX_LEN
    ).from_pretrained('gpt2-medium', output_hidden_states=True)

# Model Definition
model = GPT2LMHeadModel.from_pretrained('gpt2-medium', config=configuration)
model.resize_token_embeddings(len(tokenizer))

# Optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate, eps=eps)
total_steps = len(poem_stanza_train_dataloader) * 1

# Scheduler
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=100000)

model = model.to(device)

Downloading:   0%|          | 0.00/1.42G [00:00<?, ?B/s]



In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

def validate(iteration, model, optimizer, lr_scheduler):

    if iteration%iteration_step_to_log_generation==0 or iteration==total_iterations:
        sample_outputs = model.generate(generated, 
                                        do_sample=True,   
                                        top_k=50, 
                                        max_length=MAX_LEN,
                                        top_p=0.95, 
                                        num_return_sequences=5)
        
        log_generation(sample_outputs, iteration)

## Logging

In [None]:
# Model Logging setup

from os import listdir
from os.path import isfile, join

def log_checkpoint(iteration, model, optimizer, lr_scheduler, metric=None):
    if iteration%iteration_step_to_log_checkpoint==0 or iteration==total_iterations:
        state = {
            'iteration': iteration + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(), 
            'lr_scheduler_state_dict': lr_scheduler.state_dict()
        }

        check_point_dir = training_storage_path

        if not os.path.exists(check_point_dir):
            os.makedirs(check_point_dir)

        if metric == None:
            checkpoint_file_path = check_point_dir + f"/poet_ai_checkpoint.h5"
            torch.save(state, checkpoint_file_path)
        else:
            # considering minimization effort
            onlyfile_metrics = [float(f.split("_checkpoint.h5")[0]) for f in listdir(check_point_dir) if isfile(join(check_point_dir, f)) and "_checkpoint.h5" in f]

            if len(onlyfile_metrics) > 0 and metric < sorted(onlyfile_metrics)[0]:
                checkpoint_file_path = check_point_dir + f"/{metric}_checkpoint.h5"
                torch.save(state, checkpoint_file_path)
                os.remove(check_point_dir + f"/{sorted(onlyfile_metrics)[0]}_checkpoint.h5")
        

In [None]:
def log_generation(sample_outputs, iter_no):
    # create path
    
    check_point_dir = training_storage_path

    if not os.path.exists(check_point_dir):
            os.makedirs(check_point_dir)

    with open(os.path.join(check_point_dir, 'generation_log.txt'), 'a') as log_file:
        log_file.write(f"-- Iteration {iter_no} -- \n\n")
        print(f"\n\n -- Iteration {iter_no} --")
        for i, sample_output in enumerate(sample_outputs):
            log_limerick = "{}: {}\n\n".format(i, tokenizer.decode(sample_output, skip_special_tokens=True))
            log_file.write(log_limerick)
            print(log_limerick)


## Loading Model

In [None]:
# checkpoint_file_path = ""

In [None]:
# # loading models back from repos:
# # assumes model, optimizer and lr_scheduler are already defined.
# def load_logged_model(model, optimizer, lr_scheduler):

#     new_start_iteration = 0
#     if os.path.isfile(checkpoint_file_path):
#         print("=> loading checkpoint '{}'".format(checkpoint_file_path))
#         checkpoint = torch.load(checkpoint_file_path)
#         new_start_iteration = checkpoint['iteration']
#         model.load_state_dict(checkpoint['model_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#         lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
#         print("=> loaded checkpoint '{}' (new iteration {})"
#                   .format(checkpoint_file_path, checkpoint['iteration']))
#     else:
#         print("=> no checkpoint found at '{}'".format(checkpoint_file_path))

#     return model, optimizer, new_start_iteration, lr_scheduler

# model, optimizer, train_iterations, lr_scheduler = load_logged_model(model, optimizer, lr_scheduler)

# # assumes everything is on cuda; else use model.to(device)
# model = model.cuda()
# # now individually transfer the optimizer parts...
# for state in optimizer.state.values():
#     for k, v in state.items():
#         if isinstance(v, torch.Tensor):
#             state[k] = v.cuda()

=> loading checkpoint '/content/Project/trainings/POET-12/poet_ai_checkpoint.h5'
=> loaded checkpoint '/content/Project/trainings/POET-12/poet_ai_checkpoint.h5' (new iteration 1251)


# Training

In [None]:
train_iterations = 1

In [None]:
outputs = None

while (train_iterations < total_iterations):

    print(f'Iteration {train_iterations} of {total_iterations}')

    total_train_loss = 0
    model.train()

    for step, batch in enumerate(poem_stanza_train_dataloader):

        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)

        model.zero_grad()

        outputs = model(b_input_ids,
                                    labels=b_labels,
                                    attention_mask=b_masks,
                                    token_type_ids=None)
        embeddings = model.transformer(b_input_ids,
                                    attention_mask=b_masks,
                                    token_type_ids=None)[0]

        loss = outputs[0]
        loss = get_custom_loss(outputs.logits, embeddings, loss)
    
        batch_loss = loss.item()
        total_train_loss += batch_loss

        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        print(f'Iteration number: {train_iterations}')
        
        validate(iteration=train_iterations, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)

        log_checkpoint(train_iterations, model, optimizer, lr_scheduler)
        
        train_iterations += 1
        

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Total Loss: 3.0213677883148193 | Context Loss: 0.0006736295763403177 | Rhyme Loss: 0.0 
Iteration number: 80579
Total Loss: 2.5904297828674316 | Context Loss: 0.0005638297880068421 | Rhyme Loss: 0.0 
Iteration number: 80580
Total Loss: 2.2920279502868652 | Context Loss: 0.0006786675658077002 | Rhyme Loss: 0.0 
Iteration number: 80581
Total Loss: 2.3744921684265137 | Context Loss: 0.0008914298377931118 | Rhyme Loss: 0.0 
Iteration number: 80582
Total Loss: 2.6634607315063477 | Context Loss: 0.0006932662799954414 | Rhyme Loss: 0.0 
Iteration number: 80583
Total Loss: 2.0970983505249023 | Context Loss: 0.0007316371193155646 | Rhyme Loss: 0.0 
Iteration number: 80584
Total Loss: 2.4497735500335693 | Context Loss: 0.0008394810138270259 | Rhyme Loss: 0.0 
Iteration number: 80585
Total Loss: 3.2959063053131104 | Context Loss: 0.0006519206799566746 | Rhyme Loss: 0.0 
Iteration number: 80586
Total Loss: 2.378627300262451 | Context