<a href="https://colab.research.google.com/github/Indirajith-jithu/ArithmeticGPT/blob/Indirajith-jithu-patch-1/basic_math.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
 !pip install datasets==3.6.0

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F



device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from datasets import load_dataset

ds = load_dataset("EleutherAI/arithmetic", "arithmetic_1dc")

In [None]:
df = ds['validation'].to_pandas()

data = df['context'].str.split("is", expand=True)[1].str.split("?", expand=True)[0] + "="+ \
        df['completion']

In [None]:
data = data.str.replace(" ", "", regex=False).tolist()

In [None]:
PAD = "<PAD>"
EOD = "<EOD>"

chars = sorted(list(set("".join(data)))) + [PAD, EOD]
vocab_size = len(chars)

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}

pad_id = stoi[PAD]
eod_id = stoi[EOD]


encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join(itos[i] for i in l if itos[i] != PAD)


In [None]:
encoded = [encode(s) + [eod_id] for s in data]

max_len = max(len(s) for s in encoded)
# block_size = max_len - 1

padded = [
    seq + [pad_id] * (max_len - len(seq))
    for seq in encoded
]

dataset = torch.tensor(padded, dtype=torch.long)


# train / val split
n = int(0.9 * len(dataset))
train_data = dataset[:n]
val_data = dataset[n:]


In [None]:
def get_batch(split):
    data = train_data if split == "train" else val_data

    while True:
        ix = torch.randint(0, len(data), (batch_size,))
        batch = data[ix]

        x = batch[:, :-1]
        y = batch[:, 1:]

        if (y != pad_id).any():
            break

    return x.to(device), y.to(device)


In [None]:
# get_batch('train')

In [None]:
### model

In [None]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        self.register_buffer(
            "tril", torch.tril(torch.ones(block_size, block_size))
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)

        wei = q @ k.transpose(-2, -1) * (k.shape[-1] ** -0.5)

        # causal mask
        wei = wei.masked_fill(self.tril[:T, :T] == 0, -1e4)

        # key  mask
        wei = wei.masked_fill(attn_mask[:, None, :] == 0, -1e4)
        # query  mask
        wei = wei.masked_fill(attn_mask[:, :, None] == 0, -1e4)

        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v

        # zero out PAD queries
        # out = out * attn_mask[:, :, None]

        return out


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask):
        out = torch.cat([h(x, attn_mask) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [None]:
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


In [None]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x, attn_mask):
        x = x + self.sa(self.ln1(x), attn_mask)
        x = x + self.ffwd(self.ln2(x))
        return x


In [None]:
class GPTLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        self.blocks = nn.ModuleList(
            [Block(n_embd, n_head) for _ in range(n_layer)]
        )

        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        attn_mask = (idx != pad_id).long()

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(
            torch.arange(T, device=idx.device)
        )

        x = tok_emb + pos_emb

        for block in self.blocks:
            x = block(x, attn_mask)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            logits = logits.reshape(-1, vocab_size)
            targets = targets.reshape(-1)
            loss = F.cross_entropy(
                logits, targets, ignore_index=pad_id
            )

        return logits, loss


In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()

    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)

        for k in range(eval_iters):
            xb, yb = get_batch(split)
            _, loss = model(xb, yb)
            losses[k] = loss.item()

        out[split] = losses.mean()

    model.train()
    return out


In [None]:
96 / 8

In [None]:


# training
max_iters      = 6500
eval_interval  = 500
learning_rate  = 1e-3
eval_iters     = 200
batch_size     = 64

# model
n_embd   = 96                   # must be divisible by n_head
n_head   = 8
n_layer  = 8
dropout  = 0.1                  # lower dropout = better memorization

# context
block_size = max_len            # full equation including "=" and result

# device
device = 'cuda' if torch.cuda.is_available() else 'cpu'



In [None]:
model = GPTLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:


for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"{iter}: train {losses['train']:.4f}, val {losses['val']:.4f}")

    xb, yb = get_batch("train")
    _, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


In [None]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0):
    model.eval()

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]
        logits, _ = model(idx_cond)

        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)

        idx_next = torch.multinomial(probs, 1)
        idx = torch.cat([idx, idx_next], dim=1)

        if (idx_next == eod_id).all():
            break

    return idx



@torch.no_grad()
def generate(
    model,
    idx,
    max_new_tokens,
    pre_temp=1.0,
    post_temp=0.01
):
    model.eval()
    eq_id = stoi["="]
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]
        logits, _ = model(idx_cond)

        logits = logits[:, -1, :]  # (B, vocab)

        # detect "=" in each sequence
        has_equal = (idx == eq_id).any(dim=1)  # (B,)

        # apply per-sample temperature
        temps = torch.where(
            has_equal,
            torch.tensor(post_temp, device=idx.device),
            torch.tensor(pre_temp, device=idx.device)
        )

        logits = logits / temps.unsqueeze(1)
        probs = F.softmax(logits, dim=-1)

        idx_next = torch.multinomial(probs, 1)
        idx = torch.cat([idx, idx_next], dim=1)

        # stop if all sequences ended
        if (idx_next == eod_id).all():
            break

    return idx



In [None]:
context = torch.tensor([encode("(")], device=device)
out = generate(
    model,
    context,
    max_new_tokens=30
)
res = decode(out[0].tolist())
print(res)
print(eval(res.split("=")[0]))


In [None]:
inpu = res.split("=")[0]

In [None]:
out[0].tolist()

In [None]:
context = torch.tensor([encode(inpu)], device=device)
out = generate(
    model,
    context,
    max_new_tokens=30
)
res = decode(out[0].tolist())
print(res)

In [None]:
context = torch.tensor([encode("(6+6)*3")], device=device)
out = generate(
    model,
    context,
    max_new_tokens=30
)
res = decode(out[0].tolist())
print(res)

In [None]:
out[0].tolist()