In [1]:
import json, types

CONFIG_PATH = "models--HuggingFaceTB--SmolLM2-135M/snapshots/93efa2f097d58c2a74874c7e644dbc9b0cee75a2/config.json"   # <- change this
MODEL_DIR   = "models--HuggingFaceTB--SmolLM2-135M/snapshots/93efa2f097d58c2a74874c7e644dbc9b0cee75a2/"              # folder containing model.safetensors

with open(CONFIG_PATH, "r") as f:
    cfg_dict = json.load(f)

# turn into attribute-style object
config = types.SimpleNamespace(**cfg_dict)

config


namespace(architectures=['LlamaForCausalLM'],
          attention_bias=False,
          attention_dropout=0.0,
          bos_token_id=0,
          eos_token_id=0,
          hidden_act='silu',
          hidden_size=576,
          initializer_range=0.041666666666666664,
          intermediate_size=1536,
          is_llama_config=True,
          max_position_embeddings=8192,
          model_type='llama',
          num_attention_heads=9,
          num_hidden_layers=30,
          num_key_value_heads=3,
          pretraining_tp=1,
          rms_norm_eps=1e-05,
          rope_interleaved=False,
          rope_scaling=None,
          rope_theta=100000,
          tie_word_embeddings=True,
          torch_dtype='bfloat16',
          transformers_version='4.40.1',
          use_cache=True,
          vocab_size=49152)

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
# increasing the batch size x2 + grad_accum = 4 

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        return torch.nn.functional.rms_norm(
            x, 
            normalized_shape=(x.size(-1),),
            weight=self.weight,
            eps=self.eps,
        )

class RotaryEmbedding(nn.Module):
    """
    HF-style RoPE, theta = rope_theta.
    Applies RoPE to first head_dim dims (which is full head_dim in LLaMA).
    """
    def __init__(self, head_dim, base=10000):
        super().__init__()
        self.head_dim = head_dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seq_len, device, dtype):
        # positions: (seq_len,)
        t = torch.arange(seq_len, device=device).float()
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  # (T, head_dim/2)
        emb = torch.cat([freqs, freqs], dim=-1)            # (T, head_dim)
        sin = emb.sin()[None, None, :, :]                 # (1,1,T,head_dim)
        cos = emb.cos()[None, None, :, :]                 # (1,1,T,head_dim)
        return sin.to(dtype), cos.to(dtype)


def apply_rope(x, sin, cos):
    """
    x: (B, n_heads, T, head_dim)
    sin/cos: (1,1,T,head_dim)
    """
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    # rotate pairs
    x_rot = torch.stack((-x2, x1), dim=-1).reshape_as(x)
    return x * cos + x_rot * sin

class LlamaAttentionSDPA(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads         # Hq
        self.num_kv_heads = config.num_key_value_heads      # Hkv
        self.head_dim = self.hidden_size // self.num_heads

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

        self.rope = RotaryEmbedding(self.head_dim, base=config.rope_theta)
        self.attn_dropout = float(getattr(config, "attention_dropout", 0.0))

    def forward(self, x, attention_mask=None):
        B, T, C = x.shape
        device, dtype = x.device, x.dtype

        # Project to Q, K, V  -> (B,T,H,D)
        q = self.q_proj(x).reshape(B, T, self.num_heads, self.head_dim)
        k = self.k_proj(x).reshape(B, T, self.num_kv_heads, self.head_dim)
        v = self.v_proj(x).reshape(B, T, self.num_kv_heads, self.head_dim)

        # RoPE wants (B,H,T,D)
        sin, cos = self.rope(T, device=device, dtype=dtype)
        q = apply_rope(q.transpose(1, 2), sin, cos)  # (B,Hq,T,D)
        k = apply_rope(k.transpose(1, 2), sin, cos)  # (B,Hkv,T,D)
        v = v.transpose(1, 2)                        # (B,Hkv,T,D)  ‚úÖ IMPORTANT

        # GQA: expand K,V heads to match Q heads
        if self.num_kv_heads != self.num_heads:
            repeat_factor = self.num_heads // self.num_kv_heads
            k = k.repeat_interleave(repeat_factor, dim=1)  # (B,Hq,T,D)
            v = v.repeat_interleave(repeat_factor, dim=1)  # (B,Hq,T,D)

        # SDPA fused causal attention
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=self.attn_dropout if self.training else 0.0,
            is_causal=True,
        )  # (B,Hq,T,D)

        out = out.transpose(1, 2).reshape(B, T, C)
        return self.o_proj(out)


