In [None]:
import torch
import torch.nn as nn

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from model import build_transformer

from torch.utils.data import Dataset,DataLoader,random_split
from dataset import BilingualDataset,casual_mask

from torch.utils.tensorboard import SummaryWriter

from config import get_config,get_weights_file_path
from tqdm import tqdm
import random

In [1]:

# creating the tokenizer

def get_all_sentences(ds,lang):
    '''This function takes in the dataset and lang and what it does is iterate the dataset and returns all sentences in one particular language'''
    for item in ds:
        yield item['translation'][lang]

def build_tokenizer(config,ds,lang):
    '''This functions builds and saves the tokenizer if it does not exist, if it exist it just fetches the tokenizer. It returns the tokenizer'''
    tokenizer_path=Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer=Tokenizer(WordLevel(unk_token='[UNK]'))
        tokenizer.pre_tokenizer=Whitespace()
        trainer=WordLevelTrainer(special_tokens=["[UNK]","[PAD]","[SOS]","[EOS]"],min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds,lang),trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer=Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

In [2]:
#loading data\
def get_ds(config):
    '''This function loads the data from hugging face making use of load_dataset. It then builds the tokenizer using the loaded dataset. After which datas
    t is split into training 90% and validation 10% using torch random_split function'''
    ds_raw=load_dataset('cous_books', f'{config['lang_src']}.{config['lang_tgt']}', split='train')

    tokenizer_src=build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt=build_tokenizer(config,ds_raw,config['lang_tgt'])

    train_ds_size=int(0.9*len(ds_raw))
    val_ds_size=len(ds_raw)-train_ds_size
    train_ds_raw,val_ds_raw=random.split(ds_raw,[train_ds_size,val_ds_size])


    train_ds=BilingualDataset(train_ds_raw,tokenizer_src,tokenizer_tgt,config['lang_src'],config['lang_tgt'],config['seq_len'])
    val_ds=BilingualDataset(val_ds_raw,tokenizer_src,tokenizer_tgt, config[['lang_src'],config['lang_tgt'],config['seq_len']])

    max_len_src=0
    max_len_tgt=0

    for item in ds_raw:
        src_ids=tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids=tokenizer_src.encode(item['translation'][config['lang_src']]).ids

        max_len_src=max(max_len_src, len(src_ids))
        max_len_tgt=max(max_len_tgt,len(tgt_ids))

    print(f" Maximum length of source sentence is {max_len_src}")
    print(f"Maximum length of target id is {max_len_tgt}")
    #creating dataloaders

    train_dataloader=DataLoader(train_ds,batch_size=config['batch_size'],shuffle=True)
    val_dataloader=DataLoader(val_ds,batch_size=1,shuffle=False)

    return train_dataloader,val_dataloader,tokenizer_src,tokenizer_tgt






SyntaxError: f-string: unmatched '[' (3376763086.py, line 5)

In [3]:
#defining the model
def get_model(config, vocab_src_len,vocab_tgt_len):
    model=build_transformer(vocab_src_len,vocab_tgt_len,config['seq_len'],config['seq_len'],config['d_model'])
    return model


In [None]:
def greedy_decode(model,source,source_mask,tokenizer_src,tokenizer_tgt,max_len,device):
    sos_ids=tokenizer_src.token_to_id('[SOS]')
    eos_ids=tokenizer_src.token_to_id('[EOS]')

    encoder_output=model.encode(source,source_mask)

    decoder_input=torch.empty(1,1).fill_(sos_ids).type_as(source).to(device)
    while True:
        if decoder_input.size(1)==max_len:
            break
        decoder_mask=casual_mask(decoder_input.size(1)).type_as(source).to(device)

        out=model.decode(encoder_output,source_mask,decoder_input,decoder_mask)

        prob=model.project(out[1,-1])
        _,next_word=torch.max(prob,dim=1)

        decoder_input=torch.cat([decoder_input, torch.empty(1,1)].type_as(source).fill_(next_word.item().to(device)),dim=1)

        if next_word==eos_ids:
            break
    return decoder_input.squeeze(0)

In [None]:
#inference mode
def run_validation(model,validation_ds,tokenizer_src,tokenizer_tgt,max_len,device,print_msg,global_state,writer,num_examples):
    model.eval()
    count=0
    source_texts=[]
    expected=[]
    predicted=[]
    console_width=50

    with torch.no_grad():
        for batch in validation_ds:
            count+=1
            encoder_input=batch['encoder_input'].to(device)
            encoder_mask=batch['encoder_mask'].to(device)

            assert encoder_input.size(0)==1

            model_out=greedy_decode(model,encoder_input,encoder_mask,tokenizer_src,tokenizer_tgt,max_len,device)

            source_text=batch['src_text'][0]
            target_text=batch['tgt_text'][0]
            model_out_text=tokenizer_tgt.decode(model_out.detach().cpu().numpy)

            source_text.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)

            print_msg('-' console_width)
            print_msg(f'SOURCE: {source_text}')
            print_msg(f"TARGET:{target_text}")
            print_msg(f"Predicted:{model_out_text}")

            if count==num_examples:
                break

            



In [None]:
def train_model(config):
    device='cuda ' if torch.cuda.is_available else 'cpu'
    print(device)

    Path(config['model_folder']).mkdir(parents=True,exist_ok=True)
    train_dataloader,val_dataloader,tokenizer_src,tokenizer_tgt=get_ds(config)
    model=get_model(config, tokenizer_src.get_vocab_size(),tokenizer_tgt.get_vocab_size()).to(device)

    writer=SummaryWriter(config['experiment_name'])

    optimizer=torch.optim.Adam(model.parameters(),lr=config['lr'],eps=1e-9)

    initial_epoch,global_step=0,0

    if config['preload']:
        model_filename=get_weights_file_path(config,config['preload'])
        state=torch.load(model_filename)
        initial_epoch=state['epoch']+1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step=state['global_step']

    loss_fn=nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'),label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        
        batch_iterator=tqdm(train_dataloader, desc='processing epochs')

        for batch in batch_iterator:
            model.train()
            encoder_input=batch['encoder_input'].to(device)
            decoder_input=batch['decoder_input'].to(device)
            encoder_mask=batch['encoder_mask'].to(device)
            decoder_mask=batch['encoder_mask'].to(device)

            encoder_output=model.encode(encoder_input,encoder_mask)
            decoder_output=model.decode(encoder_output,encoder_mask,decoder_input,decoder_mask)
            proj_output=model.project(decoder_output)

            label=batch['label'].to(device)

            loss=loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({f"loss: f{loss.item():.3f}"})

            writer.add_scaler('train loss',loss.item(),global_step)
            writer.flush()

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            global_step+=1

        run_validation(model,val_dataloader,tokenizer_src,tokenizer_tgt,config['seq_len'],device,lambda msg :batch_iterator.write(msg),global_step,writer)


        model_filename=get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch':epoch,
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict':optimizer.state_dict(),
            'global_step':global_step

        },model_filename)


if __name__ == '__main__':
    config=get_config()
    train_model()
