In [None]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import math


@dataclass
class Config:
    block_size: int = 64
    vocab_size: int = 65
    n_channel : int = 16
    n_embd    : int = 16
    n_layer   : int = 3
    n_head    : int = 8
    batch_size: int = 32

    device        : str   = 'cuda' if torch.cuda.is_available() else 'cpu'
    learning_rate : int   = 3e-2


In [None]:
class AutoEncoder(nn.Module):

    def __init__(self, dim, red= 4):
        super().__init__()
        self.activation = nn.GELU(approximate='tanh')
        self.fc_in      = nn.Linear(dim, dim // red, bias= False)
        self.fc_hidden  = nn.Linear(dim // red, dim // red, bias= False)
        self.fc_out     = nn.Linear(dim // red, dim, bias= False)

    def forward(self, x):

      # Encoder
      encoded = self.activation(self.fc_in(x))
      encoded = self.activation(self.fc_hidden(encoded))

      # Decoder
      decoded = self.fc_out(encoded)
      return decoded


In [59]:
class Head(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.key    = AutoEncoder(config.n_embd) # key   matrix
        self.query  = AutoEncoder(config.n_embd) # query matrix
        self.value  = AutoEncoder(config.n_embd) # value matrix

        self.register_buffer('tril', torch.tril(torch.ones(config.block_size,config.block_size)))

    def forward(self, x):

        _,t,_ = x.shape # to mask incase if the sequence changes

        k = self.key(x)   # batch wise matrix multiplication
        q = self.query(x) # batch wise matrix multiplication
        v = self.value(x) # batch wise matrix multiplication

        scores = q @ k.transpose(-2, -1) # dot product to get attn scores
        scores = scores * (1.0 / math.sqrt(k.size(-1))) # scaled down

        mask = self.tril[:t, :t]                              # causal mask (B, T, T)
        scores = scores.masked_fill(mask == 0, float("-inf")) # mask scores
        scores = F.softmax(scores, dim=-1)                    # softmax normalize
        out  = scores @ v                                     # attention matrix multiplication

        return out


In [60]:
class AETransformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.head = Head(config)                # auto encoder attention head
        self.ln_1 = nn.LayerNorm(config.n_embd) # layer normalization

        self.w_emb = nn.Embedding(config.vocab_size, config.n_embd)             # token embedding
        self.p_emb = nn.Embedding(config.block_size, config.n_embd)             # position embedding

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)  # output layer
        self.to(config.device)

    def forward(self, idx, targets=None):

        _,t = idx.shape # time dimesnion length for optimisation and positoin

        pos = torch.arange(0, t, dtype=torch.long, device=idx.device)
        out = self.w_emb(idx) + self.p_emb(pos) # pos + embedding

        out = out + self.head(self.ln_1(out))  #add (residual connection) + layer norm + mimicked attentioon
        logits = self.lm_head(out)

        if targets is None:
            loss = None
        else:
          B,T,C     = logits.shape
          logits    = logits.view(B*T, C)
          targets = targets.view(B*T)

          loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens,block_size):

        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):

          idx_cond = idx[:, -block_size:]

          logits , _ = self(idx)

          logits = logits[:, -1, :] # becomes (B, C)

          probs = F.softmax(logits, dim=-1)
          idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
          idx = torch.cat((idx, idx_next), dim=1) # (B, T+1z)
        return idx.squeeze().tolist()


In [61]:
config = Config()
mT = AETransformer(config=config)
optimizer = torch.optim.AdamW(mT.parameters(), lr=config.learning_rate)


In [62]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-06-29 01:22:31--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.5’


2025-06-29 01:22:31 (21.5 MB/s) - ‘input.txt.5’ saved [1115394/1115394]



In [63]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()


In [64]:
chars = sorted(list(set(text)))
vocab_size = 65

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
print(len(data))
n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]

1115394


In [65]:
def get_batch(split, config):
    block_size = config.block_size
    batch_size = config.batch_size
    device = config.device
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x,y = x.to(device=device), y.to(device=device)
    return x, y



In [66]:
for epochs in range(5000):

    x, y  = get_batch('train', config)


    optimizer.zero_grad()           # clear old gradients
    _, loss = mT(x, y)              # forward pass
    loss.backward()                 # back-propagate
    optimizer.step()                # update weights

    if epochs % 500 == 0:
        context = torch.randint(high=len(chars), size=(1, 1))
        print(f"Epoch: {epochs} | Loss: {loss.item():.4f}")
        generated_text = decode(mT.generate(context, config.block_size, config.block_size))
        print(f"Sample: {generated_text}")


Epoch: 0 | Loss: 4.4705
Sample: Ogx!qgfsviDvI.&Rg,w-naK'Khyf3.WHQm,MB! 
lIVrKNyRdWW'r!CXRWWjBIRg:
Epoch: 500 | Loss: 2.3885
Sample: ncen illenere te
duregied-ld-?
HARCLOLLUES:
To nge theee yorenere
Epoch: 1000 | Loss: 2.3629
Sample: ; wer:
Nad worealobe yound leind Orofe fot, nory onthexthe woowth
Epoch: 1500 | Loss: 2.3333
Sample: bis: bich dibys mukle, alif thany, thomid itstive thig, thanth yo
Epoch: 2000 | Loss: 2.3587
Sample: Y:
On ton, I Ooth ath thind; wa we foo.
Whoeny ghat mys aligrkis 
Epoch: 2500 | Loss: 2.3835
Sample: 3:


Whe de he's aduny,
The keand sos ard drovequge, byond zes to
Epoch: 3000 | Loss: 2.3471
Sample: SODUSBUSAR:
Haspr'd llet, fase pralolll wis, ye. Bar.

IOLEO:
Bu 
Epoch: 3500 | Loss: 2.3441
Sample: QULI:
Wice me hang thingh doth
Fink Rou, my beinth vakes.

GIUCAN
Epoch: 4000 | Loss: 2.2731
Sample: VI her hin me no me the, y, cout porsthar ood ur tow's tontheant 
Epoch: 4500 | Loss: 2.2413
Sample: :
Fad! hon anow fry  as, ianf ay et; ankey ainth weap
USTIUS