<a href="https://colab.research.google.com/github/Sangharsh1215/Transformer/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

In [21]:
from dataset import BillingualDataset, casual_mask

In [22]:
from model import build_transformer

In [23]:
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path

In [24]:
def get_all_sentences(ds,lang):
  for item in ds:
    yield item['translation'][lang]

In [25]:
def get_tokenizer(config, ds, lang):
  tokenizer_path = Path(config['tokenizer_file'].format(lang))
  if not Path.exists(tokenizer_path):
    tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
    tokenizer.pre_tokenizer = Whitespace()
    trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
    tokenizer.train_from_iterator(get_all_sentences(ds,lang), trainer = trainer)
    tokenizer.save(str(tokenizer_path))

  else:
    tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

In [26]:
def get_ds(config):
  ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', splits = 'train')

  tokenizer_src = get_tokenizer(config, ds_raw, config['lang_src'])
  tokenizer_tgt = get_tokenizer(config, ds_raw, config['lang_tgt'])

  train_ds_size = int(0.9 + len(ds_raw))
  val_ds_size = len(ds_raw) - train_ds_size
  train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, ])

  train_ds = BillingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
  val_ds = BillingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])


  max_len_src = 0 
  max_len_tgt = 0


  for item in ds_raw: 
    src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
    tgt_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
    max_len_src = max(max_len_src, len(src_ids))
    max_len_tgt = max(max_len_tgt, len(tgt_ids))

  print('max len of source sent: {max_len_src}')
  print('max len of target sent: {max_len_tgt}')

  train_dataloader = DataLoader(train_ds, batch_size = config['batch_size'], shuffle = True)
  val_dataloader = DataLoader(val_ds, batch_size=1, shuffle = True)

  return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt


In [27]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])
    return model

In [29]:
from torch.utils.tensorboard import SummaryWriter

In [30]:
from config import get_weights_file_path, get_config


In [None]:
from tqdm import tqdm

In [None]:
def train_model(config):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 
    print(f'using device {device}')

    Path(config['nodel_folder']).mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    writer = Summarywriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(),lr=config['lr'], eps=9)

    initial_epoch = 0
    global_step = 0
    if config['preload']:
        model_filename = get_weights_file_path(config, config['preload'])
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']


    loss_ffn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id('[PAD]'), label_smoothing = 0.1).to(device)
    for epoch in range(initial_epoch, config['num_epochs']):
        model.train()
        batch_iterator = tqdm(train_dataloader, desc = f'Processing epoch {epoch:02d}')
        for batch in batch_iterator:

            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['dncoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_input = batch['decoder_mask'].to(device)



            encoder_output = model.encode(encoder_input,encoder_mask)