In [1]:
import torch.nn as nn
import torch
import math
import random

In [2]:
random.seed(1001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

1. Dataset

In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-11-05 20:16:51--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-11-05 20:16:52 (6.24 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
# read text file
with open('input.txt', 'r') as file:
  text = file.read()
text[:100]

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

In [5]:
len(text)

1115394

In [6]:
# convert text to [ids] according to vocab
vocab = sorted(list(set(text)))
print(vocab)

id_map = {c : i for (i,c) in enumerate(vocab)}
id_map_reverse = {i : c for (i,c) in enumerate(vocab)}
print(id_map)

def encode(text):
  """
    [chars] -> [ids]
  """
  return [id_map[c] for c in text]

def decode(ids):
  """
    [ids] -> string
  """
  return ''.join(id_map_reverse[id] for id in ids)

print(decode(encode("hello")) == "hello")

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}
True


In [7]:
# train, val set
train = text[:math.floor(0.9*len(text))]
test = text[math.floor(0.9*len(text)):]

print(f"train size: {len(train)}, test size: {len(test)}")

def fetch_batch(mode, batch_size, block_size):
  if mode=='train':
    dataset = train
  else:
    dataset = test
  batch_indices = [random.randint(0, len(dataset)-(block_size+1)) for _ in range(0, batch_size)]
  x = torch.stack([torch.tensor(encode(dataset[index:index+block_size])) for index in batch_indices])
  y = torch.stack([torch.tensor(encode(dataset[index+1:index+block_size+1])) for index in batch_indices])
  return x,y

train size: 1003854, test size: 111540


In [8]:
# Model architecture
class Bigram(nn.Module):
  def __init__(self, n_vocab):
    super().__init__()
    self.embedding_table = nn.Embedding(n_vocab, n_vocab)
    self.criterion = nn.CrossEntropyLoss() # takes logits as input, needs logits as <everything packed up> x #classes and target variables as ids

  def forward(self, x, y=None):
    logits = self.embedding_table(x)
    #print(logits.shape)
    n_batches, block_size, n_classes = logits.shape
    logits_temp = logits.view(n_batches*block_size, n_classes)
    if y == None:
      return logits, None

    y = y.view(n_batches*block_size)
    loss = self.criterion(logits_temp, y)
    return logits, loss

  def generate_next_token(self, x):
    logits,_ = self(x)
    # last timestep generated
    logits = logits[:,-1,:] # shape = batches, score over vocab size
    prob = torch.softmax(logits, dim=1)

    # top p sampling
    pred = torch.multinomial(prob, num_samples=1)

    return pred

  def generate(self, x, max_tokens):
    for i in range(max_tokens):
      next_word = self.generate_next_token(x)
      x = torch.cat((x, next_word), dim=1) # increase block size
    return x



In [9]:
# model obj + optimizer
model = Bigram(len(vocab))
# move model to device
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [11]:
# generation without training
x = torch.ones((1,1), dtype=torch.long).to(device)
print(decode(model.generate(x,100)[0].tolist()))

 t:aH MhfF?XigbqDDXlK;rZ n
GusM?ZcaKEZytF&MCeZewiAOcl OUJGT?m
dr3JlKZNOaCbrXk,tjWMtMCfm&mSjqAyDozrE h


In [12]:
@torch.no_grad()
def performance_measure(batch_size, block_size, n_iters):
  losses = {'train':0, 'val':0}
  model.eval()
  for mode in ['train', 'val']:
    for _ in range(n_iters):
      x,y = fetch_batch(mode, batch_size, block_size)
      x = x.to(device)
      y = y.to(device)
      _, loss = model(x,y)
      losses[mode] += loss.item()

  model.train()
  losses['train'] /= n_iters
  losses['val'] /= n_iters
  return losses

In [13]:
# training loop
n_iters = 10000
test_after_iters = 1000
for iters in range(0, n_iters):
  # fetch random batch of training data tensors
  x,y = fetch_batch('train',64, 256)
  # move tensors to device
  x = x.to(device)
  y = y.to(device)

  # run on model to get loss
  logits, loss = model(x, y)

  # backprop on loss
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  if iters % test_after_iters == 0:
    losses = performance_measure(64, 256, 200)
    print(f"Train loss: {losses['train']}, Val loss: {losses['val']}")

Train loss: 4.638998236656189, Val loss: 4.656174647808075
Train loss: 4.217304286956787, Val loss: 4.238060631752014
Train loss: 3.8558514380455016, Val loss: 3.8783979868888854
Train loss: 3.5514151561260223, Val loss: 3.5738569712638855
Train loss: 3.2981827664375305, Val loss: 3.321426645517349
Train loss: 3.092572566270828, Val loss: 3.1148242712020875
Train loss: 2.9300525987148287, Val loss: 2.953620855808258
Train loss: 2.804293156862259, Val loss: 2.827827477455139
Train loss: 2.7065053391456604, Val loss: 2.7304836297035218
Train loss: 2.635009207725525, Val loss: 2.6589745545387267


In [16]:
# generation after training 10k iters
x = torch.ones((1,1), dtype=torch.long).to(device)
print(decode(model.generate(x,100)[0].tolist()))

 rek nsun
:
VIVorversthor ins,.
A's
zzarsaHURICoDWhe INGLIOL's oy tinckekend is.
yandufind g rinXimol
