<a href="https://colab.research.google.com/github/KaustubhSathe/ai-papers/blob/master/nlp/attention-is-all-you-need/implementation/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [89]:
!pip install datasets



In [90]:
import torch
import math
import torch.nn as nn

In [91]:
class InputEmbeddings(nn.Module):
    # d_model: the dimension of the embeddings: 512
    # vocab_size: the size of the vocabulary
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

In [92]:
class PositionalEncoding(nn.Module):
    # d_model: the dimension of the embeddings: 512
    # seq_len: the length of the sentence: 1024
    # dropout: the dropout rate: 0.1
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        for pos in range(seq_len):
            for i in range(d_model):
                if i % 2 == 0:
                    # use log to avoid overflow
                    pe[pos][i] = math.sin(pos * (-1) * math.log(10000) * (i / d_model))
                else:
                    pe[pos][i] = math.cos(pos * (-1) * math.log(10000) * ((i-1)/ d_model))

        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.size(1), :]).requires_grad_(False)
        return self.dropout(x)

In [93]:
class LayerNormalization(nn.Module):
    def __init__(self, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1)) # multiply by alpha
        self.bias = nn.Parameter(torch.zeros(1)) # add bias

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias


In [94]:
class FeedForwardBlock(nn.Module):
    # d_model: the dimension of the embeddings: 512
    # d_ff: the dimension of the feed forward network: 2048
    # dropout: the dropout rate: 0.1
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # W1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # W2 and b2

    def forward(self, x):
        # x is of shape (batch_size, seq_len, d_model)
        return self.linear_2(
            self.dropout(
                torch.relu(
                    self.linear_1(x)
                )
            )
        )

In [95]:
class MultiHeadAttentionBlock(nn.Module):
    # d_model: the dimension of the embeddings: 512
    # num_heads: the number of heads: 8
    # dropout: the dropout rate: 0.1
    def __init__(self, d_model: int, num_heads: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.w_q = nn.Linear(d_model, d_model) # Wq
        self.w_k = nn.Linear(d_model, d_model) # Wk
        self.w_v = nn.Linear(d_model, d_model) # Wv
        self.w_o = nn.Linear(d_model, d_model) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask=None, dropout: nn.Dropout=None):
        head_dim = query.shape[-1]
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(head_dim)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
        attention_scores = torch.softmax(attention_scores, dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        output = attention_scores @ value
        return output, attention_scores


    def forward(self, q, k, v, mask=None):
        query = self.w_q(q) # (batch_size, seq_len, d_model) ---> (batch_size, seq_len, d_model)
        key = self.w_k(k) # (batch_size, seq_len, d_model) ---> (batch_size, seq_len, d_model)
        value = self.w_v(v) # (batch_size, seq_len, d_model) ---> (batch_size, seq_len, d_model)

        # query = (batch_size, seq_len, d_model) --> (batch_size, seq_len, num_heads, head_dim) --> (batch_size, num_heads, seq_len, head_dim)
        query = query.view(query.shape[0],
                           query.shape[1],
                           self.num_heads,
                           self.head_dim
                        ).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
        key = key.view(key.shape[0],
                       key.shape[1],
                       self.num_heads,
                       self.head_dim
                        ).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
        value = value.view(value.shape[0],
                           value.shape[1],
                           self.num_heads,
                           self.head_dim
                        ).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)

        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

        x = x.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, d_model)
        x = x.contiguous().view(x.shape[0], -1, self.d_model) # (batch_size, seq_len, d_model)

        return self.w_o(x)

In [96]:

class ResidualConnection(nn.Module):
    def __init__(self, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [97]:
class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock,
                 feed_forward_block: FeedForwardBlock,
                 dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connection = nn.ModuleList([
            ResidualConnection(dropout) for _ in range(2)
        ])

    def forward(self, x, src_mask=None):
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connection[1](x, self.feed_forward_block)
        return x

In [98]:

class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()


    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [99]:
class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock,
                 cross_attention_block: MultiHeadAttentionBlock,
                 feed_forward_block: FeedForwardBlock,
                 dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connection = nn.ModuleList([
            ResidualConnection(dropout) for _ in range(3)
        ])

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connection[1](x, lambda x: self.cross_attention_block(x,
                                                                                encoder_output,
                                                                                encoder_output,
                                                                                src_mask))
        x = self.residual_connection[2](x, self.feed_forward_block)
        return x

In [100]:

class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

In [101]:
class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim=-1)

