# opening and reading the file

In [1]:
with open('the-verdict.txt', 'r', encoding = 'utf-8') as f:
    raw_text = f.read()

print(len(raw_text))
print(raw_text[:99])

20479
I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no 


# creating tokenizer

### splitting the text where white spaces

In [2]:
import re

text = "Hello World, this is a test"
result = re.split(r'([,.:;?!"()\']|--|\s)', text)            #() around r'\s' here acts as container for white spaces.

print(result)

['Hello', ' ', 'World', ',', '', ' ', 'this', ' ', 'is', ' ', 'a', ' ', 'test']


### getting rid of white spaces

In [3]:
result = [item for item in result if item.strip()]
print(result)

['Hello', 'World', ',', 'this', 'is', 'a', 'test']


In [4]:
preprocessed = re.split(r'(\s|[.,:;?!\"()\[\]{}\-])', raw_text)
preprocessed = [item.strip() for item in preprocessed if item and item.strip()] #added item only for handling any null values by chance

In [5]:
preprocessed[:5]

['I', 'HAD', 'always', 'thought', 'Jack']

# creating token IDs

### creating a set of unique words

In [6]:
all_words = sorted(set(preprocessed))
vocab_size = len(all_words)

print(all_words[400:406])

['dragged', 'drawing', 'drawn', 'dress', 'drew', 'dropped']


### assigning token id to each token i.e. unique token

In [7]:
vocab = {token:integer for integer, token in enumerate(all_words)}

In [8]:
for token, token_id in enumerate(vocab.items()):
    print(token, token_id)
    if token == 15:
        break

0 ('!', 0)
1 ('"', 1)
2 ("'", 2)
3 ("'Are", 3)
4 ("'It's", 4)
5 ("'coming'", 5)
6 ("'done'", 6)
7 ("'subject", 7)
8 ("'technique'", 8)
9 ("'way", 9)
10 ('(', 10)
11 (')', 11)
12 (',', 12)
13 ('-', 13)
14 ('.', 14)
15 (':', 15)


In [9]:
class SimpleTokenizerV1:

    def __init__(self, vocab):
        """
        it takes self, and vocab(our dictionary of tokens and token ids)
        then it creates to two variables inside the object
        that are:- str_to_int and int_to_str, what we store in them is the following.
        self.str_to_int = vocab #basically storing the token:id pair here, used for encoding
        also, 
        self.int_to_str = {i:s for s, i in vocab.items()}
        """
        self.str_to_int = vocab
        self.int_to_str = {i:s for s,i in vocab.items()}

    
    def encode(self, text):
        """
        it takes self and text as parameters, it takes the text and separates it on the basis of tokens(smallest elements)
        strips it off any white spaces.
        when tokens are created, we assign token ids to the tokens
        we use re.split() for separating the elements / creating tokens
        we use re.strip() to remove white spaces from the text
        to assign token value we use dictionary, for which we can use str_to_int variable.
        we return the ids.
        """
        preprocessed = re.findall(r'<\|[^|]+?\|>|\w+|[^\w\s]', text)

        preprocessed = [
            item.strip() for item in preprocessed if item.strip()
        ]

        preprocessed = [
            item if item in self.str_to_int         #new code that checks if any word is from outside the vocab and assigns "<|unk|>"
            else "<|unk|>" for item in preprocessed
        ]
        
        ids = [self.str_to_int[s] for s in preprocessed]
        return ids

    def decode(self, ids):
        """
        we have encoded i.e. tokenized data, numerical data or token id.
        we need to create words back from it simply.
        for that we need int_to_str variable.
        also we need " ".join() to join them along, because they were split.
        where are the words -> [self.int_to_str[i] for i in ids]
        we used re.split() method to separate them, now we need   re.sub() method to join them.
        """
        text = " ".join([self.int_to_str[i] for i in ids])
        text = re.sub(r'\s+([.,:;?!\"()\[\]{}\-])', r'\1', text)
        return text


In [10]:
tokenizer = SimpleTokenizerV1(vocab)

text = "done subject"
ids = tokenizer.encode(text)
print(ids)

[396, 985]


In [11]:
tokenizer = SimpleTokenizerV1(vocab)

ids = [396, 985]
text = tokenizer.decode(ids)
print(text)

done subject


# New Tokens- Special Context Tokens

