In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import random

#replace with pytorch lightning
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.preprocessing import MinMaxScaler    
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import string

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device(device)
print(device)

In [None]:
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    """
    Mask the upper half of the dot product matrix in self attention.
    This prevents flow of information from future tokens to current token.
    1's in the lower triangle, counting from the lower right corner.
    """
    i=torch.range(1,n_dest)[:,None]
    j=torch.range(1,n_src)
    m = i >= j - n_src + n_dest
    mask=m.bool()
    return ~mask
#     mask=torch.reshape(mask, [1, n_dest, n_src])
#     mult=[batch_size,1,1]
#     return torch.tile(mask,mult);

In [None]:
causal_attention_mask(2,5,5,torch.bool)

In [None]:
def padding_mask(input):
    # Create mask which marks the zero padding values in the input by a 1
#     print(input)
#     input=torch.tensor(input['train']['input_ids'])
    mask=torch.eq(input, torch.zeros_like(input))

 
    return mask

In [None]:
padding_mask(torch.tensor([[1,2,3,0,0,0],[2,0,0,0,0,0]]))

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads,batch_first,rate=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim,num_heads,batch_first=batch_first)

    def forward(self, inputs,pad_mask):
        input_shape = inputs.size()
        batch_size = input_shape[0]
        seq_len = input_shape[1]
#         pad_mask=padding_mask(inputs)
        pad_mask.to(device)
        causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, torch.bool).to(device)
        attention_output,a = self.attention(inputs, inputs,inputs, key_padding_mask=pad_mask, attn_mask=causal_mask,need_weights=True,average_attn_weights=False)
        return inputs+attention_output,a

In [None]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim,max_norm=1)
        self.pos_emb = nn.Embedding(maxlen,embed_dim)

    def forward(self, x):
        maxlen = x.size()[-1]
        pad_mask=padding_mask(x)
        positions = torch.range(start=0, end=maxlen-1, step=1,dtype=torch.int32).to(device)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return (x + positions),pad_mask

In [None]:
vocab_size =50257 #28996  # Only consider the top 20k words
maxlen = 60  # Max sequence size
embed_dim = 128  # Embedding size for each token
num_heads = 8  # Number of attention heads
feed_forward_dim = 128  # Hidden layer size in feed forward network inside transformer
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
        self.transformer_block1 = TransformerBlock(embed_dim, num_heads,True)
        self.transformer_block2 = TransformerBlock(embed_dim, num_heads,True)
        self.MLP1=nn.LazyLinear(feed_forward_dim)
        self.MLP2=nn.LazyLinear(feed_forward_dim)
        self.outputs= nn.LazyLinear(vocab_size)
        
    def forward(self, x):
        x,pad_mask = self.embedding_layer(x)
#         print(x,"maximum=",torch.max(x))
        x,a = self.transformer_block1(x,pad_mask)
        x=x+self.MLP1(x);
        x,a = self.transformer_block2(x,pad_mask)
#         print(x)
        x=x+self.MLP2(x);
        x = self.outputs(x)
        
        return x,a


In [None]:
# from datasets import load_dataset

# dataset = load_dataset("wiki_bio")

In [None]:
import pandas as pd
import os
import re
directories = [
    "/kaggle/input/aclimdb-v1/aclImdb/train/pos",
    "/kaggle/input/aclimdb-v1/aclImdb/train/neg",
    "/kaggle/input/aclimdb-v1/aclImdb/test/pos",
    "/kaggle/input/aclimdb-v1/aclImdb/test/neg",
]

from datasets import load_dataset
filenames = []
for dir in directories:
    for f in os.listdir(dir):
        filenames.append(os.path.join(dir, f))

dataset = load_dataset("text", data_files=filenames)

def processing(s):
  s['text']=s['text'].lower()
  s['text']=re.sub("<br />", " ", s['text'])
  s['text']=re.sub(f"([{string.punctuation}])", r" \1", s['text'])
  return s

dataset=dataset.map(processing)



In [None]:
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [None]:
dataset = dataset.map(lambda dataset: tokenizer(dataset["text"],truncation=True, max_length=maxlen))

In [None]:
def padding(s):
    if len(s['input_ids'])<maxlen :
        s['input_ids']=s['input_ids']+[0]*(maxlen-len(s['input_ids']))
    return s
                                           
dataset=dataset.map(padding)


In [None]:
dataset

In [None]:
model=Model()
# model.load_state_dict(torch.load("/kaggle/input/weights3/transformer_weights-3.pth")) 
model.to(device)

<!-- # model= Model()
# model.to(device) -->

