In [24]:
import re

In [25]:
with open("txt_data.txt") as f:
    text_corpus=f.read();

In [26]:
preprocessed=re.split(r'([,.?!"()\']|--|\s)',text_corpus)

In [27]:
preprocessed = [item for item in preprocessed if item.strip()]



In [28]:
allwords = sorted(list(set(preprocessed)))
allwords.extend(["<|endoftext|>","<|unk|>"])
vocabsize = len(allwords)
print(vocabsize)

6179


In [29]:
vocabulary={token:idx for idx,token in enumerate(allwords)}

In [30]:
class SimpleTokenizer:
    def __init__(self, vocab):
        self.str_to_int=vocab;
        self.int_to_str={i:s for s,i in vocab.items()}
    def encode(self,text):
        preprocessed=re.split(r'([,.?!"()\']|--|\s)',text)
        preprocessed = [item for item in preprocessed if item.strip()]
        preprocessed = [item if item in self.str_to_int.keys() else "<|unk|>" for item in preprocessed]
        ids=[self.str_to_int[x] for x in preprocessed]
        return ids;
    def decode(self,ids):

        text=" ".join(self.int_to_str[i] for i in ids)

        text = re.sub(r'\s+([,.?!"()\'])',r'\1', text)
        return text


In [31]:

tokenizer = SimpleTokenizer(vocabulary)


In [32]:
sample_text="A goal welcome in the mountainside";

In [33]:
tokenizer.encode(sample_text)

[1130, 3588, 6030, 3789, 5679, 6178]

In [34]:
tokenizer.decode(tokenizer.encode(sample_text))

'A goal welcome in the <|unk|>'

In [35]:
#use better encoder

In [36]:
from importlib.metadata import version
import tiktoken

In [37]:
print(version("tiktoken"))

0.8.0


In [38]:
# tokenizer = tiktoken.get_encoding("gpt2")

In [39]:
# enc_text=tiktoken.encode(text_corpus)


In [40]:
enc_text=tokenizer.encode(text_corpus)

In [41]:
sample_enc=enc_text[550:]

In [42]:
context_size = 4

In [43]:
x,y=sample_enc[:context_size],sample_enc[1:context_size+1]

In [44]:
print(f"x: {x}")
print(f"y:        {y}")

x: [2677, 5860, 2244, 3568]
y:        [5860, 2244, 3568, 2114]


In [45]:
for i in range(1, context_size+1):
    context = sample_enc[:i]
    desired = sample_enc[i]
    print(tokenizer.decode(context),"---->", tokenizer.decode([desired]))

complex ----> understanding
complex understanding ----> and
complex understanding and ----> generation
complex understanding and generation ----> abilities


In [46]:
# creating a dataset

In [47]:
import torch
from torch.utils.data import Dataset, DataLoader

class GPTDataset(Dataset):
    def __init__(self,txt, tokenizer, max_length,stride):
        self.tokenizer=tokenizer;
        self.input_ids=[]
        self.target_ids=[]
        token_ids=self.tokenizer.encode(txt);

        for i in range(0,len(token_ids)-max_length,stride):
            X_chunk=token_ids[i:i+max_length];
            y_chunk=token_ids[i+1:i+max_length+1];

            self.input_ids.append(torch.tensor((X_chunk)))
            self.target_ids.append(torch.tensor((y_chunk)))

            
            
    def __len__(self):
        return len(self.input_ids)
    def __getitem__(self,idx):
        return self.input_ids[idx],self.target_ids[idx];
        

In [48]:
dataset= GPTDataset(text_corpus,tokenizer,5,1)

In [49]:
import torch
from torch.utils.data import Dataset, DataLoader

In [50]:
def create_dataloader(txt,vocab,tokenizer_class,batch_size=4,max_length=256,stride=128,shuffle=True,drop_last=True,num_workers=0):
    tokenizer=tokenizer_class(vocab);
    dataset=GPTDataset(txt,tokenizer,max_length,stride)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers
    )
    return dataloader;


In [61]:
data_loader=create_dataloader(text_corpus,vocabulary,SimpleTokenizer)

In [66]:
dataloader = create_dataloader(
    text_corpus, vocabulary, SimpleTokenizer,batch_size=4, max_length=4, stride=16, shuffle=False
)



In [67]:
output_dim=256
token_embedding_layer=torch.nn.Embedding(len(vocabulary),output_dim)

In [68]:
token_embedding_layer

Embedding(6177, 256)

