In [None]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
from transformers import GPT2LMHeadModel

In [None]:
model_hf = GPT2LMHeadModel.from_pretrained('gpt2')
sd_hf = model_hf.state_dict()

# for k, v in sd_hf.items():
#   print(k, v.shape)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [None]:
import tiktoken

@dataclass

class GPTConfig:
  block_size: int = 1024
  vocab_size: int = 50257
  n_layer: int = 12
  n_head: int = 12
  n_embed: int = 768
  dropout: float = 0.0
  bias: bool = True







class CausalSelfAttention(nn.Module):
  def __init__(self, config):
    super().__init__()
    assert config.n_embed % config.n_head == 0

    self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
    self.c_proj = nn.Linear(config.n_embed, config.n_embed)
    self.c_proj.NANOGPT_SCALE_INIT=1

    self.n_head = config.n_head
    self.n_embed = config.n_embed

    bias = torch.tril(torch.ones(config.block_size, config.block_size))
    bias = bias.view(1, 1, config.block_size, config.block_size)

    self.register_buffer("bias", bias)



  def forward(self, x):
    B, T, C = x.size()

    qkv = self.c_attn(x)

    q, k, v = qkv.split(self.n_embed, dim=2)

    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

    # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
    # att = F.softmax(att, dim=-1)
    # y = att @ v

    # Flash attn
    y=F.scaled_dot_product_attention(q,k,v,is_causal=True)




    y = y.transpose(1, 2).contiguous().view(B, T, C)

    y = self.c_proj(y)

    return y







