<a href="https://colab.research.google.com/github/Sidy3143/TinyGPT/blob/main/TinyGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Generative Model trained on TinyStories dataset

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import inspect

Model architecture

In [None]:
class CausalAttention(nn.Module):
  def __init__(self, n_head, n_embed, seq_len, dropout, bias=False, trainning=True):
    super().__init__()

    assert n_embed % n_head == 0, "n_embed must be divisible by n_head"
    self.n_embed = n_embed
    self.n_head = n_head
    self.h_dim = self.n_embed // self.n_head

    self.QKV = nn.Linear(n_embed, 3 * n_embed, bias=bias)
    self.out = nn.Linear(n_embed, n_embed, bias=bias)

    self.att_dropout = nn.Dropout(dropout)
    self.out_dropout = nn.Dropout(dropout)
    self.dropout = dropout
    self.training = trainning

    self.flash = hasattr(torch.nn.functional ,'scaled_dot_product_attention')
    if not self.flash:
      print("No flashattention, GPUs will not be going brrr")
      self.register_buffer("mask", torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len))

  def forward(self, input): # b, T, n_embed
    B, T, _ = input.size()

    query, key, value = self.QKV(input).split(self.n_embed, 2)

    query = query.view(B, T, self.n_head, self.h_dim).transpose(1, 2) # B, n_head, T, h_dim
    key = key.view(B, T, self.n_head, self.h_dim).transpose(1, 2)
    value = value.view(B, T, self.n_head, self.h_dim).transpose(1, 2)

    if self.flash:
        attention_output = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)

    else:
        att = (query @ key.transpose(-1, -2)) * (1 / math.sqrt(self.h_dim)) # B, n_head, T, T
        att = att.masked_fill_(self.mask[:, :, :T, :T] ==0, float('-inf'))
        if self.dropout:
          att = self.att_dropout(att)
        attention_scores = torch.softmax(att, dim=-1)
        attention_output = attention_scores @ value # B, n_head, T, h_dim

        #attention_scores = attention_scores.transpose(1, 2).contiguous().view(B, T, -1)

    attention_output = attention_output.transpose(1, 2).contiguous().view(B, T, -1) # B, T, n_embed

    out = self.out_dropout(self.out(attention_output))

    return out #B, T, n_embed

In [None]:
class SWIGLUFFN(nn.Module):
 def __init__(self, n_embed, dropout, bias):
  super().__init__()

  hidden_dim = 4 * n_embed
  hidden_dim = int(2 * hidden_dim/3)
  self.layer1 = nn.Linear(n_embed, hidden_dim, bias=bias)
  self.silu = nn.SiLU()
  self.layer2 = nn.Linear(n_embed, hidden_dim, bias=bias)
  self.layer3 = nn.Linear(hidden_dim, n_embed, bias=bias)

  self.dropout = nn.Dropout(dropout)

 def forward(self, x):
    x = self.silu(self.layer1(x)) * self.layer2(x)
    x = self.dropout(self.layer3(x))

    return x

In [None]:
class Block(nn.Module):
  def __init__(self, config):
    super().__init__()

    self.norm1 = nn.RMSNorm(config.n_embed)
    self.Causal_attention = CausalAttention(config.n_head, config.n_embed, config.seq_len, config.dropout, config.bias)
    self.norm2 = nn.RMSNorm(config.n_embed)
    self.Feed_forward = SWIGLUFFN(config.n_embed, config.dropout, config.bias)

  def forward(self, x): # B, T, n_embed
    x = x + self.Causal_attention(self.norm1(x))
    x = x + self.Feed_forward(self.norm2(x))

    return x

