In [1]:
def get_config():
    return {
        "logs": "/content/drive/MyDrive/Colab Notebooks/T-MLM/T-MLM-logs",
        "batch_size": 4,
        "num_epochs": 30,
        "lr": 3e-5,
        "seq_len": 512,
        "d_model": 768,
        "n_layers": 12,
        "head": 12,
        "d_ff": 3072,
        "dropout": 0.1,
        "masking_prob": 0.15,
        "model_file_path": "/content/drive/MyDrive/Colab Notebooks/T-MLM/T-MLM.pt",
        "tokenizer_file": "/content/drive/MyDrive/Colab Notebooks/T-MLM/tokenizer.json",
    }

In [2]:
from pathlib import Path
from tokenizers import Tokenizer, models, trainers, pre_tokenizers

def get_all_sentences(ds, field):
    for item in ds:
        yield item[field]

def build_or_get_tokenizer(config, ds):
    tokenizer_path = Path(config['tokenizer_file'])
    if not tokenizer_path.exists():
        tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
        trainer = trainers.BpeTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]", "[MASK]"], min_frequency=1)
        tokenizer.train_from_iterator(get_all_sentences(ds, "text"), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))

    return tokenizer


In [3]:
import torch
from torch.utils.data import Dataset
import random

class MLMDataset(Dataset):
    def __init__(self, config, ds, tokenizer):
        super().__init__()
        self.seq_len = config['seq_len']
        self.mask_probability = config.get('mask_probability', 0.15)

        self.ds = ds
        self.tokenizer = tokenizer

        self.sos_token = torch.tensor([tokenizer.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer.token_to_id("[PAD]")], dtype=torch.int64)
        self.mask_token = torch.tensor(tokenizer.token_to_id("[MASK]"), dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_text = self.ds[idx]['text']
        input_tokens = self.tokenizer.encode(src_text).ids
        num_padding_tokens = self.seq_len - len(input_tokens) - 2

        if num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        tokenized_text = torch.cat(
            [
                self.sos_token,
                torch.tensor(input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        assert tokenized_text.size(0) == self.seq_len

        original_tokens = tokenized_text.clone()
        tokenized_text = self.apply_mlm(tokenized_text)

        return {
            "tokenized_text": tokenized_text,  # (seq_len)
            "original_tokens": original_tokens,  # (seq_len)
        }

    def apply_mlm(self, tokens):
        for i in range(1, len(tokens) - 1):  # Exclude [SOS] and [EOS]
            if random.random() < self.mask_probability:
                tokens[i] = self.mask_token
        return tokens

def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0


In [4]:
import torch
import torch.nn as nn
import math

class LayerNormalization(nn.Module):

    def __init__(self, features: int, eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))

    def forward(self, x):
        mean = x.mean(dim = -1, keepdim = True)
        std = x.std(dim = -1, keepdim = True)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class InputEmbeddings(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)

class ResidualConnection(nn.Module):

        def __init__(self, features: int, dropout: float) -> None:
            super().__init__()
            self.dropout = nn.Dropout(dropout)
            self.norm = LayerNormalization(features)

        def forward(self, x, sublayer):
            return x + self.dropout(sublayer(self.norm(x)))

class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)

        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

        # Combine all the heads together
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        return self.w_o(x)

class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        return self.proj(x)

class Transformer(nn.Module):

    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x):
        return self.projection_layer(x)

def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model, N, h, dropout: float, d_ff: int) -> Transformer:
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)

    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

In [None]:
import json
import sys
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

def get_weights_file_path(config):
    model_file_path = config.get('model_file_path', '')
    return model_file_path if Path(model_file_path).exists() else None

def get_model(config, vocab_size):
    model = build_transformer(
        vocab_size, vocab_size, config["seq_len"], config["seq_len"],
        d_model=config['d_model'], N=config['n_layers'], h=config['head'],
        dropout=config['dropout'], d_ff=config['d_ff']
    )
    return model

def create_masks(tokenized_text, pad_token_id):
    # Create masks
    encoder_mask = (tokenized_text != pad_token_id).unsqueeze(1).unsqueeze(2)
    decoder_input = tokenized_text.clone()
    decoder_mask = causal_mask(decoder_input.size(1)).to(tokenized_text.device)

    return encoder_mask, decoder_input, decoder_mask