class MLP(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.c_fc = nn.Linear(config.n_embed, 4 * config.n_embed)
    self.gelu = nn.GELU(approximate='tanh')
    self.c_proj = nn.Linear(4*config.n_embed, config.n_embed)
    self.c_proj.NANOGPT_SCALE_INIT=1


  def forward(self, x):
    x = self.c_fc(x)
    x = self.gelu(x)
    x = self.c_proj(x)

    return x

class Block(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.ln1 = nn.LayerNorm(config.n_embed)
    self.attn = CausalSelfAttention(config)
    self.ln2 = nn.LayerNorm(config.n_embed)
    self.mlp = MLP(config)

  def forward(self, x):
    x = x + self.attn(self.ln1(x))
    x = x + self.mlp(self.ln2(x))
    return x

class GPT(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config


    self.transformer = nn.ModuleDict(dict(
        wte = nn.Embedding(config.vocab_size, config.n_embed),
        wpe = nn.Embedding(config.block_size, config.n_embed),
        h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ln_f = nn.LayerNorm(config.n_embed)

    ))

    self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)



    self.transformer.wte.weight = self.lm_head.weight

    self.apply(self._init_weights)

  def _init_weights(self, module):
    if isinstance(module, nn.Linear):
      std=0.02
      if hasattr(module, "NANOGPT_SCALE_INIT"):
        std *= (2 * self.config.n_layer) ** -0.5

      torch.nn.init.normal_(module.weight, mean=0.0, std=std)
      if module.bias is not None:
        torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


  def forward(self, idx, targets=None):
    B, T = idx.size()
    assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"

    pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
    pos_emb = self.transformer.wpe(pos)

    tok_emb = self.transformer.wte(idx)

    x = tok_emb + pos_emb
    for block in self.transformer.h:
      x = block(x)

    x = self.transformer.ln_f(x)
    logits = self.lm_head(x)
    loss = None
    if targets is not None:
      loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))



    return logits, loss




  @classmethod
  def from_pretrained(cls, model_type: str):
      assert model_type in ["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"]

      from transformers import GPT2LMHeadModel

      # minimal config map
      hf_cfg = {
          "gpt2":        dict(n_layer=12, n_head=12, n_embed=768),
          "gpt2-medium": dict(n_layer=24, n_head=16, n_embed=1024),
          "gpt2-large":  dict(n_layer=36, n_head=20, n_embed=1280),
          "gpt2-xl":     dict(n_layer=48, n_head=25, n_embed=1600),
      }[model_type]

      cfg = GPTConfig(**hf_cfg, vocab_size=50257, block_size=1024)
      model = cls(cfg)

      # load HF
      hf = GPT2LMHeadModel.from_pretrained(model_type)
      sd_hf = hf.state_dict()  # keys like 'transformer.h.0.ln_1.weight', etc.

      # build a mapping from *our* key -> *HF* key
      def map_key_ours_to_hf(k: str) -> str | None:
          # ignore our registered buffer for the mask
          if k.endswith(".attn.bias"):
              return None

          # embeddings
          if k == "transformer.wte.weight": return "transformer.wte.weight"
          if k == "transformer.wpe.weight": return "transformer.wpe.weight"

          # final layernorm
          if k == "transformer.ln_f.weight": return "transformer.ln_f.weight"
          if k == "transformer.ln_f.bias":   return "transformer.ln_f.bias"

          # lm head
          if k == "lm_head.weight": return "lm_head.weight"

          # blocks
          # ours:  transformer.h.{i}.(ln1|ln2|attn.*|mlp.*).(weight|bias)
          # hf:    transformer.h.{i}.(ln_1|ln_2|attn.*|mlp.*).(weight|bias)
          if k.startswith("transformer.h."):
              parts = k.split(".")
              # e.g. ['transformer','h','0','ln1','weight']
              i = parts[2]
              sub = parts[3]

              # layer norms
              if sub == "ln1":
                  return f"transformer.h.{i}.ln_1.{parts[4]}"
              if sub == "ln2":
                  return f"transformer.h.{i}.ln_2.{parts[4]}"

              # attention
              if sub == "attn":
                  # c_attn / c_proj names are the same in HF
                  return f"transformer.h.{i}.attn.{parts[4]}.{parts[5]}"

              # mlp
              if sub == "mlp":
                  # c_fc / c_proj names are the same
                  return f"transformer.h.{i}.mlp.{parts[4]}.{parts[5]}"

          # fallback: no mapping
          return None

      # which weights need a transpose (HF Conv1D -> our Linear)
      needs_T = (
          "attn.c_attn.weight",
          "attn.c_proj.weight",
          "mlp.c_fc.weight",
          "mlp.c_proj.weight",
      )

      sd_ours = model.state_dict()
      with torch.no_grad():
          for k_ours in sd_ours.keys():
              k_hf = map_key_ours_to_hf(k_ours)
              if k_hf is None:
                  # ignore buffers like the causal mask
                  continue
              if k_hf not in sd_hf:
                  raise KeyError(f"HF key missing for our key {k_ours!r} -> {k_hf!r}")

              w_src = sd_hf[k_hf]
              w_dst = sd_ours[k_ours]

              # transpose if needed
              if any(k_ours.endswith(s) for s in needs_T):
                  if w_src.ndim != 2:
                      raise ValueError(f"Expected 2D for transposed weight: {k_hf} got {w_src.shape}")
                  w_src = w_src.t()

              if w_src.shape != w_dst.shape:
                  raise ValueError(f"Shape mismatch: {k_ours}: dst {w_dst.shape} vs src {w_src.shape} (from {k_hf})")

              w_dst.copy_(w_src)

      model.load_state_dict(sd_ours)
      return model













In [None]:
import torch
import tiktoken

class DataLoaderLite:
    def __init__(self, B, T, filename="/content/input.txt", max_chars=1000):
        self.B = B
        self.T = T

        # load and tokenize once at init
        with open(filename, "r") as f:
            text = f.read()


        enc = tiktoken.get_encoding("gpt2")
        self.tokens = torch.tensor(enc.encode(text), dtype=torch.long)

        print(f"loaded {len(self.tokens)} tokens")
        print(f"1 epoch = {len(self.tokens) // (B*T)} batches")

        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T

        # slice out B*T+1 tokens so we can make (x,y) pairs
        buf = self.tokens[self.current_position : self.current_position + B*T + 1]

        x = buf[:-1].view(B, T)
        y = buf[1: ].view(B, T)

        self.current_position += B*T
        if self.current_position + B*T + 1 >= len(self.tokens):
            self.current_position = 0

        return x, y


In [None]:
model = GPT.from_pretrained('gpt2')
print('all okay')

all okay


In [None]:
# import torch
# import torch_xla.core.xla_model as xm

# device = xm.xla_device()
# print(device)

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
device

'cuda'

In [None]:
model.eval()
model.to(device)

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='tanh')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [None]:
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading tiktoken-0.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.11.0


In [None]:
num_return_sequences = 5
max_len = 30