In [12]:
print(len(all_words))
all_words.extend(["<|endoftext|>", "<|unk|>"])
vocab = {token : token_id for token_id, token in enumerate(all_words)}

1172


In [13]:
len(all_words)

1174

In [14]:
print(list(vocab.items())[-2])
print(list(vocab.items())[-1])

('<|endoftext|>', 1172)
('<|unk|>', 1173)


In [15]:
tokenizer = SimpleTokenizerV1(vocab)
text1 = "do you like tea GurSanjjam?"
text2 = "do you like tea?"
com_text = "<|endoftext|>".join((text1, text2))
print(com_text)
ids = tokenizer.encode(com_text)
print(ids)

do you like tea GurSanjjam?<|endoftext|>do you like tea?
[391, 1166, 663, 1013, 1173, 17, 1172, 391, 1166, 663, 1013, 17]


In [16]:
text = tokenizer.decode(ids)

In [17]:
text

'do you like tea <|unk|>? <|endoftext|> do you like tea?'

In [18]:
import tiktoken
tokenizer = tiktoken.get_encoding('gpt2')

# Creating input target pairs

In [19]:
with open('the-verdict.txt', 'r', encoding = 'utf-8') as f:
    raw_text = f.read()

enc_text = tokenizer.encode(raw_text)
print(len(enc_text))

5145


In [20]:
context_size = 4
x = enc_text[:context_size]
y = enc_text[1:context_size+1]

print(f"x: {x}")
print(f"y:     {y}")

x: [40, 367, 2885, 1464]
y:     [367, 2885, 1464, 1807]


In [21]:
for i in range(1,context_size+1):
    context = enc_text[:i]
    desired = enc_text[i]

    print(f"{tokenizer.decode(context)} ----> {tokenizer.decode([desired])}")

I ---->  H
I H ----> AD
I HAD ---->  always
I HAD always ---->  thought


# Implementing a data loader

In [22]:
from torch.utils.data import Dataset, DataLoader

In [23]:
class GPTDatasetV1(Dataset):
    def __init__(self, text, tokenizer, context_size, stride):
        self.input_ids = []
        self.target_ids = []

        token_ids = tokenizer.encode(text, allowed_special ={"<|endoftext|>"})

        for i in range(0, len(token_ids) - context_size, stride):
            input_chunk = token_ids[i: i+context_size]
            target_chunk = token_ids[i+1: i+context_size+1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [24]:
def create_dataloader_v1(text, batch_size = 4, context_size = 256, 
                         stride = 128, shuffle = True, drop_last = True, 
                         num_workers = 0):

    tokenizer = tiktoken.get_encoding('gpt2')

    dataset = GPTDatasetV1(text, tokenizer, context_size, stride)

    #create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size = batch_size,
        shuffle = shuffle,
        drop_last = drop_last,
        num_workers = num_workers
    )

    return dataloader

In [25]:
with open('the-verdict.txt', 'r', encoding = 'utf-8') as f:
    raw_text = f.read()

In [26]:
import torch
dataloader = create_dataloader_v1(
    raw_text, batch_size = 1, context_size = 4, stride = 1, shuffle = False
)

data_iter = iter(dataloader)
first_batch = next(data_iter)
print(first_batch)

[tensor([[  40,  367, 2885, 1464]]), tensor([[ 367, 2885, 1464, 1807]])]


# Implementing a Positional Encoder

### we are assuming that the input is the output of the bpe or byte pair encoder.

In [36]:
import torch.nn as nn

In [37]:
vocab_size = 1172
output_dim = 256

token_embedding_layer = nn.Embedding(vocab_size, output_dim)

# simple self attention mechanism

In [38]:
import torch

inputs =  torch.tensor([
    [0.43, 0.15, 0.89], #Your x1
    [0.55, 0.87, 0.66], #journey x2
    [0.57, 0.85, 0.64], #starts x3
    [0.22, 0.58, 0.33], #with x4
    [0.77, 0.25, 0.10], #one x5
    [0.05, 0.80, 0.55]  #step x6
])

In [39]:
'''query = inputs[1] #2nd input token is the query

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query) #dot product

print(attn_scores_2)'''

