In [45]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

In [46]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7eccac7a22f0>

In [47]:
# Get the shakespeare karpathy dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-05-25 22:37:34--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.2’


2025-05-25 22:37:34 (123 MB/s) - ‘input.txt.2’ saved [1115394/1115394]



In [48]:
with open('input.txt', 'r') as f:
    text = f.read()

In [49]:
print("Num of chars:", len(text))

Num of chars: 1115394


In [50]:
text[:1000]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger 

In [51]:
chars = sorted(list(set(text)))
vocab_size = len(chars) # our vocab size is all the unique characters used in the Shakespeare data
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [52]:
# Now we want to tokenize our characters by mapping each character to a number
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s] # encoder takes a string, outputs a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder takes a list of integers, outputs a string

In [53]:
print(encode("hello"))
print(decode(encode("hello")))

[46, 43, 50, 50, 53]
hello


In [54]:
# Now encode the entire text and save it to a tensor
data = torch.tensor(encode(text), dtype=torch.long)
data.shape, data.dtype

(torch.Size([1115394]), torch.int64)

In [55]:
#split the data into train and val sets
n = int(0.9*len(data))
train = data[:n]
val = data[n:]
print(len(train), len(val))

1003854 111540


In [56]:
block_size = 8
train[:block_size+1]
# We use a sequence length of 8, but we take the data up to and including sequence length# + 1
# because we want to know for 8 chars, what the pred for next char should be

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [57]:
x = train[:block_size]
y = train[1:block_size+1]
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f"When input is {context}, the target is: {target}")

When input is tensor([18]), the target is: 47
When input is tensor([18, 47]), the target is: 56
When input is tensor([18, 47, 56]), the target is: 57
When input is tensor([18, 47, 56, 57]), the target is: 58
When input is tensor([18, 47, 56, 57, 58]), the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]), the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]), the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is: 58


In [58]:
# Now we want to generate batches to process in parallel
batch_size = 4 # how many batches do we want to process in parallel
block_size = 8 # maximum sequence length for predictions

def get_batch(data, batch_size, block_size):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  indices = torch.randint(0, len(data)-block_size, (batch_size, )) # Creating random starting indices for however many batches we want
  x = torch.stack([data[i : i + block_size] for i in indices])
  y = torch.stack([data[i + 1 : i + block_size + 1] for i in indices])
  return x.to(device), y.to(device)

In [59]:
# GPT uses a decoder-only structure from the transformer, which doesn't include cross-attention
# First we are going to start with the input embedding
class InputEmbedding(nn.Module):
  def __init__(self, vocab_size, d_model):
    super().__init__()
    self.d_model = d_model
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    return self.embedding(x)

# Takes tensor of shape (batch size, seq len) -> (batch size, seq len, embedding dim)

In [60]:
class RoPE(nn.Module):
  def __init__(self, d_rope, max_seq_len):
    super().__init__()
    indices = torch.arange(d_rope) # d_rope,
    pairs = indices // 2
    positions = torch.arange(max_seq_len).unsqueeze(1) # M x 1
    thetas = 10000 ** (-2 * pairs / d_rope) # d_rope,
    rotations = positions*thetas
    self.cos_rotations = torch.cos(rotations)
    self.sin_rotations = torch.sin(rotations)
  def forward(self, x):
    cos_vals = self.cos_rotations[:x.size(-2)].to(x.device)
    sin_vals = self.sin_rotations[:x.size(-2)].to(x.device)
    x_swapped = torch.zeros_like(x)
    x_swapped[:,:,0::2] = -x[:,:,1::2]
    x_swapped[:,:,1::2] = x[:,:,0::2]
    rotated_x = x*cos_vals + x_swapped*sin_vals
    return rotated_x

In [61]:
class RMSNorm(nn.Module):
  def __init__(self, d_rms, eps=1e-6):
    super().__init__()
    self.gamma = nn.Parameter(torch.ones(d_rms))
    self.eps = eps

  def forward(self, x):
    rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
    return (x/rms) * self.gamma