In [None]:
class TextGenerator(nn.Module):
    """A callback to generate text from a trained model.
    1. Feed some starting prompt to the model
    2. Predict probabilities for the next token
    3. Sample the next token and add it to the next input

    Arguments:
        max_tokens: Integer, the number of tokens to be generated after prompt.
        start_tokens: List of integers, the token indices for the starting prompt.
        index_to_word: List of strings, obtained from the TextVectorization layer.
        top_k: Integer, sample from the `top_k` token predictions.
        print_every: Integer, print after this many epochs.
    """

    def __init__(
        self, max_tokens, start_tokens, top_k=10, print_every=1
    ):
        self.max_tokens = max_tokens
        self.start_tokens = start_tokens
#         self.index_to_word = index_to_word
        self.print_every = print_every
        self.k = top_k

    def sample_from(self, logits):
        logits, indices = torch.topk(logits, k=self.k, sorted=True)
        logits=logits.cpu()
        indices=indices.cpu()
        indices = np.asarray(indices).astype("int32")
       
        softmax=nn.Softmax(dim=0)
        preds = softmax(logits)
        preds = np.asarray(preds).astype("float32")
#         return np.random.choice(indices, p=preds) THIS IS THE CORRECT CODE, BUT HAD TO COMMENT IT AS
#.        PROBABILITIES HAVE NAN AND I HAD TO VERIFY PIPELINE, BELOW LINE WILL BE REMOVED ONCE NAN ISSUE 
#.        IS RESOLVED
        return np.random.choice(indices, p=preds)
#         return np.random.choice(5, 1, p=[0.1, 0, 0.3, 0.6, 0])

    def detokenize(self, number):
        return tokenizer.decode(number)

    def on_epoch_end(self, epoch, logs=None):
        start_tokens = [_ for _ in self.start_tokens]
        if (epoch + 1) % self.print_every != 0:
            return
        num_tokens_generated = 0
        tokens_generated = []
        attention_scores=[]
        while num_tokens_generated <= self.max_tokens:
            pad_len = maxlen - len(start_tokens)
            sample_index = len(start_tokens) - 1
            if pad_len < 0:
                data = start_tokens[:maxlen]
                sample_index = maxlen - 1
            elif pad_len > 0:
                data = start_tokens + [0] * pad_len
            else:
                data = start_tokens
                
            data = torch.Tensor(np.array([data])).type(torch.int32).to(device)
            
            y,attention_scores = model(data)
            sample_token = self.sample_from(y[0][sample_index])
            tokens_generated.append(sample_token)
            start_tokens.append(sample_token)
            num_tokens_generated = len(tokens_generated)
        txt = " ".join(
            [self.detokenize(_) for _ in self.start_tokens + tokens_generated]
        )
        print(f"generated text:\n{txt}\n")
        return attention_scores, txt

# Tokenize starting prompt
# word_to_index = {}
# for index, word in enumerate(vocab):
#     word_to_index[word] = index
start_prompt = "Mr and Mrs Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense. Mr"

# start_prompt = "Brilliant over-acting by Lesley Ann Warren. Best dramatic hobo lady I have ever seen, and love scenes in clothes warehouse are second to none. The dramatic hobo"
start_tokens=tokenizer(start_prompt)['input_ids']
# start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]
num_tokens_generated = 40
# text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)

In [None]:
from datasets import Dataset
train_dataset= Dataset.from_dict({"id": dataset['train']['input_ids']})
train_dataset = train_dataset.with_format("torch")

In [None]:
count=0
TEST=[]
train_loader=DataLoader(train_dataset,batch_size=50,shuffle=True)
for i in train_loader:
    TEST.append(random.choice(i['id']))
    if count>100:
        break
    count+=1

In [None]:
TEST=torch.stack(TEST,0) 
# print(TEST[:,:])

In [None]:
train_loader=DataLoader(train_dataset,batch_size=50,shuffle=True)
optim=torch.optim.AdamW(model.parameters(),lr=1e-4)
loss_fn=torch.nn.CrossEntropyLoss()
count=0
loss_stats = {
    'test': [],
}
for epoch in tqdm(range(20)):
    for batch in tqdm(train_loader):
        optim.zero_grad()
#         print(batch['id'][:,:-1])
        input_ids=batch['id'][:,:-1].to(device)
#         print(input_ids.shape())
        labels=batch['id'][:,1:].to(device)
        outputs,attention_scores=model.forward(input_ids)
        labels=nn.functional.one_hot(labels,num_classes=vocab_size).type(torch.float)
        loss=loss_fn(outputs,labels)
        loss.backward()
        optim.step()
    
    with torch.no_grad():
        TextGenerator(40, start_tokens).on_epoch_end(epoch);
        test_input=TEST[:,:-1].to(device)
        test_output=TEST[:,1:].to(device)
        outputs,attention_scores=model.forward(test_input)
        labels=nn.functional.one_hot(test_output,num_classes=vocab_size).type(torch.float)
        loss=loss_fn(outputs,labels).cpu().item()
        loss_stats['test'].append(loss)
        