'query = inputs[1] #2nd input token is the query\n\nattn_scores_2 = torch.empty(inputs.shape[0])\nfor i, x_i in enumerate(inputs):\n    attn_scores_2[i] = torch.dot(x_i, query) #dot product\n\nprint(attn_scores_2)'

In [40]:
# normalisation
'''
attn_weights_2 = torch.softmax(attn_scores_2, dim = 0)
print("Attention Weights: ", attn_weights_2)'''

'\nattn_weights_2 = torch.softmax(attn_scores_2, dim = 0)\nprint("Attention Weights: ", attn_weights_2)'

In [41]:
# context vector

'''context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i

print(context_vec_2)'''

'context_vec_2 = torch.zeros(query.shape)\nfor i, x_i in enumerate(inputs):\n    context_vec_2 += attn_weights_2[i]*x_i\n\nprint(context_vec_2)'

In [42]:
attn_scores = inputs @ inputs.T    
attn_weights = torch.softmax(attn_scores, dim =-1) #dim -1 means previous dimension i.e. dim of attn_scores
context_vector = attn_weights @ inputs
print(context_vector)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


### Self Attention with trainable weights

In [74]:
import torch

inputs =  torch.tensor([
    [0.43, 0.15, 0.89], #Your x1
    [0.55, 0.87, 0.66], #journey x2
    [0.57, 0.85, 0.64], #starts x3
    [0.22, 0.58, 0.33], #with x4
    [0.77, 0.25, 0.10], #one x5
    [0.05, 0.80, 0.55]  #step x6
])

In [75]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [76]:
#assigning random values to each parameter as a start
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)

In [77]:
print(W_query)

Parameter containing:
tensor([[0.4411, 0.6821],
        [0.1414, 0.6950],
        [0.1004, 0.3306]])


In [78]:
#creating Q,K,V matrices
queries = inputs @ W_query
key = inputs @ W_key
value = inputs @ W_value

In [79]:
print(queries)

tensor([[0.3003, 0.6917],
        [0.4319, 1.1979],
        [0.4359, 1.1911],
        [0.2122, 0.6622],
        [0.3851, 0.7320],
        [0.1904, 0.7719]])


In [80]:
# since we now have the matrices, we can calculate attention scores for each query
# for which we dot product the eg. query[1] @ key matrix
attn_score = queries @ key.T
print(attn_score)

tensor([[0.7198, 0.9995, 0.9943, 0.5280, 0.6225, 0.6148],
        [1.2187, 1.6865, 1.6783, 0.8890, 1.0586, 1.0314],
        [1.2138, 1.6801, 1.6719, 0.8858, 1.0539, 1.0279],
        [0.6653, 0.9189, 0.9146, 0.4838, 0.5793, 0.5601],
        [0.7828, 1.0915, 1.0856, 0.5781, 0.6735, 0.6761],
        [0.7576, 1.0424, 1.0378, 0.5475, 0.6627, 0.6313]])


In [81]:
#we now calc attn_weights by appyling softmax on attn_score scaled down to underoot(dim_keys)
dim_keys = key.shape[-1]
attn_weights = torch.softmax(attn_score/dim_keys**0.5, dim = -1)
print(attn_weights, dim_keys)

tensor([[0.1621, 0.1976, 0.1969, 0.1416, 0.1513, 0.1505],
        [0.1578, 0.2197, 0.2184, 0.1250, 0.1409, 0.1382],
        [0.1578, 0.2195, 0.2182, 0.1252, 0.1410, 0.1384],
        [0.1629, 0.1949, 0.1943, 0.1433, 0.1533, 0.1512],
        [0.1613, 0.2006, 0.1998, 0.1395, 0.1493, 0.1495],
        [0.1625, 0.1988, 0.1981, 0.1401, 0.1520, 0.1486]]) 2


# Implementing a simple python class for self attention 

In [82]:
import torch

inputs =  torch.tensor([
    [0.43, 0.15, 0.89], #Your x1
    [0.55, 0.87, 0.66], #journey x2
    [0.57, 0.85, 0.64], #starts x3
    [0.22, 0.58, 0.33], #with x4
    [0.77, 0.25, 0.10], #one x5
    [0.05, 0.80, 0.55]  #step x6
])

In [83]:
import torch.nn as nn

