In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## token based on ascii

In [8]:
encoder = lambda input: [ord(char) for char in input]
decoder = lambda input: ''.join([chr(char) for char in input])
sample_input = "hi there!"
token = encoder(sample_input)
print(f"input         : {sample_input}\n"
      f"token         : {token}\n"
      f"decoded_token : {decoder(token)}")

input         : hi there!
token         : [104, 105, 32, 116, 104, 101, 114, 101, 33]
decoded_token : hi there!


## reading the input and tokenizing

In [9]:
with open("data.txt", "r") as f:
    data = f.read()

#converting the word into token and creating a tensor of it
token_data = torch.tensor(encoder(data))

#spliting data into train and val
train_len = int(0.9 * len(token_data))
train_data = token_data[:train_len]
val_data = token_data[train_len:]

print(f"train data length : {train_data.numel()}")
print(f"val data length : {val_data.numel()}")

train data length : 1003853
val data length : 111540


## create the input and traget, with the batch dim

In [10]:
torch.manual_seed(1234)
def get_batch(batch, block_len, data):
    random_pos = torch.randint(0, len(data)-block_len, (batch,))
    input = torch.stack([data[i:i+block_len] for i in random_pos])
    target = torch.stack([data[i+1:i+block_len+1] for i in random_pos])
    return input, target

In [11]:
batch = 4
block_len = 8
ipt, tgt = get_batch(batch, block_len, train_data)
print(ipt.shape)
print(f"input : \n{ipt.numpy()}\n"
      f"output: \n{tgt.numpy()}")

torch.Size([4, 8])
input : 
[[108 100  32 109  97 114 114 105]
 [101  32 109 101 110  32  99  97]
 [ 65 110  32 121 101 116  44  32]
 [116  39 115 121  32 119 111 114]]
output: 
[[100  32 109  97 114 114 105  97]
 [ 32 109 101 110  32  99  97 110]
 [110  32 121 101 116  44  32 102]
 [ 39 115 121  32 119 111 114 116]]


# Model Architecture

In [24]:
class Embedding(nn.Module):
    def __init__(self, embedding_dim=64, n_vocal=128, context_len=50):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocal, embedding_dim) # (vocab_size, embedding_dim)
        self.time_embedding = nn.Parameter(torch.randn((embedding_dim, context_len))) # (embedding_dim, context_len)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        batch, context_len = x.shape
        time_range = torch.arange(context_len, device=x.device)
        time_embed = self.time_embedding[:, time_range].unsqueeze(0).expand(batch, -1, -1)
        token_embed = self.token_embedding(x)
        teken_embed = self.dropout(token_embed)
        return time_embed.permute(0, 2, 1) + token_embed  # (batch, context_len, embedding_dim)

class Head(nn.Module):
    def __init__(self, context_len=50, embedding_dim=64, attention_dim=8, mode="encoder"):
        super().__init__()
        self.query = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.value = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.key = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.mode = mode
        self.dropout = nn.Dropout(0.5)
        self.register_buffer('tril', torch.tril(torch.ones(context_len, context_len)))

    def forward(self, x):
        b, t, c = x.shape  
        q = self.query(x)  
        k = self.key(x)    
        v = self.value(x)  
        d_k = q.shape[-1]
        scaled_scores = torch.bmm(q, k.permute(0, 2, 1)) / (d_k ** 0.5)

        if self.mode == "encoder":
            attn_weights = scaled_scores
        else:
            attn_weights = scaled_scores.masked_fill(self.tril[:t, :t] == 0, float('-inf'))
        
        attn_weights = F.softmax(attn_weights, dim=-1)  
        attn_weights = self.dropout(attn_weights)
        attention_output = torch.bmm(attn_weights, v)  
        return attention_output  

class SelfAttention(nn.Module):
    def __init__(self, context_l=50, embedding_dim=64, heads=8, mode="decoder"):
        super().__init__()
        self.attention_dim = embedding_dim // heads
        self.heads = nn.ModuleList([
            Head(context_len=context_l, embedding_dim=embedding_dim, attention_dim=self.attention_dim, mode=mode)
            for _ in range(heads)
        ])
        self.projection_weight = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        b, t, c = x.shape
        x_ln = F.layer_norm(x, (c,))  # Correct normalization
        output = [head(x_ln) for head in self.heads]
        output = torch.cat(output, dim=2)
        x_p = self.projection_weight(output)
        x_p = self.dropout(x_p)
        return x + x_p

class MLP(nn.Module):
    def __init__(self, embedding_dim=64, mlp_multiplier=4):
        super().__init__()
        self.mlp_weight = nn.Linear(embedding_dim, embedding_dim * mlp_multiplier)
        self.mlp_projection = nn.Linear(embedding_dim * mlp_multiplier, embedding_dim)
        self.gelu1 = nn.GELU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        b, t, c = x.shape
        x_ln = F.layer_norm(x, (c,))
        output = self.gelu1(self.mlp_weight(x_ln))
        output = self.dropout(output)
        output = self.mlp_projection(output)
        return x + output

