# Importing necessary libraries for SLM Training

In [39]:
import torch
import torch.nn as nn

In [40]:
import os, time, math, random
from tokenizers import Tokenizer
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
import torch.nn.functional as F

In [41]:
from tqdm import tqdm

# Root Mean Square Normalization

In [42]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):  
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # Force float32 for norm calculation to avoid overflow
        x_f32 = x.float()
        norm = x_f32.pow(2).mean(-1, keepdim=True)
        return x * torch.rsqrt(norm + self.eps).type_as(x) * self.weight

# Rotary Positional Embeddings (RoPE)

In [43]:
def rotate_half(x):
    x1 = x[..., :x.shape[-1]//2]
    x2 = x[..., x.shape[-1]//2:]
    return torch.cat([-x2, x1], dim=-1)

def apply_rope(x, rope_sin, rope_cos):
    return (x * rope_cos) + (rotate_half(x) * rope_sin)

# Multi-Head Self Attention (Causal)

In [44]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, mask, rope_sin, rope_cos):
        B, T, C = x.shape

        qkv = self.qkv(x)  # (B, T, 3*C)
        q, k, v = qkv.chunk(3, dim=-1)

        # reshape -> (B, heads, T, head_dim)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # apply RoPE to q,k
        q = apply_rope(q, rope_sin, rope_cos)
        k = apply_rope(k, rope_sin, rope_cos)

        att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        att = att.masked_fill(mask == 0, float('-inf'))
        att = torch.softmax(att, dim=-1)

        y = att @ v  # (B, heads, T, head_dim)
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        return self.proj(y)


# SwiGLU FFN

In [45]:
class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim)
        self.w2 = nn.Linear(dim, hidden_dim)
        self.w3 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        return self.w3(torch.nn.functional.silu(self.w1(x)) * self.w2(x))


# Transformer Block

In [46]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_dim):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads)
        self.norm2 = RMSNorm(dim)
        self.ffn = SwiGLU(dim, ffn_dim)

    def forward(self, x, mask, rope_sin, rope_cos):
        x = x + self.attn(self.norm1(x), mask, rope_sin, rope_cos)
        x = x + self.ffn(self.norm2(x))
        return x


# Decoder Model

In [47]:
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, dim, num_layers, num_heads, ffn_dim, max_seq_len=2048):
        super().__init__()

        self.vocab_size = vocab_size
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(vocab_size, dim)

        self.blocks = nn.ModuleList([
            TransformerBlock(dim, num_heads, ffn_dim)
            for _ in range(num_layers)
        ])

        self.norm = RMSNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)

        # weight tying
        self.lm_head.weight = self.token_emb.weight

        pos = torch.arange(max_seq_len)  # [T]
        freqs = 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2) / self.head_dim))
        # freqs len = head_dim/2

        sinusoid = torch.einsum("i,j->ij", pos, freqs)  # [T, head_dim/2]

        rope_sin = sinusoid.sin()   # [T, head_dim/2]
        rope_cos = sinusoid.cos()   # [T, head_dim/2]

        # Expand to: [1, 1, T, head_dim]
        rope_sin = torch.cat([rope_sin, rope_sin], dim=-1)
        rope_cos = torch.cat([rope_cos, rope_cos], dim=-1)

        self.register_buffer("rope_sin", rope_sin.unsqueeze(0).unsqueeze(0))
        self.register_buffer("rope_cos", rope_cos.unsqueeze(0).unsqueeze(0))
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, mask):
        B, T = idx.shape

        x = self.token_emb(idx)

        # Correct slice: [1,1,T,head_dim]
        rope_sin = self.rope_sin[:, :, :T, :]
        rope_cos = self.rope_cos[:, :, :T, :]

        for blk in self.blocks:
            x = blk(x, mask, rope_sin, rope_cos)

        x = self.norm(x)
        logits = self.lm_head(x)

        return logits


# Causal Mask

In [48]:
def causal_mask(T, device):
    m = torch.tril(torch.ones((T, T), dtype=torch.bool, device=device))
    return m.unsqueeze(0).unsqueeze(0)

# Dataset Class