import tiktoken, math
enc= tiktoken.get_encoding("gpt2")
tokens = enc.encode("hi how are you , ")
tokens = torch.tensor(tokens, dtype=torch.long, device=device)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
x= tokens.to(device)
# seed=42
# xm.set_rng_state(seed, xm.xla_device())
torch.manual_seed(42)

while x.size(1) < max_len:
  with torch.no_grad():
    logits = model(x)
    logits = logits[:, -1, :]
    probs = F.softmax(logits, dim=-1)
    topk_probs, topk_indices = torch.topk(probs, k=5, dim=-1)
    ix = torch.multinomial(topk_probs, num_samples=1)
    xcol = torch.gather(topk_indices, dim=-1, index=ix)
    x = torch.cat((x, xcol), dim=1)

# %%time
for i in range(num_return_sequences):
  tokens = x[i, :max_len].tolist()
  decoded = enc.decode(tokens)
  print(decoded)


hi how are you , ?"

"You're a good person, !"

"You're very nice and polite. !"
hi how are you ,  I have a problem with you,  but I am not going to tell you why, I will just tell
hi how are you ,  you're not doing anything  that is going to change the world,  and you're not going to
hi how are you ,  so we can all get along and enjoy the game!
The game was released in Japan on November 3rd,
hi how are you , ????

I don't know how you're gonna get your name, ????

I don't know if


In [None]:
with open('/content/input.txt', 'r') as f:
  text = f.read()