In [102]:
class Transformer(nn.Module):
    def __init__(self,
                 encoder: Encoder,
                 decoder: Decoder,
                 src_embed: InputEmbeddings,
                 tgt_embed: InputEmbeddings,
                 src_pos_embed: PositionalEncoding,
                 tgt_pos_embed: PositionalEncoding,
                 projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos_embed = src_pos_embed
        self.tgt_pos_embed = tgt_pos_embed
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos_embed(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos_embed(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x):
        return self.projection_layer(x)

In [103]:
def build_transformer(src_vocab_size: int,
                      tgt_vocab_size: int,
                      src_seq_len: int,
                      tgt_seq_len: int,
                      d_model: int,
                      num_layers: int = 6,
                      num_heads: int = 8,
                      d_ff: int = 2048,
                      dropout: float = 0.1) -> Transformer:

    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create the positional encoding
    src_pos_embed = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos_embed = PositionalEncoding(d_model, tgt_seq_len, dropout)

    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(num_layers):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, num_heads, dropout)
        encoder_feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_blocks.append(EncoderBlock(encoder_self_attention_block, encoder_feed_forward_block, dropout))

    encoder = Encoder(nn.ModuleList(encoder_blocks))

    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(num_layers):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, num_heads, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, num_heads, dropout)
        decoder_feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_blocks.append(DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, decoder_feed_forward_block, dropout))

    decoder = Decoder(nn.ModuleList(decoder_blocks))

    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

    transformer = Transformer(encoder,
                              decoder,
                              src_embed,
                              tgt_embed,
                              src_pos_embed,
                              tgt_pos_embed,
                              projection_layer
                            )


    # intialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

In [104]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset


def causal_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1).type(torch.int64)
    return mask == 0


class BilingualDataset(Dataset):
    def __init__(self, ds,
                 tokenizer_src,
                 tokenizer_tgt,
                 src_lang,
                 tgt_lang,
                 seq_len):
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len


        self.sos_token = torch.tensor([tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_src.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        encoder_input = torch.cat([self.sos_token,
                           torch.tensor(enc_input_tokens, dtype=torch.int64),
                           self.eos_token,
                           torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)]
            )


        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
            ]
        )

        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
            ]
        )


        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),
            "label": label,
            "src_text": src_text,
            "tgt_text": tgt_text
        }

In [105]:
import torch
import torch.nn as nn
from pathlib import Path
from torch.utils.data import DataLoader

from datasets import load_dataset, Dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace



def get_or_build_tokenizer(config, dataset_name: str, language: str):
    # config['tokenizer_file'] is a format string that will be used to create the tokenizer file path
    # eg. tokenizer_{lang}.json
    tokenizer_file_path = Path(config['tokenizer_file'].format(lang=language))
    if not Path.exists(tokenizer_file_path):
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=[
            "[UNK]",
            "[SOS]",
            "[EOS]",
            "[PAD]"
        ], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(dataset_name, language), trainer=trainer)
        tokenizer.save(str(tokenizer_file_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_file_path))
    return tokenizer


def get_all_sentences(dataset: Dataset, language: str):
    for item in dataset:
        yield item['translation'][language]


def get_ds(config: dict):
    ds_raw = load_dataset('opus_books',
                         f'{config["lang_src"]}-{config["lang_tgt"]}',
                         split='train'
                         )

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = torch.utils.data.random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw,
                                tokenizer_src,
                                tokenizer_tgt,
                                config['lang_src'],
                                config['lang_tgt'],
                                config['seq_len']
                            )

    val_ds = BilingualDataset(val_ds_raw,
                              tokenizer_src,
                              tokenizer_tgt,
                              config['lang_src'],
                              config['lang_tgt'],
                              config['seq_len']
                              )

    max_len_src = 0
    max_len_tgt = 0

    for item in train_ds:
        src_ids = tokenizer_src.encode(item['src_text']).ids
        tgt_ids = tokenizer_tgt.encode(item['tgt_text']).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f"Max length of source sentence: {max_len_src}")
    print(f"Max length of target sentence: {max_len_tgt}")


    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

In [106]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len,
                              vocab_tgt_len,
                              config['seq_len'],
                              config['seq_len'],
                              config['d_model']
                            )
    return model

In [107]:
def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 10,
        "lr": 1e-4,
        "seq_len": 350,
        "d_model": 512,
        "lang_src": "en",
        "lang_tgt": "it",
        "model_folder": "models",
        "model_basename": "tmodel_",
        "preload": None,
        "tokenizer_file": "tokenizer_{lang}.json",
        "experiment_name": "runs/tmodel"
    }


def get_weights_file_path(config, epoch: str):
    model_folder = config['model_folder']
    model_basename = config['model_basename']
    model_filename = f"{model_basename}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)




In [108]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def train_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

    # load datasets
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)

    # load model
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

    # tensorboard writer
    writer = SummaryWriter(config['experiment_name'])

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    initial_epoch = 0
    global_step = 0

    if config['preload']:
        model_filename = get_weights_file_path(config, config['preload'])
        print(f"Preloading model from {model_filename}")
        state_dict = torch.load(model_filename)
        initial_epoch = state_dict['epoch'] + 1
        optimizer.load_state_dict(state_dict['optimizer_state_dict'])
        global_step = state_dict['global_step']


    loss_fn = nn.CrossEntropyLoss(
        ignore_index=tokenizer_src.token_to_id("[PAD]"),
        label_smoothing=0.1).to(device)

    # training loop
    for epoch in range(initial_epoch, config['num_epochs']):
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing epoch {epoch:02d}", leave=False)
        for batch in batch_iterator:
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)

            # forward pass
            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            # calculate loss
            label = batch['label'].to(device)
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item(): 6.3f}"})

            # log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # backpropagate the loss
            loss.backward()
            optimizer.step()
            global_step += 1


    # save the model
    model_filename = get_weights_file_path(config, f"{epoch:02d}")
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "global_step": global_step
    }, model_filename)

In [None]:
config = get_config()
train_model(config)

Using device: cuda
