In [None]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

In [None]:
torch.manual_seed(1337)

<torch._C.Generator at 0x79c88c7a27b0>

In [None]:
# Get the shakespeare karpathy dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-04-24 19:41:24--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-04-24 19:41:24 (30.0 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [None]:
with open('input.txt', 'r') as f:
    text = f.read()

In [None]:
print("Num of chars:", len(text))

Num of chars: 1115394


In [None]:
text[:1000]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger 

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars) # our vocab size is all the unique characters used in the Shakespeare data
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [None]:
# Now we want to tokenize our characters by mapping each character to a number
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s] # encoder takes a string, outputs a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder takes a list of integers, outputs a string

In [None]:
print(encode("hello"))
print(decode(encode("hello")))

[46, 43, 50, 50, 53]
hello


In [None]:
# Now encode the entire text and save it to a tensor
data = torch.tensor(encode(text), dtype=torch.long)
data.shape, data.dtype

(torch.Size([1115394]), torch.int64)

In [None]:
#split the data into train and val sets
n = int(0.9*len(data))
train = data[:n]
val = data[n:]
print(len(train), len(val))

1003854 111540


In [None]:
block_size = 8
train[:block_size+1]
# We use a sequence length of 8, but we take the data up to and including sequence length# + 1
# because we want to know for 8 chars, what the pred for next char should be

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [None]:
x = train[:block_size]
y = train[1:block_size+1]
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f"When input is {context}, the target is: {block_size}")

When input is tensor([18]), the target is: 8
When input is tensor([18, 47]), the target is: 8
When input is tensor([18, 47, 56]), the target is: 8
When input is tensor([18, 47, 56, 57]), the target is: 8
When input is tensor([18, 47, 56, 57, 58]), the target is: 8
When input is tensor([18, 47, 56, 57, 58,  1]), the target is: 8
When input is tensor([18, 47, 56, 57, 58,  1, 15]), the target is: 8
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is: 8


In [None]:
# Now we want to generate batches to process in parallel
batch_size = 4 # how many batches do we want to process in parallel
block_size = 8 # maximum sequence length for predictions

def get_batch(data, batch_size, block_size):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  indices = torch.randint(0, len(data)-block_size, (batch_size, )) # Creating random starting indices for however many batches we want
  x = torch.stack([data[i : i + block_size] for i in indices])
  y = torch.stack([data[i + 1 : i + block_size + 1] for i in indices])
  return x.to(device), y.to(device)

In [None]:
# GPT uses a decoder-only structure from the transformer, which doesn't include cross-attention
# First we are going to start with the input embedding
class InputEmbedding(nn.Module):
  def __init__(self, vocab_size, d_model):
    super().__init__()
    self.d_model = d_model
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    return self.embedding(x).type(torch.float32)

