In [None]:
from datasets import load_dataset
import re

def clean(s):
    s = s.replace("\t", " ").replace("  ", " ")
    s = re.sub(r"http\S+", "", s)
    return s.strip()

print("Loading DailyDialog...")
dd = load_dataset("daily_dialog")["train"]

chat_lines = []

for dialog in dd["dialog"]:
    turns = [t.strip() for t in dialog if t.strip()]
    speaker = "User"
    for t in turns:
        chat_lines.append(f"{speaker}: {clean(t)}")
        speaker = "Assistant" if speaker == "User" else "User"
    chat_lines.append("")

print("DailyDialog processed.")

print("Loading WikiText-2 (smaller, safe)...")
wiki = load_dataset("wikitext", "wikitext-2-raw-v1")["train"]["text"]
wiki_raw = clean("\n".join(wiki))
wiki_raw = wiki_raw[:200_000]  # limit size ~200k chars

wiki_sents = re.split(r'(?<=[.!?])\s+', wiki_raw)
for s in wiki_sents:
    s = s.strip()
    if len(s) > 5:
        chat_lines.append(f"Assistant: {s}")
chat_lines.append("")

final_text = "\n".join(chat_lines)

with open("datacorpus.txt", "w", encoding="utf-8") as f:
    f.write(final_text)

print("Saved datacorpus.txt successfully!")
print("Size (MB):", len(final_text) / 1e6)


## Model Implementation

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class Head(nn.Module):
    """ One head of self-attention """

    def __init__(self, head_size, n_embd, block_size):
        super().__init__()

        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        # causal mask: to prevent attention to future tokes
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        B, T, C = x.shape

        # compute q, k, v projections
        k = self.key(x)     # (B, T, head_size)
        q = self.query(x)   # (B, T, head_size)

        # compute attention scores
        # (B, T, head_size) @ (B, head_size, T) = (B, T, T)
        wei = q @ k.transpose(-2, -1)

        # scale
        wei = wei / (k.shape[-1] ** 0.5)

        # apply mask
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))

        # softmax
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # weighted sum over values
        v = self.value(x)  # (B, T, head_size)
        out = wei @ v      # (B, T, head_size)

        return out


class MultiHeadAttention(nn.Module):
    """ Multiple attention heads in parallel """

    def __init__(self, num_heads, head_size, n_embd, block_size):
        super().__init__()
        self.heads = nn.ModuleList([
            Head(head_size, n_embd, block_size)
            for _ in range(num_heads)
        ])
        self.proj = nn.Linear(num_heads * head_size, n_embd)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # concatenate all head outputs on the feature dimension
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class FeedForward(nn.Module):
    """ feed-forward network """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),   # activation
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(0.1),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication + computation """

    def __init__(self, n_embd, num_heads, block_size):
        super().__init__()

        head_size = n_embd // num_heads

        self.sa = MultiHeadAttention(num_heads, head_size, n_embd, block_size)
        self.ffn = FeedForward(n_embd)

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # first sub-layer: self-attention
        x = x + self.sa(self.ln1(x))

        # second sub-layer: feed-forward
        x = x + self.ffn(self.ln2(x))

        return x


class GPTLM(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, n_layers, n_heads, dropout=0.1):
        super().__init__()
        self.block_size = block_size
        self.n_embd = n_embd

        # embeddings
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        # transformer blocks
        self.blocks = nn.ModuleList([
            Block(n_embd=n_embd, num_heads=n_heads, block_size=block_size)
            for _ in range(n_layers)
        ])

        # final norm and linear head
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        self.lm_head.weight = self.token_embedding_table.weight

        self.dropout = nn.Dropout(dropout)

    def forward(self, idx, targets=None):
        """
        idx: (B, T) token indices
        targets: (B, T) token indices (next tokens), optional

        Returns:
          - if targets is None: logits (B, T, vocab_size)
          - else: (logits_flat, loss)
        """
        B, T = idx.shape
        assert T <= self.block_size, "input length T must be <= block_size"

        # token + position embeddings
        tok_emb = self.token_embedding_table(idx)          # (B, T, n_embd)
        pos = torch.arange(T, device=idx.device)           # (T,)
        pos_emb = self.position_embedding_table(pos)       # (T, n_embd)
        x = tok_emb + pos_emb                              # broadcasting -> (B, T, n_embd)
        x = self.dropout(x)

        # transformer blocks
        for block in self.blocks:
            x = block(x)

        # final norm, linear head to vocab
        x = self.ln_f(x)                                   # (B, T, n_embd)
        logits = self.lm_head(x)                           # (B, T, vocab_size)

        if targets is None:
            return logits

        # compute loss (flatten B*T)
        B, T, C = logits.shape
        logits_flat = logits.view(B * T, C)
        targets_flat = targets.view(B * T)
        loss = F.cross_entropy(logits_flat, targets_flat)
        return logits_flat, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        idx: (B, T) initial context
        returns: (B, T + max_new_tokens)
        """
        for _ in range(max_new_tokens):
            # crop to block_size if needed
            if idx.size(1) > self.block_size:
                idx_cond = idx[:, -self.block_size:]
            else:
                idx_cond = idx

            # forward to get logits for current context
            logits = self(idx_cond)             # (B, T_cond, vocab)
            logits = logits[:, -1, :]           # (B, vocab) — only last time step

            # temperature
            if temperature != 1.0:
                logits = logits / temperature

            # top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                min_topk = v[:, -1].unsqueeze(-1)
                logits = torch.where(logits < min_topk, torch.full_like(logits, -1e10), logits)

            probs = F.softmax(logits, dim=-1)   # (B, vocab)
            next_token = torch.multinomial(probs, num_samples=1)  # (B, 1)
            idx = torch.cat((idx, next_token), dim=1)
        return idx


