In [1]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
# !pip install matplotlib

In [1]:
with open('input.txt', 'r') as file:
    text = file.read()

print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [2]:
import torch 

chars = sorted(list(set(text)))
vocab_size = len(chars)
print(vocab_size)


65


In [3]:
char_to_idx = { ch:i for i, ch in enumerate(chars) }
idx_to_char = { i:ch for i, ch in enumerate(chars) }

def encode(s):
    return [char_to_idx[ch] for ch in s]

def decode(l):
    return ''.join([idx_to_char[i] for i in l])

In [4]:
decode(encode("hii"))

'hii'

In [5]:
def load_data():
    data = torch.tensor(encode(text), dtype=torch.long)
    n = int(0.9*len(data))
    train_data = data[:n]
    val_data = data[n:]
    return train_data, val_data

train_data, val_data = load_data()
print(train_data[:100])

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [6]:
device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
# Add after device definition (around in[6])
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"Current device: {device}")
print(f"Default dtype: {torch.get_default_dtype()}")

# To enable BF16:
torch.set_default_dtype(torch.bfloat16)

CUDA Available: True
Current device: cuda:3
Default dtype: torch.float32


In [7]:


def load_batch(batch_size, block_size):
    random_indices = torch.randint(len(train_data) - block_size, (batch_size,))
    x = torch.stack([train_data[i:i+block_size] for i in random_indices])
    y = torch.stack([train_data[i+1:i+block_size+1] for i in random_indices])
    return x.to(device), y.to(device)

def prefetch_batches(num_batches, batch_size, block_size):
    batches = []
    xb, yb = load_batch(batch_size=batch_size*num_batches, block_size=block_size)
    # Reshape to split into num_batches
    xb = xb.view(num_batches, batch_size, -1) 
    yb = yb.view(num_batches, batch_size, -1)
    batches = list(zip(xb, yb))
    return batches

xb, yb = load_batch(batch_size=32, block_size=8)
print(xb.shape)
print(yb.shape)
print(decode(xb[0].tolist()))
print(decode(yb[0].tolist()))

torch.Size([32, 8])
torch.Size([32, 8])
YORK:
Wh
ORK:
Wha


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionHead(nn.Module):
    def __init__(self, n_embd, head_size, block_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.tril = torch.tril(torch.ones_like(torch.zeros(block_size, block_size))).to(device)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        weight = q @ k.transpose(-2, -1) * (C**-0.5) # (B, T, C) @ (B, C, T) -> (B, T, T)
        # put upper triangular part of weight to -inf
        weight = weight.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weight = F.softmax(weight, dim=-1)
        v = weight @ v
        return v





In [9]:
class MHA(nn.Module):
    def __init__(self, n_embd, n_heads, block_size):
        super().__init__()
        head_size = n_embd // n_heads
        self.heads = nn.ModuleList([AttentionHead(n_embd, head_size, block_size) for _ in range(n_heads)])

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)
    




In [10]:
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x):
        return self.net(x)


In [11]:
class Block(nn.Module):
    def __init__(self, n_embd, n_heads, block_size):
        super().__init__()
        self.sa = MHA(n_embd, n_heads, block_size)
        self.ffwd = FeedForward(n_embd)
        # Keep LayerNorm in fp32
        self.ln1 = nn.LayerNorm(n_embd).to(torch.float32)
        self.ln2 = nn.LayerNorm(n_embd).to(torch.float32)

    def forward(self, x):
        # Convert to fp32 for LayerNorm, then back to bf16
        x = self.ln1(x.to(torch.float32)).to(torch.bfloat16)
        x = x + self.sa(x)
        x = self.ln2(x.to(torch.float32)).to(torch.bfloat16)
        x = x + self.ffwd(x)
        return x