# Takes tensor of shape (batch size, seq len) -> (batch size, seq len, embedding dim)

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_seq_len):
    super().__init__()
    encoding_matrix = torch.zeros(max_seq_len, d_model)
    position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) # column vec of positions for vectorization
    div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))
    # Each index in positional encoding accounts for TWO dimensions in an embedding vector
    # What I do here is create the denominator term for the even set of indices,
    # then apply it to both even and odd dimensions in the embeding
    encoding_matrix[:, 0::2] = torch.sin(position*div_term) # even indices, start at 0 stepsize 2
    encoding_matrix[:, 1::2] = torch.cos(position*div_term) # odd indices, start at 1 stepsize 2
    self.register_buffer("pos_encoding", encoding_matrix)
  def forward(self, x):
    return x + self.pos_encoding[:x.size(1)] # Return positional encodings for the target sequence length

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, n_heads, max_seq_len, dropout_rate):
    super().__init__()
    self.n_heads = n_heads
    self.d_head = d_model // n_heads
    self.W_qkv = nn.Linear(d_model, 3*d_model) # One matrix for q, k, v for computation efficiency
    self.dropout = nn.Dropout(dropout_rate)
    self.W_o = nn.Linear(d_model, d_model)
    self.d_model = d_model
    self.mask = self.register_buffer("bias", torch.tril(torch.ones(max_seq_len, max_seq_len))
                    .unsqueeze(0).unsqueeze(1))
    self.dropout_rate = dropout_rate

  def ScaledDotProductAttention(self, q, k, v, mask = None):
    d_k = k.size(-1) # Q and K have shape (batch_size, n heads, seq_len, head dim)
    scores = q@k.transpose(-2, -1) / math.sqrt(d_k) # (batch_size, n heads, q_len, k_len)
    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = self.dropout(F.softmax(scores, dim=-1)) # use last dim because we want to sum across keys to get probs for each query
    return attn@v # (batch_size, n_heads, q_len, head dim)

  def forward(self, x, kv_cache=None):

    qkv = self.W_qkv(x)

    q, k, v = torch.split(qkv, self.d_model, dim=-1) # Split into q, k, v

    batch_size_q, seq_len_q, d_model_q = q.shape
    batch_size_k, seq_len_k, d_model_k = k.shape
    batch_size_v, seq_len_v, d_model_v = v.shape

    q = q.reshape(batch_size_q, seq_len_q, self.n_heads, self.d_head).transpose(1, 2)
    k = k.reshape(batch_size_k, seq_len_k, self.n_heads, self.d_head).transpose(1, 2)
    v = v.reshape(batch_size_v, seq_len_v, self.n_heads, self.d_head).transpose(1, 2)


    # KV cache during inference
    if kv_cache is not None:
      if seq_len_k == 1: # during generation only
        k_full = torch.cat((kv_cache['k'], k), dim=-2) # want to concatenate on seq_len (T) axis
        v_full = torch.cat((kv_cache['v'], v), dim=-2)
        kv_cache_out = {"k": k_full, "v": v_full}
      else:
        k_full = k
        v_full = v
        kv_cache_out = {"k": k_full, "v": v_full}
    else:
      k_full = k
      v_full = v
      kv_cache_out = {"k": k_full, "v": v_full}


    mask = self.bias[:, :, :seq_len_q, :seq_len_q]

    attn = self.ScaledDotProductAttention(q, k_full, v_full, mask) # Shape (batch size, # of heads, seq_len, # embedding dim per head)
    attn_concat = attn.transpose(1, 2).reshape(batch_size_q, seq_len_q, d_model_q) # return axes to original posns before reshaping
    mha_output = self.W_o(attn_concat)

    return mha_output, kv_cache_out # shape (batch size, seq len, embedding dimension)

In [None]:
class FeedForward(nn.Module):
  def __init__(self, d_model, d_ff, dropout_rate):
    super().__init__()
    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, x):
    z1 = self.linear1(x)
    a1 = self.dropout(F.gelu(z1))
    z2 = self.linear2(a1)
    return z2

In [None]:
class LayerNorm(nn.Module):
  def __init__(self, d_model, eps=1e-6):
    super().__init__()
    self.gamma = nn.Parameter(torch.ones(d_model))
    self.beta = nn.Parameter(torch.zeros(d_model))
    self.eps = eps

  def forward(self, x):
    mean = x.mean(-1, keepdim=True) # keep dim to maintain tensor dimensionality
    std = x.std(-1, keepdim=True)
    return self.gamma * (x - mean) / (std + self.eps) + self.beta

# In batch normalization, each feature is normalized across all samples (like axis = 0)
# In layer normalization, all features are normalized across each sample (like axis = 1)

In [None]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, n_heads, d_ff, dropout_rate, max_seq_len):
    super().__init__()
    self.masked_attention = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout_rate)
    self.norm1 = LayerNorm(d_model)
    self.ff = FeedForward(d_model, d_ff, dropout_rate)
    self.norm2 = LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, x, kv_cache=None):
    norm_x = self.norm1(x.type(torch.float32))
    attn_output, kv_cache_out = self.masked_attention(norm_x, kv_cache)
    x = x + self.dropout(attn_output)
    norm_x = self.norm2(x.type(torch.float32))
    x = x + self.dropout(self.ff(norm_x))
    return x, kv_cache_out

