# Import Torch functions and tokenizers

In [1]:
import torch
from torch import Tensor, nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import dataset
from torch.utils.tensorboard import SummaryWriter

import regex as re
import os
import time
from tqdm import tqdm
import copy
import math

from model import TransformerModel
from utils import preProcessText, getTokenizer,try_gpu , word_piece_decoder, word_piece_encoder
from config import getConfig

import pickle

In [2]:
model_config, app_config = getConfig(small = False)
print(model_config)
print(app_config)

bptt=model_config["bptt"]
device = try_gpu(0)

{'emsize': 300, 'd_hid': 1024, 'nlayers': 6, 'nhead': 6, 'dropout': 0.2, 'bptt': 64}
{'logs': 'tensorboard_logs', 'epochs': 25}


# Preprocessing Text

In [3]:
file_path = 'data/preprocessed_word_piece.txt'
if not os.path.exists(file_path):
    with open('data/ne_dedup.txt', 'r', encoding='utf-8') as f:
        text = f.read()
        print("Preprocessing file")
        text = preProcessText(text,tokenizer_type = 'word_piece')
    with open(file_path, 'w', encoding='utf-8') as f:
        f.write(text)
else:
    print(f"Reading file  : {file_path}")
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()

Preprocessing file


In [5]:
len(text.split('\n'))

341961

In [6]:
train_split = 300_000

train_iter_first = text.split('\n')[:train_split]
test_iter = text.split('\n')[train_split:]

In [7]:
tokenizer,vocab = getTokenizer(tokenizer_type = 'word_piece')

In [8]:
# vocab = tokenizer.get_vocab()
len(vocab)

30000

In [9]:

# Try the encoder and decoder
l = tokenizer.encode(word_piece_encoder('महानायक राजेश हमाल अहिले चलचित्र क्षेत्रमा पातलिए ।')).tokens
l_ = tokenizer.encode(word_piece_encoder('महानायक राजेश हमाल अहिले चलचित्र क्षेत्रमा पातलिए ।')).ids

print("Encoded $tring: ",l)
print("Decoded $tring:",word_piece_decoder(tokenizer.decode(l_)))


Encoded $tring:  ['महान', '##ा', '##यक', 'राजlश', 'हमाल', 'अहिलl', 'चलचितaर', 'कaषlतaरमा', 'पात', '##लिए', '।']
Decoded $tring: महानायक राजेश हमाल अहिले चलचित्र क्षेत्रमा पातलिए ।


In [10]:

# Try the encoder and decoder
l = tokenizer.encode(word_piece_encoder('हातमा त्रिशुल जटा मुकुट शुशोभीत ब्रम्हा उत्पति हुनु ।')).tokens
l_ = tokenizer.encode(word_piece_encoder('हातमा त्रिशुल जटा मुकुट शुशोभीत ब्रम्हा उत्पति हुनु ।')).ids

print("Encoded $tring: ",l)
print("Encoded id$: ",l_)
print("Decoded $tring:",word_piece_decoder(tokenizer.decode(l_)))


Encoded $tring:  ['हातमा', '[UNK]', 'ज', '##टा', '[UNK]', '[UNK]', 'बaरमaहा', 'उतa', '##पति', '[UNK]', '।']
Encoded id$:  [4308, 1, 42, 307, 1, 1, 27723, 475, 826, 1, 77]
Decoded $tring: हातमा जटा ब्रम्हा उत्पति ।


#  some utility functions

In [11]:
def split_list(l):
    splitted_list = []
    z = 0
    for i,idx in enumerate(l):
        if idx == 220:
            splitted_list.append(l[z:i])
            z = i+1
    if z <= len(l)-1:
        splitted_list.append(l[z:])
    return splitted_list

def splits_to_token(splited_list):
    strings = [tokenizer.decode(l) for l in splited_list]
    
    return strings

# print(tokenizer.encode(' ').ids)

word_piece_decoder(tokenizer.decode([978,
 261,
 264,
624,
 261,
 263]))

k = split_list(l_)