class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj   = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.act = nn.SiLU()

    def forward(self, x):
        # LLaMA MLP = SiLU(gate_proj(x)) * up_proj(x) then down_proj
        return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = LlamaAttentionSDPA(config)
        #self.self_attn = LlamaAttention(config)
        self.mlp = LlamaMLP(config)

        self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)

    def forward(self, x, attention_mask=None):
        # attention
        h = x + self.self_attn(self.input_layernorm(x), attention_mask=attention_mask)
        # mlp
        out = h + self.mlp(self.post_attention_layernorm(h))
        return out

class LlamaModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)

    def forward(self, input_ids, attention_mask=None):
        x = self.embed_tokens(input_ids)
        for layer in self.layers:
            x = layer(x, attention_mask=attention_mask)
        return self.norm(x)

class LlamaForCausalLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.model = LlamaModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # tie embeddings if needed
        if getattr(config, "tie_word_embeddings", False):
            self.lm_head.weight = self.model.embed_tokens.weight

    def forward(self, input_ids, labels=None, attention_mask=None):
        # transformer hidden states
        h = self.model(input_ids, attention_mask=attention_mask)
        logits = self.lm_head(h)

        # INFERENCE PATH
        if labels is None:
            return logits, None

        # TRAINING PATH
        # shift for autoregressive loss
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()

        loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )

        return logits, loss

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if config.torch_dtype == "bfloat16" else torch.float16

# CHANGES IN CURRENT CODE
#torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


model = LlamaForCausalLM(config).to(device=device, dtype=dtype)
print (model.eval())

sum(p.numel() for p in model.parameters()) / 1e6

from safetensors.torch import load_file
import os

weights_path = os.path.join(MODEL_DIR, "model.safetensors")
state = load_file(weights_path)

layer = model.model.layers[0].self_attn.q_proj.weight

print("Before:", layer.view(-1)[:5])

# load into our model
missing, unexpected = model.load_state_dict(state, strict=False)

layer = model.model.layers[0].self_attn.q_proj.weight

print("After:", layer.view(-1)[:5])

print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

# Fix for tied embeddings
if "lm_head.weight" in missing:
    model.lm_head.weight = model.model.embed_tokens.weight
    print("‚Üí Tied lm_head.weight to embed_tokens.weight")

# print a few to debug if any mismatch
print("Example missing:", missing[:10])
print("Example unexpected:", unexpected[:10])

from transformers import AutoTokenizer

TOKENIZER_DIR = "models--HuggingFaceTB--SmolLM2-135M/snapshots/93efa2f097d58c2a74874c7e644dbc9b0cee75a2/"  # where tokenizer.json is
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)

with open("input.txt", "r") as f:
    text = f.read()

from transformers import AutoTokenizer
import torch

class DataLoaderLite:
    def __init__(self, B, T, tokenizer_path=TOKENIZER_DIR):
        self.B = B
        self.T = T

        # load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

        # load raw text
        with open('input.txt', 'r', encoding='utf-8') as f:
            text = f.read()

        # tokenize entire dataset at once
        tokens = self.tokenizer(text).input_ids
        self.tokens = torch.tensor(tokens, dtype=torch.long)
        
        print(f"Loaded {len(self.tokens)} tokens")
        print(f"1 epoch = {len(self.tokens) // (B * T)} batches")

        # pointer
        self.current_position = 0
    
    def next_batch(self):
        B, T = self.B, self.T

        # get (B*T) + 1 tokens for labels
        buf = self.tokens[self.current_position : self.current_position + B*T + 1]

        # reset if we hit the end
        if len(buf) < B*T + 1:
            self.current_position = 0
            buf = self.tokens[: B*T + 1]

        x = buf[:-1].view(B, T)  # (B, T)
        y = buf[1:].view(B, T)   # (B, T)

        self.current_position += B*T
        return x, y


  self.setter(val)
    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttentionSDPA(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rope): RotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act): SiLU()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (lm_head): Linear(in_features=576, out_features=49152,

In [4]:
import os
import torch
from torch import nn
import torch.nn.functional as F
import time
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

# ------------------------------
# CONFIGURATION
# ------------------------------

device = "cuda"

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5000)