In [49]:
class PackedLMDataset(Dataset):
    def __init__(self, file, tokenizer, seq_len=128):
        self.seq_len = seq_len
        self.examples = []
    
        self.eos_id = tokenizer.token_to_id("</s>")
        if self.eos_id is None:
            self.eos_id = tokenizer.token_to_id("<|endoftext|>")
        
        if self.eos_id is None:
            print("WARNING: No EOS token found! Defaulting to ID 0. Please check tokenizer.")
            self.eos_id = 0

        buffer = []
        with open(file, "r", encoding="utf-8") as f:
            for line in f:
                text = line.strip()
                if len(text) < 2: continue 

                ids = tokenizer.encode(text).ids
                buffer.extend(ids)
                buffer.append(self.eos_id)

        chunk_size = seq_len + 1
        total_chunks = len(buffer) // chunk_size
        
        print(f"Dataset {file}: Packed {len(buffer)} tokens into {total_chunks} sequences.")
        
        for i in range(total_chunks):
            start = i * chunk_size
            end = start + chunk_size
            self.examples.append(buffer[start:end])

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        chunk = self.examples[idx]
        data = torch.tensor(chunk, dtype=torch.long)
        x = data[:-1]
        y = data[1:]
        return x, y

In [50]:
DATA_DIR = "../prepared"
TRAIN_FILE = os.path.join(DATA_DIR, "lm_train.txt")
VAL_FILE = os.path.join(DATA_DIR, "lm_val.txt")
TOKENIZER_FILE = "../tokens/tokenizer.json"
OUTPUT_DIR = "checkpoints"
os.makedirs(OUTPUT_DIR, exist_ok=True)

VOCAB_SIZE = 16000
DIM = 256
NUM_LAYERS = 6
NUM_HEADS = 8
FFN_DIM = 1024
MAX_SEQ_LEN = 512

SEQ_LEN = 512
PHYSICAL_BATCH = 1
GRAD_ACCUM_STEPS = 32
EPOCHS = 15
LR = 3e-4
WEIGHT_DECAY = 0.01
SAVE_EVERY_STEPS = 2000
VAL_EVERY_STEPS = 2000
LOG_EVERY = 50
USE_AMP = True
NUM_WORKERS = 0
PIN_MEMORY = True


In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [52]:
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Devices:", torch.cuda.device_count())
print(torch.version.cuda)

CUDA Available: True
CUDA Devices: 1
12.1


In [53]:
def save_checkpoint(step, model, optimizer, scaler=None):
    ckpt = {
        "step": step,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict()
    }
    if scaler is not None:
        ckpt["scaler_state_dict"] = scaler.state_dict()
    fname = os.path.join(OUTPUT_DIR, f"ckpt_step_{step}.pt")
    torch.save(ckpt, fname)
    print("Saved checkpoint:", fname)

In [54]:
print("Loading tokenizer...")
tok = Tokenizer.from_file(TOKENIZER_FILE)

train_ds = PackedLMDataset(TRAIN_FILE, tok, seq_len=SEQ_LEN)
val_ds = PackedLMDataset(VAL_FILE, tok, seq_len=SEQ_LEN)
train_loader = DataLoader(train_ds, batch_size=PHYSICAL_BATCH, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=PHYSICAL_BATCH, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False)

print("Creating model...")
model = TransformerLM(VOCAB_SIZE, DIM, NUM_LAYERS, NUM_HEADS, FFN_DIM, max_seq_len=MAX_SEQ_LEN).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
print("Param count (M):", sum(p.numel() for p in model.parameters())/1e6)

scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and device.type=="cuda"))

Loading tokenizer...
Dataset ../prepared\lm_train.txt: Packed 8847968 tokens into 17247 sequences.
Dataset ../prepared\lm_val.txt: Packed 456854 tokens into 890 sequences.
Creating model...
Param count (M): 10.410752


In [55]:
@torch.no_grad()
def evaluate(model, val_loader, device, max_batches=None):
    model.eval()
    total_loss = 0.0
    n = 0

    for i, (x, y) in enumerate(val_loader):
        x, y = x.to(device), y.to(device)

        T = x.size(1)
        mask = causal_mask(T, device)
        logits = model(x, mask)

        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            y.view(-1)
        )

        total_loss += loss.item()
        n += 1
        if max_batches and n >= max_batches:
            break

    model.train()
    return total_loss / max(1, n)