data = text[:1000]
print(data[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [None]:
import tiktoken
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(data)
print(tokens[:24])

[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13, 198, 198, 3237, 25, 198, 5248, 461, 11, 2740, 13]


In [None]:
import torch
buf = torch.tensor(tokens[:24+1])

x = buf[:-1].view(4, 6)
y = buf[1:].view(4, 6)
print(x)
print(y)

tensor([[ 5962, 22307,    25,   198,  8421,   356],
        [ 5120,   597,  2252,    11,  3285,   502],
        [ 2740,    13,   198,   198,  3237,    25],
        [  198,  5248,   461,    11,  2740,    13]])
tensor([[22307,    25,   198,  8421,   356,  5120],
        [  597,  2252,    11,  3285,   502,  2740],
        [   13,   198,   198,  3237,    25,   198],
        [ 5248,   461,    11,  2740,    13,   198]])


In [None]:
import tiktoken
enc = tiktoken.get_encoding("gpt2")
with open('/content/input.txt', 'r') as f:
    text = f.read()
text = text[:1000]
tokens = enc.encode(text)
B, T = 4, 32
buf = torch.tensor(tokens[:B*T + 1])
x = buf[:-1].view(B, T).to(device)
y = buf[1:].view(B, T).to(device)


In [None]:
model=GPT(GPTConfig())
model.to(device)
logits, loss = model(x, y)
print(loss)
print(logits.shape)

tensor(10.9952, device='cuda:0', grad_fn=<NllLossBackward0>)
torch.Size([4, 32, 50257])


In [None]:
train_loader = DataLoaderLite(B=8, T=1024)

loaded 338025 tokens
1 epoch = 41 batches


# normal training

In [None]:
import time, math
train_loader = DataLoaderLite(B=8, T=1024)
model=GPT(GPTConfig())
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0=time.time()
  x, y = train_loader.next_batch()
  x = x.to(device)
  y = y.to(device)
  optimizer.zero_grad()
  logits, loss = model(x, y)
  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1=time.time()
  dt = (t1-t0) * 1000
  tokens_per_sec = (train_loader.B + train_loader.T) / (t1-t0)
  print(f"step{i} , loss {loss.item()} , time {dt:.2f} ms, tokens/sec {tokens_per_sec:.2f}")


loaded 338025 tokens
1 epoch = 41 batches
step0 , loss 11.137648582458496 , time 932.85 ms, tokens/sec 1106.29
step1 , loss 9.672666549682617 , time 932.55 ms, tokens/sec 1106.64
step2 , loss 8.906072616577148 , time 936.65 ms, tokens/sec 1101.79
step3 , loss 8.671490669250488 , time 935.73 ms, tokens/sec 1102.88
step4 , loss 8.507978439331055 , time 933.56 ms, tokens/sec 1105.45
step5 , loss 8.427803039550781 , time 934.19 ms, tokens/sec 1104.70
step6 , loss 8.361255645751953 , time 935.96 ms, tokens/sec 1102.61
step7 , loss 8.01619815826416 , time 935.50 ms, tokens/sec 1103.15
step8 , loss 7.736024856567383 , time 937.84 ms, tokens/sec 1100.40
step9 , loss 7.6974639892578125 , time 938.01 ms, tokens/sec 1100.20
step10 , loss 7.671797275543213 , time 936.94 ms, tokens/sec 1101.46
step11 , loss 7.495182037353516 , time 938.55 ms, tokens/sec 1099.57
step12 , loss 7.4293036460876465 , time 938.68 ms, tokens/sec 1099.41
step13 , loss 7.186164855957031 , time 939.17 ms, tokens/sec 1098.84


# BF 32 training

In [None]:
import time
train_loader = DataLoaderLite(B=8, T=1024)
torch.set_float32_matmul_precision('high')
model=GPT(GPTConfig())
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0=time.time()
  x, y = train_loader.next_batch()
  x = x.to(device)
  y = y.to(device)
  optimizer.zero_grad()
  logits, loss = model(x, y)
  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1=time.time()
  dt = (t1-t0) * 1000
  tokens_per_sec = (train_loader.B + train_loader.T) / (t1-t0)
  print(f"step{i} , loss {loss.item()} , time {dt:.2f} ms, tokens/sec {tokens_per_sec:.2f}")

loaded 338025 tokens
1 epoch = 41 batches
step0 , loss 10.963109970092773 , time 934.57 ms, tokens/sec 1104.25
step1 , loss 9.52399730682373 , time 935.68 ms, tokens/sec 1102.94
step2 , loss 8.894760131835938 , time 932.21 ms, tokens/sec 1107.05
step3 , loss 8.691394805908203 , time 936.00 ms, tokens/sec 1102.57
step4 , loss 8.559069633483887 , time 936.77 ms, tokens/sec 1101.66
step5 , loss 8.408178329467773 , time 934.04 ms, tokens/sec 1104.87
step6 , loss 8.351126670837402 , time 933.51 ms, tokens/sec 1105.51
step7 , loss 8.024391174316406 , time 934.55 ms, tokens/sec 1104.28
step8 , loss 7.745752334594727 , time 934.25 ms, tokens/sec 1104.63
step9 , loss 7.717187881469727 , time 932.82 ms, tokens/sec 1106.32
step10 , loss 7.703033924102783 , time 932.40 ms, tokens/sec 1106.83
step11 , loss 7.492473125457764 , time 933.09 ms, tokens/sec 1106.00
step12 , loss 7.4355902671813965 , time 933.26 ms, tokens/sec 1105.80
step13 , loss 7.245229244232178 , time 932.52 ms, tokens/sec 1106.68
s

In [None]:
# from torch.cuda.amp import autocast, GradScaler

# scaler = GradScaler()

# for i in range(50):
#     t0 = time.time()
#     x, y = train_loader.next_batch()
#     x = x.to(device, non_blocking=True)
#     y = y.to(device, non_blocking=True)

#     optimizer.zero_grad(set_to_none=True)
#     with autocast(dtype=torch.float16):          # FP16 compute
#         logits, loss = model(x, y)

#     scaler.scale(loss).backward()
#     scaler.step(optimizer)
#     scaler.update()

#     # for fair timing:
#     torch.cuda.synchronize()
#     t1 = time.time()

#     dt_ms = (t1 - t0) * 1000
#     tokens = train_loader.B * train_loader.T     # not B+T
#     tps = tokens / (t1 - t0)
#     print(f"step {i}, loss {loss.item():.4f}, time {dt_ms:.2f} ms, tokens/s {tps:.1f}")


  scaler = GradScaler()
  with autocast(dtype=torch.float16):          # FP16 compute


step 0, loss 6.7530, time 419.87 ms, tokens/s 4877.7
step 1, loss 6.6962, time 223.31 ms, tokens/s 9171.3
step 2, loss 6.7090, time 221.67 ms, tokens/s 9238.8
step 3, loss 6.4120, time 221.87 ms, tokens/s 9230.6
step 4, loss 6.4810, time 221.78 ms, tokens/s 9234.4
step 5, loss 6.6615, time 223.67 ms, tokens/s 9156.3
step 6, loss 6.6375, time 223.40 ms, tokens/s 9167.3
step 7, loss 6.4719, time 220.86 ms, tokens/s 9273.0
step 8, loss 6.4212, time 220.35 ms, tokens/s 9294.4
step 9, loss 6.3914, time 223.72 ms, tokens/s 9154.3
step 10, loss 6.2609, time 221.94 ms, tokens/s 9227.9
step 11, loss 6.2923, time 221.62 ms, tokens/s 9241.2
step 12, loss 6.1004, time 222.05 ms, tokens/s 9223.3
step 13, loss 6.3277, time 222.62 ms, tokens/s 9199.5
step 14, loss 6.3900, time 223.54 ms, tokens/s 9161.5
step 15, loss 6.4900, time 223.41 ms, tokens/s 9166.9
step 16, loss 6.6758, time 222.82 ms, tokens/s 9191.4
step 17, loss 6.5570, time 222.25 ms, tokens/s 9214.7
step 18, loss 6.7468, time 221.88 ms, 

# Mixed precision T4 doesnt support but supports L4

In [None]:
import time
train_loader = DataLoaderLite(B=8, T=1024)
# torch.set_float32_matmul_precision('high')
model=GPT(GPTConfig())
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0=time.time()
  x, y = train_loader.next_batch()
  x = x.to(device)
  y = y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)

  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1=time.time()
  dt = (t1-t0) * 1000
  tokens_per_sec = (train_loader.B + train_loader.T) / (t1-t0)
  print(f"step{i} , loss {loss.item()} , time {dt:.2f} ms, tokens/sec {tokens_per_sec:.2f}")

loaded 338025 tokens
1 epoch = 41 batches
step0 , loss 11.059371948242188 , time 749.88 ms, tokens/sec 1376.22
step1 , loss 9.565448760986328 , time 750.11 ms, tokens/sec 1375.80
step2 , loss 9.601821899414062 , time 747.63 ms, tokens/sec 1380.36
step3 , loss 8.694334030151367 , time 747.59 ms, tokens/sec 1380.43
step4 , loss 8.53415584564209 , time 746.87 ms, tokens/sec 1381.76
step5 , loss 8.528876304626465 , time 747.66 ms, tokens/sec 1380.31
step6 , loss 8.43848991394043 , time 746.75 ms, tokens/sec 1381.98
step7 , loss 8.079373359680176 , time 747.27 ms, tokens/sec 1381.03
step8 , loss 7.797780990600586 , time 748.86 ms, tokens/sec 1378.09
step9 , loss 7.75493049621582 , time 746.60 ms, tokens/sec 1382.27
step10 , loss 7.741931915283203 , time 744.58 ms, tokens/sec 1386.02
step11 , loss 7.585170745849609 , time 748.80 ms, tokens/sec 1378.20
step12 , loss 7.533687591552734 , time 747.13 ms, tokens/sec 1381.28
step13 , loss 7.325366973876953 , time 747.58 ms, tokens/sec 1380.46
step

# TORCH.COMPILE

In [None]:
import time
train_loader = DataLoaderLite(B=8, T=1024)
# torch.set_float32_matmul_precision('high')
model=GPT(GPTConfig())
model.to(device)
model = torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0=time.time()
  x, y = train_loader.next_batch()
  x = x.to(device)
  y = y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)

  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1=time.time()
  dt = (t1-t0) * 1000
  tokens_per_sec = (train_loader.B + train_loader.T) / (t1-t0)
  print(f"step{i} , loss {loss.item()} , time {dt:.2f} ms, tokens/sec {tokens_per_sec:.2f}")