class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out)
        self.W_key = nn.Linear(d_in, d_out)
        self.W_value = nn.Linear(d_in, d_out)

    def forward(self, x):
        query = self.W_query(x)
        key = self.W_key(x)
        value = self.W_value(x)

        attn_score = query @ key.transpose(-2,-1)
        attn_weight = torch.softmax(attn_score / key.shape[-1]**0.5, dim = -1)

        context_vec = attn_weight @ value
        return context_vec

In [84]:
sa_v2 = SelfAttentionV2(3,6)
print(sa_v2(inputs))

tensor([[ 0.3684,  0.0559, -0.1986,  0.1970, -0.0378, -0.5136],
        [ 0.3655,  0.0521, -0.1989,  0.1917, -0.0344, -0.5117],
        [ 0.3654,  0.0520, -0.1989,  0.1916, -0.0344, -0.5118],
        [ 0.3661,  0.0530, -0.2030,  0.1943, -0.0347, -0.5167],
        [ 0.3643,  0.0508, -0.2010,  0.1917, -0.0342, -0.5163],
        [ 0.3668,  0.0540, -0.2028,  0.1949, -0.0350, -0.5154]],
       grad_fn=<MmBackward0>)


# hiding future words with masked attention

In [85]:
context_length = attn_score.shape[0]
lower_triangular = torch.tril(torch.ones(context_length, context_length))

In [86]:
print(lower_triangular)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [87]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1621, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1578, 0.2197, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1578, 0.2195, 0.2182, 0.0000, 0.0000, 0.0000],
        [0.1629, 0.1949, 0.1943, 0.1433, 0.0000, 0.0000],
        [0.1613, 0.2006, 0.1998, 0.1395, 0.1493, 0.0000],
        [0.1625, 0.1988, 0.1981, 0.1401, 0.1520, 0.1486]])


In [88]:
row_sum = masked_simple.sum(dim = 1, keepdim = True)
print(row_sum)

tensor([[0.1621],
        [0.3775],
        [0.5955],
        [0.6955],
        [0.8505],
        [1.0000]])


In [89]:
normalised_masked = masked_simple / row_sum
print(normalised_masked)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4181, 0.5819, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2650, 0.3685, 0.3664, 0.0000, 0.0000, 0.0000],
        [0.2343, 0.2803, 0.2794, 0.2060, 0.0000, 0.0000],
        [0.1896, 0.2359, 0.2349, 0.1641, 0.1755, 0.0000],
        [0.1625, 0.1988, 0.1981, 0.1401, 0.1520, 0.1486]])


## creating a better matrix, preventing data leakage

In [96]:
upper_triangular = torch.triu(torch.ones(context_length, context_length), diagonal = 1) #diagonal shifted by 1
masked = attn_scores.masked_fill(upper_triangular.bool(), -torch.inf)

In [97]:
print(masked)

tensor([[0.9995,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.9544, 1.4950,   -inf,   -inf,   -inf,   -inf],
        [0.9422, 1.4754, 1.4570,   -inf,   -inf,   -inf],
        [0.4753, 0.8434, 0.8296, 0.4937,   -inf,   -inf],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654,   -inf],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [98]:
attn_weights = torch.softmax(masked / key.shape[-1]**0.5, dim = 1)

In [100]:
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4056, 0.5944, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2566, 0.3741, 0.3693, 0.0000, 0.0000, 0.0000],
        [0.2176, 0.2823, 0.2796, 0.2205, 0.0000, 0.0000],
        [0.1826, 0.2178, 0.2191, 0.1689, 0.2115, 0.0000],
        [0.1473, 0.2033, 0.1996, 0.1500, 0.1160, 0.1839]])


## CAUSAL ATTENTION CLASS

In [18]:
import torch.nn as nn
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)   #takes input_dimension -> converts Linear matrix into out_dimension
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)     #here bias is qkv_bias but qkv_bias is set to false
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.dropout = nn.Dropout(dropout)                       #dropout parameter value decides what % of hidden nodes
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal = -1))

    def forward(self, x):
        batch_size, num_tokens, input_dim = x.shape

        #creating weighted matrices of Q,K,V
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        #calculating attention scores
        attention_scores = queries @ keys.transpose(1,2)     #we dont want to transpose the batch size
        attention_scores.masked_fill_( 
            self.mask[:num_tokens, :num_tokens].bool(), -torch.inf    #add !(∞)
        )

        #calculating attention weights
        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1]**0.5, dim = -1    #shape[-1] is input_dimension, dim -1 means->shape it as last shape of mat
        )
        attention_weights = self.dropout(attention_weights)

        #context vector
        context_vector = attention_weights @ values
        return context_vector