In [56]:
def load_checkpoint(ckpt_path, model, optimizer, scaler=None, device="cuda"):
    print(f"Loading checkpoint from {ckpt_path} ...")
    checkpoint = torch.load(ckpt_path, map_location=device)

    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    if scaler and "scaler_state_dict" in checkpoint:
        scaler.load_state_dict(checkpoint["scaler_state_dict"])

    step = checkpoint.get("step", 0)
    print(f"Checkpoint loaded. Resuming from step {step}")
    return step


In [57]:
from tqdm import tqdm

def get_lr(step, warmup_steps=500, max_lr=LR, min_lr=1e-6, total_steps=200000):
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
    return min_lr + cosine_decay * (max_lr - min_lr)


resume_from = None

if resume_from:
    global_step = load_checkpoint(resume_from, model, optimizer, scaler, device=device)
else:
    global_step = 0

model.train()
print(f"Begin training from step {global_step}")

for epoch in range(EPOCHS):
    epoch_start = time.time()
    epoch_loss_sum = 0.0
    epoch_token_count = 0

    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}")

    for batch_idx, (x, y) in pbar:
        x, y = x.to(device), y.to(device)
        T = x.size(1)
        mask = causal_mask(T, device)

        lr_now = get_lr(global_step)
        optimizer.param_groups[0]["lr"] = lr_now

        with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
            logits = model(x, mask)
            loss_full = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))

        instant_loss = loss_full.item()
        epoch_loss_sum += instant_loss * x.numel()
        epoch_token_count += x.numel()

        loss = loss_full / GRAD_ACCUM_STEPS
        scaler.scale(loss).backward()

        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item()

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            global_step += 1
            avg_epoch_loss = epoch_loss_sum / epoch_token_count

            if global_step % LOG_EVERY == 0:
                pbar.set_postfix({
                    "inst_loss": f"{instant_loss:.6f}",
                    "avg_epoch": f"{avg_epoch_loss:.6f}",
                    "grad": f"{grad_norm:.6f}",
                    "lr": f"{lr_now:.6f}"
                })

            if global_step % VAL_EVERY_STEPS == 0:
                val_loss = evaluate(model, val_loader, device, max_batches=100)
                print(f"[VAL] step {global_step} | loss {val_loss:.6f}")
                save_checkpoint(global_step, model, optimizer, scaler)

            elif global_step % SAVE_EVERY_STEPS == 0:
                save_checkpoint(global_step, model, optimizer, scaler)

    epoch_loss_final = epoch_loss_sum / epoch_token_count
    print(f"Epoch {epoch+1} done | loss {epoch_loss_final:.6f} | time {time.time() - epoch_start:.2f}s")

save_checkpoint(global_step, model, optimizer, scaler)
print("Training complete.")


Begin training from step 0


Epoch 1: 100%|██████████| 17247/17247 [46:38<00:00,  6.16it/s, inst_loss=3.437768, avg_epoch=6.209408, grad=1.428424, lr=0.000299]


Epoch 1 done | loss 6.022870 | time 2798.25s


Epoch 2: 100%|██████████| 17247/17247 [46:37<00:00,  6.17it/s, inst_loss=3.413316, avg_epoch=2.999848, grad=1.050429, lr=0.000300]


Epoch 2 done | loss 2.981179 | time 2797.30s


Epoch 3: 100%|██████████| 17247/17247 [46:37<00:00,  6.16it/s, inst_loss=3.065455, avg_epoch=2.461367, grad=1.020763, lr=0.000300]


Epoch 3 done | loss 2.458145 | time 2797.97s


Epoch 4:  72%|███████▏  | 12352/17247 [33:25<57:07,  1.43it/s, inst_loss=2.485291, avg_epoch=2.224422, grad=1.181633, lr=0.000300]

[VAL] step 2000 | loss 2.291374
Saved checkpoint: checkpoints\ckpt_step_2000.pt