# ‚úÖ TensorBoard Writer
writer = SummaryWriter(log_dir="runs/llama_training")

train_loader = DataLoaderLite(B=16, T=1024, tokenizer_path=TOKENIZER_DIR)

CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

grad_accum = 4
max_steps = 5000
save_interval = 500
predict_interval = 500

# ------------------------------
# RESUME TRAINING (if exists)
# ------------------------------

RESUME_FROM = f"{CHECKPOINT_DIR}/latest.pt"
start_step = 0

if os.path.exists(RESUME_FROM):
    ckpt = torch.load(RESUME_FROM, map_location="cuda")

    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    scheduler.load_state_dict(ckpt["scheduler"])
    start_step = ckpt["step"]

    torch.set_rng_state(ckpt["cpu_rng"].cpu())
    torch.cuda.set_rng_state_all(ckpt["cuda_rng"])

    print(f"üîÑ Resumed from step: {start_step}")
else:
    print("‚û°Ô∏è Starting training from scratch")

optimizer.zero_grad()

# ------------------------------
# TRAIN LOOP
# ------------------------------

pbar = tqdm(range(start_step, max_steps), desc="Training")

for step in pbar:

    t0 = time.time()

    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)

    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = model(x, y)
        loss = loss / grad_accum

    loss.backward()

    if ((step + 1) % grad_accum) == 0:
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    dt_ms = (time.time() - t0) * 1000
    tokens_processed = train_loader.B * train_loader.T
    tokens_per_sec = tokens_processed / (dt_ms / 1000)

    current_lr = scheduler.get_last_lr()[0]

    # ----------------- TensorBoard Logging -----------------
    writer.add_scalar("loss/train", loss.item(), step)
    writer.add_scalar("learning_rate", current_lr, step)
    writer.add_scalar("tokens_per_sec", tokens_per_sec, step)

    # --- tqdm progress bar display ---
    pbar.set_postfix({
        "loss": f"{loss.item():.4f}",
        "lr": f"{current_lr:.6f}",
        "tok/s": f"{tokens_per_sec:.1f}",
        "dt(ms)": f"{dt_ms:.2f}"
    })

    # ---------------------- TEXT GENERATION ---------------------------
    if (step + 1) % predict_interval == 0:
        model.eval()
        prompt = "The meaning of life is"
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

        with torch.no_grad():
            out = model(input_ids)[0]
            next_token = out[:, -1].argmax(-1)
            text = tokenizer.decode(next_token)

        writer.add_text("generation/sample", f"{prompt} {text}", step)
        print(f"\nüìù Sample (step {step}): {prompt} {text}")
        model.train()

    # ------------------------- CHECKPOINTING ---------------------------
    if (step + 1) % save_interval == 0:
        ckpt_path = f"{CHECKPOINT_DIR}/step_{step+1}.pt"

        torch.save({
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "step": step + 1,
            "cpu_rng": torch.get_rng_state(),
            "cuda_rng": torch.cuda.get_rng_state_all(),
        }, ckpt_path)

        print(f"\nüíæ Saved checkpoint: {ckpt_path}")

writer.close()
print("üéâ Training finished!")


Token indices sequence length is longer than the specified maximum sequence length for this model (341094 > 8192). Running this sequence through the model will result in indexing errors


Loaded 341094 tokens
1 epoch = 20 batches
‚û°Ô∏è Starting training from scratch


Training:  10%|‚ñâ         | 499/5000 [06:28<57:55,  1.30it/s, loss=0.9962, lr=0.000300, tok/s=27851.7, dt(ms)=588.26]  


üìù Sample (step 499): The meaning of life is  to


Training:  10%|‚ñà         | 500/5000 [06:29<1:22:40,  1.10s/it, loss=0.9962, lr=0.000300, tok/s=27851.7, dt(ms)=588.26]


üíæ Saved checkpoint: checkpoints/step_500.pt


Training:  20%|‚ñà‚ñâ        | 999/5000 [12:57<51:31,  1.29it/s, loss=0.5026, lr=0.000298, tok/s=27571.7, dt(ms)=594.23]  


üìù Sample (step 999): The meaning of life is  to


Training:  20%|‚ñà‚ñà        | 1000/5000 [12:58<1:13:17,  1.10s/it, loss=0.5026, lr=0.000298, tok/s=27571.7, dt(ms)=594.23]