In [62]:
# Deepseek Multi Latent Attention without ROPE
class MultiLatentAttention(nn.Module):
  def __init__(self, d_model, d_rope, n_heads, d_latent_q, d_latent_kv, max_seq_len, dropout_rate):
    super().__init__()
    self.d_model = d_model
    self.d_rope = d_rope
    self.d_latent_kv = d_latent_kv
    self.d_latent_q = d_latent_q
    self.n_heads = n_heads
    # nn Linear takes in_dim, out_dim as arguments
    self.W_DQ = nn.Linear(in_features=d_model, out_features=d_latent_q) # down projection to project hidden state into latent space
    self.W_UQ = nn.Linear(d_latent_q, d_model) # up projection back to hidden state
    self.W_DKV = nn.Linear(d_model, d_latent_kv) # down projection for shared KV input
    self.W_UK = nn.Linear(d_latent_kv, d_model) # up projection for K
    self.W_UV = nn.Linear(d_latent_kv, d_model) # up projection for V
    # Decoupled rope
    self.W_QR = nn.Linear(d_latent_q, d_rope*n_heads) # Problem: initializing with shape d_rope * n_heads
    self.rope_q = RoPE(d_rope*n_heads, max_seq_len)
    self.W_KR = nn.Linear(d_model, d_rope)
    self.rope_k = RoPE(d_rope, max_seq_len)
    # Output matrix for concatenated attention
    self.W_O = nn.Linear(d_model, d_model)
    self.dropout = nn.Dropout(dropout_rate)
    self.register_buffer("bias", torch.tril(torch.ones(max_seq_len, max_seq_len)).unsqueeze(0).unsqueeze(1))
    self.rms_q = RMSNorm(d_latent_q)
    self.rms_kv = RMSNorm(d_latent_kv)
    self.q_scale = nn.Parameter(torch.tensor(1.0))
    self.kv_scale = nn.Parameter(torch.tensor(1.0))

  def ScaledDotProductAttention(self, q, k, v, mask = None):
    d_k = k.size(-1) # Q and K have shape (batch_size, n heads, seq_len, head dim)
    scores = q@k.transpose(-2, -1) / math.sqrt(d_k) # (batch_size, n heads, q_len, k_len)
    #print("Scores shape:", scores.shape)
    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))
      #print("Mask shape:", mask.shape)
    attn = self.dropout(F.softmax(scores, dim=-1)) # use last dim because we want to sum across keys to get probs for each query
    return attn@v # (batch_size, n_heads, T, d_head)

  def forward(self, x):
    c_q = self.W_DQ(x) # B, T, d_latent_q
    c_q = self.q_scale * self.rms_q(c_q) # B, T, d_latent_q

    q_latent = self.W_UQ(c_q) # B, T, d_model
    q_rope = self.rope_q(self.W_QR(c_q)) # B, T, d_rope*n_heads

    B, T, _ = q_latent.shape
    d_head = self.d_model // self.n_heads
    q_latent = q_latent.reshape(B, T, self.n_heads, d_head).transpose(1, 2) # B, n_heads, T, d_head

    B, T, _ = q_rope.shape
    q_rope = q_rope.reshape(B, T, self.n_heads, self.d_rope).transpose(1, 2) # B, n_heads, T, d_rope

    q = torch.cat((q_latent, q_rope), dim=-1) # B, n_heads, T, d_rope + d_head

    c_kv = self.W_DKV(x) # B, T, d_latent_kv
    c_kv = self.kv_scale * self.rms_kv(c_kv)

    k_latent = self.W_UK(c_kv) #B, T, d_model
    k_rope = self.rope_k(self.W_KR(x)) # B, T, d_rope

    B, T, _ = k_latent.shape
    k_latent = k_latent.reshape(B, T, self.n_heads, d_head).transpose(1, 2) # B, n_heads, T, d_head

    B, T, _ = k_rope.shape
    k_rope = k_rope.unsqueeze(1).expand(-1, self.n_heads, -1, -1) # B, n_heads, T, d_rope

    k = torch.cat((k_latent, k_rope), dim=-1) # B, n_heads, T, d_rope + d_head

    v = self.W_UV(c_kv)
    B, T, _ = v.shape
    v = v.reshape(B, T, self.n_heads, d_head).transpose(1, 2)

    mask = self.bias[:, :, :T, :T]

    attn = self.ScaledDotProductAttention(q, k, v, mask)
    # B, n_heads, T, T @ v (B, n_heads, T, d_head)
    # B, n_heads, T, d_head

    attn = attn.transpose(1, 2).reshape(B, T, self.d_model)
    mla_output = self.W_O(attn)

    return mla_output# shape (B, T, d_model)

In [63]:
class FeedForward(nn.Module):
  def __init__(self, d_model, d_ff, dropout_rate):
    super().__init__()
    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, x):
    z1 = self.linear1(x)
    a1 = self.dropout(F.gelu(z1))
    z2 = self.linear2(a1)
    return z2

In [64]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, d_rope, n_heads, d_latent_q, d_latent_kv, d_ff, dropout_rate, max_seq_len):
    super().__init__()
    self.ml_attention = MultiLatentAttention(d_model, d_rope, n_heads, d_latent_q, d_latent_kv, max_seq_len, dropout_rate)
    self.norm1 = RMSNorm(d_model)
    self.ff = FeedForward(d_model, d_ff, dropout_rate)
    self.norm2 = RMSNorm(d_model)
    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, x):
    norm_x = self.norm1(x)
    x = x + self.dropout(self.ml_attention(norm_x))
    norm_x = self.norm2(x)
    x = x + self.dropout(self.ff(norm_x))
    return x