save checkpoint

In [None]:
torch.save({
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "step": 3000
}, "checkpoint.pth")

print("Checkpoint saved!")


# Load tokenizer, dataset, model

In [2]:
import torch
from tokenizers import Tokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

tokenizer = Tokenizer.from_file("bpe_tokenizer.json")

with open("datacorpus.txt", "r", encoding="utf-8") as f:
    text = f.read()

ids = tokenizer.encode(text).ids
data = torch.tensor(ids, dtype=torch.long)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

batch_size = 64
block_size = 256
vocab_size = tokenizer.get_vocab_size()

def get_batch(split):
    d = train_data if split == "train" else val_data
    ix = torch.randint(0, len(d) - block_size - 1, (batch_size,))
    x = torch.stack([d[i:i+block_size] for i in ix]).to(device)
    y = torch.stack([d[i+1:i+block_size+1] for i in ix]).to(device)
    return x, y

model = GPTLM(
    vocab_size=vocab_size,
    n_embd=768,         
    block_size=block_size,
    n_layers=8,
    n_heads=8,
    dropout=0.1
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")


Device: cuda
Model parameters: 63.026688 M


## Load check points

In [4]:
checkpoint = torch.load("checkpoint.pth", map_location=device)

model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])

start_step = checkpoint["step"]

print("Checkpoint loaded at step:", start_step)

Checkpoint loaded at step: 39800


## Continue training

In [None]:
max_steps = 55000

checkpoint = torch.load("checkpoint.pth", map_location=device)

model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])

start_step = checkpoint["step"]

print("Checkpoint loaded at step:", start_step)

for step in range(start_step, max_steps):
    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(step, loss.item())

        torch.save({
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "step": step
        }, "checkpoint.pth")
        print("Checkpoint saved.")

generate output

In [None]:
bos = tokenizer.token_to_id("<bos>")
context = torch.tensor([[bos]], dtype=torch.long).to(device)

out = model.generate(context, max_new_tokens=300, temperature=0.8, top_k=50)
print(tokenizer.decode(out[0].tolist()))


## Training from text corpus (for initial training)

In [None]:
import torch
from tokenizers import Tokenizer, models, trainers, pre_tokenizers

with open("datacorpus.txt", "r", encoding="utf-8") as f:
    text = f.read()

tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()

trainer = trainers.BpeTrainer(
    vocab_size=8000,
    min_frequency=2,
    special_tokens=["<pad>", "<unk>", "<bos>", "<eos>"]
)

tokenizer.train_from_iterator([text], trainer)
tokenizer.save("bpe_tokenizer.json")

tokenizer = Tokenizer.from_file("bpe_tokenizer.json")

ids = tokenizer.encode(text).ids
data = torch.tensor(ids, dtype=torch.long)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
print("done")

In [None]:
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("bpe_tokenizer.json")

def encode(text):
    return torch.tensor([tokenizer.encode(text).ids], dtype=torch.long).to(device)

def decode(ids):
    return tokenizer.decode(ids)

def chat_turn(history, user_message, max_new_tokens=150, temperature=0.8, top_k=50):
    model.eval()
    history += f"User: {user_message}\nAssistant:"
    x = encode(history)
    y = model.generate(x, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
    full = decode(y[0].tolist())
    if "Assistant:" in full:
        assistant_reply = full.split("Assistant:", 1)[1].strip()
    else:
        assistant_reply = full
    history += " " + assistant_reply + "\n"
    return assistant_reply, history

history = ""

print("Chat started. Type 'exit' or 'quit' to stop.\n")

while True:
    user_input = input("User: ").strip()
    if user_input.lower() in ["exit", "quit", "stop"]:
        print("Ending chat.")
        break
    reply, history = chat_turn(history, user_input)
    print("Assistant:", reply)
