In [1]:
import torch.nn as nn
import torch
import math
import random

In [72]:
random.seed(1001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

1. Dataset

In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-10-28 06:41:34--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-10-28 06:41:34 (22.4 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
# read text file
with open('input.txt', 'r') as file:
  text = file.read()
text[:100]

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

In [5]:
len(text)

1115394

In [6]:
# convert text to [ids] according to vocab
vocab = sorted(list(set(text)))
print(vocab)

id_map = {c : i for (i,c) in enumerate(vocab)}
id_map_reverse = {i : c for (i,c) in enumerate(vocab)}
print(id_map)

def encode(text):
  """
    [chars] -> [ids]
  """
  return [id_map[c] for c in text]

def decode(ids):
  """
    [ids] -> string
  """
  return ''.join(id_map_reverse[id] for id in ids)

print(decode(encode("hello")) == "hello")

['f', 'h', 'S', "'", 'n', 'V', 'j', 'K', '&', 'T', 'O', 'z', 'i', 'q', 'P', 'M', 'W', 'y', 'L', 'v', 'U', 'b', 'Z', 'R', 'X', 'g', '-', 'd', 'Q', 'x', 'o', 'A', 'a', 'w', ';', 'C', 'r', '3', '!', 'N', ' ', 'E', 't', '\n', 'H', 'p', ',', 'l', 'I', '$', '?', ':', 'Y', 's', '.', 'D', 'B', 'u', 'e', 'm', 'k', 'J', 'c', 'F', 'G']
{'f': 0, 'h': 1, 'S': 2, "'": 3, 'n': 4, 'V': 5, 'j': 6, 'K': 7, '&': 8, 'T': 9, 'O': 10, 'z': 11, 'i': 12, 'q': 13, 'P': 14, 'M': 15, 'W': 16, 'y': 17, 'L': 18, 'v': 19, 'U': 20, 'b': 21, 'Z': 22, 'R': 23, 'X': 24, 'g': 25, '-': 26, 'd': 27, 'Q': 28, 'x': 29, 'o': 30, 'A': 31, 'a': 32, 'w': 33, ';': 34, 'C': 35, 'r': 36, '3': 37, '!': 38, 'N': 39, ' ': 40, 'E': 41, 't': 42, '\n': 43, 'H': 44, 'p': 45, ',': 46, 'l': 47, 'I': 48, '$': 49, '?': 50, ':': 51, 'Y': 52, 's': 53, '.': 54, 'D': 55, 'B': 56, 'u': 57, 'e': 58, 'm': 59, 'k': 60, 'J': 61, 'c': 62, 'F': 63, 'G': 64}
True


In [16]:
# train, val set
train = text[:math.floor(0.9*len(text))]
test = text[math.floor(0.9*len(text)):]

print(f"train size: {len(train)}, test size: {len(test)}")

def fetch_batch(mode, batch_size, block_size):
  if mode=='train':
    dataset = train
  else:
    dataset = test
  batch_indices = [random.randint(0, len(dataset)-(block_size+1)) for _ in range(0, batch_size)]
  x = torch.stack([torch.tensor(encode(dataset[index:index+block_size])) for index in batch_indices])
  y = torch.stack([torch.tensor(encode(dataset[index+1:index+block_size+1])) for index in batch_indices])
  return x,y

train size: 1003854, test size: 111540


In [76]:
x,y=fetch_batch('train',5,8)
print(x)
print(y)

tensor([[40, 59, 32, 53, 42, 58, 36, 46],
        [40, 32, 42, 42, 58,  4, 27, 40],
        [26,  1, 30, 57, 53, 58, 54, 43],
        [59, 32, 60, 58, 40, 32, 40, 62],
        [53, 40,  0, 57, 47, 47, 40, 30]])
tensor([[59, 32, 53, 42, 58, 36, 46, 40],
        [32, 42, 42, 58,  4, 27, 40, 30],
        [ 1, 30, 57, 53, 58, 54, 43, 43],
        [32, 60, 58, 40, 32, 40, 62, 32],
        [40,  0, 57, 47, 47, 40, 30,  0]])


In [73]:
# Model architecture
class Bigram(nn.Module):
  def __init__(self, n_vocab):
    super().__init__()
    self.embedding_table = nn.Embedding(n_vocab, n_vocab)
    self.criterion = nn.CrossEntropyLoss() # takes logits as input, needs logits as <everything packed up> x #classes and target variables as ids

  def forward(self, x, y=None):
    logits = self.embedding_table(x)
    #print(logits.shape)
    n_batches, block_size, n_classes = logits.shape
    logits_temp = logits.view(n_batches*block_size, n_classes)
    if y == None:
      return logits, None

    y = y.view(n_batches*block_size)
    loss = self.criterion(logits_temp, y)
    return logits, loss

  def generate_next_token(self, x):
    logits,_ = self(x)
    # last timestep generated
    logits = logits[:,-1,:] # shape = batches, score over vocab size
    prob = torch.softmax(logits, dim=1)

    # top p sampling
    pred = torch.multinomial(prob, num_samples=1)

    return pred

  def generate(self, x, max_tokens):
    for i in range(max_tokens):
      next_word = self.generate_next_token(x)
      x = torch.cat((x, next_word), dim=1) # increase block size
    return x



In [74]:
# model obj + optimizer
model = Bigram(len(vocab))
# move model to device
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [77]:
print(x.shape)
print(y.shape)

logits, loss = model(x,y)
print(logits.shape)
print(loss.item())

torch.Size([5, 8])
torch.Size([5, 8])
torch.Size([5, 8, 65])
4.769946098327637


In [78]:
# generation without training
x = torch.ones((1,1), dtype=torch.long)
decode(model.generate(x,100)[0].tolist())

"hp3Mgpf!nDHjnDynVQaqlyDqrd.,&Sy-GZ??'NDXTMO:U&:kHJrex,sD C?hlkse3?BO:UEnDhzlydMj:cGI;Er&v.MppJv.YmHF."

In [79]:
# training loop
n_iters = 10000
test_after_iters = 1000
for iters in range(0, n_iters):
  # fetch random batch of training data tensors
  x,y = fetch_batch('train',16, 8)
  # move tensors to device
  x = x.to(device)
  y = y.to(device)

  # run on model to get loss
  logits, loss = model(x, y)

  # backprop on loss
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  if iters % test_after_iters == 0:
    # train loss
    print(f"train loss: {loss.item()}")
    # val loss
    x,y = fetch_batch('test', 16, 8)
    x = x.to(device)
    y = y.to(device)
    logits, loss = model(x, y)
    print(f"val loss: {loss.item()}")

train loss: 4.5500102043151855
val loss: 4.616785526275635
train loss: 4.497040271759033
val loss: 4.395230293273926
train loss: 4.1943888664245605
val loss: 4.119892597198486
train loss: 3.9951870441436768
val loss: 3.912943124771118
train loss: 3.7846150398254395
val loss: 3.699993371963501
train loss: 3.4637200832366943
val loss: 3.6194820404052734
train loss: 3.3314173221588135
val loss: 3.510249137878418
train loss: 3.3457372188568115
val loss: 3.204047679901123
train loss: 3.2574288845062256
val loss: 2.9902563095092773
train loss: 2.944948673248291
val loss: 3.0935096740722656


In [80]:
# generation after training 10k iters
x = torch.ones((1,1), dtype=torch.long)
decode(model.generate(x,100)[0].tolist())

"hbin jg tDHVETon,'OHoR$$uce;\nsth.LufrdYILFbrNAObPO h tt\nfokyopiithktGOXd qLQwkGathow'hVme3GI nD: iuck"