In [69]:
data_iter=iter(dataloader)
input,target=next(data_iter)

In [70]:
token_embd=token_embedding_layer(input)

In [71]:
context_length=4;
output_dim=256
pos_embedding_layer=torch.nn.Embedding(context_length,output_dim)

In [72]:
output_pos=pos_embedding_layer(torch.arange(context_length))

In [73]:
(token_embd+output_pos).shape

torch.Size([4, 4, 256])

In [207]:
torch.manual_seed(123)

<torch._C.Generator at 0x107db5470>

In [212]:
input= torch.rand(6,6)

In [213]:
input

tensor([[0.0772, 0.3565, 0.1479, 0.5331, 0.4066, 0.2318],
        [0.4545, 0.9737, 0.4606, 0.5159, 0.4220, 0.5786],
        [0.9455, 0.8057, 0.6775, 0.6087, 0.6179, 0.6932],
        [0.4354, 0.0353, 0.1908, 0.9268, 0.5299, 0.0950],
        [0.5789, 0.9131, 0.0275, 0.1634, 0.3009, 0.5201],
        [0.3834, 0.4451, 0.0126, 0.7341, 0.9389, 0.8056]])

In [129]:
query=input[0]

In [126]:
for x,y in enumerate(input):
    print(y@query)

tensor(2.2392)
tensor(1.4211)
tensor(0.9672)
tensor(1.6398)
tensor(1.9152)
tensor(0.8563)
tensor(1.2123)
tensor(1.3347)
tensor(1.2462)
tensor(1.2325)


In [109]:
# query=query.unsqueeze(0)

In [130]:
query.shape

torch.Size([6])

In [131]:
input.shape

torch.Size([10, 6])

In [132]:
dot=query @ input.T

In [133]:
dot

tensor([2.2392, 1.4211, 0.9672, 1.6398, 1.9152, 0.8563, 1.2123, 1.3347, 1.2462,
        1.2325])

In [134]:
norm=torch.softmax(dot,dim=0)

In [135]:
norm.sum()

tensor(1.)

In [76]:
pair_wise_attn= input @ input.T

In [97]:
norm_attn=torch.softmax(pair_wise_attn,dim=1)

In [99]:
norm_attn.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])

In [151]:
wk=nn.Parameter(torch.rand(6,3),requires_grad=True);
wq=nn.Parameter(torch.rand(6,3),requires_grad=True);
wv=nn.Parameter(torch.rand(6,3),requires_grad=True);
    

In [155]:
x=input[1]

In [158]:
x

tensor([0.8089, 0.5460, 0.8383, 0.5359, 0.1298, 0.7184])

In [162]:
wk

Parameter containing:
tensor([[0.3632, 0.1988, 0.5198],
        [0.9128, 0.9011, 0.5084],
        [0.9404, 0.5893, 0.9334],
        [0.4826, 0.9069, 0.3541],
        [0.6769, 0.4524, 0.4210],
        [0.2534, 0.0338, 0.2954]], requires_grad=True)

In [174]:
(wk[:,0]*x).sum()

tensor(2.1092, grad_fn=<SumBackward0>)

In [172]:
k=x @ wk

In [175]:
query= x@ wq
print(query)

tensor([2.1875, 2.0642, 1.4362], grad_fn=<SqueezeBackward4>)


In [176]:
keys= input @ wk
queries= input @ wq
values= input @ wv

In [178]:
keys[0]

tensor([1.7492, 1.3924, 1.3594], grad_fn=<SelectBackward0>)

In [185]:
k2=keys[1]
q2=queries[1]
attn_kq= query.dot(k2)
print(attn_kq)

tensor(10.9380, grad_fn=<DotBackward0>)


In [190]:
attn_scors= q2 @ keys.T

In [191]:
attn_scors

tensor([ 8.6529, 10.9380, 10.8418, 12.5827, 11.4919,  8.3091,  7.9694,  7.6344,
         7.8425,  5.9916], grad_fn=<SqueezeBackward4>)

In [229]:
import torch.nn as nn
class AttentionBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.context_len=6;
        self.wk=nn.Linear(6,4);
        self.wq=nn.Linear(6,4);
        self.wv=nn.Linear(6,4);
        self.dropout=torch.nn.dropout(.5)
    
    def forward(self,x):
        q= self.wq(x);
        k=self.wk(x);
        v= self.wv(x);
        attn_score= q @ k.T
        
        attn_weights= torch.softmax(attn_score/k.shape[-1]**0.5,dim=-1)
        context_vec= attn_weights @ v;
        return context_vec
        