loaded 338025 tokens
1 epoch = 41 batches
step0 , loss 10.98843002319336 , time 31047.56 ms, tokens/sec 33.24
step1 , loss 9.501754760742188 , time 345.52 ms, tokens/sec 2986.79
step2 , loss 8.811883926391602 , time 350.64 ms, tokens/sec 2943.21
step3 , loss 8.60702133178711 , time 349.57 ms, tokens/sec 2952.22
step4 , loss 8.446940422058105 , time 346.24 ms, tokens/sec 2980.63
step5 , loss 8.384843826293945 , time 342.70 ms, tokens/sec 3011.41
step6 , loss 8.318967819213867 , time 347.51 ms, tokens/sec 2969.73
step7 , loss 7.9779558181762695 , time 348.63 ms, tokens/sec 2960.14
step8 , loss 7.697345733642578 , time 342.86 ms, tokens/sec 3009.95
step9 , loss 7.652235984802246 , time 348.52 ms, tokens/sec 2961.09
step10 , loss 7.6265668869018555 , time 352.67 ms, tokens/sec 2926.27
step11 , loss 7.456626892089844 , time 348.47 ms, tokens/sec 2961.48
step12 , loss 7.3813934326171875 , time 353.07 ms, tokens/sec 2922.89
step13 , loss 7.15690279006958 , time 352.87 ms, tokens/sec 2924.56
s