In [None]:
class GPT(nn.Module):
  def __init__(self, GPTConfig):
    super().__init__()

    self.vocab_size = GPTConfig.vocab_size
    self.n_embed = GPTConfig.n_embed
    self.n_head = GPTConfig.n_head

    self.seq_len = GPTConfig.seq_len
    self.n_layers = GPTConfig.n_layers
    self.dropout = GPTConfig.dropout

    self.transformer = nn.ModuleDict(dict(
        token_embed = nn.Embedding(self.vocab_size, self.n_embed),
        pos_embed = nn.Embedding(self.seq_len, self.n_embed),
        drop = nn.Dropout(self.dropout),
        blocks = nn.ModuleList([Block(GPTConfig) for _ in range(self.n_layers)]),
        norm = nn.RMSNorm(self.n_embed),
    ))

    self.projection = nn.Linear(self.n_embed, self.vocab_size, bias=False)

    #weight tying
    self.transformer.token_embed.weight = self.projection.weight

    self.apply(self._init_weights)

  def _init_weights(self, module):
    if isinstance(module, nn.Linear):
      nn.init.normal_(module.weight, mean=0.0, std=0.02)
      if module.bias is not None:
        nn.init.zeros_(module.bias)

    elif isinstance(module, nn.Embedding):
      nn.init.normal_(module.weight, mean=0.0, std=0.02)


  def get_num_params(self, non_embedding=True):
    num_params = sum(p.numel() for p in self.parameters())
    if non_embedding:
      num_params -= self.transformer.pos_embed.weight.numel()
    return num_params

  def forward(self, ids, targets=None): #B, T
    B, T = ids.size()
    device = ids.device

    assert T <= self.seq_len, f"cannot process setences with more than {self.seq_len}"
    positions = torch.arange(0, T, dtype=torch.long, device=device)  # T, n_embed

    token_embeddings = self.transformer.token_embed(ids)  # B, T, n_embed
    pos_embeddings = self.transformer.pos_embed(positions)
    x = self.transformer.drop(token_embeddings + pos_embeddings)

    for block in self.transformer.blocks:
      x = block(x)
    out = self.transformer.norm(x) # B, T, n_embed

    if targets is not None:
      logits = self.projection(out) # B, T, vocab_size
      loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1), ignore_index=-100)

    else: # inference
      logits = self.projection(out[:, [-1], :]) # small optimization for inference
      loss = None

    return logits, loss


  def configure_optimizers(self, weight_decay, learning_rate, betas):
    params = {n:p for n,p in self.named_parameters()}
    params = {n:p for n,p in params.items() if p.requires_grad}

    decay_params = [p for n,p in params.items() if p.dim() >= 2]
    non_decay_params = [p for n,p in params.items() if p.dim() < 2]
    optim_params = [
        {'params': decay_params, 'weight_decay': weight_decay},
        {'params': non_decay_params, 'weight_decay': 0.0}
      ]
    num_decay_params = sum(p.numel() for p in decay_params)
    num_non_decay_params = sum(p.numel() for p in non_decay_params)
    print(f"number of non decay params: {num_decay_params}")
    print(f"number of non decay params: {num_non_decay_params}")

    fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
    fuse_dict = dict(fused=True) if fused_available else dict()
    optimizer = torch.optim.AdamW(optim_params, lr=learning_rate, betas=betas, **fuse_dict)

    return optimizer


  def estimate_mfu(self, fwd_per_iter, dt): # according to Palm paper
    N = self.get_num_params()

    L, H, Q, T = self.n_layers, self.n_head, self.n_embed//self.n_head, self.seq_len
    flops_per_token = 6*N + 12*L*H*Q*T
    flops_fwd_bwd = flops_per_token * T
    flops_per_iter = flops_fwd_bwd * fwd_per_iter

    flops_achieved = flops_per_iter * (1.0/dt)
    flops_promised = 120e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS, T4 is 65 Tflops, 120 for L4GPU

    mfu = flops_achieved / flops_promised
    return mfu


  @torch.no_grad()
  def generate(self, text_ids, eos_id, max_samples, temperature=1.0, top_k=None, top_p=None, min_p=None, greedy=False):
    #prompt : B, T
    self.eval() # put in eval mode
    response = torch.empty(1, 0, dtype=torch.int64, device=text_ids.device)
    for _ in range(max_samples):
        text_wdw = text_ids if text_ids.size(1)<=self.seq_len else text_ids[:, -self.seq_len:]

        logits, _ = self(text_wdw) #  1, 1, vocab_size

        logits = logits[:, -1, :] / temperature # 1, vocab_size

        if top_p:
          probs = torch.softmax(logits, dim=-1)  #b, vocab
          sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
          cum_probs = torch.cumsum(sorted_probs, dim=-1)

          indices_to_remove = cum_probs > top_p
          indices_to_remove[:, 1:] = indices_to_remove[:, :-1].clone()
          indices_to_remove[..., 0] = False

          top_p_mask = torch.zeros_like(logits, dtype=bool, device=logits.device)
          for b in range(logits.size(0)):
            remove_indices = sorted_indices[b][indices_to_remove[b]]
            top_p_mask[b, remove_indices] = True
          logits[top_p_mask] = float('-inf')

        if top_k:
          v, _ = torch.topk(logits, min(self.vocab_size, top_k)) # B, top_k  #prevent top_k > vocab_size
          logits[logits < v[:, [-1]]] = float('-inf') # only keep top_k values

        if min_p:
          max_logit = torch.max(logits, dim=-1, keepdim=True).values
          threshold = max_logit + torch.log(torch.tensor(min_p, device=logits.device, dtype=logits.dtype))
          logits[logits < threshold] = float('-inf')

        probs = torch.softmax(logits, dim=-1) #1, vocab_size
        if greedy == True:
          next_token = torch.argmax(probs, dim=-1, keepdim=True) #useful for evaluation
        else:
          next_token = torch.multinomial(probs, num_samples=1) #1, 1

        if next_token == eos_id: # more general
          break

        text_ids = torch.cat((text_ids, next_token), dim=1)
        response = torch.cat([response, next_token], dim=1)
    return text_ids, response