In [230]:
torch.manual_seed(123)

<torch._C.Generator at 0x107db5470>

In [231]:
attn_block=AttentionBlock()


In [232]:
output=attn_block(input)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [236]:
batch=torch.stack([input,input],dim=0)
print(batch)

tensor([[[0.0772, 0.3565, 0.1479, 0.5331, 0.4066, 0.2318],
         [0.4545, 0.9737, 0.4606, 0.5159, 0.4220, 0.5786],
         [0.9455, 0.8057, 0.6775, 0.6087, 0.6179, 0.6932],
         [0.4354, 0.0353, 0.1908, 0.9268, 0.5299, 0.0950],
         [0.5789, 0.9131, 0.0275, 0.1634, 0.3009, 0.5201],
         [0.3834, 0.4451, 0.0126, 0.7341, 0.9389, 0.8056]],

        [[0.0772, 0.3565, 0.1479, 0.5331, 0.4066, 0.2318],
         [0.4545, 0.9737, 0.4606, 0.5159, 0.4220, 0.5786],
         [0.9455, 0.8057, 0.6775, 0.6087, 0.6179, 0.6932],
         [0.4354, 0.0353, 0.1908, 0.9268, 0.5299, 0.0950],
         [0.5789, 0.9131, 0.0275, 0.1634, 0.3009, 0.5201],
         [0.3834, 0.4451, 0.0126, 0.7341, 0.9389, 0.8056]]])


In [254]:
import torch.nn as nn
class MaskedAttentionBlock(nn.Module):
    def __init__(self, context_length):
        super().__init__()
        self.context_len=6;
        self.wk=nn.Linear(6,4);
        self.wq=nn.Linear(6,4);
        self.wv=nn.Linear(6,4);
        self.dropout=torch.nn.Dropout(.5)
        self.register_buffer('mask',torch.tril(torch.ones(context_length,context_length),diagonal=1))
    
    def forward(self,x):
        b,num_tokens,dim=x.shape                   
        q= self.wq(x);
        k=self.wk(x);
        v= self.wv(x);
        attn_scores= q @ k.transpose(1,2)
        attn_scores.masked_fill(self.mask.bool()[:num_tokens,:num_tokens], -torch.inf)
              
        attn_weights= torch.softmax(attn_scores/k.shape[-1]**0.5,dim=-1)
        attn_weights=self.dropout(attn_weights)     

        context_vec= attn_weights @ v;
        return context_vec
        

In [255]:
block=MaskedAttentionBlock(6)

In [280]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self,context_length,heads):
        super().__init__();
        self.attention_heads=nn.ModuleList([MaskedAttentionBlock(context_length) for _ in range(heads)])
    def forward(self, x):
        return torch.cat([head(x) for head in self.attention_heads], dim=-1)

In [281]:
mthb=MultiHeadAttentionWrapper(6,2)

In [284]:
mthb(batch)

tensor([[[ 0.5483,  0.2711, -0.1395, -0.0766, -0.0293, -0.6766,  0.1983,
           0.0731],
         [ 0.7867,  0.3218, -0.1234, -0.0414,  0.1032, -0.4606,  0.2253,
           0.0929],
         [ 0.3118,  0.1745, -0.0995, -0.0499, -0.0153, -0.4542,  0.1040,
          -0.0485],
         [ 0.8555,  0.5030, -0.2244, -0.1346,  0.0665, -0.9245,  0.2987,
           0.1044],
         [ 0.5436,  0.3309, -0.1291, -0.0862,  0.0460, -1.0082,  0.3657,
           0.2002],
         [ 1.0914,  0.5998, -0.2678, -0.1622, -0.0172, -0.2138,  0.0924,
           0.1286]],

        [[ 0.9291,  0.3749, -0.0282, -0.0925,  0.0209, -1.1143,  0.3571,
           0.1531],
         [ 1.0296,  0.5232, -0.2772, -0.1126,  0.0343, -0.7808,  0.2338,
           0.1339],
         [ 0.8469,  0.5052, -0.2354, -0.1378, -0.0997, -0.5310,  0.1521,
           0.1548],
         [ 1.1542,  0.4226, -0.0142, -0.0581,  0.0658, -0.6776,  0.2974,
           0.1476],
         [ 0.5436,  0.3309, -0.1291, -0.0862,  0.0778, -0.5822,  0.1