# with flash attn

In [None]:
import time
train_loader = DataLoaderLite(B=8, T=1024)
# torch.set_float32_matmul_precision('high')
model=GPT(GPTConfig())
model.to(device)
model = torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0=time.time()
  x, y = train_loader.next_batch()
  x = x.to(device)
  y = y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)

  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1=time.time()
  dt = (t1-t0) * 1000
  tokens_per_sec = (train_loader.B + train_loader.T) / (t1-t0)
  print(f"step{i} , loss {loss.item()} , time {dt:.2f} ms, tokens/sec {tokens_per_sec:.2f}")

loaded 338025 tokens
1 epoch = 41 batches
step0 , loss 11.025934219360352 , time 1656.06 ms, tokens/sec 623.17
step1 , loss 9.580589294433594 , time 236.78 ms, tokens/sec 4358.49
step2 , loss 9.052183151245117 , time 237.12 ms, tokens/sec 4352.19
step3 , loss 8.705053329467773 , time 247.11 ms, tokens/sec 4176.29
step4 , loss 8.576934814453125 , time 252.74 ms, tokens/sec 4083.20
step5 , loss 8.457185745239258 , time 243.87 ms, tokens/sec 4231.76
step6 , loss 8.398371696472168 , time 237.65 ms, tokens/sec 4342.49
step7 , loss 8.109476089477539 , time 244.88 ms, tokens/sec 4214.23
step8 , loss 7.8267903327941895 , time 247.62 ms, tokens/sec 4167.75
step9 , loss 7.7667951583862305 , time 246.41 ms, tokens/sec 4188.17
step10 , loss 7.740192890167236 , time 241.95 ms, tokens/sec 4265.34
step11 , loss 7.55926513671875 , time 244.23 ms, tokens/sec 4225.56
step12 , loss 7.518259048461914 , time 245.11 ms, tokens/sec 4210.43
step13 , loss 7.288984775543213 , time 246.10 ms, tokens/sec 4193.34


# NIce numbers / vocab size

In [None]:
import time
train_loader = DataLoaderLite(B=8, T=1024)
# torch.set_float32_matmul_precision('high')
model=GPT(GPTConfig(vocab_size=50304))
model.to(device)
model = torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
  t0=time.time()
  x, y = train_loader.next_batch()
  x = x.to(device)
  y = y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)

  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1=time.time()
  dt = (t1-t0) * 1000
  tokens_per_sec = (train_loader.B + train_loader.T) / (t1-t0)
  print(f"step{i} , loss {loss.item()} , time {dt:.2f} ms, tokens/sec {tokens_per_sec:.2f}")

loaded 338025 tokens
1 epoch = 41 batches
step0 , loss 10.879960060119629 , time 22362.86 ms, tokens/sec 46.15
step1 , loss 9.449729919433594 , time 231.75 ms, tokens/sec 4452.99
step2 , loss 8.74146842956543 , time 234.81 ms, tokens/sec 4395.00
step3 , loss 8.783228874206543 , time 240.61 ms, tokens/sec 4289.07
step4 , loss 8.411894798278809 , time 242.26 ms, tokens/sec 4259.87
step5 , loss 8.378988265991211 , time 228.68 ms, tokens/sec 4512.76
step6 , loss 8.351465225219727 , time 233.57 ms, tokens/sec 4418.41
step7 , loss 7.979005336761475 , time 234.17 ms, tokens/sec 4407.08
step8 , loss 7.680255889892578 , time 244.85 ms, tokens/sec 4214.87
step9 , loss 7.659424304962158 , time 235.12 ms, tokens/sec 4389.16
step10 , loss 7.673420429229736 , time 233.66 ms, tokens/sec 4416.72
step11 , loss 7.473594665527344 , time 239.59 ms, tokens/sec 4307.27
step12 , loss 7.435180187225342 , time 243.04 ms, tokens/sec 4246.29
step13 , loss 7.23840856552124 , time 240.80 ms, tokens/sec 4285.64
ste