In [None]:
wandb_log = True # for logging metrics to wandb

In [None]:
from torch.cuda.amp import GradScaler, autocast
import wandb
from torch.nn.utils import clip_grad_norm_
from contextlib import nullcontext
import numpy as np
import os

In [None]:
from dataclasses import dataclass
import torch
import time

In [None]:
# mount drive to save the tokenizer and dataset there
from google.colab import drive
drive.mount('/content/drive')

dataset and tokenization

In [None]:
!pip install tokenizers datasets

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import BPE # we use BPE tokenizer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer

from datasets import load_dataset
from tqdm import tqdm

In [None]:
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))

tokenizer.pre_tokenizer = Whitespace()

In [None]:
dataset = load_dataset("roneneldan/TinyStories") # just the first run. Then we save it, especially for huge datasets

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

(…)-00000-of-00004-2d5a1467fff1081b.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(…)-00001-of-00004-5852b56a2bd28fd9.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00002-of-00004-a26307300439e943.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(…)-00003-of-00004-d243063613e5a057.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00000-of-00001-869c898b519ad725.parquet:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})

In [None]:
dataset['train']['text'][0] # grab te first story, just to see what the data looks like

'One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.'

In [None]:
def batch_iterator():
    for story in dataset['train']['text']:
        yield story

trainer = BpeTrainer(
    vocab_size=6000,  # Adjust vocab_size as needed
    special_tokens=["<unk>", "<EOS>"],
    min_frequency=2,
    show_progress=True,
)

In [None]:
tokenizer.train_from_iterator(batch_iterator(), trainer)

tokenizer.save("drive/MyDrive/tinystories_bpe_tokenizer.json") # just once

In [None]:
tokenizer.encode("<EOS>").ids

In [None]:
def process_tokens(example):
  token_ids = tokenizer.encode(example['text'] + '<EOS>').ids

  return {'tokens': token_ids, 'len': len(token_ids)}

In [None]:
tokenized = dataset.map(process_tokens, remove_columns=['text']) # just the first run