In [12]:
splits_to_token(k)

['हातमा जटा बaरमaहा उतaपति ।']

In [14]:
def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(tokenizer.encode(word_piece_encoder(item)).ids, dtype=torch.long)
            for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))


def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.
    Args:
        data: Tensor, shape [N]
        bsz: int, batch size
    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
#     data = data.view(bsz, seq_len).t().contiguous()
    data = data.view(bsz,seq_len).t()
#     return data.to(device)
    return data


seq_length = bptt
import math


def get_batch(source: Tensor, i: int):
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int
    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(seq_length, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    #target = source[i+1:i+1+seq_len]
    return data, target

In [15]:
#Train and Test Split
train_data = data_process(train_iter_first)
test_data = data_process(test_iter)

In [22]:
print(train_data.shape, test_data.shape)

torch.Size([74348215]) torch.Size([11128127])


In [19]:
train_data[:3]

tensor([27894,  3615,  2007])

In [21]:
torch.cuda.empty_cache() 
torch.cuda.memory_allocated() 

0

# Model Definition

In [23]:
batched_train_data = batchify(train_data, bptt).to(device)  # shape [seq_len, batch_size]
batched_test_data = batchify(test_data, bptt).to(device)

In [24]:
def get_model(model_config, ntokens):
    emsize = model_config["emsize"]
    d_hid = model_config["d_hid"]
    nlayers = model_config["nlayers"]
    nhead = model_config["nhead"]
    dropout = model_config["dropout"]
    model = TransformerModel(ntokens, emsize,nhead, d_hid, nlayers, dropout)
    return model

In [25]:
ntokens = len(vocab)
print(f"No. of tokens : {ntokens}")
model = get_model(model_config, ntokens).to(device)
torch.cuda.memory_allocated()

No. of tokens : 30000




787643392

In [27]:
seq_length = bptt

# Hyper-Parameter Tuning

In [28]:
criterion = nn.CrossEntropyLoss()
lr = 1  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
softmax = nn.Softmax(dim=2)
#softmax = nn.LogSoftmax(dim=2)

  from .autonotebook import tqdm as notebook_tqdm


In [29]:
def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

In [30]:
def train(model: nn.Module) -> None:
    global epoch
    global global_step
    model.train()  # turn on train mode
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device)

    num_batches = len(batched_train_data) // bptt
    progress_bar = tqdm(enumerate(range(0, batched_train_data.size(0) - 1, bptt)), total=num_batches, desc=f'Epoch {epoch}', ncols=80)
    for batch_idx, i in progress_bar:
        ### batch_idx -> (1, 2, 3, 4, ...)
        ### i -> (0, bptt, 2*bptt, ....)
        data, targets = get_batch(batched_train_data, i)
        batch_size = data.size(0)
        if batch_size != bptt:  # only on last batch
            src_mask = src_mask[:batch_size, :batch_size]
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()

        ## calculate the postfix description for the progress bar
        cur_loss = total_loss / (batch_idx + 1)
        ppl = math.exp(cur_loss)
        
        progress_bar.set_postfix({"loss": cur_loss, "ppl" : ppl}, refresh=True)
        
        writer.add_scalar('loss/train loss', cur_loss, global_step)
        writer.flush()
        writer.add_scalar('ppl/train perplexity', ppl, global_step)
        writer.flush()
        global_step += 1

In [31]:
def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device)

    num_batches = len(eval_data) // bptt
    with torch.no_grad():
        progress_bar = tqdm(enumerate(range(0, eval_data.size(0) - 1, bptt)), total=num_batches, desc=f'Validation {epoch}', ncols=80)
        for batch_idx, i in progress_bar:
            data, targets = get_batch(eval_data, i)
            batch_size = data.size(0)
            if batch_size != bptt:
                src_mask = src_mask[:batch_size, :batch_size]
            output = model(data, src_mask)
            output_softmax = softmax(output)
            output_softmax_permuted = output_softmax.permute(1, 0, 2)
            indices = torch.argmax(output_softmax_permuted, dim=2)
            target_indices = targets.t()
            output_flat = output.view(-1, ntokens)
            total_loss += batch_size * criterion(output_flat, targets).item()
    
    eval_loss = total_loss / (len(eval_data) - 1)
    eval_ppl = math.exp(eval_loss)

    writer.add_scalar('loss/val loss', eval_loss, global_step)
    writer.flush()
    writer.add_scalar('ppl/val perplexity', eval_ppl, global_step)
    writer.flush()

    return eval_loss

