DSAGPT â€” A Minimal GPT-Style Transformer Built from Scratch

In [1]:
from google.colab import files
uploaded = files.upload()

Saving input.txt to input (1).txt


In [2]:
import os, time, math
import torch
import torch.nn as nn
from torch.nn import functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

#hyperparameters
input_file = "input.txt"
batch_size = 16
block_size = 32
n_embd = 256
n_head = 4
n_layer = 4
dropout = 0.1

max_iters = 2000
eval_interval = 200
eval_iters = 50
learning_rate = 3e-4
grad_clip = 1.0

gen_max_tokens = 100
gen_temperature = 0.8
gen_top_k = 30

ckpt_dir = "checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
# --------------------------------------------------

torch.manual_seed(1337)
#Simple whitespace tokenizer

def build_word_tokenizer(text):
    words = text.split()
    words_unique = list(dict.fromkeys(words))
    stoi = {w: i+1 for i, w in enumerate(words_unique)}  # +1 reserve PAD as 0
    itos = {i: w for w, i in stoi.items()}
    return stoi, itos

# Load dataset
with open(input_file, "r", encoding="utf-8") as f:
    raw = f.read()

stoi, itos = build_word_tokenizer(raw)
vocab_size = len(stoi) + 1
print("Vocab size:", vocab_size)

def encode_text(s):
    return [stoi[w] for w in s.split() if w in stoi]

def decode_ids(ids):
    return " ".join(itos[i] for i in ids if i in itos)

# Tokenize dataset
ids = encode_text(raw)
print("Tokenized dataset length:", len(ids))

data = torch.tensor(ids, dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

train_data = train_data.to(device)
val_data = val_data.to(device)

#Batch loader
def get_batch(split):
    data_src = train_data if split == "train" else val_data
    data_len = data_src.size(0)

    if data_len < block_size + 1:
        raise ValueError(f"Dataset too small even for block_size={block_size}. token_len={data_len}")

    max_start = data_len - (block_size + 1)
    ix = torch.randint(0, max_start, (batch_size,))

    x = torch.stack([data_src[i:i+block_size] for i in ix])
    y = torch.stack([data_src[i+1:i+block_size+1] for i in ix])

    return x.to(device), y.to(device)

@torch.no_grad()
def estimate_loss(model):
    model.eval()
    out = {}
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out

#GPT model definition

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = (q @ k.transpose(-2, -1)) * (C ** -0.5)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        return wei @ self.value(x)

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        head_size = n_embd // num_heads
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.GELU(),
            nn.Linear(4*n_embd, n_embd),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.sa = MultiHeadAttention(n_head)
        self.ffwd = FeedForward()

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block() for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        tok = self.token_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=device))
        x = tok + pos
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            return logits, None

        B2, T2, C = logits.size()
        if targets.size(1) != T2:
            targets = targets[:, :T2]

        logits_flat = logits.reshape(B2*T2, C)
        targets_flat = targets.reshape(B2*T2)

        loss = F.cross_entropy(logits_flat, targets_flat)
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=50, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                vals, idxs = torch.topk(logits, top_k)
                probs = F.softmax(vals, dim=-1)
                next_token = idxs[0, torch.multinomial(probs[0], 1)].unsqueeze(0)
            else:
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs[0], 1).unsqueeze(0)
            idx = torch.cat((idx, next_token.to(device)), dim=1)
        return idx

#Train model

model = GPT().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print("Model parameters:", sum(p.numel() for p in model.parameters())/1e6, "M")

best_val = float('inf')
start = time.time()

for it in range(max_iters):
    if it % eval_interval == 0:
        losses = estimate_loss(model)
        print(f"[{it}/{max_iters}] train {losses['train']:.4f}, val {losses['val']:.4f}, time {int(time.time()-start)}s")

    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()

#Generation test

prompt = "Q: What is a stack?\nA:"
prompt_ids = torch.tensor([encode_text(prompt)], dtype=torch.long).to(device)
out = model.generate(prompt_ids, max_new_tokens=gen_max_tokens, temperature=gen_temperature, top_k=gen_top_k)

print("\n=== PROMPT ===")
print(prompt)
print("\n=== MODEL OUTPUT ===")
print(decode_ids(out[0].tolist()))


Device: cuda
Vocab size: 360
Tokenized dataset length: 966
Model parameters: 3.349352 M
[0/2000] train 6.0889, val 6.0181, time 0s
[200/2000] train 0.2747, val 6.2723, time 6s
[400/2000] train 0.0914, val 7.0785, time 12s
[600/2000] train 0.0824, val 7.1231, time 17s
[800/2000] train 0.0722, val 7.4078, time 23s
[1000/2000] train 0.0715, val 7.4757, time 28s
[1200/2000] train 0.0728, val 7.8019, time 34s
[1400/2000] train 0.0726, val 7.8066, time 39s
[1600/2000] train 0.0676, val 7.9799, time 45s
[1800/2000] train 0.0671, val 8.0532, time 50s

=== PROMPT ===
Q: What is a stack?
A:

=== MODEL OUTPUT ===
Q: What is a stack? A: A stack is a LIFO (Last In First Out) linear data structure used for storing and retrieving data. Q: What is a queue? A: A queue is a FIFO (First In First Out) linear data structure used for processing elements in order. Q: What is a linked list? A: A linked list is a dynamic data structure composed of nodes, where each node points to the next. Q: What is a doubly 

### Example Inference

This section demonstrates text generation from the trained DSAGPT model on a simple DSA-style prompt.


In [3]:
# Example inference
model.eval()

prompt = "Q: What is a binary search tree?\nA:"

# Encode prompt
prompt_ids = torch.tensor([encode_text(prompt)], dtype=torch.long).to(device)

# Generate output
with torch.no_grad():
    output_ids = model.generate(
        prompt_ids,
        max_new_tokens=25,
        temperature=0.8,
        top_k=30
    )

# Decode and print
generated_text = decode_ids(output_ids[0].tolist())

print("PROMPT:")
print(prompt)
print("\nMODEL OUTPUT:")
print(generated_text)

PROMPT:
Q: What is a binary search tree?
A:

MODEL OUTPUT:
Q: What is a binary search tree? A: A BST is a binary tree in which the left child contains smaller values and the right child contains larger values. Q: What is inorder