Map:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Map:   0%|          | 0/21990 [00:00<?, ? examples/s]

In [None]:
# save the data as memory map
for split, data in tokenized.items():
  filename = f"drive/MyDrive/tokenized_stories_{split}.bin"
  dtype = np.uint16
  arr_length = np.sum(data['len'], dtype=np.uint64)
  arr = np.memmap(filename, dtype=dtype, mode = 'w+', shape=(arr_length,))
  num_shards = 5

  idx = 0
  for batch_idx in tqdm(range(num_shards), desc='writing to shards'):
    batch = data.shard(num_shards, index=batch_idx, contiguous=True)
    tokens = np.concatenate(batch['tokens']) #concatenate tokens from all rows
    arr[idx:idx+len(tokens)] = tokens
    idx = len(tokens)
  arr.flush()

In [None]:
tokenized.save_to_disk('/content/drive/MyDrive/tokenized_stories_dataset')

In [None]:
#Load the tokenizer and tokenized dataset, after having saved them

from datasets import load_from_disk

tokenizer = Tokenizer.from_file("drive/MyDrive/tinystories_bpe_tokenizer.json")
tokenized = load_from_disk('/content/drive/MyDrive/tokenized_stories_dataset')

Hyperparameters

In [None]:
@dataclass
class GPTConfig:
  n_head : int = 8
  dropout : int = 0.2
  vocab_size : int = 6000
  n_embed : int = 512
  n_layers : int = 8
  seq_len : int = 512
  bias : bool = False
  layernorm : bool = False
  mlp : bool = False

In [None]:
@dataclass
class TrainningConfig:
  learning_rate : float = 6e-4
  max_iters : int = 30000
  weight_decay: float = 1e-1
  beta1: float = 0.9
  beta2: float = 0.95

  grad_clip: float = 1.0
  warmup_steps = 1000
  max_decay_iter: int = 30000
  min_lr: float = 6e-5
  decay_r = True

  eval_interval: int = 100
  log_interval: int = 10
  eval_iters: int = 200
  gradient_accumulation_steps: int = 4
  batch_size: int = 128

  compile : bool = True

In [None]:
torch.manual_seed(43) # for reproducibility
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

Set up our device and data type

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == "cuda":
  compute_capability = torch.cuda.get_device_capability(0)[0]
  if compute_capability >= 8: #torch.cuda.is_bf16_supported():
    dtype = 'bfloat16'
    pdtype = torch.bfloat16
  else:
    dtype = 'float16'
    pdtype = torch.float16
else:
  dtype = 'float32'
  pdtype = torch.float32

print(f"device: {device}")
print(f"using dtype: {dtype}")
ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=pdtype) # auto-mixed precision

In [None]:
wandb.login() # log our metrics to weights & biases (wandb)

In [None]:
wandb.init(project="Sidy_GPT", name="Sidy_GPT_tranning", resume=True)

In [None]:
# get the learning rate at each step
def get_lr(it):
  warmup = TrainningConfig.warmup_steps
  max_decay_iter = TrainningConfig.max_decay_iter
  min_lr = TrainningConfig.min_lr
  learning_rate = TrainningConfig.learning_rate

  if it < warmup:
    return TrainningConfig.learning_rate * (it + 1)/(warmup + 1)
  if it > max_decay_iter:
    return min_lr
  decay_ratio = (it - warmup) / (max_decay_iter - warmup)
  coef = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
  return min_lr + coef * (learning_rate - min_lr)

In [None]:
# poor man's dataloader
def get_batch(split, batch_size=TrainningConfig.batch_size, seq_len=GPTConfig.seq_len, device=device):
  filename = f"drive/MyDrive/tokenized_stories_{split}.bin"
  data = np.memmap(filename, dtype=np.uint16, mode='r')
  idxs = torch.randint(len(data) - seq_len, (batch_size,))
  x = torch.stack([torch.from_numpy((data[i:i+seq_len]).astype(np.int64)) for i in idxs])
  y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_len]).astype(np.int64)) for i in idxs])

  if device == 'cuda':
     x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
  else:
    x = x.to(device)
    y = y.to(device)

  return x, y

