In [None]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, get_cosine_schedule_with_warmup

### Tokenizer

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-base")

print("Special Tokens in Roberta Tokenizer")
special_tokens = tokenizer.special_tokens_map
print(special_tokens)


special_token_idx = {}
for token in special_tokens.values():
    special_token_idx[token] = tokenizer.encode(token,add_special_tokens = False)[0]

print(f"Special Token Index: {special_token_idx}")

all_tokens_idx = list(range(tokenizer.vocab_size))
all_special_tokens_idx = sorted(list(special_token_idx.values()))
all_non_special_tokens_idx = [token for token in all_tokens_idx if token not in all_special_tokens_idx]

### Prepare Dataset

In [None]:
path_to_data = "./data/harry_potter_txt/"

text_files = os.listdit(path_to_data)

all_text = ""
for book in text_files:
    with open(os.path.join(path_to_data,book),"r") as f:
        text = f.readlines()
        text = [line for line in text if "Page" not in line]
        text = " ".join(text).replace("\n", "")
        text = [word for word in text.split(" ") if len(word)>0]
        text = " ".join(text)
        all_text += text
all_text = all_text.split(".") 

all_text_chunked = [".".join(all_text[i:i+5]) for i in range(0,len(all_text),5)]

tokenized_text = [tokenizer.encode(text) for text in all_text_chunked]

In [None]:
class MaskedLMLoader(Dataset):
    def __init__(self, tokenized_data, max_seq_len=100, masking_ratio=0.15):
        self.data = tokenized_data
        self.mask_ratio = masking_ratio
        self.max_seq_len = max_seq_len
        
    def __len__(self):
        return len(self.data)

    def _random_mask_text(self, tokens):
        random_masking = torch.rand(*tokens.shape)

        special_tokens = torch.tensor(tokenizer.get_special_tokens_mask(tokens, already_has_special_tokens=True))
        random_masking[special_tokens==1] = 1

        random_masking = (random_masking < self.mask_ratio)

        labels = torch.full((tokens.shape), -100)
        labels[random_masking] = tokens[random_masking]

        random_selected_idx = random_masking.nonzero()

        masking_flag = torch.rand(*random_selected_idx.shape)
        masking_flag = (masking_flag<0.8)
        selected_idx_for_masking = random_selected_idx[masking_flag]

        unselected_idx_for_masking = random_selected_idx[~masking_flag]

        masking_flag = torch.rand(*unselected_idx_for_masking.shape)
        masking_flag = (masking_flag<0.5)
        selected_idx_for_random_filling = unselected_idx_for_masking[masking_flag]
        selected_idx_to_be_left_alone = unselected_idx_for_masking[~masking_flag]
        
        if len(selected_idx_for_masking) > 0:
            tokens[selected_idx_for_masking] = special_token_idx["<mask>"]
        
        if len(selected_idx_for_random_filling) > 0:
            randomly_selected_tokens = torch.tensor(random.sample(all_non_special_tokens_idx, len(selected_idx_for_random_filling)))
            tokens[selected_idx_for_random_filling] = randomly_selected_tokens
        
        
        return tokens, labels
        
    def __getitem__(self, idx):
        data = torch.tensor(self.data[idx])

        if len(data) > self.max_seq_len:
            rand_start_idx = random.choice(list(range(len(data) - self.max_seq_len)))
            end_idx = rand_start_idx + self.max_seq_len
            data = data[rand_start_idx:end_idx]
  
        masked_tokens, label = self._random_mask_text(data)

        return masked_tokens, label

mlm = MaskedLMLoader(tokenized_text)

for masked_tokens, labels in mlm:
    print(masked_tokens)
    print(labels)
    break

In [None]:
def collate_fn(batch):
    token_samples = []
    label_samples =[]

    for token, label in batch:
        token_samples.append(token)
        label_samples.append(label)

    sequence_lengths = [len(tok) for tok in token_samples]
    max_seq_len = max(sequence_lengths)

    padding_masks = []
    for idx in range(len(token_samples)):
        sample = token_samples[idx]
        seq_len = len(sample)
        diff = max_seq_len - seq_len

        if diff > 0:

            padding = torch.tensor([special_token_idx["<pad>"] for _ in range(diff)])
            sample = torch.concatenate((sample, padding))
            token_samples[idx] = sample
            
            label_padding = torch.tensor([-100 for _ in range(diff)])
            label_samples[idx] = torch.concatenate((label_samples[idx], label_padding))

            padding_mask = (sample==special_token_idx["<pad>"])
            padding_masks.append(padding_mask)

        else:
            padding_masks.append(torch.zeros(max_seq_len))

    token_samples = torch.stack(token_samples)
    label_samples = torch.stack(label_samples)
    padding_masks = torch.stack(padding_masks)

    assert token_samples.shape == label_samples.shape == padding_masks.shape
    
    batch = {"input_ids": token_samples, 
             "labels": label_samples, 
             "attention_mask": padding_masks.bool()}

    return batch
    