üíæ Saved checkpoint: checkpoints/step_1000.pt


Training:  30%|‚ñà‚ñà‚ñâ       | 1499/5000 [19:26<45:10,  1.29it/s, loss=0.0528, lr=0.000296, tok/s=27382.4, dt(ms)=598.34]  


üìù Sample (step 1499): The meaning of life is  to


Training:  30%|‚ñà‚ñà‚ñà       | 1500/5000 [19:27<1:03:02,  1.08s/it, loss=0.0528, lr=0.000296, tok/s=27382.4, dt(ms)=598.34]


üíæ Saved checkpoint: checkpoints/step_1500.pt


Training:  40%|‚ñà‚ñà‚ñà‚ñâ      | 1999/5000 [25:54<38:36,  1.30it/s, loss=0.0192, lr=0.000293, tok/s=27700.5, dt(ms)=591.47]  


üìù Sample (step 1999): The meaning of life is  to


Training:  40%|‚ñà‚ñà‚ñà‚ñà      | 2000/5000 [25:55<53:46,  1.08s/it, loss=0.0192, lr=0.000293, tok/s=27700.5, dt(ms)=591.47]


üíæ Saved checkpoint: checkpoints/step_2000.pt


Training:  50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 2499/5000 [32:23<32:08,  1.30it/s, loss=0.0129, lr=0.000289, tok/s=27894.7, dt(ms)=587.35]


üìù Sample (step 2499): The meaning of life is  to


Training:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 2500/5000 [32:24<44:25,  1.07s/it, loss=0.0129, lr=0.000289, tok/s=27894.7, dt(ms)=587.35]


üíæ Saved checkpoint: checkpoints/step_2500.pt


Training:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 2999/5000 [38:52<25:48,  1.29it/s, loss=0.0104, lr=0.000284, tok/s=27902.0, dt(ms)=587.20]


üìù Sample (step 2999): The meaning of life is  to


Training:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 3000/5000 [38:53<35:54,  1.08s/it, loss=0.0104, lr=0.000284, tok/s=27902.0, dt(ms)=587.20]


üíæ Saved checkpoint: checkpoints/step_3000.pt


Training:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 3499/5000 [45:21<19:19,  1.29it/s, loss=0.0090, lr=0.000278, tok/s=27653.7, dt(ms)=592.47]


üìù Sample (step 3499): The meaning of life is  to


Training:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 3500/5000 [45:22<27:15,  1.09s/it, loss=0.0090, lr=0.000278, tok/s=27653.7, dt(ms)=592.47]


üíæ Saved checkpoint: checkpoints/step_3500.pt


Training:  80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ  | 3999/5000 [51:49<12:54,  1.29it/s, loss=0.0081, lr=0.000271, tok/s=28186.1, dt(ms)=581.28]


üìù Sample (step 3999): The meaning of life is  to


Training:  80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 4000/5000 [51:50<17:50,  1.07s/it, loss=0.0081, lr=0.000271, tok/s=28186.1, dt(ms)=581.28]


üíæ Saved checkpoint: checkpoints/step_4000.pt


Training:  90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ | 4499/5000 [58:18<06:28,  1.29it/s, loss=0.0075, lr=0.000264, tok/s=27964.3, dt(ms)=585.89]


üìù Sample (step 4499): The meaning of life is  to


Training:  90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 4500/5000 [58:19<09:08,  1.10s/it, loss=0.0075, lr=0.000264, tok/s=27964.3, dt(ms)=585.89]


üíæ Saved checkpoint: checkpoints/step_4500.pt


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 4999/5000 [1:04:47<00:00,  1.30it/s, loss=0.0073, lr=0.000256, tok/s=27943.3, dt(ms)=586.33]


üìù Sample (step 4999): The meaning of life is  to


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5000/5000 [1:04:48<00:00,  1.29it/s, loss=0.0073, lr=0.000256, tok/s=27943.3, dt(ms)=586.33]


üíæ Saved checkpoint: checkpoints/step_5000.pt
üéâ Training finished!





In [8]:
# ------------------------------
# CONFIGURATION
# ------------------------------
from tqdm import tqdm

device = "cuda"

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5000)

train_loader = DataLoaderLite(B=16, T=1024, tokenizer_path=TOKENIZER_DIR)

CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

grad_accum = 4
max_steps = 5500
save_interval = 500
predict_interval = 500

