In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define model hyperparameters
BATCH_SIZE = 32  # Parallel sequences processed
CONTEXT_WINDOW = 8  # Max context length for predictions
EPOCHS = 5000
CHECKPOINT_INTERVAL = 500
LR = 3e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EVAL_ITERS = 200
EMBEDDING_DIM = 384
HEADS = 6
LAYERS = 6
DROPOUT_RATE = 0.2

torch.manual_seed(457)

# Load dataset
with open('stoic.txt', 'r', encoding='utf-8') as file:
    corpus = file.read()

# Character encoding setup
char_list = sorted(set(corpus))
VOCAB_SIZE = len(char_list)
char_to_index = {ch: i for i, ch in enumerate(char_list)}
index_to_char = {i: ch for i, ch in enumerate(char_list)}

encode_text = lambda s: [char_to_index[c] for c in s]
decode_text = lambda l: ''.join([index_to_char[i] for i in l])

# Train-validation split
data_tensor = torch.tensor(encode_text(corpus), dtype=torch.long)
split_idx = int(0.9 * len(data_tensor)) #we'll be training with the first 90% of the data and do validation with the rest
train_data, val_data = data_tensor[:split_idx], data_tensor[split_idx:]


# Function to generate mini-batches
def get_batch(mode):
    dataset = train_data if mode == 'train' else val_data
    idxs = torch.randint(len(dataset) - CONTEXT_WINDOW, (BATCH_SIZE,))
    x_batch = torch.stack([dataset[i:i + CONTEXT_WINDOW] for i in idxs])
    y_batch = torch.stack([dataset[i + 1:i + CONTEXT_WINDOW + 1] for i in idxs])
    return x_batch.to(DEVICE), y_batch.to(DEVICE)

@torch.no_grad()
def compute_loss():
    losses = {}
    model.eval()
    for mode in ['train', 'val']:
        batch_losses = torch.zeros(EVAL_ITERS)
        for i in range(EVAL_ITERS):
            x, y = get_batch(mode)
            _, loss = model(x, y)
            batch_losses[i] = loss.item()
        losses[mode] = batch_losses.mean()
    model.train()
    return losses

In [21]:
#singkle head of attention
class Head(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        self.key = nn.Linear(EMBEDDING_DIM, head_dim, bias=False)
        self.query = nn.Linear(EMBEDDING_DIM, head_dim, bias=False)
        self.value = nn.Linear(EMBEDDING_DIM, head_dim, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(CONTEXT_WINDOW, CONTEXT_WINDOW)))
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, x):
        #here B-> batch, T -> time-step, C -> channels
        B, T, C = x.shape
        k, q, v = self.key(x), self.query(x), self.value(x) # dimesnsion of this is (B,T,HeadSize)

        # compute attention scores ("affinities")
        attention_scores = (q @ k.transpose(-2, -1)) * k.shape[-1]**-0.5
        attention_scores = attention_scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        return attention_probs @ v



class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, EMBEDDING_DIM)
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    def __init__(self, EMBEDDING_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(EMBEDDING_DIM, 4 * EMBEDDING_DIM),
            nn.ReLU(),
            nn.Linear(4 * EMBEDDING_DIM, EMBEDDING_DIM),
            nn.Dropout(DROPOUT_RATE),
        )

    def forward(self, x):
        return self.net(x)

#communication followed by computation
class Block(nn.Module):
    def __init__(self, EMBEDDING_DIM, HEADS):
        super().__init__()
        head_size = EMBEDDING_DIM // HEADS
        self.mha = MultiHeadAttention(HEADS, head_size)
        self.ffwd = FeedForward(EMBEDDING_DIM)
        self.ln1 = nn.LayerNorm(EMBEDDING_DIM)
        self.ln2 = nn.LayerNorm(EMBEDDING_DIM)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


In [22]:
class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.position_embedding_table = nn.Embedding(CONTEXT_WINDOW, EMBEDDING_DIM)
        self.blocks = nn.Sequential(*[Block(EMBEDDING_DIM, HEADS=HEADS) for _ in range(LAYERS)])
        self.ln_f = nn.LayerNorm(EMBEDDING_DIM)
        self.lm_head = nn.Linear(EMBEDDING_DIM, VOCAB_SIZE)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=DEVICE))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -CONTEXT_WINDOW:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

model = LanguageModel()
m = model.to(DEVICE)

print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# We'll be using the AdamW optimizer for best results
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    if epoch % CHECKPOINT_INTERVAL == 0 or epoch == EPOCHS - 1:
        losses = compute_loss()
        print(f"step {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
print(decode_text(m.generate(context, max_new_tokens=500)[0].tolist()))

10.713691 M parameters
step 0: train loss 4.4460, val loss 4.4495
step 500: train loss 2.1563, val loss 2.2279
step 1000: train loss 2.0441, val loss 2.1356
step 1500: train loss 1.9811, val loss 2.0880
step 2000: train loss 1.9382, val loss 2.0526
step 2500: train loss 1.8820, val loss 2.0114
step 3000: train loss 1.8584, val loss 1.9786
step 3500: train loss 1.8271, val loss 1.9861
step 4000: train loss 1.8289, val loss 1.9594
step 4500: train loss 1.7941, val loss 1.9273
step 4999: train loss 1.7831, val loss 1.9317

have‟m? Are man will yhe he good;* what the subber the sorking see as sect not if are is of ke that peopion, then to one do, know Wish these contrantingure not‟re time elsels that gerthing twhigh to may inhored to things wayschd, and lityus
of the low
your? Aster in it a reper a soulf, ha8*
what therad. The elimmb up hen hast if the was diver stiveradly
to he this is to unn rehile is pleciad it an thou seerve coun chat that roough by did to she
charoed is resurgery to p

In [23]:
open('more_stoic.txt', 'w').write(decode_text(m.generate(context, max_new_tokens=10000)[0].tolist()))

10001

In [24]:
print(decode_text(m.generate(context, max_new_tokens=500)[0].tolist()))


hy them that the say comnd afteren he heap? To should urate demire the
perion I men to what the are other, bundly you abOut mone to gry you are in if tobutt (on the make this man if world with a you awaturelf or charef and all may eatx, and or they worther worde. Of alsef. But in refferenter a coelation: with there colsist the prenestus the such desire, or would that
is
also to life to they on bean ablamle to
to plisicestess teycarly nor withere the growife you from lige to ovherar?
which has si