In [12]:
class GPT(nn.Module):
    def __init__(self, n_embd, n_heads, n_blocks, block_size):
        super().__init__()
        self.n_embd = n_embd
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        self.block_size = block_size
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_heads, block_size) for _ in range(n_blocks)])
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T).to(device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens, T = 1.0):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            if T > 0.01:
                logits = logits / T
                probs = F.softmax(logits, dim=-1)
                idx_next = torch.multinomial(probs, num_samples=1)
            else: # greedy sampling
                idx_next = logits.argmax(dim=-1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

    # class method
    @classmethod
    def from_checkpoint(cls, filepath, n_embd, n_heads, n_blocks, block_size):
        """Load model from a checkpoint file"""
        model = cls(n_embd=n_embd, n_heads=n_heads, n_blocks=n_blocks, block_size=block_size)
        model.load_state_dict(torch.load(filepath))
        return model


In [13]:
def estimate_loss(model, batch_size, block_size):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        X, Y = load_batch(batch_size, block_size)
        logits, loss = model(X, Y)
        out[split] = loss.item()
    model.train()
    return out


In [15]:
block_size = 128
n_blocks = 8
n_heads = 8
n_embd = 128


# try inference
# model = GPT(n_embd=n_embd, n_heads=n_heads, n_blocks=n_blocks, block_size=block_size).to(device)
model = GPT.from_checkpoint("model_19500.pth", n_embd, n_heads, n_blocks, block_size).to(device)
# model.compile()


  model.load_state_dict(torch.load(filepath))


Hello, how are you? How are you doing?
Is the my deeds deliver'd up from the diadem.
O God! if you be so?

KING RICHARD III:
O, nothing pi


In [19]:
message = "Hello, how are you? How are you doing?"
idx = torch.tensor(encode(message), dtype=torch.long).unsqueeze(0).to(device)
print(decode(model.generate(idx, max_new_tokens=300, T=0.3)[0].tolist()))

Hello, how are you? How are you doing?
If thou depart to come the law of them?

LUCIO:
I warrant thee, for I have heard him for his country
And manners that the sea mock'd the babe,
And the deceived of my strength weak a far
And see him show me on him.

FRIAR LAURENCE:
I will be king, and there be still.

CAMILLO:
Sir, my lord, this is 


In [16]:
lr_max = 1e-3  # Peak learning rate
lr_min = 1e-4  # Minimum learning rate
total_iters = 20000
warmup_iters = 0  # Number of warmup iterations
lr_decay_iters = total_iters  # Total number of iterations for lr decay
anneal_iters = 500

def get_lr(it):
    # Linear warmup for warmup_iters steps
    if it < warmup_iters:
        return lr_max * it / warmup_iters
    # Cosine learning rate decay
    if it > lr_decay_iters:
        return lr_min
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # Cosine decay
    return lr_min + coeff * (lr_max - lr_min)

def trapezoid_lr(it):
    if it < warmup_iters:
        return lr_max * it / warmup_iters
    if it > total_iters - anneal_iters:
        return lr_max - (it - (total_iters - anneal_iters)) * (lr_max - lr_min) / anneal_iters
    return lr_max


# Modify training loop
import math

In [17]:
# train
import time
from torch.amp import autocast
import matplotlib.pyplot as plt

lr = 1e-4

batch_size = 128
eval_interval = 100
prefetch_size = 10
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# Lists to store metrics for plotting
losses = []
learning_rates = []

start_time = time.time()
batches = prefetch_batches(prefetch_size, batch_size, block_size)
for i in range(total_iters):

    lr = trapezoid_lr(i)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    learning_rates.append(lr)

    batch_idx = i % prefetch_size
    if batch_idx == 0:
        batches = prefetch_batches(prefetch_size, batch_size, block_size)
    xb, yb = batches[batch_idx]
    with autocast('cuda', dtype=torch.bfloat16):
        logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    
    if i % eval_interval == 0:
        elapsed = time.time() - start_time
        print(f"iter {i}: loss {loss.item():.4f} (elapsed: {elapsed:.2f}s) lr {lr:.4f}")
        print("example output:", decode(model.generate(idx, max_new_tokens=100)[0].tolist()))
        out = estimate_loss(model, batch_size, block_size)
        print(f"train loss: {out['train']:.4f}, val loss: {out['val']:.4f}")
    
    # save a checkpoint before annealing
    if i == total_iters - anneal_iters:
        torch.save(model.state_dict(), f"model_{i}.pth")
# Plot loss and learning rate
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

ax1.plot(losses)
ax1.set_title('Training Loss')
ax1.set_xlabel('Iteration')
ax1.set_ylabel('Loss')

ax2.plot(learning_rates)
ax2.set_title('Learning Rate Schedule')
ax2.set_xlabel('Iteration') 
ax2.set_ylabel('Learning Rate')

plt.tight_layout()
plt.show()


iter 0: loss 1.1986 (elapsed: 50.62s) lr 0.0010
example output: Hello, how are you? How are you doing?
son that shun's musil'd too blowshin, though noney
tones ouncious plotshion, a roblion
ttone, and t


In [35]:
message = "To be or not"
idx = torch.tensor(encode(message), dtype=torch.long).unsqueeze(0).to(device)
print (decode(model.generate(idx, max_new_tokens=300)[0].tolist()))

To be or not unnatural chance: the can bite
yet: but yet gone to all the struck in exile;
Condemns thee humour fly: where was he move
homely--as now my mother, my mouse will gaze him,
Mumusic, O see me bid prince disobed,
And weak straight-westeed were substitute
From thyself to sweat but a suit: how some stone