In [None]:
test_loss_df = pd.DataFrame.from_dict(loss_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
# Plot the dataframes
fig,axes = plt.subplots(nrows=1, ncols=1, figsize=(20,7))
sns.lineplot(data=test_loss_df, x = "epochs", y="value", hue="variable",  ax=axes).set_title('TestLoss')


In [None]:
!mkdir /kaggle/working/saved_models/
torch.save(model.state_dict(),"/kaggle/working/saved_models/transformer_weights.pth")

In [None]:
with torch.no_grad():
    score, txt=TextGenerator(6, start_tokens).on_epoch_end(1);
score=score.cpu().numpy()

In [None]:
print(score.shape)

In [None]:
len(start_tokens)
len(start_prompt.split())

In [None]:
txt.split()[32]

In [None]:
dfx = pd.DataFrame(list(np.arange(maxlen)), columns =['Keys'] );
dfy2=pd.DataFrame(score[0][0],columns=list(np.arange(maxlen)));
plt.figure(figsize=(100.0,100.0));
plt.title("Attention scores");
plt.xlabel('Keys',size=maxlen);
plt.ylabel('Queries',size=maxlen);
plt.plot();
sns.heatmap(dfy2,fmt=".3f",annot=True,linewidths=2,square=True,cmap='twilight');

In [None]:
dfx = pd.DataFrame(list(np.arange(maxlen)), columns =['Keys'] );
dfy2=pd.DataFrame(score[0][1],columns=list(np.arange(maxlen)));
plt.figure(figsize=(100.0,100.0));
plt.title("Attention scores");
plt.xlabel('Keys',size=maxlen);
plt.ylabel('Queries',size=maxlen);
plt.plot();
sns.heatmap(dfy2,fmt=".3f",annot=True,linewidths=2,square=True,cmap='twilight');

In [None]:
dfx = pd.DataFrame(list(np.arange(maxlen)), columns =['Keys'] );
dfy2=pd.DataFrame(score[0][2],columns=list(np.arange(maxlen)));
plt.figure(figsize=(100.0,100.0));
plt.title("Attention scores");
plt.xlabel('Keys',size=maxlen);
plt.ylabel('Queries',size=maxlen);
plt.plot();
sns.heatmap(dfy2,fmt=".3f",annot=True,linewidths=2,square=True,cmap='twilight');

In [None]:
dfx = pd.DataFrame(list(np.arange(maxlen)), columns =['Keys'] );
dfy2=pd.DataFrame(score[0][3],columns=list(np.arange(maxlen)));
plt.figure(figsize=(100.0,100.0));
plt.title("Attention scores");
plt.xlabel('Keys',size=maxlen);
plt.ylabel('Queries',size=maxlen);
plt.plot();
sns.heatmap(dfy2,fmt=".3f",annot=True,linewidths=2,square=True,cmap='twilight');

In [None]:
dfx = pd.DataFrame(list(np.arange(maxlen)), columns =['Keys'] );
dfy2=pd.DataFrame(score[0][4],columns=list(np.arange(maxlen)));
plt.figure(figsize=(100.0,100.0));
plt.title("Attention scores");
plt.xlabel('Keys',size=maxlen);
plt.ylabel('Queries',size=maxlen);
plt.plot();
sns.heatmap(dfy2,fmt=".3f",annot=True,linewidths=2,square=True,cmap='twilight');

In [None]:
dfx = pd.DataFrame(list(np.arange(maxlen)), columns =['Keys'] );
dfy2=pd.DataFrame(score[0][5],columns=list(np.arange(maxlen)));
plt.figure(figsize=(100.0,100.0));
plt.title("Attention scores");
plt.xlabel('Keys',size=maxlen);
plt.ylabel('Queries',size=maxlen);
plt.plot();
sns.heatmap(dfy2,fmt=".3f",annot=True,linewidths=2,square=True,cmap='twilight');

In [None]:
dfx = pd.DataFrame(list(np.arange(maxlen)), columns =['Keys'] );
dfy2=pd.DataFrame(score[0][6],columns=list(np.arange(maxlen)));
plt.figure(figsize=(100.0,100.0));
plt.title("Attention scores");
plt.xlabel('Keys',size=maxlen);
plt.ylabel('Queries',size=maxlen);
plt.plot();
sns.heatmap(dfy2,fmt=".3f",annot=True,linewidths=2,square=True,cmap='twilight');

In [None]:
dfx = pd.DataFrame(list(np.arange(maxlen)), columns =['Keys'] );
dfy2=pd.DataFrame(score[0][7],columns=list(np.arange(maxlen)));
plt.figure(figsize=(100.0,100.0));
plt.title("Attention scores");
plt.xlabel('Keys',size=maxlen);
plt.ylabel('Queries',size=maxlen);
plt.plot();
sns.heatmap(dfy2,fmt=".3f",annot=True,linewidths=2,square=True,cmap='twilight');