<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/gpts/DIY_AK_GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
#####################
# DATASET RETRIEVAL #
#####################

In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-09-21 22:13:27--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-09-21 22:13:27 (16.6 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
with open("input.txt", mode="r") as f:
  text = f.read()

print(text[:10])

First Citi


In [5]:
###################
# HYPERPARAMETERS #
###################

In [6]:
block_size = 8
batch_size = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
embed_dim = 32
n_epochs = 10
learning_rate = 1e-04
test_epochs = 200
eval_iters = 100

In [7]:
###############################
# TOKENIZATION AND DATALOADER #
###############################

In [8]:
vocab = list(sorted(set(text)))
text_to_num = {v:k for k,v in enumerate(vocab)}
num_to_text = {k:v for k,v in enumerate(vocab)}
vocab_size = len(vocab)

In [9]:
encode = lambda text: [text_to_num[t] for t in text]
decode = lambda numbers: "".join([num_to_text[n] for n in numbers])

In [10]:
decode(encode("Haubi"))

'Haubi'

In [11]:
data = torch.tensor(encode(text))
train_idx = int(0.9 * len(data))
train = data[:train_idx]
test = data[train_idx:]
len(train), len(test)

(1003854, 111540)

In [12]:
def get_batch(split):
  dataset = train if split == "train" else test
  idx = torch.randint(0, len(train) - block_size, (batch_size,))
  X = torch.stack([train[i:i +  block_size] for i in idx])
  Y = torch.stack([train[i + 1: i + block_size + 1] for i in idx])
  X, Y = X.to(device), Y.to(device)

  return X, Y

In [13]:
X, Y = get_batch("train")
X[1], Y[1]

(tensor([58, 46, 39, 58,  1, 21,  1, 57]),
 tensor([46, 39, 58,  1, 21,  1, 57, 46]))

In [14]:
@torch.no_grad()
def evaluate_model(model):
  out = {}
  model.eval()
  for split in ["train", "test"]:
    test_loss = torch.zeros(test_epochs)
    for epch in range(test_epochs):
      Xtst_b, Ytst_b = get_batch("test")
      _, loss = model(Xtst_b, Ytst_b)

      test_loss[epch] = loss.item()

    out[split] = test_loss.mean(dim=-1)

  model.train()

  return out

In [15]:
################
# BUILD MODELS #
################

In [28]:
class SimpleBiGramModel(nn.Module):

  def __init__(self):
    super().__init__()
    self.emb = nn.Embedding(vocab_size, vocab_size)

  def forward(self, x, targets=None):
    logits = self.emb(x)
    if targets is not None:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
      return logits, loss
    return logits

  def generate(self, start_token, length):
    """
    start_token: (B, 1) tensor with the initial token(s)
    length: how many new tokens to generate
    """
    for _ in range(length):
      # only keep the last token
      input = start_token[:, -block_size:]   # shape (B, 1)

      # forward pass → logits for vocab
      logits = self(input)          # (B, 1, C)

      # take the logits at the last position
      print(logits)
      logits = logits[:, -1, :]     # (B, C)

      # turn into probabilities
      probs = F.softmax(logits, dim=-1)  # (B, C)

      # sample next token for each batch
      next_token = torch.multinomial(probs, num_samples=1)  # (B, 1)

      # append to sequence
      start_token = torch.cat((start_token, next_token), dim=1)

    return start_token


In [29]:
model = SimpleBiGramModel()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [30]:
for epoch in range(1000):

  if epoch % eval_iters == 0 or epoch == n_epochs - 1:
    out = evaluate_model(model)
    print(f"Epoch: {epoch}, train loss: {out["train"]:.4f}, test loss: {out["test"]:.4f}")

  X_batch, Y_batch = get_batch("train")
  logits, loss = model(X_batch, Y_batch)

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  X_test_batch, Y_test_batch = get_batch("test")
  _, test_loss = model(X_test_batch, Y_test_batch)



Epoch: 0, train loss: 4.6192, test loss: 4.6172
Epoch: 9, train loss: 4.6187, test loss: 4.6210
Epoch: 100, train loss: 4.6095, test loss: 4.6135
Epoch: 200, train loss: 4.6016, test loss: 4.5922
Epoch: 300, train loss: 4.5923, test loss: 4.5869
Epoch: 400, train loss: 4.5710, test loss: 4.5753
Epoch: 500, train loss: 4.5635, test loss: 4.5519
Epoch: 600, train loss: 4.5540, test loss: 4.5532
Epoch: 700, train loss: 4.5393, test loss: 4.5432
Epoch: 800, train loss: 4.5252, test loss: 4.5361
Epoch: 900, train loss: 4.5276, test loss: 4.5125


In [31]:
generated_output = model.generate(torch.randint(0, 65, (1, 8)), 30)
output_list = generated_output.view(-1).tolist()
"".join([decode(output_list)])

tensor([[[ 9.3099e-01, -2.7371e-01,  3.6385e-01,  3.5315e-01, -1.3422e+00,
           5.3615e-02,  4.3333e-02, -2.6845e+00,  2.3390e+00,  1.4287e-01,
          -5.8477e-01,  8.2600e-01,  6.4781e-01,  2.6813e-01, -1.0676e+00,
          -1.0078e+00, -1.0609e+00, -1.3601e+00, -5.0099e-01, -2.1875e+00,
           3.2322e-01,  8.5204e-02,  1.4742e+00,  6.2716e-01, -1.0281e+00,
           4.8878e-01, -4.1130e-01,  1.6078e-01, -2.3826e-01, -6.4994e-01,
           1.7546e-01,  1.2658e+00,  9.8312e-01,  1.2014e-01, -2.5838e-01,
          -1.6622e+00, -1.3313e+00,  3.8422e-02, -1.2182e+00, -1.0247e+00,
          -8.8553e-01,  1.2659e+00,  2.6490e-01,  6.7567e-01, -1.5723e+00,
          -1.3754e-01,  7.8832e-01,  6.2039e-02, -1.9650e+00, -9.5399e-02,
          -1.6210e-01,  2.8511e-01, -1.6133e+00, -3.0136e-01, -3.4891e-01,
          -3.2842e-01,  8.8849e-02, -5.6078e-01, -1.1979e+00,  5.0096e-01,
          -5.0488e-01, -2.7828e-02, -1.1285e-02,  5.5856e-01, -1.4271e+00],
         [ 1.0515e+00,  

'Fx yG&lVI-q!k?fbOavT?bm-GIdIrEoaez:z$n'