In [None]:
@torch.no_grad
def estimate_loss(model, eval_iters):
  model.eval()
  out = {}
  for split in ['train', 'validation']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = get_batch('validation')
      with ctx:
        logits, loss = model(X, Y)
      losses[k] = loss
    out[split] = losses.mean()
  model.train()
  return out

In [None]:
model_args = dict(
    n_head = GPTConfig.n_head,
    n_embed = GPTConfig.n_embed,
    n_layers = GPTConfig.n_layers,
    seq_len = GPTConfig.seq_len,
    vocab_size = GPTConfig.vocab_size,
    dropout = GPTConfig.dropout,
    bias = GPTConfig.bias,
    mlp = GPTConfig.mlp,
)

In [None]:
iter_num = 0
best_val_loss = 1e9
always_save_checkpoint = True
chkpt_path = 'drive/MyDrive/chkpt.pt'

In [None]:
def load_model(chkpt_path, device):
  resume = False
  best_val_loss = 1e9
  iter_num = 0
  checkpoint = None
  if os.path.exists(chkpt_path): # my drive
    checkpoint = torch.load(chkpt_path, map_location=device)
    gpt_conf = GPTConfig(**checkpoint["model_args"])
    model = GPT(gpt_conf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'

    #get rid of this weird prefix
    for k, v in list(state_dict.items()):
      if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix): ]] = state_dict.pop(k)

    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
    num_tokens = checkpoint['num_tokens']
    resume = True
    print(f"trainning Resumed from {iter_num} steps with {best_val_loss} as best_val_loss.")
  else:
    model = GPT(GPTConfig)
    print(f"Starting from scratch from step {iter_num}")
  model.to(device)
  return model, resume, iter_num, num_tokens, best_val_loss, checkpoint

In [None]:
tokens_per_iter = TrainningConfig.gradient_accumulation_steps * TrainningConfig.batch_size * GPTConfig.seq_len

print(f"tokens per iteration will be {tokens_per_iter}") # you can stop trainning based on the total number of tokens processed

Trainning function