In [65]:
class MiniGPT(nn.Module):
  def __init__(self, d_model, d_rope, d_ff, n_heads, d_latent_q, d_latent_kv, dropout_rate, vocab_size, n_layers, max_seq_len):
    super().__init__()
    self.embedding = InputEmbedding(vocab_size, d_model)
    #self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
    self.decoders = nn.ModuleList([DecoderLayer(d_model, d_rope, n_heads, d_latent_q, d_latent_kv, d_ff, dropout_rate, max_seq_len) for _ in range(n_layers)])

    self.final_rmsnorm = RMSNorm(d_model)

    self.linear_output = nn.Linear(d_model, vocab_size)
    self.max_seq_len = max_seq_len

  def forward(self, x, targets=None):

    batch_size, seq_len = x.shape

    final_input = self.embedding(x)
    #final_input = self.pos_encoding(token_embedding) # batch size, seq_len, embedding dim

    for decoder in self.decoders:
      final_input = decoder(final_input)

    final_input = self.final_rmsnorm(final_input)
    logits = self.linear_output(final_input)

    if targets is not None:
      batch_size, seq_len, vocabulary_size = logits.shape
      logits = logits.reshape(batch_size*seq_len, vocabulary_size)
      targets = targets.reshape(batch_size*seq_len)
      loss = F.cross_entropy(logits, targets)
      return logits, loss
    else:
      return logits

  @torch.no_grad()
  def generate(self, idx, max_new_tokens):

    for _ in range(max_new_tokens):
      idx = idx[:, -self.max_seq_len:]
      logits = self.forward(idx) # shape batch size, seq len, vocab size
      logits_final_token = logits[:, -1, :] # shape of batch size, vocab size (we selected predictions for the next token position)
      probs = F.softmax(logits_final_token, dim=-1) # now we want probabilities of the next token across the vocab size

      idx_next = torch.multinomial(probs, num_samples = 1)
      idx = torch.cat((idx, idx_next), dim=1) # shape batch size, seq len + 1

    return idx




In [66]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MiniGPT(
    d_model=128,
    d_rope=32,
    d_ff=512,
    n_heads=4,
    d_latent_q=64,
    d_latent_kv=32,
    dropout_rate=0.1,
    vocab_size=vocab_size,
    n_layers=4,
    max_seq_len=200
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1, betas=(0.9, 0.95))

# Early stopping parameters
patience = 5
best_val_loss = float('inf')
patience_counter = 0
best_model_state = None

for step in range(10000):  # adjust iterations as needed
    # Training step
    x, y = get_batch(train, batch_size=64, block_size=128)
    x, y = x.to(device), y.to(device)

    model.train()
    logits, loss = model(x, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Validation step
    if step % 1000 == 0:  # Check validation every 100 steps
        model.eval()
        with torch.no_grad():
            val_x, val_y = get_batch(val, batch_size=64, block_size=128)
            val_x, val_y = val_x.to(device), val_y.to(device)
            _, val_loss= model(val_x, val_y)

        print(f"Step {step} | Train Loss: {loss.item():.4f} | Val Loss: {val_loss.item():.4f}")
        '''
        # Early stopping check
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            patience_counter = 0
            # Save the best model
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at step {step}")
                break
        '''
    elif step % 1000 == 0:
        print(f"Step {step} | Train Loss: {loss.item():.4f}")

Step 0 | Train Loss: 4.3784 | Val Loss: 3.8745
Step 1000 | Train Loss: 1.6093 | Val Loss: 1.7442
Step 2000 | Train Loss: 1.4746 | Val Loss: 1.6176
Step 3000 | Train Loss: 1.4688 | Val Loss: 1.5053
Step 4000 | Train Loss: 1.4040 | Val Loss: 1.5438
Step 5000 | Train Loss: 1.3468 | Val Loss: 1.5248
Step 6000 | Train Loss: 1.3524 | Val Loss: 1.4962
Step 7000 | Train Loss: 1.3116 | Val Loss: 1.5136
Step 8000 | Train Loss: 1.3340 | Val Loss: 1.4856
Step 9000 | Train Loss: 1.3306 | Val Loss: 1.4550


In [71]:
prompt = "To be, or not to be that is"
prompt = torch.tensor([encode(prompt)], dtype = torch.long).to(device)
prompt.shape
generated_idx = model.generate(prompt, max_new_tokens=120)
output = decode(generated_idx[0].tolist())
print(output)

To be, or not to be that is deem,
Or, all the time heirs. Accursed and thus
But banish one our bookle Angelo?

GREEN:
Bid how you give yondesit swo