dataloader = DataLoader(mlm, batch_size=16, collate_fn=collate_fn)

for batch in dataloader:
    print(batch["input_ids"])
    print(batch["labels"])
    print(batch["attention_mask"])
    break

In [None]:
class SelfAttentionEncoder(nn.Module):
  def __init__(self,
               embed_dim=768,
               num_heads=12, 
               attn_p=0,
               proj_p=0):

    super(SelfAttentionEncoder, self).__init__()
    assert embed_dim % num_heads == 0
    self.num_heads = num_heads
    self.head_dim = int(embed_dim / num_heads)
    self.scale = self.head_dim ** -0.5

    self.qkv = nn.Linear(embed_dim, embed_dim*3)
    self.attn_p = attn_p
    self.attn_drop = nn.Dropout(attn_p)
    self.proj = nn.Linear(embed_dim, embed_dim)
    self.proj_drop = nn.Dropout(proj_p)

  def forward(self, x, attention_mask=None):
    batch_size, seq_len, embed_dim = x.shape
    qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
    qkv = qkv.permute(2,0,3,1,4)
    q,k,v = qkv.unbind(0)

    attn = (q @ k.transpose(-2,-1)) * self.scale

    if attention_mask is not None:
      
        attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
        attn = attn.masked_fill(attention_mask, float('-inf'))


    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)
    x = attn @ v

    x = x.transpose(1,2).reshape(batch_size, seq_len, embed_dim)
    x = self.proj(x)
    x = self.proj_drop(x)
      
    return x

rand_x = torch.randn(2,5,16)
padding = torch.tensor([[False, False, False, True, True], 
                        [False, False, False, False, False]])

a = SelfAttentionEncoder(embed_dim=16, num_heads=4)
out = a(rand_x, padding)
print(out.shape)

In [None]:
class MLP(nn.Module):
    def __init__(self, 
                 in_features,
                 hidden_features,
                 out_features,
                 act_layer=nn.GELU,
                 mlp_p=0):


        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(mlp_p)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(mlp_p)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class Block(nn.Module):
    def __init__(self, 
                 embed_dim=768, 
                 num_heads=12, 
                 mlp_ratio=4, 
                 proj_p=0., 
                 attn_p=0., 
                 mlp_p=0., 
                 act_layer=nn.GELU, 
                 norm_layer=nn.LayerNorm):

        super().__init__()
        self.norm1 = norm_layer(embed_dim, eps=1e-6)
        self.attn = SelfAttentionEncoder(embed_dim=embed_dim,
                                         num_heads=num_heads, 
                                         attn_p=attn_p,
                                         proj_p=proj_p)


        self.norm2 = norm_layer(embed_dim, eps=1e-6)
        self.mlp = MLP(in_features=embed_dim,
                       hidden_features=int(embed_dim*mlp_ratio),
                       out_features=embed_dim,
                       act_layer=act_layer,
                       mlp_p=mlp_p)

    def forward(self, x, attention_mask=None):
        x = x + self.attn(self.norm1(x), attention_mask)
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
class Roberta(nn.Module):
    def __init__(self,
                 max_seq_len = 512,
                 vocab_size = tokenizer.vocab_size,
                 embed_dim = 768,
                 depth = 12,
                 num_heads = 12,
                 mlp_ratio = 4,
                 attn_p = 0.,
                 mlp_p = 0.,
                 proj_p = 0.,
                 pos_p = 0.,
                 act_layer = nn.GELU,
                 norm_layer = nn.LayerNorm):
        super(Roberta,self).__init__()

        self.max_seq_len = max_seq_len
        self.embeddings = nn.Embedding(vocab_size,embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Embedding(max_seq_len+1,embed_dim)
        self.pos_drop = nn.Dropout(pos_p)

        self.blocks = nn.ModuleList(
            [
                Block(
                    embed_dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    proj_p=proj_p,
                    attn_p=attn_p,
                    mlp_p=mlp_p,
                    act_layer=act_layer,
                    norm_layer=norm_layer
                )
                for _ in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)
        self.head = nn.Linear(embed_dim,vocab_size)

    def forward(self,x,attention_mask):
        device = x.device

        batch_size , seq_len = x.shape

        if seq_len> self.max_seq_len:
            x = x[:,-self.max_seq_len:]
        
        avail_idx = torch.range(0,seq_len+1,dtype = torch.long,device = device)

        tok_emb = self.embeddings(x)

        cls_token = self.cls_token.expand(batch_size,-1,-1)
        tok_emb = torch.cat((cls_token,tok_emb),dim = 1)

        pos_emb = self.pos_embed(avail_idx)
        x = tok_emb + pos_emb
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)

        cls_token_final = x[:,0]
        x = x[:,1:]
        x = self.head(x)

        return x
    
    rand_x = torch.randint(0,10,(2,5))
    padding = torch.tensor([[False,False,False,True,True],
                            [False,False,False,False,False]])
   

In [None]:
model = Roberta()
out = model(rand_x,padding)
print(out.shape)

In [None]:
iterations = 15000
max_len = 100
evaluate_interval = 100
embedding_dim = 384
depth = 4
num_heads = 4
lr = 0.0005
batch_size = 64

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = Roberta(max_seq_len=max_len, 
                embed_dim=embedding_dim, 
                depth=depth, 
                num_heads=num_heads, 
                attn_p=0.1, 
                mlp_p=0.1, 
                proj_p=0.1, 
                pos_p=0.1)

model = model.to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

dataset = MaskedLMLoader(tokenized_text, max_seq_len=max_len)
trainset, testset = torch.utils.data.random_split(dataset, [int(0.95*len(dataset)),int(len(dataset) - int(0.95*len(dataset)))])
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, 
                                            num_warmup_steps=1500, 
                                            num_training_steps=iterations)


