In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## token based on ascii

In [6]:
encoder = lambda input: [ord(char) for char in input]
decoder = lambda input: ''.join([chr(char) for char in input])
sample_input = "hi there!"
token = encoder(sample_input)
print(f"input         : {sample_input}\n"
      f"token         : {token}\n"
      f"decoded_token : {decoder(token)}")

input         : hi there!
token         : [104, 105, 32, 116, 104, 101, 114, 101, 33]
decoded_token : hi there!


## reading the input and tokenizing

In [7]:
with open("data.txt", "r") as f:
    data = f.read()

#converting the word into token and creating a tensor of it
token_data = torch.tensor(encoder(data))

#spliting data into train and val
train_len = int(0.9 * len(token_data))
train_data = token_data[:train_len]
val_data = token_data[train_len:]

print(f"train data length : {train_data.numel()}")
print(f"val data length : {val_data.numel()}")

train data length : 1003853
val data length : 111540


## create the input and traget, with the batch dim

In [8]:
torch.manual_seed(1234)
def get_batch(batch, block_len, data):
    random_pos = torch.randint(0, len(data)-block_len, (batch,))
    input = torch.stack([data[i:i+block_len] for i in random_pos])
    target = torch.stack([data[i+1:i+block_len+1] for i in random_pos])
    return input, target

In [9]:
batch = 4
block_len = 8
ipt, tgt = get_batch(batch, block_len, train_data)
print(ipt.shape)
print(f"input : \n{ipt.numpy()}\n"
      f"output: \n{tgt.numpy()}")

torch.Size([4, 8])
input : 
[[108 100  32 109  97 114 114 105]
 [101  32 109 101 110  32  99  97]
 [ 65 110  32 121 101 116  44  32]
 [116  39 115 121  32 119 111 114]]
output: 
[[100  32 109  97 114 114 105  97]
 [ 32 109 101 110  32  99  97 110]
 [110  32 121 101 116  44  32 102]
 [ 39 115 121  32 119 111 114 116]]


In [None]:
class Embedding(nn.Module):
    def __init__(self, embedding_dim=64, vocab_len=128, context_len=50):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_len, embedding_dim) #v,c
        self.time_embedding = nn.Parameter(torch.randn((embedding_dim, context_len))) #v,l

    def forward(self, x):
        batch, context_len = x.shape
        time_range = torch.arange(context_len, device=x.device)
        time_embed = self.time_embedding[:, time_range].unsqueeze(0).expand(batch, -1, -1)
        token_embed = self.token_embedding(x)

        return time_embed.permute(0,2,1) + token_embed # (b, t, c)
    
class Head(nn.Module):
    def __init__(self, context_len=50, embedding_dim=64, attention_dim=8, mode="encoder"):
        super().__init__()
        self.query = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.value = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.key = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.mode = mode
        self.register_buffer('tril', torch.tril(torch.ones(context_len, context_len)))

    def forward(self, x):
        b, t, c = x.shape  # (b, t, c)
        # Compute Q, K, V
        q = self.query(x)  # (b, t, a)
        k = self.key(x)    # (b, t, a)
        v = self.value(x)  # (b, t, a)
        # Compute scaled dot-product attention
        d_k = q.shape[-1]
        scaled_scores = torch.bmm(q, k.permute(0, 2, 1)) / (d_k ** 0.5)
        if self.mode=="encoder":
            attn_weights = scaled_scores
        else:
            attn_weights = scaled_scores.masked_fill(self.tril[:t, :t]==0, float('-inf'))# (b, t, a).(b, a, t) -> (b, t, t)
        attn_weights = F.softmax(attn_weights, dim=-1)  # (b, t, a).(b, a, t) -> (b, t, t)
        attention_output = torch.bmm(attn_weights, v)  #(b,t,a).(b, t, t) -> (b, t, a)
        return attention_output #(b, t, a)
    
class SelfAttention(nn.Module):
    def __init__(self, embedding_dim=64, heads=8, mode="decoder"):
        super().__init__()
        self.attention_dim = embedding_dim//heads
        self.heads = nn.ModuleList([Head(context_len=50, 
                                         embedding_dim=embedding_dim, 
                                         attention_dim=self.attention_dim, 
                                         mode=mode) for _ in range(heads)])
        self.projection_weight = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, x):
        b, t, c = x.shape
        x_ln = F.layer_norm(x, (t, c))
        output = [head(x_ln) for head in self.heads]
        output = torch.cat(output, dim=2) #cat on the channel or attention dimention in my case its 2
        x_p = self.projection_weight(output)
        return x + x_p
    
class MLP(nn.Module):
    def __init__(self, embedding_dim=64, mpl_multipler=4):
        super().__init__()
        self.mlp_weight = nn.Linear(embedding_dim, embedding_dim*mpl_multipler)
        self.mpl_projection = nn.Linear(embedding_dim*mpl_multipler, embedding_dim)
        self.gelu1 = nn.GELU()

    def forward(self, x):
        b, t, c = x.shape
        x_ln = F.layer_norm(x, (t, c))
        output = self.gelu1(self.mlp_weight(x_ln))
        output = self.mpl_projection(output)
        return x + output

In [65]:
em = Embedding()
a = SelfAttention()
mlp = MLP()

In [66]:
input = torch.randint(0,10, (4,5))

In [67]:
em_in = em(input)
print(em_in.shape)
a_in = a(em_in)
a_in.shape
mlp_in = mlp(a_in)
mlp_in.shape

torch.Size([4, 5, 64])


torch.Size([4, 5, 64])