start_step = 0

# ------------------------------
# BUILD MODEL
# ------------------------------
#model = LlamaForCausalLM(config).to(device) # already instanced
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# ------------------------------
# RESUME TRAINING (if exists)
# ------------------------------
# ------------------------------
# RESUME TRAINING (SAFE VERSION)
# ------------------------------
RESUME_FROM = f"{CHECKPOINT_DIR}/step_5000.pt"
start_step = 0

if os.path.exists(RESUME_FROM):

    ckpt = torch.load(RESUME_FROM, map_location="cuda")
    print(f"üîÑ Loading checkpoint: {RESUME_FROM}")

    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    scheduler.load_state_dict(ckpt["scheduler"])
    start_step = ckpt["step"]

    # ------------------------------
    # SAFE RNG RESTORE
    # ------------------------------

    # CPU RNG
    cpu_rng = ckpt["cpu_rng"]
    if cpu_rng.device != torch.device("cpu"):
        cpu_rng = cpu_rng.cpu()
    torch.set_rng_state(cpu_rng)

    # CUDA RNG ‚Äî FIX FOR YOUR ERROR
    cuda_rng_list = ckpt["cuda_rng"]
    
    # ensure list
    if isinstance(cuda_rng_list, torch.Tensor):
        cuda_rng_list = [cuda_rng_list]

    safe_cuda_states = []
    for s in cuda_rng_list:
        if not isinstance(s, torch.ByteTensor):
            s = s.to(torch.uint8)
        if s.device.type != "cpu":
            s = s.cpu()
        safe_cuda_states.append(s)

    torch.cuda.set_rng_state_all(safe_cuda_states)

    print(f"üîÑ Resumed from step {start_step}")

else:
    print("‚û°Ô∏è Starting training from scratch")

# -----------------------------------------------------------------------------

# ------------------------------
# TRAINING LOOP (with timing + throughput)
# ------------------------------
from tqdm import tqdm

optimizer.zero_grad()

pbar = tqdm(range(start_step, max_steps), desc="Training")

for step in pbar:

    t0 = time.time()

    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)

    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = model(x, y)
        loss = loss / grad_accum

    loss.backward()

    # Apply optimizer step only every grad_accum steps
    if ((step + 1) % grad_accum) == 0:
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    # --- SPEED METRICS ---
    dt_ms = (time.time() - t0) * 1000
    tokens_processed = train_loader.B * train_loader.T
    tokens_per_sec = tokens_processed / (dt_ms / 1000)

    current_lr = scheduler.get_last_lr()[0]

    # --- TensorBoard Logging ---
    writer.add_scalar("loss/train", loss.item(), step)
    writer.add_scalar("learning_rate", current_lr, step)
    writer.add_scalar("tokens_per_sec", tokens_per_sec, step)

    # --- tqdm progress bar display ---
    pbar.set_postfix({
        "loss": f"{loss.item():.4f}",
        "lr": f"{current_lr:.6f}",
        "tok/s": f"{tokens_per_sec:.1f}",
        "dt(ms)": f"{dt_ms:.2f}"
    })

    # ---------------------- TEXT GENERATION ---------------------------
    if (step + 1) % predict_interval == 0:
        model.eval()
        prompt = "The meaning of life is"
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

        with torch.no_grad():
            out = model(input_ids)[0]
            next_token = out[:, -1].argmax(-1)
            text = tokenizer.decode(next_token)

        writer.add_text("generation/sample", f"{prompt} {text}", step)
        print(f"\nüìù Sample (step {step}): {prompt} {text}")
        model.train()
    # ------------------------------------------------------------------

    # ------------------------- CHECKPOINTING ---------------------------
    if (step + 1) % save_interval == 0:
        ckpt_path = f"{CHECKPOINT_DIR}/step_{step+1}.pt"

        torch.save({
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "step": step + 1,
            "cpu_rng": torch.get_rng_state(),
            "cuda_rng": torch.cuda.get_rng_state_all(),
        }, ckpt_path)

        print(f"\nüíæ Saved checkpoint: {ckpt_path}")

writer.close()
print("üéâ Training finished!")


Token indices sequence length is longer than the specified maximum sequence length for this model (341094 > 8192). Running this sequence through the model will result in indexing errors


Loaded 341094 tokens
1 epoch = 20 batches


TypeError: RNG state must be a torch.ByteTensor