In [None]:
def train_GPT(TrainningConfig, device, chkpt_path):
  model, resume, iter_num, num_tokens, best_val_loss, checkpoint = load_model(chkpt_path, device)
  num_params = model.get_num_params()
  print(f"Number of parameters {num_params/1e6}M params")

  # if ddp:
  #   model = DDP(model, )
  model.train()

  optimizer = model.configure_optimizers(TrainningConfig.weight_decay, TrainningConfig.learning_rate, betas=(TrainningConfig.beta1, TrainningConfig.beta2))
  if resume == True:
    optimizer.load_state_dict(checkpoint['optimizer'])

  scaler = torch.GradScaler(enabled=(dtype=='float16'))

  if TrainningConfig.compile:
    print("compiling model...")
    #unoptimized_model = model
    model = torch.compile(model)

  running_mfu = -1.0
  acc_loss = 0
  t0 = time.time()
  local_iter = 0 # iteration of the loop
  while True:
    # get the lr
    lr = get_lr(iter_num) if TrainningConfig.decay_r else TrainningConfig.learning_rate
    for param_group in optimizer.param_groups:
      param_group['lr'] = lr

    X, Y = get_batch("train")

    if iter_num % TrainningConfig.eval_interval == 0:
      losses = estimate_loss(model, TrainningConfig.eval_iters)
      print(f"step: {iter_num}, train_loss: {losses['train']}, eval_loss: {losses['validation']}")
      # wandb logging
      if wandb_log:
        wandb.log({
            "iter_num": iter_num,
            "train_loss" : losses['train'],
            "val_loss": losses['validation'],
            "model_args": model_args,
            "learning_rate": lr,
            "running_mfu": running_mfu*100,
            "num_tokens": num_tokens
        }, step=iter_num)
      if losses['validation'] < best_val_loss or always_save_checkpoint:
        checkpoint = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "model_args": model_args,
            "iter_num": iter_num,
            "best_val_loss": best_val_loss,
            "num_tokens": num_tokens
        }
        best_val_loss = losses['validation']
        print(f'saving checkpoint to {chkpt_path}...')
        torch.save(checkpoint, chkpt_path)
        model.eval()
        prompt = "once upon a time"
        ids = tokenizer.encode(prompt).ids
        ids = torch.tensor(ids, dtype=torch.int64, device=device).unsqueeze(0)
        _, response = model.generate(ids, max_samples=256, temperature=1.0)
        print(f"Generated story:\n -- {tokenizer.decode(response.squeeze(0).tolist())} --")
        if wandb_log:
          wandb.log({
             "story": response
         })
        model.train()

    for microstep in range(TrainningConfig.gradient_accumulation_steps):
        with ctx:
          logits, loss = model(X, Y)
          loss = loss / TrainningConfig.gradient_accumulation_steps
        acc_loss += loss
        X, Y = get_batch('train') # prefetch, while GPU doing backprop
        scaler.scale(loss).backward() # pytorch accumulates gradients automatically

    if TrainningConfig.grad_clip != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), TrainningConfig.grad_clip)
    scaler.step(optimizer)
    scaler.update()

    optimizer.zero_grad(set_to_none=True)

    t1 = time.time()
    dt = t1 - t0
    t0 = t1

    if iter_num % TrainningConfig.log_interval == 0:
      if iter_num >=5: #no mfu in the first 5 iterations
         mfu = model.estimate_mfu(TrainningConfig.batch_size * TrainningConfig.gradient_accumulation_steps, dt)
         running_mfu = mfu  #you can do if running_mfu ==-1.0 else 0.4*running_mfu + 0.6*mfu
      print(f"iteration step {iter_num}, loss {acc_loss:.4f}, time {dt*1000:.2f} ms, mfu {running_mfu*100:.2f} %")
      print(f"tokens processed {num_tokens}")

    num_tokens += tokens_per_iter
    iter_num += 1
    local_iter +=1
    acc_loss = 0

    if  iter_num == TrainningConfig.max_iters:
      break

In [None]:
train_GPT(TrainningConfig, device, chkpt_path)

trainning Resumed from 6100 steps with 0.6517184376716614 as best_val_loss.
Number of parameters 28.242432M params
number of non decay params: 28495872
number of non decay params: 8704
compiling model...
step: 6100, train_loss: 0.6513611674308777, eval_loss: 0.6507238745689392
saving checkpoint to drive/MyDrive/chkpt.pt...
Generated story:
 -- , a little girl named Emmy heard a big decision . She would be very envious of the coin she had seen earlier . Emmy was determined to learn how to reach a coin , so she decided to ask her parents . She walked up to the coin , which was about to half across the finish line . She saw a big man working on some very hard work . Emmy asked him if he could help . The man told his name was Joe and that he can help . So , each day , Joe and Emmy worked together to reach for the coin . Sometimes , Joe brought out apples and re jo iced in the meadow . It was such a nice way to prevent getting the coin or lost it . Eventually , Carl and Emmy sh own that if 



iteration step 6100, loss 0.6540, time 70494.68 ms, mfu 0.60 %
tokens processed 1535792128
iteration step 6110, loss 0.6922, time 2040.65 ms, mfu 20.83 %
tokens processed 1538413568
iteration step 6120, loss 0.6487, time 2080.14 ms, mfu 20.44 %
tokens processed 1541035008
iteration step 6130, loss 0.6294, time 2079.07 ms, mfu 20.45 %
tokens processed 1543656448
iteration step 6140, loss 0.6474, time 2079.53 ms, mfu 20.44 %
tokens processed 1546277888
iteration step 6150, loss 0.7176, time 2059.40 ms, mfu 20.64 %
tokens processed 1548899328
iteration step 6160, loss 0.7112, time 2058.05 ms, mfu 20.66 %
tokens processed 1551520768
iteration step 6170, loss 0.6587, time 2080.61 ms, mfu 20.43 %
tokens processed 1554142208
iteration step 6180, loss 0.7149, time 2051.59 ms, mfu 20.72 %
tokens processed 1556763648
iteration step 6190, loss 0.6919, time 2079.52 ms, mfu 20.44 %
tokens processed 1559385088
step: 6200, train_loss: 0.6448690295219421, eval_loss: 0.6485092639923096
saving checkpoin