train = True
completed_steps = 0
all_training_losses, all_validation_losses = [], []
training_losses, validation_losses = [], []
progress_bar = tqdm(range(iterations))

while train:

    for batch in trainloader:

        inputs, labels, mask = batch["input_ids"], batch["labels"], batch["attention_mask"]
        inputs, labels, mask = inputs.to(DEVICE), labels.to(DEVICE), mask.to(DEVICE)

        prediction = model(inputs, mask)

        prediction = prediction.reshape(-1, prediction.shape[-1]) 
        labels = labels.reshape(-1) 
        loss = loss_fn(prediction, labels)
        training_losses.append(loss.item())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        completed_steps += 1
        progress_bar.update(1)
            
        if completed_steps % evaluate_interval == 0:
            with torch.no_grad():
                for batch in testloader:
                    inputs, labels, mask = batch["input_ids"], batch["labels"], batch["attention_mask"]
                    inputs, labels, mask = inputs.to(DEVICE), labels.to(DEVICE), mask.to(DEVICE)
                    prediction = model(inputs, mask)
                    prediction = prediction.reshape(-1, prediction.shape[-1]) 
                    labels = labels.reshape(-1) 
                    loss = loss_fn(prediction, labels)
                    validation_losses.append(loss.item())

            avg_training_loss = np.mean(training_losses)
            avg_eval_loss = np.mean(validation_losses)
            
            print(f"Iteration {completed_steps} Training Loss:", avg_training_loss, "Validation Loss:", avg_eval_loss)

            all_training_losses.append(avg_training_loss)
            all_validation_losses.append(avg_eval_loss)

            training_losses = []
            validation_losses = []

        if completed_steps >= iterations:
            train = False
            break

In [None]:
plt.plot(all_training_losses)
plt.plot(all_validation_losses)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

In [None]:
def fill_in_the_gap(masked_sentence, k=3):
    model.eval()
    encoded_text = tokenizer.encode(masked_sentence)
    input_tokens = torch.tensor(encoded_text).to(DEVICE)

    masked_token_index = (input_tokens == special_token_idx["<mask>"]).nonzero().item()

    with torch.no_grad():
        output = model(input_tokens.unsqueeze(0), attention_mask=None)

    predicted_output = output[0, masked_token_index]

    top_k_predicted = torch.topk(predicted_output, k=k).indices

    predicted_words = tokenizer.decode(top_k_predicted.tolist()).strip().split(" ")

    return predicted_words

In [None]:
sentence = "The Wizarding <mask> is a wonderful place to visit"
fill_in_the_gap(masked_sentence=sentence)

In [None]:
sentence = "Lord Voldemort is a <mask> wizard"
fill_in_the_gap(masked_sentence=sentence)