In [None]:
class MiniGPT(nn.Module):
  def __init__(self, d_model, d_ff, n_heads, dropout_rate, vocab_size, n_layers, max_seq_len):
    super().__init__()
    self.embedding = InputEmbedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
    self.decoders = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout_rate, max_seq_len) for _ in range(n_layers)])

    self.final_layernorm = LayerNorm(d_model)

    self.linear_output = nn.Linear(d_model, vocab_size)
    self.max_seq_len = max_seq_len

  def forward(self, x, targets=None, decoders_kv_cache=None):

    batch_size, seq_len = x.shape

    token_embedding = self.embedding(x)
    final_input = self.pos_encoding(token_embedding) # batch size, seq_len, embedding dim

    updated_decoders_cache = {} # To store the updated kv caches of each decoder
    for i, decoder in enumerate(self.decoders):
      # A decoder cache looks like {0: {k: ..., v:...}, 1: {k:..., v:...}}
      if decoders_kv_cache is not None:
        decoder_cache = decoders_kv_cache[i]
      else:
        decoder_cache = None

      final_input, updated_decoder_cache = decoder(final_input, decoder_cache)
      updated_decoders_cache[i] = updated_decoder_cache

    final_input = self.final_layernorm(final_input)
    logits = self.linear_output(final_input)

    if targets is not None:
      batch_size, seq_len, vocabulary_size = logits.shape
      logits = logits.reshape(batch_size*seq_len, vocabulary_size)
      targets = targets.reshape(batch_size*seq_len)
      loss = F.cross_entropy(logits, targets)
      return logits, loss, updated_decoders_cache
    else:
      return logits, updated_decoders_cache

  @torch.no_grad()
  def generate(self, idx, max_new_tokens):

    kv_cache = {}

    for _ in range(max_new_tokens):
      logits, kv_cache = self.forward(idx, decoders_kv_cache=kv_cache) # shape batch size, seq len, vocab size
      next_token_logits = logits[:, -1, :] # shape of batch size, vocab size (we selected predictions for the next token position)
      probs = F.softmax(next_token_logits, dim=-1) # now we want probabilities of the next token across the vocab size

      idx_next = torch.multinomial(probs, num_samples = 1)
      idx = torch.cat((idx, idx_next), dim=1) # shape batch size, seq len + 1

    return idx




In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MiniGPT(
    d_model=128,
    d_ff=512,
    n_heads=4,
    dropout_rate=0.1,
    vocab_size=vocab_size,
    n_layers=4,
    max_seq_len=128
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1, betas=(0.9, 0.95))

# Early stopping parameters
patience = 5
best_val_loss = float('inf')
patience_counter = 0
best_model_state = None

for step in range(50000):  # adjust iterations as needed
    # Training step
    x, y = get_batch(train, batch_size=64, block_size=128)
    x, y = x.to(device), y.to(device)

    model.train()
    logits, loss, _ = model(x, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Validation step
    if step % 1000 == 0:  # Check validation every 100 steps
        model.eval()
        with torch.no_grad():
            val_x, val_y = get_batch(val, batch_size=64, block_size=128)
            val_x, val_y = val_x.to(device), val_y.to(device)
            _, val_loss, _= model(val_x, val_y)

        print(f"Step {step} | Train Loss: {loss.item():.4f} | Val Loss: {val_loss.item():.4f}")
        '''
        # Early stopping check
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            patience_counter = 0
            # Save the best model
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at step {step}")
                break
        '''
    elif step % 1000 == 0:
        print(f"Step {step} | Train Loss: {loss.item():.4f}")

In [None]:
prompt = "What is the answer to life, the universe, and everything?"
prompt = torch.tensor([encode(prompt)], dtype = torch.long).to(device)
prompt.shape
generated_idx = model.generate(prompt, max_new_tokens=100)
output = decode(generated_idx[0].tolist())
print(output)