# further optimizations based on GPT 3

In [None]:
import time
train_loader = DataLoaderLite(B=8, T=1024)
# torch.set_float32_matmul_precision('high')
model=GPT(GPTConfig(vocab_size=50304))
model.to(device)
model = torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)
for i in range(50):
  t0=time.time()
  x, y = train_loader.next_batch()
  x = x.to(device)
  y = y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)

  loss.backward()
  norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  optimizer.step()
  torch.cuda.synchronize()
  t1=time.time()
  dt = (t1-t0) * 1000
  tokens_per_sec = (train_loader.B + train_loader.T) / (t1-t0)
  print(f"step{i} , loss {loss.item()} , time {dt:.2f} ms, tokens/sec {tokens_per_sec:.2f}, norm {norm :.4f}")

loaded 338025 tokens
1 epoch = 41 batches
step0 , loss 10.949409484863281 , time 365.43 ms, tokens/sec 2824.08, norm 27.0177
step1 , loss 9.585895538330078 , time 232.29 ms, tokens/sec 4442.69, norm 8.7193
step2 , loss 9.138595581054688 , time 237.14 ms, tokens/sec 4351.83, norm 9.0948
step3 , loss 8.827247619628906 , time 238.49 ms, tokens/sec 4327.27, norm 4.1152
step4 , loss 8.654638290405273 , time 250.53 ms, tokens/sec 4119.30, norm 3.6695
step5 , loss 8.454200744628906 , time 249.86 ms, tokens/sec 4130.25, norm 2.6055
step6 , loss 8.382661819458008 , time 241.86 ms, tokens/sec 4267.00, norm 2.0046
step7 , loss 8.051900863647461 , time 238.51 ms, tokens/sec 4326.84, norm 2.3224
step8 , loss 7.736559867858887 , time 246.16 ms, tokens/sec 4192.39, norm 1.6099
step9 , loss 7.682575225830078 , time 247.60 ms, tokens/sec 4168.06, norm 1.8356
step10 , loss 7.654802322387695 , time 246.47 ms, tokens/sec 4187.05, norm 1.9479
step11 , loss 7.438798904418945 , time 244.07 ms, tokens/sec 422

# lr schedule

In [None]:
import time, math
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_loader = DataLoaderLite(B=8, T=1024)
# torch.set_float32_matmul_precision('high')
model=GPT(GPTConfig(vocab_size=50304))
model.to(device).train()
model = torch.compile(model)

max_lr = 6e-4
min_lr = max_lr * 0.1

warmup_steps = 10
max_steps = 50

def get_lr(it):
  if it < warmup_steps:
    return max_lr * (it+1) / warmup_steps
  if it > max_steps:
    return min_lr

  decay_ratio = (it-warmup_steps) / (max_steps-warmup_steps)

  assert 0<=decay_ratio<=1
  coeff = 0.5 *(1.0 + math.cos(math.pi * decay_ratio))
  return min_lr + coeff * (max_lr - min_lr)









optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)
for step in range(max_steps):
  t0=time.time()
  x, y = train_loader.next_batch()
  x = x.to(device)
  y = y.to(device)
  optimizer.zero_grad()
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)

  loss.backward()
  norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

  lr = get_lr(step)
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

  optimizer.step()
  torch.cuda.synchronize()
  t1=time.time()
  dt = (t1-t0) * 1000
  tokens_per_sec = (train_loader.B * train_loader.T) / (t1-t0)
  print(f"step{step} , loss {loss.item()} , time {dt:.2f} ms, tokens/sec {tokens_per_sec:.2f}, norm {norm :.4f}, lr {lr}")