class Logits(nn.Module):
    def __init__(self, n_vocal=128, embedding_dim=64):
        super().__init__()
        self.encode = nn.Linear(embedding_dim, n_vocal, bias=False)

    def forward(self, x):
        b, t, c = x.shape
        x_ln = F.layer_norm(x, (c,))  # Normalize across channels (embedding dim)
        logits = self.encode(x_ln)  # (b, t, n_vocal)
        return logits  # No softmax applied

class TransformerMini(nn.Module):
    def __init__(self, context_l=50, n_vocal=128, embedding_dim=64, attention_heads=8, mode="decoder"):
        super().__init__()
        self.embedding = Embedding(embedding_dim, n_vocal, context_l)
        self.self_attention = SelfAttention(context_l, embedding_dim, attention_heads, mode)
        self.mlp = MLP(embedding_dim)
        self.logits = Logits(n_vocal, embedding_dim)

    def forward(self, x):
        x = self.embedding(x)
        x = self.self_attention(x)
        x = self.mlp(x)
        x = self.logits(x)
        return x  # Output shape: (batch, context_len, n_vocal)

# training

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 1e-3
batch = 64
epoch = 6000

In [32]:
128//16

8

In [33]:
model = TransformerMini(context_l=256, embedding_dim=128, attention_heads=16).to(device=device)
model.train()

TransformerMini(
  (embedding): Embedding(
    (token_embedding): Embedding(128, 128)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (self_attention): SelfAttention(
    (heads): ModuleList(
      (0-15): 16 x Head(
        (query): Linear(in_features=128, out_features=8, bias=False)
        (value): Linear(in_features=128, out_features=8, bias=False)
        (key): Linear(in_features=128, out_features=8, bias=False)
        (dropout): Dropout(p=0.5, inplace=False)
      )
    )
    (projection_weight): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (mlp): MLP(
    (mlp_weight): Linear(in_features=128, out_features=512, bias=True)
    (mlp_projection): Linear(in_features=512, out_features=128, bias=True)
    (gelu1): GELU(approximate='none')
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (logits): Logits(
    (encode): Linear(in_features=128, out_features=128, bias=False)
  )
)

In [34]:
optimizer = torch.optim.AdamW(model.parameters(), lr)
crtirion = nn.CrossEntropyLoss()

In [None]:
for i in range(epoch):
    model.train()
    xb, yb = get_batch(batch, 128, train_data)
    xb, yb = xb.to(device=device), yb.to(device=device)
    print(xb.shape)
    optimizer.zero_grad()
    output = model(xb)
    b, t, n = output.shape #(b, t, n)
    loss = crtirion(output.view(b*t, n), yb.view(b*t))
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        vxb, vyb = get_batch(batch, 128, val_data)
        vxb, vyb = vxb.to(device=device), vyb.to(device=device)
        model.eval()
        output = model(vxb)
        b, t, n = output.shape #(b, t, n)
        vloss = crtirion(output.view(b*t, n), vyb.view(b*t))


    print(F"epoch : {i}, train_l : {loss.item()}, val_l : {vloss.item()}")



torch.Size([64, 128])
epoch : 0, train_l : 5.019121170043945, val_l : 5.017431735992432
torch.Size([64, 128])
epoch : 1, train_l : 4.999518394470215, val_l : 5.0017194747924805
torch.Size([64, 128])
epoch : 2, train_l : 4.988387584686279, val_l : 5.003236293792725
torch.Size([64, 128])
epoch : 3, train_l : 4.965498447418213, val_l : 4.990217208862305
torch.Size([64, 128])
epoch : 4, train_l : 4.960470199584961, val_l : 4.970786094665527
torch.Size([64, 128])
epoch : 5, train_l : 4.947319030761719, val_l : 4.966724872589111
torch.Size([64, 128])
epoch : 6, train_l : 4.930920124053955, val_l : 4.9585371017456055
torch.Size([64, 128])
epoch : 7, train_l : 4.932190895080566, val_l : 4.956240177154541
torch.Size([64, 128])
epoch : 8, train_l : 4.910511493682861, val_l : 4.944454669952393
torch.Size([64, 128])
epoch : 9, train_l : 4.89579963684082, val_l : 4.941092491149902
torch.Size([64, 128])
epoch : 10, train_l : 4.893231391906738, val_l : 4.935298919677734
torch.Size([64, 128])
epoch : 

KeyboardInterrupt: 

In [36]:
def generator(model:TransformerMini, token_len):
    token = torch.tensor([[10]]).to(device)
    for i in range(token_len):
        out = model(token)
        sm_out = F.softmax(out[:,-1,:], dim=-1)
        predict = torch.multinomial(sm_out, num_samples=1)
        token = torch.cat((token, predict), dim=1)[:, -50:]
        print(decoder(predict), end='')
        

In [38]:
generator(model.eval(), 400)


hee

L aee s dicY ce kare:.endteUo!daney,gid'.
A t
oAd aneod 
erad
Ioallcethe oseofo
d torly cinet in hUg Ee&ere d  m he%risahs p \. wanrlld id lleneer wo lc r ur en aan n agasor mrthte, w hay icorXvhoe f
d hensere pAxIfoiloae o 
qhe t
)
mad htheelarytdutenocerp.othfua{n b th h awin
(rineounitheaaufraneon allly,iis mcere mhar tononad illheroar yea
AasuK it as hothe
acI lliled Dgsrnrreaf lH