# Training Data

In [32]:
best_model_path = 'models/best_model_wp.pt'

In [36]:
# Loop over epochs. Save the model if the validation loss is the best
# we've seen so far. Adjust the learning rate after each epoch.
best_val_loss = float('inf')
initial_epoch = 0
epochs = app_config["epochs"]
global_step = 0
best_model = None

# preload the model if exists to train more epochs

if os.path.exists(best_model_path):
    print(f"Preloading model {best_model_path}")
    state = torch.load(best_model_path)
    
    initial_epoch = state['epoch'] + 1
    model.load_state_dict(state['model_state_dict'])
    optimizer.load_state_dict(state['optimizer_state_dict'])
    global_step = state['global_step']
    best_val_loss = state['best_val_loss']
    
    print(initial_epoch, global_step, best_val_loss)

# initializing the tensorbaord log writer
writer = SummaryWriter(app_config["logs"])


for epoch in range(initial_epoch, epochs):
    train(model)
    eval_loss = evaluate(model, batched_test_data)

    # save the model if validation loss decreases

    if eval_loss < best_val_loss:
        print(f"eval perplexity : {math.exp(eval_loss)}")
        print("saving the model")
        best_val_loss = eval_loss
        best_model = copy.deepcopy(model)

        directory_path = 'models'
        # Create the directory if it doesn't exist
        if not os.path.exists(directory_path):
            os.makedirs(directory_path)
        torch.save({
                'epoch': epoch,
                'model_state_dict': best_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'global_step': global_step, 
                'best_val_loss' : best_val_loss,
            }, os.path.join(directory_path, 'best_model_wp.pt'))

Epoch 0: 18152it [28:43, 10.54it/s, loss=6.2, ppl=494]                          
Validation 0: 2717it [01:40, 27.14it/s]                                         


eval perplexity : 255.52019156968703
saving the model


Epoch 1: 18152it [28:44, 10.53it/s, loss=5.46, ppl=234]                         
Validation 1: 2717it [01:40, 27.06it/s]                                         


eval perplexity : 181.48158158143409
saving the model


Epoch 2: 18152it [28:33, 10.59it/s, loss=5.22, ppl=185]                         
Validation 2: 2717it [01:40, 26.94it/s]                                         


eval perplexity : 154.73720022563174
saving the model


Epoch 3: 18152it [28:33, 10.59it/s, loss=5.09, ppl=163]                         
Validation 3: 2717it [01:40, 27.06it/s]                                         


eval perplexity : 140.20334244818596
saving the model


Epoch 4: 18152it [28:30, 10.61it/s, loss=5.01, ppl=149]                         
Validation 4: 2717it [01:39, 27.44it/s]                                         


eval perplexity : 131.45470255401509
saving the model


Epoch 5: 18152it [28:08, 10.75it/s, loss=4.95, ppl=141]                         
Validation 5: 2717it [01:38, 27.46it/s]                                         


eval perplexity : 124.55780365059415
saving the model


Epoch 6: 18152it [28:08, 10.75it/s, loss=4.9, ppl=134]                          
Validation 6: 2717it [01:39, 27.42it/s]                                         


eval perplexity : 119.95044059916711
saving the model


Epoch 7: 18152it [28:40, 10.55it/s, loss=4.86, ppl=129]                         
Validation 7: 2717it [01:42, 26.63it/s]                                         


eval perplexity : 116.32955185860645
saving the model


Epoch 8:   0%|           | 80/18151 [00:07<29:22, 10.25it/s, loss=4.83, ppl=126]


KeyboardInterrupt: 