def train_model(config):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)
    if device == 'cuda':
        print(f"Device name: {torch.cuda.get_device_name(0)}")
        print(f"Device memory: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3} GB")
    device = torch.device(device)

    # Load raw dataset
    with open('/content/drive/MyDrive/Colab Notebooks/T-MLM/dataset.json', 'r', encoding='utf-8') as f:
        raw_ds = json.load(f)

    # Build or get tokenizer (BPE)
    tokenizer = build_or_get_tokenizer(config, raw_ds)

    # Masked Language Model (MLM) Dataset
    train_ds = MLMDataset(config, raw_ds, tokenizer)

    # Create data loader
    data_loader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)

    # Get model
    model = get_model(config, tokenizer.get_vocab_size()).to(device)

    # Define loss function and optimizer
    pad_token_id = tokenizer.token_to_id("[PAD]")
    criterion = CrossEntropyLoss(ignore_index=pad_token_id)
    optimizer = Adam(model.parameters(), lr=config['lr'])

    # TensorBoard
    writer = SummaryWriter(log_dir=config['logs'])

    initial_epoch = 0
    global_step = 0
    model_filename = get_weights_file_path(config)
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename, map_location=device)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    # Training loop
    model.train()
    for epoch in range(initial_epoch, config['num_epochs']):
        total_loss = 0
        num_batches = len(data_loader)
        batch_iterator = tqdm(data_loader, desc=f"Epoch {epoch + 1}/{config['num_epochs']}")
        for batch in batch_iterator:
            tokenized_text = batch['tokenized_text'].to(device)
            original_tokens = batch['original_tokens'].to(device)

            # Create masks
            encoder_mask, decoder_input, decoder_mask = create_masks(tokenized_text, pad_token_id)

            # Forward pass
            optimizer.zero_grad()
            encoder_output = model.encode(tokenized_text, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            # Compute loss
            loss = criterion(proj_output.view(-1, proj_output.size(-1)), original_tokens.view(-1))
            total_loss += loss.item()

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            global_step += 1
            batch_iterator.set_postfix({'Loss': loss.item()})


        # Log the loss
        avg_loss = total_loss / num_batches
        writer.add_scalar('train_loss', avg_loss, epoch)
        writer.flush()

        print(f"Avg Loss: {avg_loss}")

        # Save the model state
        model_save_path = config['model_file_path']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_save_path)

    writer.close()

if __name__ == '__main__':
    config = get_config()
    train_model(config)


In [None]:
%load_ext tensorboard

import tensorflow as tf
import tensorboard

log_dir = "./content/T-MLM-logs"
%tensorboard --logdir {log_dir}


In [None]:
import torch

def create_masks(tokenized_text, pad_token_id):
    # Create masks
    encoder_mask = (tokenized_text != pad_token_id).unsqueeze(1).unsqueeze(2)
    decoder_input = tokenized_text.clone()
    decoder_mask = causal_mask(decoder_input.size(1)).to(tokenized_text.device)

    return encoder_mask, decoder_input, decoder_mask

def mlm_inference(model, tokenizer, input_text, mask_token, device):
    model.eval()
    # Tokenize input text
    input_tokens = tokenizer.encode(input_text).ids
    tokenized_text = torch.cat(
            [
                self.sos_token,
                torch.tensor(input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

    # Create an attention mask
    encoder_mask, decoder_input, decoder_mask = create_masks(tokenized_text, pad_token_id)


    # Forward pass through the encoder
    with torch.no_grad():
      encoder_output = model.encode(tokenized_text, encoder_mask)
      decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
      proj_output = model.project(decoder_output)

      # Get the predicted tokens (argmax over vocab dimension)
      _, predicted_tokens = torch.max(proj_output, dim=-1)

      Predicted_tokens = tokenizer.decode(predicted_tokens[i].cpu().numpy())

    return predicted_tokens


def get_model(config, vocab_size):
    model = build_transformer(
        vocab_size, vocab_size, config["seq_len"], config["seq_len"],
        d_model=config['d_model'], N=config['n_layers'], h=config['head'],
        dropout=config['dropout'], d_ff=config['d_ff']
    )
    return model

def get_weights_file_path(config):
    model_file_path = config.get('model_file_path', '')
    return model_file_path if Path(model_file_path).exists() else None

if __name__ == '__main__':
    config = get_config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = Tokenizer.from_file(config['tokenizer_file'])
    model = get_model(config, tokenizer.get_vocab_size()).to(device)
    input_text = "The universe is [MASK] and mysterious"

    model_filename = get_weights_file_path(config)
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename, map_location=device)
        model.load_state_dict(state['model_state_dict'])
    else:
        print('No model to preload, starting from scratch')
    while input_text != "exit":
      input_text = input()
      mask_token_id = tokenizer.token_to_id('[MASK]')

      inferred_output = mlm_inference(model, tokenizer, input_text, mask_token_id, device)

      print(f"Original text: {input_text}")
      print(f"Predicted text: {inferred_output}")