In [19]:
batch = torch.stack((inputs, inputs), dim = 0)

In [20]:
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0)
context_vecs = ca(batch)
print(context_vecs)

tensor([[[    nan,     nan],
         [    nan,     nan],
         [-0.4519,  0.2216],
         [-0.5857,  0.0086],
         [-0.6281, -0.0602],
         [-0.5678, -0.0850]],

        [[    nan,     nan],
         [    nan,     nan],
         [-0.4519,  0.2216],
         [-0.5857,  0.0086],
         [-0.6281, -0.0602],
         [-0.5678, -0.0850]]], grad_fn=<UnsafeViewBackward0>)


# linear multi head attention ineffiecient technique

In [44]:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias = False) 
            for _ in range(num_heads)]   #creating as many instances as the number of heads
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim = -1)    #concatenating the heads along the columns

In [45]:
inputs =  torch.tensor([
    [0.43, 0.15, 0.89], #Your x1
    [0.55, 0.87, 0.66], #journey x2
    [0.57, 0.85, 0.64], #starts x3
    [0.22, 0.58, 0.33], #with x4
    [0.77, 0.25, 0.10], #one x5
    [0.05, 0.80, 0.55]  #step x6
])
batch = torch.stack((inputs, inputs), dim = 0)
print(batch)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


In [46]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3 , 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads = 2)
context_vec = mha(batch)
print(context_vec)
print("context vector shape = ", context_vec.shape)

tensor([[[    nan,     nan,     nan,     nan],
         [    nan,     nan,     nan,     nan],
         [-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5857,  0.0086,  0.5836,  0.3149],
         [-0.6281, -0.0602,  0.6129,  0.3716],
         [-0.5678, -0.0850,  0.5487,  0.3626]],

        [[    nan,     nan,     nan,     nan],
         [    nan,     nan,     nan,     nan],
         [-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5857,  0.0086,  0.5836,  0.3149],
         [-0.6281, -0.0602,  0.6129,  0.3716],
         [-0.5678, -0.0850,  0.5487,  0.3626]]], grad_fn=<CatBackward0>)
context vector shape =  torch.Size([2, 6, 4])


In [47]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3 , 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads = 3)
context_vec = mha(batch)
print(context_vec)
print("context vector shape = ", context_vec.shape)

tensor([[[    nan,     nan,     nan,     nan,     nan,     nan],
         [    nan,     nan,     nan,     nan,     nan,     nan],
         [-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729],
         [-0.5857,  0.0086,  0.5836,  0.3149,  0.5852,  0.3024],
         [-0.6281, -0.0602,  0.6129,  0.3716,  0.6278,  0.3108],
         [-0.5678, -0.0850,  0.5487,  0.3626,  0.5687,  0.2784]],

        [[    nan,     nan,     nan,     nan,     nan,     nan],
         [    nan,     nan,     nan,     nan,     nan,     nan],
         [-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729],
         [-0.5857,  0.0086,  0.5836,  0.3149,  0.5852,  0.3024],
         [-0.6281, -0.0602,  0.6129,  0.3716,  0.6278,  0.3108],
         [-0.5678, -0.0850,  0.5487,  0.3626,  0.5687,  0.2784]]],
       grad_fn=<CatBackward0>)
context vector shape =  torch.Size([2, 6, 6])


# creating a multi head attention class and causal attention into a single class

In [50]:
# reshapes the projected tensors and combines the result after computing attention
import torch
import torch.nn as nn

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(d_in, d_out, context_length, dropout, num_heads, bias = False):
        super().__init__()
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out / num_heads   #reduce projection dim to match the desired output_dim


        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_vaue = nn.Linear(d_in, d_out, bias = qkv_bias)
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_keys(x)
        queries = self.W_queries(x)
        values = self.W_values(x)

        keys = keys.view(b, num_tokens, num_heads, head_dim)
        queries = keys.queries(b, num_tokens, num_heads, head_dim)
        values = keys.values(b, num_tokens, num_heads, head_dim)

        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.tranpose(1,2)