In [None]:
!pip install datasets

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import psutil
import gc
from model import build_transformer
from dataset import BilingualDataset, causal_mask
from config import get_config, get_weights_file_path

# Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

import warnings
from tqdm import tqdm
import os
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

def print_memory_usage():
    process = psutil.Process()
    print(f"RAM Usage: {process.memory_info().rss / 1024 / 1024:.2f} MB")

def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every token we get from the decoder
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # Build mask for the target (decoder_input)
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # Calculate the output of the decoder
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # Get the next token
        prob = model.project(out[:, -1])

        # Select the token with the maximum probability (because it is a greedy search)
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1)

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    # Size of the control window (just use a default value)
    console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)

            # check that the batch size is 1
            assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch['src_text'][0]
            target_text = batch['tgt_text'][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)

            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-'*console_width)
                break

def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))

    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def analyze_sequence_lengths(ds_raw, tokenizer_src, tokenizer_tgt, config):
    """
    Analyzes and prints statistics about sequence lengths in the dataset
    Added to help with sequence length debugging and monitoring
    """
    src_lengths = []
    tgt_lengths = []
    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        src_lengths.append(len(src_ids))
        tgt_lengths.append(len(tgt_ids))

    print("\nSequence Length Analysis:")
    print(f"Source sequences:")
    print(f"  Max length: {max(src_lengths)}")
    print(f"  95th percentile: {sorted(src_lengths)[int(len(src_lengths)*0.95)]}")
    print(f"  Mean length: {sum(src_lengths)/len(src_lengths):.1f}")

    print(f"\nTarget sequences:")
    print(f"  Max length: {max(tgt_lengths)}")
    print(f"  95th percentile: {sorted(tgt_lengths)[int(len(tgt_lengths)*0.95)]}")
    print(f"  Mean length: {sum(tgt_lengths)/len(tgt_lengths):.1f}")

def get_ds(config):
    # It only has the train split, so we divide it ourselves
    ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')

    # Build tokenizer
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config["lang_src"])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config["lang_tgt"])

    # Analyze sequence lengths before filtering
    print("\nAnalyzing sequence lengths before filtering...")
    analyze_sequence_lengths(ds_raw, tokenizer_src, tokenizer_tgt, config)

    # Filter long sequences following paper guidelines
    total_samples = len(ds_raw)

    def filter_long_sequences(item):
        src_text = item['translation'][config['lang_src']]
        tgt_text = item['translation'][config['lang_tgt']]

        # Count words as per paper approach
        src_words = len(src_text.split())
        tgt_words = len(tgt_text.split())

        # Paper guideline: limit to 100 total words
        return src_words + tgt_words <= 100

    ds_raw = ds_raw.filter(filter_long_sequences)
    filtered_samples = total_samples - len(ds_raw)

    print(f"\nFiltered {filtered_samples} samples ({filtered_samples/total_samples*100:.2f}%) due to length")

    # Analyze sequence lengths after filtering
    print("\nAnalyzing sequence lengths after filtering...")
    analyze_sequence_lengths(ds_raw, tokenizer_src, tokenizer_tgt, config)

    # keep 90 % for training and 10 % for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    # Calculating the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids

        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')

    # Sort datasets by length for more efficient batching
    def get_length(item):
        return len(tokenizer_src.encode(item['translation'][config['lang_src']]).ids)

    train_ds_raw = sorted(train_ds_raw, key=get_length)
    val_ds_raw = sorted(val_ds_raw, key=get_length)

    # Create batches of similar lengths
    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=False)  # shuffle=False because we sorted
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], d_model=config['d_model'])
    return model

def train_model(config):
    # Define the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device {device}')

    # Make sure the weights folder exists
    model_folder = Path(config['model_folder'])
    model_folder.mkdir(parents=True, exist_ok=True)

    print('Loading Dataset....')
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    print(f"Train dataset size: {len(train_dataloader.dataset)}")
    print(f"Validation dataset size: {len(val_dataloader.dataset)}")

    print("Building model....")
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    print("Model built Successfully")

    # Tensorboard
    writer = SummaryWriter(config["experiment_name"])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    initial_epoch = 0
    global_step = 0
    if config['preload']:
        model_filename = get_weights_file_path(config, config['preload'])
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state["optimizer_state_dict"])
        global_step = state['global_step']

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    validation_frequency = 1000  # Run validation every 1000 steps

    print("Starting training...")
    for epoch in range(initial_epoch, config['num_epochs']):
        batch_iterator = tqdm(train_dataloader, desc=f'Processing epoch {epoch:02d}')
        for batch_idx, batch in enumerate(batch_iterator):
            try:
                if batch_idx % 100 == 0:
                    print_memory_usage()
                    gc.collect()

                model.train()

                encoder_input = batch['encoder_input'].to(device)
                decoder_input = batch['decoder_input'].to(device)
                encoder_mask = batch['encoder_mask'].to(device)
                decoder_mask = batch['decoder_mask'].to(device)

                encoder_output = model.encode(encoder_input, encoder_mask)
                decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
                proj_output = model.project(decoder_output)

                label = batch['label'].to(device)

                loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
                batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

                # Log the loss
                writer.add_scalar('train loss', loss.item(), global_step)
                writer.flush()

                # Backpropagate the loss
                loss.backward()

                # Update the weights
                optimizer.step()
                optimizer.zero_grad()

                global_step += 1

                if global_step % validation_frequency == 0:
                    run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'],
                                device, lambda msg: batch_iterator.write(msg), global_step, writer)

            except Exception as e:
                print(f"Error in batch {batch_idx}: {str(e)}")
                continue

        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f'{epoch:02d}')
        print(f"Saving model to {model_filename}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

def get_weights_file_path(config, epoch: str):
    model_folder = config['model_folder']
    model_basename = config['model_basename']
    model_filename = f'{model_basename}{epoch}.pt'
    return str(Path('.') / model_folder / model_filename)

if __name__ == "__main__":
    warnings.filterwarnings('ignore')
    config = get_config()
    train_model(config)