KeyboardInterrupt: 

Let's try a prompt and see what the model outputs

In [None]:
model.eval()

GPT(
  (transformer): ModuleDict(
    (token_embed): Embedding(6000, 512)
    (pos_embed): Embedding(512, 512)
    (drop): Dropout(p=0.2, inplace=False)
    (blocks): ModuleList(
      (0-7): 8 x Block(
        (norm1): RMSNorm((512,), eps=None, elementwise_affine=True)
        (Causal_attention): CausalAttention(
          (QKV): Linear(in_features=512, out_features=1536, bias=False)
          (out): Linear(in_features=512, out_features=512, bias=False)
          (att_dropout): Dropout(p=0.2, inplace=False)
          (out_dropout): Dropout(p=0.2, inplace=False)
        )
        (norm2): RMSNorm((512,), eps=None, elementwise_affine=True)
        (Feed_forward): SWIGLUFFN(
          (layer1): Linear(in_features=512, out_features=1365, bias=False)
          (silu): SiLU()
          (layer2): Linear(in_features=512, out_features=1365, bias=False)
          (layer3): Linear(in_features=1365, out_features=512, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )

In [None]:
print(f"number of tokens processed: {num_tokens}")  # ~4B tokens

4052374528

In [None]:
prompt_ids = tokenizer.encode("once upon a time,").ids
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64, device=device).unsqueeze(0)

In [None]:
story, _ = model.generate(prompt_ids, max_samples=350)

In [None]:
print(tokenizer.decode(story.squeeze(0).tolist())) # example story

once upon a time , there was a mysterious dog . He wanted to burn things , but he was scared factory s are be matched . One day , he saw smoke coming from the factory . He had heard of them before , but something was different . He was afraid that something was going to happen away . He barked loudly but nobody answered . Then he heard a little voice . " I ' m just a statue of those things . Don ' t be scared !" It came a little bird singing in the corner of the factory . The dog stopped his barking and the bird flew away . The dog felt relieved . He was no longer scared of the factory . He knew his mom would always find him and his toy was all ready to burning things .


In [None]:
!pip install gradio  # gradio for a nice and simple UI

In [None]:
# load our latest checkpoint

model, resume, iter_num, num_tokens, best_val_loss, checkpoint = load_model(chkpt_path, device)

trainning Resumed from 15700 steps with 0.6005626320838928 as best_val_loss.


In [None]:
import gradio as gr

def chat_with_model(prompt):
    prompt_ids = tokenizer.encode(prompt).ids
    prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64, device=device).unsqueeze(0)
    generated, _ = model.generate(prompt_ids, eos_id=1 , max_samples=350, temperature=1.0, top_p=0.8)
    generated_story = tokenizer.decode(generated.squeeze(0).tolist(), skip_special_tokens=True)
    return f" Story:\n {generated_story}"

iface = gr.Interface(
    fn=chat_with_model,
    inputs=gr.Textbox(lines=2, placeholder="Type just the beginning of the story (like once upon a time)..."),
    outputs=gr.Textbox(),
    title="TinyGPT",
    description="A custom GPT-style model trained on TinyStories. Only ask it to write fun-child stories!"
)

In [None]:
iface.launch()

It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://16219f8a664eb636e6.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [None]:
# Doesn't work currently but hopefully you get the idea!