Epoch 4: 100%|██████████| 17247/17247 [46:40<00:00,  6.16it/s, inst_loss=1.901951, avg_epoch=2.210123, grad=1.092904, lr=0.000300]


Epoch 4 done | loss 2.210049 | time 2800.58s


Epoch 5: 100%|██████████| 17247/17247 [46:24<00:00,  6.19it/s, inst_loss=0.947873, avg_epoch=2.041641, grad=1.132612, lr=0.000300]


Epoch 5 done | loss 2.039837 | time 2784.47s


Epoch 6: 100%|██████████| 17247/17247 [46:37<00:00,  6.16it/s, inst_loss=2.103732, avg_epoch=1.906870, grad=1.140923, lr=0.000300]


Epoch 6 done | loss 1.906873 | time 2797.98s


Epoch 7: 100%|██████████| 17247/17247 [46:35<00:00,  6.17it/s, inst_loss=1.494647, avg_epoch=1.797836, grad=1.231954, lr=0.000300]


Epoch 7 done | loss 1.797819 | time 2795.08s


Epoch 8:  43%|████▎     | 7488/17247 [20:16<1:51:53,  1.45it/s, inst_loss=1.337981, avg_epoch=1.686850, grad=1.296434, lr=0.000300]

[VAL] step 4000 | loss 2.043457
Saved checkpoint: checkpoints\ckpt_step_4000.pt


Epoch 8: 100%|██████████| 17247/17247 [46:40<00:00,  6.16it/s, inst_loss=1.461035, avg_epoch=1.706231, grad=1.238158, lr=0.000300] 


Epoch 8 done | loss 1.706555 | time 2800.33s


Epoch 9: 100%|██████████| 17247/17247 [46:39<00:00,  6.16it/s, inst_loss=1.256974, avg_epoch=1.628789, grad=1.267757, lr=0.000300]


Epoch 9 done | loss 1.628502 | time 2799.41s


Epoch 10: 100%|██████████| 17247/17247 [46:38<00:00,  6.16it/s, inst_loss=1.962602, avg_epoch=1.558166, grad=1.324396, lr=0.000300]


Epoch 10 done | loss 1.561295 | time 2798.74s


Epoch 11: 100%|██████████| 17247/17247 [46:39<00:00,  6.16it/s, inst_loss=1.474430, avg_epoch=1.501053, grad=1.378304, lr=0.000299]


Epoch 11 done | loss 1.501828 | time 2799.01s


Epoch 12:  15%|█▌        | 2624/17247 [07:07<2:48:07,  1.45it/s, inst_loss=0.814274, avg_epoch=1.395229, grad=1.415693, lr=0.000299]

[VAL] step 6000 | loss 1.997204
Saved checkpoint: checkpoints\ckpt_step_6000.pt


Epoch 12: 100%|██████████| 17247/17247 [46:41<00:00,  6.16it/s, inst_loss=1.070801, avg_epoch=1.448779, grad=1.493665, lr=0.000299] 


Epoch 12 done | loss 1.449855 | time 2801.52s


Epoch 13: 100%|██████████| 17247/17247 [46:35<00:00,  6.17it/s, inst_loss=1.454180, avg_epoch=1.399046, grad=1.478283, lr=0.000299]


Epoch 13 done | loss 1.403319 | time 2795.68s


Epoch 14: 100%|██████████| 17247/17247 [46:30<00:00,  6.18it/s, inst_loss=1.429631, avg_epoch=1.358742, grad=1.419133, lr=0.000299]


Epoch 14 done | loss 1.360885 | time 2790.20s


Epoch 15:  87%|████████▋ | 14976/17247 [40:20<26:35,  1.42it/s, inst_loss=1.706369, avg_epoch=1.316011, grad=1.549766, lr=0.000299]

[VAL] step 8000 | loss 1.984653
Saved checkpoint: checkpoints\ckpt_step_8000.pt


Epoch 15: 100%|██████████| 17247/17247 [46:26<00:00,  6.19it/s, inst_loss=1.606475, avg_epoch=1.320162, grad=1.534188, lr=0.000299]


Epoch 15 done | loss 1.323051 | time 2786.79s
Saved checkpoint: checkpoints\ckpt_step_8070.pt
Training complete.