loaded 338025 tokens
1 epoch = 41 batches
step0 , loss 11.047126770019531 , time 235.80 ms, tokens/sec 34741.90, norm 29.1180, lr 5.9999999999999995e-05
step1 , loss 9.64242172241211 , time 234.11 ms, tokens/sec 34991.80, norm 9.5832, lr 0.00011999999999999999
step2 , loss 9.007974624633789 , time 235.72 ms, tokens/sec 34753.32, norm 5.7911, lr 0.00017999999999999998
step3 , loss 9.886116027832031 , time 244.37 ms, tokens/sec 33523.46, norm 10.5479, lr 0.00023999999999999998
step4 , loss 9.111015319824219 , time 253.28 ms, tokens/sec 32344.24, norm 4.3315, lr 0.0003
step5 , loss 8.721638679504395 , time 248.93 ms, tokens/sec 32909.04, norm 3.1822, lr 0.00035999999999999997
step6 , loss 8.648611068725586 , time 237.93 ms, tokens/sec 34430.01, norm 4.0620, lr 0.00041999999999999996
step7 , loss 8.237208366394043 , time 240.11 ms, tokens/sec 34118.01, norm 4.0748, lr 0.00047999999999999996
step8 , loss 7.776317119598389 , time 245.07 ms, tokens/sec 33427.54, norm 1.7705, lr 0.000539999999

# Grad accumulation

In [None]:
import time, math
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_loader = DataLoaderLite(B=8, T=1024)
# torch.set_float32_matmul_precision('high')
model=GPT(GPTConfig(vocab_size=50304))
model.to(device).train()
model = torch.compile(model)

max_lr = 6e-4
min_lr = max_lr * 0.1

warmup_steps = 10
max_steps = 50

total_batch_size = 524288
B = 8
T = 1024

assert total_batch_size % (B*T) == 0
grad_acc_steps = total_batch_size // (B*T)

def get_lr(it):
  if it < warmup_steps:
    return max_lr * (it+1) / warmup_steps
  if it > max_steps:
    return min_lr

  decay_ratio = (it-warmup_steps) / (max_steps-warmup_steps)

  assert 0<=decay_ratio<=1
  coeff = 0.5 *(1.0 + math.cos(math.pi * decay_ratio))
  return min_lr + coeff * (max_lr - min_lr)






optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)
for step in range(max_steps):
  t0=time.time()
  x, y = train_loader.next_batch()
  loss_accum = 0.0
  for microstep in range(grad_acc_steps):

    x = x.to(device)
    y = y.to(device)
    optimizer.zero_grad()
    with torch.autocast(device_type=device, dtype=torch.bfloat16):
      logits, loss = model(x, y)
    loss = loss/grad_acc_steps
    loss_accum += loss.detach()
    loss.backward()


  norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

  lr = get_lr(step)
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

  optimizer.step()
  torch.cuda.synchronize()
  t1=time.time()
  dt = (t1-t0) * 1000
  tokens_per_sec = (train_loader.B * train_loader.T * grad_acc_steps) / (t1-t0)
  print(f"step{step} , loss {loss_accum.item()} , time {dt:.2f} ms, tokens/sec {tokens_per_sec:.2f}, norm {norm :.4f}, lr {lr}")

loaded 338025 tokens
1 epoch = 41 batches
step0 , loss 11.004355430603027 , time 13984.99 ms, tokens/sec 37489.33, norm 0.4644, lr 5.9999999999999995e-05
step1 , loss 9.662590980529785 , time 13962.02 ms, tokens/sec 37551.01, norm 0.1462, lr 0.00011999999999999999
step2 , loss 8.993795394897461 , time 13593.30 ms, tokens/sec 38569.59, norm 0.0883, lr 0.00017999999999999998
step3 , loss 9.354405403137207 , time 13479.50 ms, tokens/sec 38895.20, norm 0.1276, lr 0.00023999999999999998
step4 , loss 8.83443832397461 , time 13311.05 ms, tokens/sec 39387.42, norm 0.0630, lr 0.0003
step5 , loss 8.700010299682617 , time 13271.96 ms, tokens/sec 39503.44, norm 0.0363, lr 0.00035999999999999997
step6 , loss 8.677406311035156 , time 13321.60 ms, tokens/sec 39356.24, norm 0.0684, lr 0.00041999999999999996
step7 , loss 8.201203346252441 , time 13417.04 ms, tokens/sec 39076.28, norm 0.0457, lr 0.00047999999999999996
step8 , loss 7.775548934936523 , time 13458.12 ms, tokens/sec 38957.01, norm 0.0274, l