# Created with ❤️ by Haidar Ali YOUSEF  

[Connect with me on LinkedIn](https://www.linkedin.com/in/haidar-ali-yousef-815018231/)


# Clean UP


In [None]:
%reset -f
!rm -rf /content/*
!rm -rf ~/.cache/pip
!apt-get clean
!rm -rf /var/lib/apt/lists/* /tmp/*
!df -h


Filesystem      Size  Used Avail Use% Mounted on
overlay          30G   17M   30G   1% /
tmpfs            64M     0   64M   0% /dev
/dev/nvme1n1     80G   34G   47G  43% /workspace
shm              29G     0   29G   0% /dev/shm
/dev/nvme0n1p2  1.8T   24G  1.7T   2% /usr/bin/nvidia-smi
tmpfs           252G     0  252G   0% /sys/fs/cgroup
tmpfs           252G   12K  252G   1% /proc/driver/nvidia
tmpfs           252G  4.0K  252G   1% /etc/nvidia/nvidia-application-profiles-rc.d
tmpfs            51G   39M   51G   1% /run/nvidia-persistenced/socket
tmpfs           252G     0  252G   0% /proc/asound
tmpfs           252G     0  252G   0% /proc/acpi
tmpfs           252G     0  252G   0% /proc/scsi
tmpfs           252G     0  252G   0% /sys/firmware
tmpfs           252G     0  252G   0% /sys/devices/virtual/powercap


# Import Libraries

In [2]:
!pip install numpy pandas torch transformers datasets tokenizers huggingface_hub pydantic rich psutil tqdm hf_transfer pynvml peft 

Collecting pandas
  Using cached pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)
Collecting transformers
  Using cached transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
Collecting datasets
  Using cached datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting tokenizers
  Using cached tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting huggingface_hub
  Using cached huggingface_hub-1.1.7-py3-none-any.whl.metadata (13 kB)
Collecting pydantic
  Using cached pydantic-2.12.5-py3-none-any.whl.metadata (90 kB)
Collecting rich
  Using cached rich-14.2.0-py3-none-any.whl.metadata (18 kB)
Collecting tqdm
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting hf_transfer
  Using cached hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Collecting pynvml
  Using cached pynvml-13.0.1-py3-none-any.whl.metadata (5.6 kB)
Collecting peft
  Us

In [None]:
#standard libraries
import os
import math
import time
import random
import warnings
from typing import Optional,Tuple,List,Union,Iterator
from collections import defaultdict
#computation linraries
import numpy as np
import pandas as pd
#Pytorch libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset ,DataLoader

#for automatic mixed precision and context management
from contextlib import nullcontext
try:
  import transformers , datasets
except:
  !pip install transformers
  !pip install datasets
  import transformers , datasets

from datasets import load_dataset

from transformers import AutoTokenizer, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import (
    CausalLMOutputWithPast,
    BaseModelOutputWithPast
)
# could be helpfull
from transformers.activations import ACT2FN  
from tokenizers import Tokenizer as HFTokenizer
import tokenizers.models as hf_models
from tokenizers import trainers as hf_trainers
from tokenizers import pre_tokenizers as hf_pre_tokenizers
from tokenizers import decoders as hf_decoders


torch.cuda.empty_cache()

# Use Pretrained Tokenizer

In [None]:
from transformers import AutoTokenizer
from huggingface_hub import snapshot_download
import os

#download the GPT-2 files to a local cache directory
cache_dir = "./gpt2_local_cache"  
repo_id = "gpt2"
local_model_path = snapshot_download(
    repo_id=repo_id,
    cache_dir=cache_dir,
    local_dir_use_symlinks=False,  # ensures full download
    ignore_patterns=["*.h5"]  #  skip any unnecessary large files if present
)
print(f"Downloaded GPT-2 to: {local_model_path}")
tokenizer=AutoTokenizer.from_pretrained(
    local_model_path,
    local_files_only=True,
    use_fast=True
)
print(f"Vocab size: {tokenizer.vocab_size}")

special_tokens={
    "pad_token":"<|pad|>",
    "additional_special_tokens":["<|user|>", "<|assistant|>"]
}
#add the new tokens
num_added_tokens=tokenizer.add_special_tokens(special_tokens)
print(f"Added {num_added_tokens} new tokens")
if tokenizer.pad_token is None:
  tokenizer.pad_token=tokenizer.eos_token
PAD_TOKEN_ID=tokenizer.pad_token_id
print(f"PAD_TOKEN_ID: {PAD_TOKEN_ID}")
print(f"Vocab size: {tokenizer.vocab_size}") # actually 50260 if you do len(tokenizer)
print(f"Special tokens: {tokenizer.all_special_tokens}")

Downloaded GPT-2 to: ./gpt2_local_cache/models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e
Vocab size: 50257
Added 3 new tokens
PAD_TOKEN_ID: 50257
Vocab size: 50257
Special tokens: ['<|endoftext|>', '<|pad|>', '<|user|>', '<|assistant|>']


In [5]:
tokenizer.eos_token_id

50256

# Building the Transformer Architecture

In [None]:
DEVICE="cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [None]:
# MLP Class
class MLP(nn.Module):
  """
  Multi-Layer Perceptron with one hidden Layer
  It expands the input embedding size,applies GELU activation and then projects it back
  """
  def __init__(self,n_embed,dropout=0.1,rms_norm=False):
    super().__init__()
    self.hidden=nn.Linear(n_embed,4*n_embed,bias=not rms_norm)
    self.gelu=nn.GELU()
    self.proj=nn.Linear(4*n_embed,n_embed)
    self.dropout=nn.Dropout(dropout)
  def forward(self,x):
    # standard MLP in transformer: Linear -> GELU -> Linear -> Dropout
    return  self.dropout(self.proj(self.gelu(self.hidden(x))))


In [None]:

#MultiHead Attention Class
class MultiHeadAttention(nn.Module):
  """
  This module combines multiple attention heads in parallel .The outputs of each head
  are concatenated to form the final output
  """
  def __init__(self,n_head,n_embed,context_length,dropout=0.1):
    super().__init__()
    assert n_embed % n_head==0,"n_embed must be devisible by n_head"
    self.context_length=context_length
    self.n_head=n_head
    self.head_size=n_embed//n_head
    assert self.head_size%2==0,"head_size must be even for RoPE (split into pairs)."
    self.n_embed=n_embed
    #single linear layer for multi-head projections
    self.key=nn.Linear(n_embed,n_embed,bias=False)
    self.query=nn.Linear(n_embed,n_embed,bias=False)
    self.value=nn.Linear(n_embed,n_embed,bias=False)
    self.proj = nn.Linear(n_embed, n_embed)
    self.attn_dropout=nn.Dropout(dropout)
    self.resid_dropout=nn.Dropout(dropout)


    # RoPE cache
    self.register_buffer('cos_cached',torch.empty(0),persistent=False)
    self.register_buffer('sin_cached',torch.empty(0),persistent=False)
    self.register_buffer('causal_mask',torch.tril(torch.ones(context_length,context_length)))
   
    self._build_rope_cache()

  def _build_rope_cache(self,device=None,seq_len=None):
    device=device or next(self.parameters()).device
    seq_len=seq_len or self.context_length
    i_th=torch.arange(0,self.head_size,2,dtype=torch.float32,device=device)
    omega=10000**(-i_th/self.head_size) # head_size//2
    position=torch.arange(seq_len,dtype=torch.float32,device=device)
    theta=torch.outer(position,omega) # (context_length ,head_size//2)

    sin=torch.sin(theta).unsqueeze(0).unsqueeze(0) # (1,1,context_length ,head_size//2)
    cos=torch.cos(theta).unsqueeze(0).unsqueeze(0) # (1,1,context_length ,head_size//2)
    self.sin_cached=sin
    self.cos_cached=cos
    self.causal_mask=self.causal_mask.to(device)
  def apply_rope(self,q,k,start_pos=0):
    # q,k : (B,n_head,T,head_size)
    B,n_head,T,C=q.shape
    max_pos=start_pos+T
    if self.sin_cached.numel() == 0 or self.sin_cached.shape[2]<max_pos or self.sin_cached.device != q.device:
        self._build_rope_cache(device=q.device,seq_len=max_pos)
    sin=self.sin_cached[:,:,start_pos:start_pos+T,:] # (1,1,T,head_size//2)
    cos=self.cos_cached[:,:,start_pos:start_pos+T,:] # (1,1,T,head_size//2)
    def rotate(x):
      x1=x[...,0::2]
      x2=x[...,1::2]
      return torch.cat([x1*cos-x2*sin,x1*sin+x2*cos],dim=-1)
    return rotate(q),rotate(k)
  def forward(self,x,attn_mask=None,cache=None,use_flash_attn=True):
    """
    x: (B, T, C)
    attn_mask: (B, T)  with 1 for valid tokens and 0 for padding
    use_flash_attn: whether to use flash attention if available (not used during inference)
    cache: dict with keys 'k' and 'v' for cached past keys and values
    Returns:
        out: attention output
        updated_cache: dict with updated 'k' and 'v'
    """
    #assure input is the same dtype as weights
    if x.dtype != self.query.weight.dtype:
        x = x.to(dtype=self.query.weight.dtype)
    #assure input on the same device as rope cache
    if self.sin_cached.device != x.device:
      self._build_rope_cache(device=x.device)
    B,T,C=x.shape
    #project all heads at once
    q=self.query(x).view(B,T,self.n_head,self.head_size).transpose(1,2) # (B,n_head,T,head_size)
    k=self.key(x).view(B,T,self.n_head,self.head_size).transpose(1,2) # (B,n_head,T,head_size)
    v=self.value(x).view(B,T,self.n_head,self.head_size).transpose(1,2) # (B,n_head,T,head_size)

    # Compute old_len before concatentation
    old_len=cache['k'].shape[2] if (cache is not None and 'k' in cache) else 0
    #Apply RoPE
    q,k=self.apply_rope(q,k,start_pos=old_len)

    #Append last kv-cache if they exist

    if cache is not None and 'k' in cache and 'v' in cache:
        old_k=cache['k'].to(device=k.device,dtype=k.dtype)
        old_v=cache['v'].to(device=v.device,dtype=v.dtype)
        k=torch.cat([old_k,k],dim=2) #concatenate along sequence dimention
        v=torch.cat([old_v,v],dim=2)
    #update the cache
    cache={'k':k,'v':v}
    scale_factor = 1.0 /math.sqrt(self.head_size)
    S=k.shape[2] # full key length (old+new)

    use_flash=(hasattr(F, 'scaled_dot_product_attention') and
                 x.is_cuda and
                 x.dtype in (torch.float16, torch.float32, torch.bfloat16) and
                  attn_mask is None and use_flash_attn)
    if use_flash:
      out = F.scaled_dot_product_attention(q,k,v,is_causal=True,dropout_p=self.attn_dropout.p if self.training else 0.0)
    else:
      # manual attention
      attn_weights=(q @ k.transpose(-2,-1))*scale_factor # (B,n_head,T,S)
      row_start=old_len
      row_end=old_len+T

      # compute causal mask (T x S)
      if S <= self.context_length:
         causal_mask = self.causal_mask[row_start:row_end, :S].to(attn_weights.device)
      else:
         full_mask = torch.tril(torch.ones(self.context_length, S, device=attn_weights.device))
         causal_mask = full_mask[row_start:row_end, :S]

      # causal_mask: (T,S) -> expand to (1,1,T,S)
      causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
      combined_mask=causal_mask
      if attn_mask is not None: # Training padding
        attn_mask_broadcast = attn_mask[:, None, None, :S].to(dtype=causal_mask.dtype,device=attn_weights.device)
        combined_mask = causal_mask * attn_mask_broadcast  # (B,1,T,S)
      attn_weights = attn_weights.float()
      neg_inf = float(torch.finfo(attn_weights.dtype).min / 2)
      attn_weights = attn_weights.masked_fill(combined_mask == 0, neg_inf)
      attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True)
      attn_weights = F.softmax(attn_weights, dim=-1)
      attn_weights=self.attn_dropout(attn_weights) # apply attention dropout
      attn_weights=torch.nan_to_num(attn_weights,nan=0.0)
      out=attn_weights @ v # (B,n_head,T,head_size)
    out=out.transpose(1,2).reshape(B,T,C) # (B,T,n_head*head_size)
    out=self.proj(out)
    out=self.resid_dropout(out) # apply residual dropout

    return out,cache


In [None]:
#RMSNorm
class RMSNorm(nn.Module):
  def __init__(self,n_embed,eps=1e-8):
    super().__init__()
    self.eps=eps
    self.scale=nn.Parameter(torch.ones(n_embed)) #learnable scale
  def forward(self,x):
    # x: (B,T,n_embed) or (B,n_embed)
    rms=x.pow(2).mean(-1,keepdim=True)
    norm=torch.sqrt(rms+self.eps)
    scale=self.scale
    if x.ndim==3:
      scale=scale.view(1,1,-1)
    elif x.ndim==2:
      scale=scale.view(1,-1)
    # return x/(sqrt(sum(x**2))) * scale :  (B,T,n_embed) or (B,n_embed)
    return (x/norm) *scale

#Transformer Block Class
class Block(nn.Module):
  """
  Transformer block:
    - Multi-Head Attention + residual
    - MLP + residual
    - LayerNorm before each module

  This block consists of a multi-head attention layer followed by an MLP,
  with layer normalization and residual connections.
  """
  def __init__(self,n_head,n_embed,context_length,norm_type="pre",rms_norm=False):
    super().__init__()
    self.rms_norm=rms_norm
    self.norm_type=norm_type
    self.n_head=n_head
    self.n_embed=n_embed
    if rms_norm and norm_type != "pre":
            print("[Warning] RMSNorm with post-norm is uncommon and may reduce stability.")

    NormClass=RMSNorm if rms_norm else nn.LayerNorm
    self.ln1=NormClass(n_embed)
    self.attn=MultiHeadAttention(n_head=n_head,n_embed=n_embed,context_length=context_length)
    self.ln2=NormClass(n_embed)
    self.mlp=MLP(n_embed=n_embed,rms_norm=rms_norm)
    #residual scale for RMS Norm with post_norm
    self.res_scale=0.9 if rms_norm and norm_type!='pre' else 1.0

  def forward(self,x,attn_mask=None, cache=None,use_flash_attn=True):
    """
    x: (B, T, C)
    attn_mask: (B, T)  with 1 for valid tokens and 0 for padding
    use_flash_attn: whether to use flash attention if available (not used during inference)
    cache: dict with keys 'k' and 'v' for cached past keys and values
    Returns:
        out: block output
        updated_cache: dict with updated 'k' and 'v'
    """
    if self.norm_type == "pre":
        attn_out,new_cache = self.attn(self.ln1(x),attn_mask=attn_mask,cache=cache,use_flash_attn=use_flash_attn)
        x = x + attn_out
        x = x + self.mlp(self.ln2(x))

    else:
        attn_out,new_cache = self.attn(x,attn_mask=attn_mask,cache=cache,use_flash_attn=use_flash_attn)
        x = self.ln1(x + attn_out*self.res_scale)
        x = self.ln2(x + self.mlp(x)*self.res_scale)
    return x,new_cache


In [None]:
from pydantic.dataclasses import dataclass
from pydantic import Field
from typing import Optional
from rich.table import Table
from rich.console import Console

# MODEL CONFIG
@dataclass
class ModelConfig:
    context_length: int = Field(..., gt=0, description="Maximum context length")
    n_embed: int = Field(..., gt=0, description="Embedding dimension")
    n_head: int = Field(..., gt=0, description="Number of attention heads")
    n_block: int = Field(..., gt=0, description="Number of transformer blocks")
    vocab_size: int = Field(..., gt=1, description="Vocabulary size")
    pad_token_id: int = Field(..., ge=0, description="Padding token index")
    use_zloss: bool = Field(False, description="Enable z-loss regularization")
    zloss_coeff: float = Field(1e-4, ge=0.0, description="Coefficient for z-loss")
    dropout: float = Field(0.1, ge=0.0, le=1.0, description="Dropout probability")
    norm_type: str = Field("pre", description="Normalization type: 'pre' or 'post'")
    rms_norm: bool = Field(False, description="Use RMSNorm instead of LayerNorm")
    model_type : str = Field("custom", description="Model type")
    tie_word_embeddings: bool = Field(True, description="Tie word embeddings and output embeddings")

# INFERENCE CONFIG
@dataclass
class InferenceConfig:
    max_new_tokens: int = Field(...,gt=0,description="Maximum number of new tokens to generate")
    temperature: float = Field(1.0,ge=1e-5,description="Sampling temperature")
    topk: Optional[int] = Field(None,ge=0,description="Top-k sampling")
    topp: Optional[float] = Field(None,ge=0.0,le=1.0,description="Top-p (nucleus) sampling probability")
    frequency_penalty: float = Field(0.0,ge=0.0,description="Penalty for repeated tokens based on frequency")
    presence_penalty: float = Field(0.0,ge=0.0,description="Penalty for repeated tokens based on presence")
    mode: str = Field("combined", description="Sampling mode: 'combined' or 'independent'")
    eos_tokens: list|int|str | None = Field(None,description="End of sequence tokens")
    return_only_generated: bool = Field(True, description="Return only generated text after assistant token")



#Transformer 
class Transformer(nn.Module):
  """
  This class combines token and position embeddings with a sequence of Transformer blocks
  and a final linear layer for language modeling.
  """
  def __init__(self,config:ModelConfig):
    super().__init__()

    # PEFT (LoRA) prerequisites
    cfg_dict = dict(vars(config))
    self.config = PretrainedConfig(**cfg_dict)
    # Model setup
    self.ignore_index=-100
    self.context_length= config.context_length
    self.n_block=config.n_block
    self.n_embed=config.n_embed
    self.n_head=config.n_head
    self.vocab_size=config.vocab_size
    self.pad_token_id=config.pad_token_id
    self.use_zloss=config.use_zloss
    self.zloss_coeff=config.zloss_coeff
    self.dropout_p=config.dropout
    self.norm_type=config.norm_type
    self.rms_norm=config.rms_norm
    NormClass = RMSNorm if self.rms_norm else nn.LayerNorm
    self.token_embed=nn.Embedding(self.vocab_size,self.n_embed)
    self.dropout=nn.Dropout(self.dropout_p)
    self.attn_blocks=nn.ModuleList([Block(self.n_head,self.n_embed,self.context_length,self.norm_type,self.rms_norm) for _ in range(self.n_block)])
    self.layer_norm=NormClass(self.n_embed)
    self.lm_head=nn.Linear(self.n_embed,self.vocab_size,bias=False) #projects back to vocabulary logits for prediction
    self.lm_head.weight=self.token_embed.weight # weight tying ,this will reduce the model size by (n_embed*vocab_size)
    self.apply(self._init_weights)

  # huggingFace helper methods expected by some wrappers
  def get_input_embeddings(self):
      return self.token_embed

  def set_input_embeddings(self, new_embeddings):
      self.token_embed = new_embeddings
      # re-tie lm_head if needed
      if hasattr(self, "lm_head"):
          self.lm_head.weight = self.token_embed.weight

  def get_output_embeddings(self):
      return self.lm_head

  def set_output_embeddings(self, new_output):
      self.lm_head = new_output

  # Initialize module weights
  def _init_weights(self, m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
        # special initialization for projection and MLP layers
        if hasattr(m, 'in_features') and hasattr(m, 'out_features'):
            if any(m is block.attn.proj or m is block.mlp.proj for block in self.attn_blocks):
                nn.init.normal_(m.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.n_block))
    elif isinstance(m, nn.Embedding):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
    elif isinstance(m, nn.LayerNorm):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)



  # Initialize rope-cache
  def _init_rope_cache(self,device=None):
    #to make sure that the rope_cache are on the same device of the Model
    for block in self.attn_blocks:
        block.attn._build_rope_cache(device=device)

  # Compute model weights and memory size
  def _get_size(self):
      from rich.table import Table
      from rich.console import Console
      # compute the number of parameters of the model
      num_params=sum(p.numel() for p in self.parameters())

      # compute the memory size of the model
      dtype_sizes=defaultdict(int)
      for p in self.parameters():
        dtype_sizes[str(p.dtype)]+=p.numel()*p.element_size()
      for b in self.buffers():
        dtype_sizes[str(b.dtype)]+=b.numel()*b.element_size()
      total_size_mb=sum(dtype_sizes.values())/(1024**2)
      #create rich table
      table=Table(title='Model Memory Summary',show_lines=True)
      table.add_column("Data Type",justify='center',style='bold yellow',no_wrap=True)
      table.add_column("Memory (MB)",justify='right',style='green')
      table.add_column("Share (%)",justify='right',style='green')
      for dtype,size in dtype_sizes.items():
        size_mb=size/(1024**2)
        share=(size_mb/total_size_mb)*100
        table.add_row(dtype,f"{size_mb:.2f}",f"{share:.2f}%")
      console=Console()
      console.print(f"[bold yellow]Number of parameters:[/bold yellow] {num_params:,}")
      console.print(f"[bold yellow]Total memory usage:[/bold yellow] {total_size_mb:.2f} MB")
      console.print(table)
  # Apply token embedding with dropout
  def _pre_attn_pass(self,idx):
    """
    Combines token and position embeddings
    idx:Input token indices (B,T)

    """
    B,T=idx.shape
    token_emb=self.token_embed(idx)
    x=self.dropout(token_emb) # (B , T , n_embed)
    return x

  # Save the Model
  @staticmethod

  def save(model, path, config: ModelConfig = None, extra_dict: dict = None):
        state_dict = (
            model.module.state_dict()
            if isinstance(model, nn.DataParallel)
            else model.state_dict()
        )
    
        save_dict = {
            "state_dict": state_dict
        }
    
        if config is not None:
            save_dict["config_dict"] = config.__dict__
    
        if extra_dict is not None:
            save_dict["extra"] = extra_dict     # safer
    
        torch.save(save_dict, path)
        print(f"Model saved to {path}")

  # Load the Model
  @classmethod

  def load(cls, path, device=None, strict=True):
        """
        Universal load: works with any saved checkpoint (clean or dirty).
        Auto-fixes _orig_mod. prefixes, works with weights_only=True, gives clear errors.
        """
        import torch
        
        device = device or 'cpu'
        
        try:
           
            saved = torch.load(path, map_location=device, weights_only=False)
        except Exception as e:
            if "weights_only" in str(e) or "UnpicklingError" in str(e):
                raise RuntimeError(
                    f"Failed to load {path}\n"
                    "-> This file is either corrupted (e.g. HTML from Google Drive) or saved with complex objects.\n"
                    "-> Run: !ls -lh '{path}' and !head -c 200 '{path}'\n"
                    "-> Real .pth files are ~195MB and show binary garbage."
                ) from e
            raise
    
        # Extract config
        config_dict = saved.get("config_dict") or saved.get("config", None)
        if config_dict is None:
            raise ValueError("No model config found in checkpoint!")
    
        config = ModelConfig(**config_dict)
        model = cls(config).to(device)
    
        # Extract state_dict (could be directly saved or nested)
        state_dict = saved.get("state_dict", saved if isinstance(saved, dict) else None)
        if state_dict is None:
            raise ValueError("No state_dict found in checkpoint!")
    
        # Auto-clean _orig_mod. prefixes if found
        if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
            print("Detected _orig_mod. prefixes → auto-cleaning state_dict...")
            new_sd = {}
            for k, v in state_dict.items():
                new_k = k[len("_orig_mod."):] if k.startswith("_orig_mod.") else k
                new_sd[new_k] = v
            state_dict = new_sd
    
        # Load weights
        model.load_state_dict(state_dict, strict=strict)
    
        # Extra info
        extra = saved.get("extra", {})
        epoch = extra.get("epoch", "?")
        val_ppl = extra.get("val_ppl", "?")
        if isinstance(val_ppl, (int, float)): val_ppl = f"{val_ppl:.2f}"
    
        print(f"Model loaded successfully!")
        print(f"   Epoch: {epoch} | Val PPL: {val_ppl} | Params: ~51M | Device: {device}")
    
        return model, config, extra
    

  # Override the to() method
  def to(self,*args,**kwargs):
    """
    Overrides nn.Module.to() to ensure that rope cache is initialized
    after the model is moved to a new device.
    """
    model=super().to(*args,**kwargs)
    device = args[0] if len(args) > 0 else kwargs.get('device')
    if device is None and len(args) > 0:
        device = args[0]
    # normalize to torch.device if it's a string
    if isinstance(device, str):
        device = torch.device(device)
    if  device is None:
      try:
        device=next(model.parameters()).device
      except StopIteration:
        device=None
    if device is not None:
      model._init_rope_cache(device=device)
      if hasattr(self, "lm_head") and hasattr(self, "token_embed"):
            print("The tite done")
            self.lm_head.weight = self.token_embed.weight
    return model

  # Define perplexity as metric
  def perplexity(self, dataloader, device):
    was_training=self.training
    self.eval()
    total_loss = 0.0
    total_tokens = 0.0
    with torch.inference_mode():
        for xb, yb, attn_mask in tqdm(dataloader, desc="Computing Perplexity"):
            xb, yb, attn_mask = xb.to(device), yb.to(device), attn_mask.to(device)
            logits, loss , _ = self(input_ids=xb, labels=yb, attention_mask=attn_mask,use_flash_attn=False)
            if loss is not None and not torch.isnan(loss):
                valid_tokens= int((yb != self.ignore_index).sum().item())
                total_loss += loss.item() * valid_tokens
                total_tokens += valid_tokens
    avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
    ppl = math.exp(avg_loss)
    if was_training:
      self.train() #restore the original one
    return ppl

  # Apply forward propagation
  def forward(self, input_ids=None, inputs_embeds=None, labels=None,attention_mask=None,past_key_values=None,use_flash_attn=True,** kwargs):
      # alias target labels to targets
      if 'targets' in kwargs and labels is None:
          labels = kwargs.pop('targets')
  
      if inputs_embeds is not None:
        x = inputs_embeds
        B, T = x.shape[:2]
      else:
        if input_ids is None:
            raise ValueError("Either input_ids or inputs_embeds mst be provided")
        x = self._pre_attn_pass(input_ids)
        B, T = input_ids.shape

      if T>self.context_length:
        raise ValueError(f"Input sequence length {T} exceeds context length {self.context_length};ensure tokenization truncates.")


      if past_key_values is not None:

        assert isinstance(past_key_values,(list,tuple)) and len(past_key_values)==len(self.attn_blocks),"past_cache must be a list/tuple with one dict per block"
        #Trim cache across layers if total seq exceed

        for i,layer_cache in enumerate(past_key_values):
          if 'k' not in layer_cache or 'v' not in layer_cache:
            continue
          old_len=layer_cache['k'].shape[2]
          total_len=old_len+T
          if total_len>self.context_length:
            trim_len=total_len-self.context_length
            keep=max(0,old_len-trim_len)
            #slice to keep the last position
            if keep==0:
              layer_cache['k']=layer_cache['k'][:,:,0:0,:]
              layer_cache['v']=layer_cache['v'][:,:,0:0,:]
            else:
              layer_cache['k']=layer_cache['k'][:,:,-keep:,:]
              layer_cache['v']=layer_cache['v'][:,:,-keep:,:]
            import warnings
            warnings.warn(
                        f"Warning: Trimmed {trim_len} positions from layer {i+1} cache to keep context_length={self.context_length}. "
                        "RoPE absolute positions will shift accordingly (trimming older positions)."
                    )

      new_cache=[]
      for i,block in enumerate(self.attn_blocks):
          cache=past_key_values[i] if past_key_values is not None else None
          x,cache=block(x,attention_mask,cache,use_flash_attn=use_flash_attn)
          new_cache.append(cache)
      x = self.layer_norm(x)
      logits = self.lm_head(x) # B x T x vocab_size
      loss = None
      if labels is not None:
          B, T, C = logits.shape
          flat_logits =logits.reshape(-1,C) #logits.view(B * T, C)
          labels = labels.view(B * T).long()
          valid_count = (labels != self.ignore_index).sum().item()
          if valid_count > 0:
              ce_loss = F.cross_entropy(flat_logits, labels, ignore_index=self.ignore_index)
              loss = ce_loss
              if self.use_zloss:

                mask = (labels != self.ignore_index)  # (B*T,)
                if mask.any():
                    lse = torch.logsumexp(flat_logits, dim=-1)  # (B*T,)
                    z_reg = (lse[mask] ** 2).mean()             # mean squared over valid tokens
                    loss += self.zloss_coeff * z_reg

          else:
              print("Warning: No valid targets in batch, skipping loss computation")
      return logits, loss,new_cache

  # Generate tokens
  @torch.inference_mode()
  def generate(self, input_ids, config : InferenceConfig,tokenizer=None):
      """
      Generate tokens auto-regressively.

      Args:
          idx: torch.LongTensor (B, T) input token indices
          max_new_tokens: int, number of new tokens to generate
          temperature: float, softmax temperature
          topk: int, keep top-k tokens
          topp: float, cumulative probability for nucleus sampling
          frequency_penalty: float, frequency penalty
          presence_penalty: float, presence penalty
          mode: "combined" or "independent"
              - combined: topp is applied on topk-masked logits
              - independent: topp is applied on original logits regardless of topk
          tokenizer: transformers.PreTrainedTokenizer, tokenizer for decoding tokens

      """
      # Helper functions for handling eos-tokens
      def normalize_eos_tokens(eos_tokens,tokenizer):
            """
            Normalize eos_tokens into a list of token-id sequences (list[list[int]]).
            Returns:
            [
                [50256],                # single-token EOS
                [198,198,198],          # multi-token EOS ("\n\n\n")
                [10009]                 # special chat eos
            ]
            """
            if eos_tokens is None or tokenizer is None:
              return []
            if isinstance(eos_tokens,(int,str)):
              eos_tokens=[eos_tokens]
            normalized=[]
            for item in eos_tokens:
              if isinstance(item,int):
                normalized.append([item]) # single token ids
                continue

              if isinstance(item,str):
                ids=tokenizer.encode(item,add_special_tokens=False)
                normalized.append(ids)
                continue

              raise ValueError("eos_tokens must contain ints or strings")
            return normalized

      def ends_with_pattern(sequence,pattern):
        if len(sequence)<len(pattern):
          return False
        return sequence[-len(pattern):]==pattern

      # Helper function to apply frequence and presence penalty to the output logits
      def apply_penalties(logits, input_ids, frequency_penalty=0.0, presence_penalty=0.0):
            """
            Vectorized version — applies frequency and presence penalties to logits.
            Args:
                logits: (B, vocab_size)
                idx: (B, T)
            """
            if frequency_penalty==0.0 and presence_penalty==0.0:
              return logits
            B,vocab_size=logits.shape
            counts=torch.zeros(B,vocab_size,device=logits.device) # B x vocab_size
            counts.scatter_add_(dim=1,index=input_ids,src=torch.ones_like(input_ids,dtype=torch.float,device=logits.device))
            logits-=frequency_penalty*counts
            logits-=presence_penalty*(counts>0).float()
            return logits
      # unpack config
      max_new_tokens = config.max_new_tokens
      if max_new_tokens==0:
          return input_ids,0.0
      temperature = config.temperature if config.temperature is not None and config.temperature >0.0 else 1.0
      topk = config.topk
      topp = config.topp
      frequency_penalty = config.frequency_penalty
      presence_penalty = config.presence_penalty
      mode = config.mode
      # normalize eos_tokens
      eos_patterns=normalize_eos_tokens(config.eos_tokens,tokenizer)
      #setup
      idx = input_ids.long()
      device=next(self.parameters()).device
      idx=idx.to(device)
      B,prompt_len=idx.shape
      #compute effective prompt length (handle legacy padding)
      effective_prompt_len=prompt_len
      if self.pad_token_id is not None:
        #find last non padding position
        last_non_pad_per_batch = (idx != self.pad_token_id).sum(dim=-1)
        effective_prompt_len = int(last_non_pad_per_batch.min().item())
      #soft check
      if effective_prompt_len>self.context_length:
        import warnings
        warnings.warn(f"Effective prompt length {effective_prompt_len} exceeds context length {self.context_length}; truncating.")
        idx=idx[:,-self.context_length:]
        effective_prompt_len=min(effective_prompt_len,self.context_length)
      # Initial forward pass on full prompt : exact RoPE ,no cache
      with torch.amp.autocast(device_type="cuda", enabled=False):
          logits,_,past_key_values=self(input_ids=idx,labels=None,attention_mask=None,past_key_values=None,use_flash_attn=False)

      # Cap to remaining context
      remaining_ctx = self.context_length-effective_prompt_len
      max_gen = min(max_new_tokens, remaining_ctx)
      if max_gen<=0:
        import warnings
        warnings.warn("No room in context; returning prompt unchanged.")
        return idx,0.0

      # start time
      import time
      start_time=time.time()
      ttft=None

      for step in range(max_gen):
          logits = logits.float()
          logits = logits[:, -1, :] / temperature # B x vocab_size
          logits = apply_penalties(logits, idx, frequency_penalty, presence_penalty)
          logits = torch.nan_to_num(logits, nan=0.0, neginf=-1e9, posinf=1e9)

          # Top-k filtering
          if topk is not None:
              topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1) # B x topk
              mask = torch.full_like(logits, float('-inf')) # B x vocab_size
              mask.scatter_(-1, topk_indices, topk_logits)
              logits_topk = mask
          else:
              logits_topk = logits

          # Top-p sampling
          if topp is not None:
              probs = F.softmax(logits_topk if mode == 'combined' else logits, dim=-1) # B x vocab_size
              probs = torch.nan_to_num(probs, nan=0.0)
              sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) # B x vocab_size
              cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
              sorted_probs[cumulative_probs > topp] = 0
              sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp(min=1e-9)
              sorted_probs = torch.clamp(sorted_probs, min=1e-9)
              idx_next = torch.multinomial(sorted_probs, num_samples=1) # B x 1
              idx_next = sorted_indices.gather(dim=-1, index=idx_next)
          else:
              probs = F.softmax(logits_topk, dim=-1)
              probs = torch.nan_to_num(probs, nan=0.0)
              probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-9)
              probs = torch.clamp(probs, min=1e-9)
              idx_next = torch.multinomial(probs, num_samples=1) # B x 1
          # append token to the sequence
          idx = torch.cat((idx, idx_next), dim=-1)
          if ttft is None:
            ttft=time.time()-start_time
          #check eos_tokens : only on generated tokens
          if eos_patterns:

            hit=False
            gen_suffix=idx[:,effective_prompt_len:].tolist()
            for seq in gen_suffix:
              for pattern in eos_patterns:
                if ends_with_pattern(seq,pattern):
                  hit=True
                  break
              if hit:
                break
            if hit:
              import warnings
              warnings.warn(f"EOS pattern detected in generated tokens after {step+1} new tokens; early stop.")
              break
          #continue with kv-cache
          with torch.amp.autocast(device_type="cuda", enabled=False):
              logits,_,past_key_values=self(input_ids=idx[:,-1:].to(device),labels=None,attention_mask=None,past_key_values=past_key_values,use_flash_attn=False)

      return idx,ttft
  def prepare_inputs_for_generation(self, input_ids,past_key_values,attention_mask, **kwargs):
      # PEFT calls this during .generate()
      """
      Returns a dict that huggingface/PEFT expects when preparing next-step generation.
      """
      return {
          "input_ids": input_ids,
          "past_key_values": past_key_values,
          "attention_mask": attention_mask
      }
    



# Generate Text Function (Inference Mode)

In [96]:
from pydantic.dataclasses import dataclass
from pydantic import Field
from typing import Optional
from rich.table import Table
from rich.console import Console


def generate_text(model, tokenizer, prompt, config: InferenceConfig,is_sft=False):
    import time
    from rich.console import Console
    console = Console()
    start_time = time.time() # timer

    # prepare the prompt
    if isinstance(prompt, str):
        prompt = [prompt]
    if is_sft:
        user_token = '<|user|>'
        assistant_token = '<|assistant|>'
        prompt = [f"{user_token} {p.strip()} {assistant_token}" for p in prompt]

    # tokenize the prompt
    encoded = tokenizer(
        prompt,
        truncation=True,
        max_length=model.context_length,
        padding=False,
        return_tensors='pt'
    )

    device = next(model.parameters()).device
    input_ids = encoded['input_ids'].to(device)

    was_training = model.training
    model.eval()

    # generate

    output_ids,ttft=model.generate(input_ids=input_ids, config=config, tokenizer=tokenizer)

    
    ttft=0.0 if ttft is None else ttft*1000 # in ms


    total_time = time.time() - start_time
    prompt_tokens = input_ids.numel()
    output_tokens = output_ids.numel()
    generated_tokens = output_tokens - prompt_tokens
    gen_tps = generated_tokens / total_time
    total_tps = output_tokens / total_time

    console.print(f"[bold yellow]Prompt Tokens:[/bold yellow] {prompt_tokens}")
    console.print(f"[bold yellow]Generated Tokens:[/bold yellow] {generated_tokens}")
    console.print(f"[bold yellow]Total Time:[/bold yellow] {total_time:.2f} seconds")
    console.print(f"[bold yellow]Generated Tokens/sec:[/bold yellow] {gen_tps:.2f}")
    console.print(f"[bold yellow]Total Tokens/sec (prompt+output):[/bold yellow] {total_tps:.2f}")
    console.print(f"[bold yellow]TTFT:[/bold yellow] {ttft:.2f} ms")

    # decode the output
    output_text=[]
    for i,ids in enumerate(output_ids):
      seq=ids.tolist()
      if config.return_only_generated:
        if is_sft:
              assistant_id=tokenizer.convert_tokens_to_ids(assistant_token)
              if assistant_id in seq:
                idx=seq.index(assistant_id)+1
                seq=seq[idx:]
        else:
              idx=len(input_ids[i])
              seq=seq[idx:]
      decoded=tokenizer.decode(seq, skip_special_tokens=True)
      output_text.append(decoded)


    if len(prompt) == 1:
        output_text = output_text[0]

    if was_training:
        model.train()

    return output_text



# PreTraining

---



## Load PreTrainDataset

In [45]:
BATCH_SIZE=32
CONTEXT_LENGTH = 256   # up to 1024 for standard attention; use Sparse Attention beyond that (see Longformer)

In [None]:
# TinyStories TestDataset
# try :
#   from datasets import load_dataset
# except:
#   !pip install datasets --upgrade
# import random
# from datasets import load_dataset
# from tqdm import tqdm
# pretrain_data= load_dataset("roneneldan/TinyStories")
#extract the train and validation dataframe
#train_df=pretrain_data['train'].to_pandas()
#val_df=pretrain_data['validation'].to_pandas()

from datasets import load_dataset
import pandas as pd
# load the pretraining dataset from Hugging Face Hub
dataset = load_dataset(
    "haidar-ali/tallyformer-finance-dataset",
    data_dir="Data/PreTrainData",  
    split="train"
)
df = pd.DataFrame(dataset)
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
data_size = len(df)
train_size = int(0.9 * data_size)
train_df = df[:train_size]
val_df = df[train_size:]



## Prepare Data for online tokenization

### Create PretrainDataset Class

In [47]:
from torch.utils.data import Dataset, DataLoader
import torch
class PretrainDataset(Dataset):
    def __init__(self, data, tokenizer, context_length):
        self.data = data
        self.tokenizer = tokenizer
        self.context_length = context_length
        self.ignore_index = -100
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if isinstance(self.data,pd.DataFrame):
          text=self.data.iloc[idx]['text']
        else:
          text = self.data[idx]["text"]
        encoded = self.tokenizer(
            text,
            truncation=True,
            max_length=self.context_length,
            padding="max_length",
            return_tensors='pt'
        )
        input_ids=encoded['input_ids'].squeeze(0).cpu()
        attn_mask=encoded['attention_mask'].squeeze(0).cpu()
        xb = input_ids[:-1]
        yb = input_ids[1:].clone()
        attn_mask = attn_mask[:-1]
        # mask padding
        yb[attn_mask == 0] = self.ignore_index
        return xb, yb,attn_mask



### Define Datalaoders

In [None]:
len_train=int(len(train_df)*0.9)
len_val=int(len(val_df)*0.75)

train_df=train_df.iloc[:len_train]
val_df=val_df.iloc[:len_val]

# Dataset
train_dataset = PretrainDataset(train_df, tokenizer, CONTEXT_LENGTH)
val_dataset = PretrainDataset(val_df, tokenizer, CONTEXT_LENGTH)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,pin_memory=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True,pin_memory=True, num_workers=0)

In [49]:
len(val_loader)

21747

## Prepare Data for offline tokenization
### Make sure you have enough disk storage ~ 9.context_length.(len(train_df)+len(val_df))/1e6 GB.
#### Otherwise your instance might start sweating, shaking, and reconsidering its life choices. 

### Run this once to create the .bin files

In [None]:

import os
import numpy as np
import torch
from tqdm import tqdm

def save_precomputed_dataset(df_texts, filename_prefix, tokenizer, context_length=256, batch_size=1024):
    os.makedirs("precomputed", exist_ok=True)
    xb_path = f"precomputed/{filename_prefix}_xb.bin"
    yb_path = f"precomputed/{filename_prefix}_yb.bin"
    mask_path = f"precomputed/{filename_prefix}_mask.bin"

    if os.path.exists(xb_path):
        print(f"Already exists: {xb_path}")
        return xb_path, yb_path, mask_path

    print(f"Precomputing {filename_prefix} dataset ({len(df_texts)} samples)...")
    n = len(df_texts)
    seq_len = context_length

    # Preallocate memmap arrays
    xb_arr = np.memmap(xb_path, dtype=np.int32, mode='w+', shape=(n, seq_len-1))
    yb_arr = np.memmap(yb_path, dtype=np.int32, mode='w+', shape=(n, seq_len-1))
    mask_arr = np.memmap(mask_path, dtype=np.uint8, mode='w+', shape=(n, seq_len-1))

    ignore_index = -100

    # Batch tokenization
    for start in tqdm(range(0, n, batch_size), desc=f"Tokenizing {filename_prefix}"):
        end = min(start + batch_size, n)
        batch_texts = df_texts.iloc[start:end]["text"].tolist()
        
        enc = tokenizer(
            batch_texts,
            truncation=True,
            max_length=seq_len,
            padding="max_length",
            return_tensors="pt",
            return_attention_mask=True
        )
        input_ids = enc["input_ids"]
        attn_mask = enc["attention_mask"]

        # Split xb, yb, and apply ignore_index mask
        xb_arr[start:end] = input_ids[:, :-1].numpy()
        yb_batch = input_ids[:, 1:].clone()
        yb_batch[attn_mask[:, :-1] == 0] = ignore_index
        yb_arr[start:end] = yb_batch.numpy()
        mask_arr[start:end] = attn_mask[:, :-1].numpy()

    # Flush to disk
    xb_arr.flush()
    yb_arr.flush()
    mask_arr.flush()

    print(f"Saved: {xb_path}")
    return xb_path, yb_path, mask_path


# run
train_xb, train_yb, train_mask = save_precomputed_dataset(
    train_df, "train", tokenizer, context_length=CONTEXT_LENGTH, batch_size=1024
)

val_xb, val_yb, val_mask = save_precomputed_dataset(
    val_df, "val", tokenizer, context_length=CONTEXT_LENGTH, batch_size=1024
)



### Define DataLoaders

In [None]:
# Faster DataLodaer
from torch.utils.data import Subset , RandomSampler ,get_worker_info
import torch
torch.multiprocessing.set_start_method("spawn", force=True)
total_len=9278584
len_train=int((total_len*0.9)*0.9) # 1.2 B Tokens
len_val=int((total_len*0.1)*0.75)
class FastPrecomputedDataset(Dataset):
    def __init__(self, prefix):
        
        
        self.xb = np.memmap(f"precomputed/{prefix}_xb.bin", dtype=np.int32, mode='r',
                            shape=(len_train if prefix == "train" else len_val, CONTEXT_LENGTH-1))
        self.yb = np.memmap(f"precomputed/{prefix}_yb.bin", dtype=np.int32, mode='r',
                              shape=(len_train if prefix == "train" else len_val, CONTEXT_LENGTH-1))
        self.mask = np.memmap(f"precomputed/{prefix}_mask.bin", dtype=np.uint8, mode='r',
                                 shape=(len_train if prefix == "train" else len_val, CONTEXT_LENGTH-1))

    def __len__(self):
        return len(self.xb)

    def __getitem__(self, idx):
        xb = torch.from_numpy(self.xb[idx].astype(np.int64))
        yb = torch.from_numpy(self.yb[idx].astype(np.int64))
        mask = torch.from_numpy(self.mask[idx].astype(np.int64))
        return xb, yb, mask


BATCH_SIZE = 32
train_dataset = FastPrecomputedDataset("train")
val_dataset   = FastPrecomputedDataset("val")


train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    sampler=RandomSampler(train_dataset),  
    num_workers=0,
    pin_memory=False,           
    persistent_workers=False
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=False,           
    persistent_workers=False
)


## Training Stage
### Training Loss

TallyFormer is trained using standard cross-entropy loss with an additional z-loss regularization term.

#### Total Loss

$$
\mathcal{L}
=
\mathcal{L}_{\text{CE}}
+
\lambda_{\text{z}} \cdot \mathcal{L}_{\text{z}}
$$

#### z-Loss

$$
\mathcal{L}_{\text{z}}
=
\mathbb{E}
\left[
\left(
\log \sum_{j=1}^{V} \exp(z_j)
\right)^2
\right]
$$

Loss is computed only on valid (non-masked) tokens.

#### Reference

Chowdhery et al., *PaLM: Scaling Language Modeling with Pathways*  
https://arxiv.org/abs/2204.02311


In [12]:
!rm -rf /tmp/torchinductor_root
import torch._dynamo
torch._dynamo.config.suppress_errors=True
torch.cuda.empty_cache()


In [37]:

#Define Model Hyperparameters
CONTEXT_LENGTH = 256   # up to 1024 for standard attention; use Sparse Attention beyond that (see Longformer paper)
N_EMBED = 512          # typical range 512–768 for larger models
N_HEAD = 8             # typical range 8–12
N_BLOCKS = 8           # typical range 8–12
NORM_TYPE='pre'
RMS_NORM=True
USE_ZLoss=True
model_config = ModelConfig(
    context_length=CONTEXT_LENGTH,
    n_embed=N_EMBED,
    n_head=N_HEAD,
    n_block=N_BLOCKS,
    vocab_size=len(tokenizer),
    pad_token_id=tokenizer.pad_token_id,
    norm_type=NORM_TYPE,
    rms_norm=RMS_NORM,
    use_zloss=USE_ZLoss
)

model=Transformer(model_config).to(device=DEVICE)
model = torch.compile(model)

model._get_size()


In [38]:
import torch
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from itertools import cycle
import math
import os
import psutil
from rich.console import Console
from contextlib import nullcontext

os.environ["TOKENIZERS_PARALLELISM"] = "false"

console = Console()

#GPU monitoring
if torch.cuda.is_available():
    import pynvml
    pynvml.nvmlInit()

def get_gpu_usage(device_idx=0):
    try:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        util = pynvml.nvmlDeviceGetUtilizationRates(handle)
        return util.gpu, mem_info.used / 1024**2
    except:
        return 0.0, 0.0

#fix torch.compile
import torch._dynamo
torch._dynamo.config.capture_scalar_outputs = True

# Training Hyperparameters
N_EPOCHS           = 60                    
GRAD_ACCUM_STEPS       = 6
LEARNING_RATE          = 8e-4
WEIGHT_DECAY           = 0.1
WARMUP_EPOCHS          = 2
MAX_GRAD_NORM           = 1.0
PATIENCE               = 4                        # early stopping

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16'
ptdtype = torch.bfloat16 if DTYPE == 'bfloat16' else torch.float16
if torch.cuda.is_available():
    if torch.cuda.is_bf16_supported(): 
        ptdtype = torch.bfloat16
        use_scaler = False    # IMPORTANT ,donot use scaler for bfloat16
    else:
        ptdtype = torch.float16
        use_scaler = True
else:
    ptdtype = torch.float32
    use_scaler = False

ctx = autocast(device_type='cuda', dtype=ptdtype) if DEVICE == 'cuda' else nullcontext()
scaler = GradScaler(enabled=use_scaler)
print(f"Use Scaler ? {use_scaler}")

#Apply data parallel in case of multiple cuda gpu's
if torch.cuda.is_available():
  num_gpus=torch.cuda.device_count()
  if num_gpus>1:
    console.print(f"[bold green] Using {num_gpus} GPUs with DataParallel [/bold green]")
    model=torch.nn.DataParallel(model)
  else:
    console.print("[bold green]Using single GPU [/bold green]")
else:
  console.print("[bold green] Using CPU [/bold green]")
# Calculate MaxIters
steps_per_epoch = len(train_loader)//GRAD_ACCUM_STEPS
MAX_ITERS = steps_per_epoch * N_EPOCHS

console.print(f"[bold cyan]Training for {N_EPOCHS} epochs → {MAX_ITERS:,} steps "
              f"({steps_per_epoch:,} steps/epoch)[/bold cyan]")
console.print(f"[bold green]Effective batch size: {BATCH_SIZE * GRAD_ACCUM_STEPS}[/bold green]")

infinite_loader = cycle(train_loader)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY ,betas=(0.9, 0.95))
WARMUP_FRACTION = 0.02   # 2%
total_opt_steps = MAX_ITERS // GRAD_ACCUM_STEPS
warmup_steps = int(WARMUP_FRACTION * total_opt_steps)

def lr_lambda(step):
    if step < warmup_steps:
        return float(step + 1) / warmup_steps
    progress = (step - warmup_steps) / (total_opt_steps - warmup_steps)
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)


# Metrics and saving
metrics = {"train_loss": [], "val_perplexity": [], "cpu_usage": [], "gpu_usage": [], "gpu_mem": []}
OUT_DIR = "./PreTrainResult"
os.makedirs(OUT_DIR, exist_ok=True)
BEST_MODEL_PATH = os.path.join(OUT_DIR, "pretrain_tallyformer.pth")

best_val_ppl = float('inf')
patience_counter = 0

# Training Loop
lr = LEARNING_RATE 
pbar = tqdm(total=MAX_ITERS, desc="Training",dynamic_ncols=True, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")

model.train()
iter_num = 0
epoch_train_losses = []
# We want 5 monitoring points per epoch
MONITOR_POINTS = 50
monitor_interval = max(1, steps_per_epoch // MONITOR_POINTS)

epoch_cpu_samples = []
epoch_gpu_samples = []
epoch_gpu_mem_samples = []

for iter_num in range(MAX_ITERS):
    xb, yb, attn_mask = next(infinite_loader)
    xb = xb.to(DEVICE, non_blocking=True)
    yb = yb.to(DEVICE, non_blocking=True)
    attn_mask = attn_mask.to(DEVICE, non_blocking=True)

    with ctx:
        _, loss, _ = model(input_ids=xb, labels=yb, attention_mask=attn_mask)
        loss = loss / GRAD_ACCUM_STEPS
    if use_scaler:
        scaler.scale(loss).backward()
    else:
        loss.backward()
    epoch_train_losses.append(loss.item() * GRAD_ACCUM_STEPS)

    # GPU/CPU monitoring
    if (iter_num % monitor_interval) == 0:
        epoch_cpu_samples.append(psutil.cpu_percent(interval=None))
        if DEVICE == "cuda":
            gpu_u, gpu_m = get_gpu_usage()
            epoch_gpu_samples.append(gpu_u)
            epoch_gpu_mem_samples.append(gpu_m)

    if (iter_num + 1) % GRAD_ACCUM_STEPS == 0 or iter_num == MAX_ITERS - 1:
        if use_scaler:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        opt_step = iter_num // GRAD_ACCUM_STEPS

        scheduler.step()
        lr = scheduler.get_last_lr()[0]

    # Live tqdm update
    current_epoch = (iter_num + 1) / steps_per_epoch
    pbar.set_postfix({
        "Epoch": f"{current_epoch:.2f}/{N_EPOCHS}",
        "Loss": f"{loss.item() * GRAD_ACCUM_STEPS:.4f}",
        "LR": f"{lr:.1e}",
        "GPU": f"{epoch_gpu_samples[-1]:.0f}%" if (DEVICE=="cuda" and epoch_gpu_samples) else "N/A",

    })
    pbar.update(1)

    # End of epoch
    if (iter_num + 1) % steps_per_epoch == 0:
        # Compute epoch metrics
        avg_train_loss = sum(epoch_train_losses) / len(epoch_train_losses)
        if len(epoch_cpu_samples) > 0:
            avg_cpu = sum(epoch_cpu_samples) / len(epoch_cpu_samples)
            avg_gpu = sum(epoch_gpu_samples) / len(epoch_gpu_samples) if epoch_gpu_samples else 0
            avg_gpu_mem = sum(epoch_gpu_mem_samples) / len(epoch_gpu_mem_samples) if epoch_gpu_mem_samples else 0
        else:
            avg_cpu, avg_gpu, avg_gpu_mem = 0, 0, 0
        

        # Validation
        model.eval()
        val_loss = 0.0
        val_tokens = 0
        val_pbar = tqdm(val_loader, desc=f"Validating (Epoch {int(current_epoch)})",
                leave=False, dynamic_ncols=True)
        
        with torch.no_grad():
            for xb_v, yb_v, mask_v in val_pbar:
                xb_v, yb_v, mask_v = xb_v.to(DEVICE), yb_v.to(DEVICE), mask_v.to(DEVICE)
                with ctx:
                    _, loss_v, _ = model(input_ids=xb_v, labels=yb_v, attention_mask=mask_v)
                
            
                valid_tokens = int((yb_v != -100).sum().item())  # Count VALID tokens only
                if valid_tokens > 0:
                    val_loss += loss_v.item() * valid_tokens
                    val_tokens += valid_tokens
                
                val_pbar.set_postfix({
                    "BatchLoss": f"{loss_v.item():.4f}"
                })
        
        avg_val_loss = val_loss / val_tokens if val_tokens > 0 else float('inf')
        val_ppl = math.exp(avg_val_loss)

        # Save metrics
        metrics["train_loss"].append(avg_train_loss)
        metrics["val_perplexity"].append(val_ppl)
        metrics["cpu_usage"].append(avg_cpu)
        metrics["gpu_usage"].append(avg_gpu)
        metrics["gpu_mem"].append(avg_gpu_mem)

        console.print(f"\n[bold green]EPOCH {int(current_epoch)}/{N_EPOCHS} COMPLETE[/bold green]")
        console.print(f"   Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val PPL: {val_ppl:.2f}")
        console.print(f"   CPU: {avg_cpu:.1f}% | GPU: {avg_gpu:.1f}% | GPU Mem: {avg_gpu_mem:.0f} MB | LR: {lr:.1e}")

        # Early stopping & save best
        if val_ppl < best_val_ppl:
            best_val_ppl = val_ppl
            patience_counter = 0
            save_payload = {
                "epoch": int(current_epoch),
                "val_ppl": val_ppl,
                "metrics": metrics
            }
            
            Transformer.save(
                model,
                BEST_MODEL_PATH,
                config=model_config,
                extra_dict=save_payload
            )
            
            console.print(f"[bold yellow]NEW BEST MODEL SAVED! Val PPL: {val_ppl:.2f}[/bold yellow]")

        else:
            patience_counter += 1
            console.print(f"[bold red]No improvement ({patience_counter}/{PATIENCE})[/bold red]")

        if patience_counter >= PATIENCE:
            console.print(f"[bold yellow]Early stopping at epoch {int(current_epoch)}[/bold yellow]")
            break

        # Reset for next epoch
        model.train()
        epoch_train_losses = []
        epoch_cpu_samples = []
        epoch_gpu_samples = []
        epoch_gpu_mem_samples = []
        train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    sampler=RandomSampler(train_dataset),  # reshuffled each epoch
    num_workers=0,
    pin_memory=False,
    persistent_workers=False
)
        infinite_loader = cycle(train_loader)
pbar.close()
console.print(f"[bold green]Training finished! Best Val PPL: {best_val_ppl:.2f}[/bold green]")

Use Scaler ? False


Training:   0%|          | 0/2348640 [00:00<?, ?it/s]

Validating (Epoch 1):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validating (Epoch 2):   0%|          | 0/21747 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 3):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 4):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 5):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 6):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 7):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 8):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 9):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 10):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 11):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 12):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 13):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 14):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 15):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 16):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 17):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 18):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 19):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 20):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 21):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validating (Epoch 22):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validating (Epoch 23):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 24):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 25):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 26):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 27):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 28):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 29):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 30):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 31):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 32):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 33):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 34):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 35):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 36):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 37):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 38):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 39):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 40):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 41):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 42):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 43):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 44):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 45):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 46):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 47):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 48):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 49):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 50):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validating (Epoch 51):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 52):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


Validating (Epoch 53):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validating (Epoch 54):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validating (Epoch 55):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter serve

Validating (Epoch 57):   0%|          | 0/21747 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validating (Epoch 58):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validating (Epoch 59):   0%|          | 0/21747 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validating (Epoch 60):   0%|          | 0/21747 [00:00<?, ?it/s]

Model saved to ./PreTrainResult/pretrain_tallyformer.pth


In [30]:
# Check the final perplexity
model_path="./PreTrainResult/pretrain_tallyformer.pth"
model,model_config,metrics=Transformer.load(model_path,device='cuda')
model.perplexity(val_loader,'cuda')

The tite done
Detected _orig_mod. prefixes → auto-cleaning state_dict...
Model loaded successfully!
   Epoch: 60 | Val PPL: 38.97 | Params: ~51M | Device: cuda


Computing Perplexity:   0%|          | 0/21747 [00:00<?, ?it/s]

38.562077264495024

## Inference Stage

### Generate Text

In [99]:
inference_config=InferenceConfig(max_new_tokens=100, temperature=0.7, topk=500, topp=0.9, frequency_penalty=0.0, presence_penalty=0.0)
prompt = "One day in the forest ,  "
model_path="./PreTrainResult/pretrain_tallyformer.pth"
model,model_config,metrics=Transformer.load(model_path,device='cuda')

generated = generate_text(model, tokenizer, prompt, inference_config)
print("Generated:", generated)

The tite done
Detected _orig_mod. prefixes → auto-cleaning state_dict...
Model loaded successfully!
   Epoch: 60 | Val PPL: 38.97 | Params: ~51M | Device: cuda


Generated: __________ A tripurate and wint woven into the land of the mountain . A time of the day , a day of the day , a dayHurvernight . A dayutions , a keyword , deprive of meaning . A day ofimpression , Spawning a feeling of guilt , a feeling®ying .ened with the hope of a greater happiness . The day of the cohesive , a senseHAMMOND , the cause of the earth ' enzwwl , the causeibly of antiparas


### estimate the Benchmark inference on the current instance

In [None]:

import time
import torch
import matplotlib.pyplot as plt
from rich.console import Console
from rich.table import Table
import numpy as np
from typing import List, Tuple
import psutil  # For CPU util
import pynvml  # For GPU util (init if cuda)

console = Console()

MODEL_PATH = "PreTrainResult/pretrain_tallyformer.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"    # Change here
N_RUNS_PER_LENGTH = 10
MAX_NEW_TOKENS = 256
INFERENCE_CONFIG = InferenceConfig(
    max_new_tokens=MAX_NEW_TOKENS,
    temperature=0.8,
    topk=100,
    topp=0.95,
    frequency_penalty=0.0,
    presence_penalty=1.1,
    eos_tokens=None
)

# Prompt lengths to test (in tokens)
PROMPT_LENGTHS =[4,8,16,32, 64, 96, 128, 160, 192, 224, 250]

# Testing Prompts
BASE_PROMPTS = [
 
    "The Board of Directors has approved a new share repurchase program of up to",
    "According to the latest quarterly earnings report, revenue increased by 18% year-over-year to",
    "The Federal Reserve announced today that it will maintain interest rates at",
    "Due to rising inflation and supply chain disruptions, the company has decided to",
    "The merger between Company A and Company B is expected to close in the fourth quarter of",
    "Net income for the fiscal year 2024 reached a record high of $2.8 billion, representing",
    "Recent advances in large language models have demonstrated remarkable capabilities in",
    "Researchers at Stanford University have developed a new algorithm that can detect",
    "The deployment of 5G networks across major cities has enabled real-time applications such as",
    "Quantum computing represents a paradigm shift in computational power, potentially solving problems that",
    "A new study published in Nature suggests that climate change may accelerate beyond previous projections due to",
    "In accordance with Section 404 of the Sarbanes-Oxley Act, management has concluded that",
    "The European Union's General Data Protection Regulation (GDPR) requires companies to",
    "The Securities and Exchange Commission has issued new guidance regarding disclosure of",
    "Pursuant to the terms of the agreement dated March 15, 2024, the parties agree that",
    "The experimental results indicate a statistically significant improvement (p < 0.001) in accuracy when using",
    "Figure 3 shows the relationship between input size and inference latency on GPU hardware, where we observe",
    "Previous work by Johnson et al. (2023) proposed a similar architecture, however our method achieves",
    "The proposed transformer-based model was trained on 1.2 billion tokens from diverse sources including",
    "Breaking: The government has just announced a major infrastructure investment package worth",
    "Sources close to the matter confirm that negotiations between the two parties are ongoing and expected to",
    "Market analysts predict that the price of crude oil will remain above $80 per barrel through",
    "The World Health Organization has declared a new public health emergency following reports of",
    "def calculate_returns(prices: List[float]) -> List[float]:\n    \"\"\"Compute daily returns from price series.\"\"\"\n    returns = []",
    "import torch\nimport torch.nn as nn\nclass MultiHeadAttention(nn.Module):\n    def __init__(self, d_model, num_heads):",
    "The following SQL query retrieves all transactions from the past 30 days where amount > 10000:",
    "Artificial intelligence (AI) is intelligence demonstrated by machines, as opposed to natural intelligence displayed by humans",
    "The Industrial Revolution, which took place from the 18th to 19th centuries, was a period during which predominantly agrarian societies",
    "In economics, inflation refers to a general increase in prices and fall in the purchasing value of money",
    "Machine learning is a field of inquiry devoted to understanding and building methods that learn from data",
    "This Agreement shall be governed by and construed in accordance with the laws of the State of Delaware",
    "Neither party shall be liable for any failure or delay in performance caused by circumstances beyond its reasonable control",
    "The Seller hereby warrants that the goods delivered shall be free from defects in material and workmanship for a period of",
    "Over the past decade, renewable energy sources such as solar and wind have become increasingly cost-competitive with",
    "The concept of universal basic income has gained traction among policymakers as a potential solution to",
    "Remote work has transformed the modern workplace, offering employees greater flexibility while presenting new challenges for",
    "Supply chain resilience has become a top priority for global corporations following disruptions caused by",
]


def get_prompt_of_length(target_len: int) -> str:
    """Repeat or truncate base prompts to reach target token count"""
    while True:
        prompt = " ".join(np.random.choice(BASE_PROMPTS, size=np.random.randint(3, 15)))
        tokens = len(tokenizer.encode(prompt))
        if tokens >= target_len:
            encoded = tokenizer.encode(prompt)[:target_len]
            return tokenizer.decode(encoded, skip_special_tokens=True)

def get_utilization(device_idx=0):
    """Get utilization (%) - GPU or CPU based on DEVICE"""
    if DEVICE == "cuda":
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
        util = pynvml.nvmlDeviceGetUtilizationRates(handle)
        return util.gpu 
    else:
        return psutil.cpu_percent(interval=None)  

def benchmark_once(prompt: str) -> Tuple[float, float, float, float, float]:
    """Returns: ttft_ms, gen_tps, total_tps, total_time, avg_util"""
    start_time = time.time()
    
    encoded = tokenizer(prompt, return_tensors='pt', truncation=False)
    input_ids = encoded['input_ids'].to(DEVICE)

    util_samples = []  # Sample util during generation
    
    model.eval()
    with torch.no_grad():
        # Sample util before/after for average (or more during if possible)
        util_samples.append(get_utilization())
        
        output_ids, ttft_sec = model.generate(
            input_ids=input_ids,
            config=INFERENCE_CONFIG,
            tokenizer=tokenizer
        )
        
        util_samples.append(get_utilization())  # After

    ttft_ms = 0.0 if ttft_sec is None else ttft_sec * 1000
    
    total_time = time.time() - start_time
    prompt_tokens = input_ids.shape[1]
    total_tokens = output_ids.shape[1]
    generated_tokens = total_tokens - prompt_tokens

    gen_tps = generated_tokens / total_time if total_time > 0 else 0
    total_tps = total_tokens / total_time if total_time > 0 else 0
    
    avg_util = np.mean(util_samples) if util_samples else 0.0

    return ttft_ms, gen_tps, total_tps, total_time, avg_util

# initiate for gpu
if DEVICE == "cuda":
    pynvml.nvmlInit()

console.print(f"[bold blue]Loading model on {DEVICE}...[/bold blue]")
model, config, info = Transformer.load(MODEL_PATH, device=DEVICE)
console.print(f"Model loaded → Epoch {info.get('epoch', '?')} | Val PPL {info.get('val_ppl', '?'):.2f}")

#Run Benchmarks
results = {
    "length": [],
    "ttft_mean": [], "ttft_std": [],
    "gen_tps_mean": [], "gen_tps_std": [],
    "total_tps_mean": [], "total_tps_std": [],
    "util_mean": [], "util_std": [] 
}

console.print(f"[bold green]Starting benchmark: {N_RUNS_PER_LENGTH} runs × {len(PROMPT_LENGTHS)} lengths[/bold green]")

from tqdm import tqdm

for length in PROMPT_LENGTHS:
    console.print(f"\n[bold cyan]→ Benchmarking prompt length: {length} tokens[/bold cyan]")
    
    ttfts, gen_tps_list, total_tps_list, utils = [], [], [], []
    
    for i in tqdm(range(N_RUNS_PER_LENGTH), 
                  desc="Running", 
                  leave=False,
                  bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]"):
        
        prompt = get_prompt_of_length(length)
        ttft, gen_tps, total_tps, _, util = benchmark_once(prompt)
        
        ttfts.append(ttft)
        gen_tps_list.append(gen_tps)
        total_tps_list.append(total_tps)
        utils.append(util)
    
    # Store averages
    results["length"].append(length)
    results["ttft_mean"].append(np.mean(ttfts))
    results["ttft_std"].append(np.std(ttfts))
    results["gen_tps_mean"].append(np.mean(gen_tps_list))
    results["gen_tps_std"].append(np.std(gen_tps_list))
    results["total_tps_mean"].append(np.mean(total_tps_list))
    results["total_tps_std"].append(np.std(total_tps_list))
    results["util_mean"].append(np.mean(utils))
    results["util_std"].append(np.std(utils))

# Plot the results
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12)) 

x = results["length"]

ax1.errorbar(x, results["ttft_mean"], yerr=results["ttft_std"], 
             marker='o', capsize=5, label="TTFT (ms)", color="tab:blue")
ax1.set_ylabel("Time to First Token (ms)")
ax1.set_xlabel("Prompt Length (tokens)")
ax1.set_title(f"51M Model Inference Benchmark (device={DEVICE}) - {N_RUNS_PER_LENGTH} runs avg")
ax1.grid(True, alpha=0.3)
ax1.legend()

ax2.errorbar(x, results["gen_tps_mean"], yerr=results["gen_tps_std"], 
             marker='s', capsize=5, label="Generation TPS", color="tab:green")
ax2.errorbar(x, results["total_tps_mean"], yerr=results["total_tps_std"], 
             marker='^', capsize=5, label="Total TPS (prompt+gen)", color="tab:orange")
ax2.set_ylabel("Tokens per Second")
ax2.set_xlabel("Prompt Length (tokens)")
ax2.grid(True, alpha=0.3)
ax2.legend()

#Utilization plot
ax3.errorbar(x, results["util_mean"], yerr=results["util_std"], 
             marker='d', capsize=5, label=f"{'GPU' if DEVICE=='cuda' else 'CPU'} Utilization (%)", color="tab:purple")
ax3.set_ylabel("Utilization (%)")
ax3.set_xlabel("Prompt Length (tokens)")
ax3.grid(True, alpha=0.3)
ax3.legend()

plt.tight_layout()
plt.savefig("inference_benchmark_51M.png", dpi=200)
plt.show()

#Print Tables
table = Table(title="Inference Benchmark Results")
table.add_column("Prompt Len", justify="right")
table.add_column("TTFT (ms)", justify="right")
table.add_column("Gen TPS", justify="right")
table.add_column("Total TPS", justify="right")
table.add_column("Util (%)", justify="right")  

for i, length in enumerate(results["length"]):
    table.add_row(
        f"{length}",
        f"{results['ttft_mean'][i]:.1f} ± {results['ttft_std'][i]:.1f}",
        f"{results['gen_tps_mean'][i]:.1f} ± {results['gen_tps_std'][i]:.1f}",
        f"{results['total_tps_mean'][i]:.1f} ± {results['total_tps_std'][i]:.1f}",
        f"{results['util_mean'][i]:.1f} ± {results['util_std'][i]:.1f}",
    )

console.print(table)


# Knowledge destillation
---


## Load KD Dataset

In [12]:
BATCH_SIZE=32
CONTEXT_LENGTH = 256   # up to 1024 for standard attention; use Sparse Attention beyond that (see Longformer)

In [None]:
from datasets import load_dataset
import pandas as pd
# load the distillation dataset from Hugging Face Hub
dataset = load_dataset(
    "haidar-ali/tallyformer-finance-dataset",
    data_dir="Data/DistillationData",  
    split="train"
)
df = pd.DataFrame(dataset)
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
data_size = len(df)
train_size = int(0.9 * data_size)
train_df = df[:train_size]
val_df = df[train_size:]

In [32]:
len_train=len(train_df)
len_val=len(val_df)
print(len_train)
print(len_val)

4242593
471400


## PrePare Data for online tokenization

### Create KDDataset Class

In [16]:
from torch.utils.data import Dataset, DataLoader
import torch

class KDDataset(Dataset):
    def __init__(self, data, tokenizer, context_length):
        self.data = data
        self.tokenizer = tokenizer
        self.context_length = context_length
        self.ignore_index = -100
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if isinstance(self.data,pd.DataFrame):
          text=self.data.iloc[idx]['text']
        else:
          text = self.data[idx]["text"]
        encoded = self.tokenizer(
            text,
            truncation=True,
            max_length=self.context_length,
            padding="max_length",
            return_tensors='pt'
        )
        input_ids=encoded['input_ids'].squeeze(0).cpu()
        attn_mask=encoded['attention_mask'].squeeze(0).cpu()
        xb = input_ids[:-1]
        yb = input_ids[1:].clone()
        attn_mask = attn_mask[:-1]
        # mask padding
        yb[attn_mask == 0] = self.ignore_index
        return xb, yb,attn_mask



### Define DataLoaders

In [None]:
len_train=len(train_df)
len_val=len(val_df)

train_df=train_df.iloc[:len_train]
val_df=val_df.iloc[:len_val]

# Dataset
train_dataset = KDDataset(train_df, tokenizer, CONTEXT_LENGTH)
val_dataset = KDDataset(val_df, tokenizer, CONTEXT_LENGTH)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,pin_memory=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True,pin_memory=True, num_workers=0)

## Prepare Data for offline tokenization
### Make sure you have enough disk storage ~ 9.context_length.(len(train_df)+len(val_df))/1e6 GB.
#### Otherwise your instance might start sweating, shaking, and reconsidering its life choices. 

### Run this once to create .bin files

In [None]:

import os
import numpy as np
import torch
from tqdm import tqdm

def save_precomputed_dataset(df_texts, filename_prefix, tokenizer, context_length=256, batch_size=1024):
    os.makedirs("precomputed", exist_ok=True)
    xb_path = f"precomputed/{filename_prefix}_xb.bin"
    yb_path = f"precomputed/{filename_prefix}_yb.bin"
    mask_path = f"precomputed/{filename_prefix}_mask.bin"

    if os.path.exists(xb_path):
        print(f"Already exists: {xb_path}")
        return xb_path, yb_path, mask_path

    print(f"Precomputing {filename_prefix} dataset ({len(df_texts)} samples)...")
    n = len(df_texts)
    seq_len = context_length

    # Preallocate memmap arrays
    xb_arr = np.memmap(xb_path, dtype=np.int32, mode='w+', shape=(n, seq_len-1))
    yb_arr = np.memmap(yb_path, dtype=np.int32, mode='w+', shape=(n, seq_len-1))
    mask_arr = np.memmap(mask_path, dtype=np.uint8, mode='w+', shape=(n, seq_len-1))

    ignore_index = -100

    # Batch tokenization
    for start in tqdm(range(0, n, batch_size), desc=f"Tokenizing {filename_prefix}"):
        end = min(start + batch_size, n)
        batch_texts = df_texts.iloc[start:end]["text"].tolist()
        
        enc = tokenizer(
            batch_texts,
            truncation=True,
            max_length=seq_len,
            padding="max_length",
            return_tensors="pt",
            return_attention_mask=True
        )
        input_ids = enc["input_ids"]
        attn_mask = enc["attention_mask"]

        # Split xb, yb, and apply ignore_index mask
        xb_arr[start:end] = input_ids[:, :-1].numpy()
        yb_batch = input_ids[:, 1:].clone()
        yb_batch[attn_mask[:, :-1] == 0] = ignore_index
        yb_arr[start:end] = yb_batch.numpy()
        mask_arr[start:end] = attn_mask[:, :-1].numpy()

    # Flush to disk
    xb_arr.flush()
    yb_arr.flush()
    mask_arr.flush()

    print(f"Saved: {xb_path}")
    return xb_path, yb_path, mask_path


# Run
train_xb, train_yb, train_mask = save_precomputed_dataset(
    train_df, "train", tokenizer, context_length=CONTEXT_LENGTH, batch_size=1024
)

val_xb, val_yb, val_mask = save_precomputed_dataset(
    val_df, "val", tokenizer, context_length=CONTEXT_LENGTH, batch_size=1024
)



### Define DataLoaders

In [None]:
# Faster DataLoaders
from torch.utils.data import Subset , RandomSampler ,get_worker_info
import torch
torch.multiprocessing.set_start_method("spawn", force=True)
len_train=len(train_df)
len_val=int(len(val_df)*0.25) # to save time and money
class FastPrecomputedDataset(Dataset):
    def __init__(self, prefix):
        
        
        self.xb = np.memmap(f"precomputed/{prefix}_xb.bin", dtype=np.int32, mode='r',
                            shape=(len_train if prefix == "train" else len_val, CONTEXT_LENGTH-1))
        self.yb = np.memmap(f"precomputed/{prefix}_yb.bin", dtype=np.int32, mode='r',
                              shape=(len_train if prefix == "train" else len_val, CONTEXT_LENGTH-1))
        self.mask = np.memmap(f"precomputed/{prefix}_mask.bin", dtype=np.uint8, mode='r',
                                 shape=(len_train if prefix == "train" else len_val, CONTEXT_LENGTH-1))

    def __len__(self):
        return len(self.xb)

    def __getitem__(self, idx):
        xb = torch.from_numpy(self.xb[idx].astype(np.int64))
        yb = torch.from_numpy(self.yb[idx].astype(np.int64))
        mask = torch.from_numpy(self.mask[idx].astype(np.int64))
        return xb, yb, mask


BATCH_SIZE = 32
train_dataset = FastPrecomputedDataset("train")
val_dataset   = FastPrecomputedDataset("val")

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    sampler=RandomSampler(train_dataset),  
    num_workers=0,
    pin_memory=False,          
    persistent_workers=False
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=False,           
    persistent_workers=False
)


## Training Stage
### Knowledge Distillation Loss

During distillation, TallyFormer is trained to mimic a teacher model (`gpt2-medium`) using a **combination of cross-entropy and KL-divergence losses**.

#### Total Loss

The student loss is a weighted sum:

L_total = α * L_KL + (1 - α) * L_CE

where:

- **L_CE**: Standard cross-entropy loss with teacher-provided labels  
- **L_KL**: KL-divergence between softened teacher and student logits

#### KL-Divergence Term

The KL term is computed with **temperature scaling** T to soften teacher logits:

L_KL = (1 / N_valid) * Σ_i mask_i * KL(softmax(z_teacher / T) || log_softmax(z_student / T)) * T²

- `mask_i` ensures the loss is computed **only on valid tokens**  
- T² rescales the gradient after temperature scaling  
- α (alpha) controls the relative weight of KL vs CE (e.g., 0.7)

#### Summary

- Cross-entropy encourages **fidelity to ground-truth tokens**  
- KL-divergence encourages the student to **mimic the teacher distribution**  
- Temperature softening allows learning from teacher's confidence distribution  
- Only valid (non-masked) tokens contribute to the loss

#### Reference

- Hinton et al., *Distilling the Knowledge in a Neural Network*, 2015  
  https://arxiv.org/abs/1503.02531


In [43]:
!rm -rf /tmp/torchinductor_root
import torch._dynamo
torch._dynamo.config.suppress_errors=True
torch.cuda.empty_cache()


In [44]:
# Instantiate the Student model (pretrained one) 
STUDENT_MODEL_PATH = "./PreTrainResult/pretrain_tallyformer.pth"
student_model, student_config, extra_info = Transformer.load(STUDENT_MODEL_PATH, device=DEVICE)

The tite done
Detected _orig_mod. prefixes → auto-cleaning state_dict...
Model loaded successfully!
   Epoch: 60 | Val PPL: 38.97 | Params: ~51M | Device: cuda


In [None]:
# The Distillation Trained on two phases

#Phase 1: Take tallyformer-pretrained - > tallyformer-distilled-phase1.pth
#{'temperature': 2.5, 'alpha': 0.7, 'teacher': 'gpt2-medium','epoch':19,'lr':2.5e-4} 

#Phase 2: Take tallyformer-destilled-phase1.pth -> tallyformer-distilled-phase2.pth 
#{'temperature': 3.0, 'alpha': 0.4, 'teacher': 'gpt2-medium','epoch':11,'lr':2e-4}

In [None]:
import math
import os 
import psutil
from rich.console import Console
from rich.table import Table
from contextlib import nullcontext
import torch
import torch.nn.functional as F
from torch.amp import autocast , GradScaler
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import GPT2LMHeadModel
from itertools import cycle
os.environ["TOKENIZERS_PARALLELISM"] = "false"
console=Console()
# GPU Monitoring
if torch.cuda.is_available():
    import pynvml
    pynvml.nvmlInit()
def get_gpu_usage(device_idx=0):
    try:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        util = pynvml.nvmlDeviceGetUtilizationRates(handle)
        return util.gpu, mem_info.used / 1024**2
    except:
        return 0.0, 0.0

# Training Hyperparameter
N_EPOCHS=19
GRAD_ACCUM_STEPS=6
LEARNING_RATE=1.5e-4
WEIGHT_DECAY=0.01
WARMUP_EPOCHS=0.8
MAX_GRAD_NORM=1.0
PATIENCE=8

#  Distillation specification
TEMPERATURE=2.5 # softens teacher loguts
ALPHA=0.7 # 70%KL + 30% CE
T2=TEMPERATURE**2 # Precompute T**2
console.print(f"[bold cyan] KNOWLEDGE DISTILLATION SETUP[/bold cyan]")
console.print(f"  Temperature: {TEMPERATURE}")
console.print(f"  Alpha (KL): {ALPHA}")
console.print(f"  LR: {LEARNING_RATE}")
console.print(f"  Epochs: {N_EPOCHS}")
console.print(f"  Batch: {BATCH_SIZE} × {GRAD_ACCUM_STEPS} = {BATCH_SIZE * GRAD_ACCUM_STEPS}")

# Load Teacher Model
console.print(f"[bold blue]Loading GPT2-Medium Teacher...[/bold blue]")
teacher_model=GPT2LMHeadModel.from_pretrained(
    "gpt2-medium",
    dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
)
# Resize Teacher Embedding to match our tokenizer (50260)
vocab_size=len(tokenizer)
teacher_model.resize_token_embeddings(vocab_size)
teacher_model.eval()
# Freeze teacher
for p in teacher_model.parameters():
    p.requires_grad=False
console.print(f"[bold green]✓ Teacher loaded & resized to {vocab_size} tokens[/bold green]")

# Device mixed precision
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'
DTYPE='bfloat16' if torch.cuda.is_bf16_supported() else 'float16'
ptdtype=torch.bfloat16 if DTYPE=='bfloat16' else torch.float16
teacher_model = teacher_model.to(DEVICE)
if torch.cuda.is_available():
    if torch.cuda.is_bf16_supported():
        ptdtype=torch.bfloat16
        use_scaler=False
    else:
        ptdtype=torch.float16
        use_scaler=True
else:
    ptdtype=torch.float32
    use_scaler=False
ctx=autocast(device_type='cuda',dtype=ptdtype) if DEVICE=='cuda' else nullcontext()
scaler=GradScaler(enabled=use_scaler)
console.print(f"Use Scaler? {use_scaler} | Dtype: {ptdtype}")

# Data Parallel
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        console.print(f"[bold green]Using {num_gpus} GPUs with DataParallel[/bold green]")
        
       
        student_model = student_model.to('cuda:0')
        teacher_model = teacher_model.to('cuda:0')
        
        student_model = torch.nn.DataParallel(
            student_model, 
            device_ids=list(range(num_gpus)),
            output_device='cuda:0'
        )
        console.print(f"[bold green] DataParallel initialized correctly[/bold green]")
    else:
        console.print("[bold green]Using single GPU[/bold green]")
        student_model = student_model.to(DEVICE)
        teacher_model = teacher_model.to(DEVICE)
student_model = torch.compile(student_model)
# Optimizer and scheduler
steps_per_epoch=len(train_loader)//GRAD_ACCUM_STEPS
MAX_ITERS=steps_per_epoch*N_EPOCHS
console.print(f"[bold cyan]Training for {N_EPOCHS} epochs → {MAX_ITERS:,} steps "
              f"({steps_per_epoch:,} steps/epoch)[/bold cyan]")
console.print(f"[bold green]Effective batch size: {BATCH_SIZE * GRAD_ACCUM_STEPS}[/bold green]")
infinite_loader=cycle(train_loader)
optimizer=torch.optim.AdamW(student_model.parameters(),lr=LEARNING_RATE,weight_decay=WEIGHT_DECAY,betas=(0.9,0.95))

# OneCycleLR (optimal for distillation)
total_opt_steps=MAX_ITERS//GRAD_ACCUM_STEPS
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=LEARNING_RATE,
    total_steps=total_opt_steps,
    pct_start=0.1,        #  5% warmup (equivalent to 0.8 epochs)
    anneal_strategy='cos',  #  Cosine annealing
    div_factor=20.0,       #  Initial LR = max_lr/25
    final_div_factor=1e4   #  Final LR = max_lr/10000
)
# Define helper function to compute distillation loss 
def compute_distillation_loss(student_logits,teacher_logits,labels,attention_mask):
    B,T,V=student_logits.shape
    # CE loss 
    ce_loss=F.cross_entropy(
        student_logits.reshape(-1,V),
            labels.reshape(-1),
        ignore_index=-100
    )
    # KL Divergence loss 
    student_log_softmax=F.log_softmax(student_logits/TEMPERATURE,dim=-1)
    teacher_softmax=F.softmax(teacher_logits/TEMPERATURE,dim=-1)
    kl_loss=F.kl_div(
        student_log_softmax,teacher_softmax,reduction='none',log_target=False
    ).sum(dim=-1) # BxT
    # mask the kl to the real tokens only (on T axis)
    kl_loss=(kl_loss*attention_mask).sum()/attention_mask.sum()
    kl_loss=kl_loss*T2
    #total loss
    total_loss=ALPHA*kl_loss +(1-ALPHA)*ce_loss
    return {
        'total_loss':total_loss,
        'ce_loss':ce_loss,
        'kl_loss':kl_loss
    }

# Matrics saving
metrics = {
    "train_total_loss": [], "train_ce": [], "train_kl": [],
    "val_ce": [], "val_kl": [], "val_perplexity": [],
    "cpu_usage": [], "gpu_usage": [], "gpu_mem": []
}
OUT_DIR='./DistillationResult'
os.makedirs(OUT_DIR,exist_ok=True)
BEST_MODEL_PATH=os.path.join(OUT_DIR,'tallyformer-distilled.pth')

best_val_ppl=float('inf')
patience_counter=0

# Training loop
lr=LEARNING_RATE
pbar=tqdm(total=MAX_ITERS,desc='Distillation',dynamic_ncols=True, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")
student_model.train()
teacher_model.eval()
iter_num=0
epoch_train_losses = {'total': [], 'ce': [], 'kl': []}
epoch_cpu_samples = []       
epoch_gpu_samples = []     
epoch_gpu_mem_samples = []   
MONITOR_POINTS=50
monitor_interval=max(1,steps_per_epoch//MONITOR_POINTS)

for iter_num in range(MAX_ITERS):
    xb,yb,attn_mask=next(infinite_loader)
    xb=xb.to(DEVICE,non_blocking=True)
    yb=yb.to(DEVICE,non_blocking=True)
    attn_mask=attn_mask.to(DEVICE,non_blocking=True)
    with ctx:
        # student forward
        student_logits,_,_=student_model(input_ids=xb,labels=None,attention_mask=attn_mask)
        #teacher forward
        with torch.no_grad():
            teacher_outputs=teacher_model(input_ids=xb,attention_mask=attn_mask)
            teacher_logits=teacher_outputs.logits
        loss_dict=compute_distillation_loss(student_logits,teacher_logits,yb,attn_mask)
        loss=loss_dict['total_loss']/GRAD_ACCUM_STEPS
    if use_scaler:
        scaler.scale(loss).backward()
    else:
        loss.backward()
    epoch_train_losses['total'].append(loss.item()*GRAD_ACCUM_STEPS)
    epoch_train_losses['ce'].append(loss_dict['ce_loss'].item())
    epoch_train_losses['kl'].append(loss_dict['kl_loss'].item())
  
    if iter_num % monitor_interval==0 :
        epoch_cpu_samples.append(psutil.cpu_percent(interval=None))
        if DEVICE=='cuda':
            gpu_u,gpu_m=get_gpu_usage()
            epoch_gpu_samples.append(gpu_u)
            epoch_gpu_mem_samples.append(gpu_m)

    # gradient accumulation step
    if (iter_num+1)%GRAD_ACCUM_STEPS==0 :
        if use_scaler:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(student_model.parameters(),MAX_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
        else:
            torch.nn.utils.clip_grad_norm_(student_model.parameters(),MAX_GRAD_NORM)
            optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()
        lr=scheduler.get_last_lr()[0]
    # live tqd, update
    current_epoch=(iter_num+1)/steps_per_epoch
    pbar.set_postfix({
        "Epoch": f"{current_epoch:.1f}/{N_EPOCHS}",
        "Loss": f"{loss.item() * GRAD_ACCUM_STEPS:.4f}",
        "CE": f"{loss_dict['ce_loss'].item():.3f}",
        "KL": f"{loss_dict['kl_loss'].item():.3f}",
        "LR": f"{lr:.1e}",
        "GPU": f"{epoch_gpu_samples[-1]:.0f}%" if epoch_gpu_samples and len(epoch_gpu_samples) > 0 else "N/A",
    })
    pbar.update(1)

    # check end of the epoch
    if (iter_num+1)%steps_per_epoch==0:
        avg_train_total = sum(epoch_train_losses['total']) / len(epoch_train_losses['total'])
        avg_train_ce = sum(epoch_train_losses['ce']) / len(epoch_train_losses['ce'])
        avg_train_kl = sum(epoch_train_losses['kl']) / len(epoch_train_losses['kl'])
        
        if len(epoch_cpu_samples) > 0:
            avg_cpu = sum(epoch_cpu_samples) / len(epoch_cpu_samples)
            avg_gpu = sum(epoch_gpu_samples) / len(epoch_gpu_samples) if epoch_gpu_samples else 0
            avg_gpu_mem = sum(epoch_gpu_mem_samples) / len(epoch_gpu_mem_samples) if epoch_gpu_mem_samples else 0
        else:
            avg_cpu, avg_gpu, avg_gpu_mem = 0, 0, 0
        
        # validation
        student_model.eval()
        teacher_model.eval()
        val_total,val_ce,val_kl,val_tokens=0.0,0.0,0.0,0
        val_pbar=tqdm(val_loader, desc=f"Validating (Epoch {int(current_epoch)})",
                leave=False, dynamic_ncols=True)
        with torch.no_grad():
            for xb_v, yb_v, mask_v in val_pbar:
                xb_v, yb_v, mask_v = xb_v.to(DEVICE), yb_v.to(DEVICE), mask_v.to(DEVICE)
                
                with ctx:
                    student_logits, _, _ = student_model(
                        input_ids=xb_v, labels=None, attention_mask=mask_v
                    )
                    
                    teacher_outputs = teacher_model(input_ids=xb_v, attention_mask=mask_v)
                    teacher_logits = teacher_outputs.logits
                    
                    loss_dict = compute_distillation_loss(student_logits, teacher_logits, yb_v, mask_v)
                
                valid_tokens = int((yb_v != -100).sum().item())
                if valid_tokens > 0:
                    val_total += loss_dict['total_loss'].item() * valid_tokens
                    val_ce +=  loss_dict['ce_loss'].item() * valid_tokens
                    val_kl += loss_dict['kl_loss'].item() * valid_tokens
                    val_tokens += valid_tokens
        avg_val_total = val_total / val_tokens if val_tokens > 0 else 0
        avg_val_ce = val_ce / val_tokens if val_tokens > 0 else 0
        avg_val_kl = val_kl / val_tokens if val_tokens > 0 else 0
        val_ppl = math.exp(avg_val_ce)

        # save te metrix
        metrics["train_total_loss"].append(avg_train_total)
        metrics["train_ce"].append(avg_train_ce)
        metrics["train_kl"].append(avg_train_kl)
        metrics["val_ce"].append(avg_val_ce)
        metrics["val_kl"].append(avg_val_kl)
        metrics["val_perplexity"].append(val_ppl)
        metrics["cpu_usage"].append(avg_cpu)
        metrics["gpu_usage"].append(avg_gpu)
        metrics["gpu_mem"].append(avg_gpu_mem)
        
        console.print(f"\n[bold green]EPOCH {int(current_epoch)}/{N_EPOCHS} COMPLETE[/bold green]")
        console.print(f" Train: Total={avg_train_total:.4f} | CE={avg_train_ce:.4f} | KL={avg_train_kl:.4f}")
        console.print(f" Valid: Total={avg_val_total:.4f} | CE={avg_val_ce:.4f} | KL={avg_val_kl:.4f} | PPL={val_ppl:.2f}")
        console.print(f" CPU: {avg_cpu:.1f}% | GPU: {avg_gpu:.1f}% | GPU Mem: {avg_gpu_mem:.0f} MB | LR: {lr:.1e}")
        
        # Early stopping & save best
        if val_ppl < best_val_ppl:
            best_val_ppl = val_ppl
            patience_counter = 0
            
            save_payload = {
                "epoch": int(current_epoch),
                "val_ppl": val_ppl,
                "distillation": {
                    "temperature": TEMPERATURE,
                    "alpha": ALPHA,
                    "teacher": "gpt2-medium"
                },
                "metrics": metrics
            }
            
            Transformer.save(
                student_model,
                BEST_MODEL_PATH,
                config=student_config,  
                extra_dict=save_payload
            )
            
            console.print(f"[bold yellow] NEW BEST MODEL SAVED! Val PPL: {val_ppl:.2f}[/bold yellow]")
        else:
            patience_counter += 1
            console.print(f"[bold red]No improvement ({patience_counter}/{PATIENCE})[/bold red]")
        
        if patience_counter >= PATIENCE:
            console.print(f"[bold yellow]Early stopping at epoch {int(current_epoch)}[/bold yellow]")
            break
        
        # reset for next epoch 
        student_model.train()
        epoch_train_losses = {'total': [], 'ce': [], 'kl': []}
        epoch_cpu_samples = []
        epoch_gpu_samples = []
        epoch_gpu_mem_samples = []
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            sampler=RandomSampler(train_dataset),  #  Reshuffle each epoch
            num_workers=0,
            pin_memory=False,
            persistent_workers=False
        )
        infinite_loader = cycle(train_loader)

pbar.close()

console.print(f"\n[bold green] DISTILLATION FINISHED![/bold green]")
console.print(f"Best Val PPL: {best_val_ppl:.2f}")
console.print(f"Improvement: {((39.0 - best_val_ppl) / 39.0 * 100):.1f}%")
console.print(f"Model saved: {BEST_MODEL_PATH}")
        
    

In [50]:
# Check the final perplexity , here I use the same validation as pretraining stage ,please return back to the "Load Pretrain Dataset" and "Define DataLoaders" 
model_path="./DistillationResult/tallyformer-distilled-phase2.pth"
model,model_config,metrics=Transformer.load(model_path,device='cuda')
student_model.perplexity(val_loader,'cuda')

The tite done
Detected _orig_mod. prefixes → auto-cleaning state_dict...
Model loaded successfully!
   Epoch: 11 | Val PPL: 39.59 | Params: ~51M | Device: cuda


Computing Perplexity:   0%|          | 0/21747 [00:00<?, ?it/s]

38.562077314174296

## Inference Stage

### Generate Text

In [None]:
inference_config=InferenceConfig(max_new_tokens=100, temperature=0.7, topk=500, topp=0.9, frequency_penalty=0.0, presence_penalty=0.0)
prompt = "One day in the forest ,there was "
model_path="./DistillationResult/tallyformer-distilled-phase2.pth"
model,model_config,metrics=Transformer.load(model_path,device='cuda')

generated = generate_text(model, tokenizer, prompt, inference_config)
print("Generated:", generated)

The tite done
Detected _orig_mod. prefixes → auto-cleaning state_dict...
Model loaded successfully!
   Epoch: 11 | Val PPL: 39.59 | Params: ~51M | Device: cuda


Generated:  a little girl dies from a lightning accident . She is forcedipient by the two boys and is in a coma . They are separated by two different animals . They are named after the girl who is called `` Muffin '' . The boy is now in a Mour lettau and the girl is still in her bed neurotransmitter . She is now living in a night bedroom . She is also in a dream room and she pinkie . The girl is going cyan . She is in a room with a


# SFT
---

## Load SFT Dataset

In [None]:
from datasets import load_dataset
import pandas as pd
# load the sft dataset from Hugging Face Hub
dataset = load_dataset(
    "haidar-ali/tallyformer-finance-dataset",
    data_dir="Data/SFTData",  
    split="train"
)
df = pd.DataFrame(dataset)
sft_df = df.sample(frac=1, random_state=42).reset_index(drop=True)

In [13]:
sft_df.shape

(28074, 2)

In [16]:
sft_df.head()

Unnamed: 0,prompt,response
0,Why does it matter if a Central Bank has a neg...,"That is kind of the point, one of the hopes is..."
1,Where should I be investing my money?,"Pay off your debt. As you witnessed, no ""inves..."
2,Approximation of equity value for company in d...,"Generally ""default"" means that the company can..."
3,Can a company charge you for services never re...,"In general, you can only be charged for servic..."
4,Working out if I should be registered as self-...,Being self employed just means you fill out so...


## Create SFTDataset Class

In [14]:

class SFTDataset(Dataset):
    def __init__(self, data, tokenizer, context_length):
        self.data = data
        self.tokenizer = tokenizer
        self.context_length = context_length
        self.ignore_index = -100
        self.user_token = "<|user|>"
        self.assistant_token = "<|assistant|>"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if isinstance(self.data,pd.DataFrame):
            item = self.data.iloc[idx]
        else:
            item = self.data[idx]
        user_msg = item['prompt']
        assistant_msg = item['response']

        prompt = f"{self.user_token} {user_msg} {self.assistant_token}"
        full_text = f"{prompt} {assistant_msg}"

        encoded_full = self.tokenizer(
            full_text,
            truncation=True,
            max_length=self.context_length,
            padding="max_length",
            return_tensors='pt'
        )
        input_ids = encoded_full['input_ids'].squeeze(0)
        attn_mask = encoded_full['attention_mask'].squeeze(0)

        # Compute prompt length
        encoded_prompt = self.tokenizer(prompt, truncation=True, padding=False, return_tensors='pt')
        prompt_len = encoded_prompt['input_ids'].shape[1]

        xb = input_ids[:-1]
        yb = input_ids[1:].clone()
        attn_mask = attn_mask[:-1]
        # mask prompt + padding
        yb[:prompt_len] = self.ignore_index
        yb[attn_mask==0]= self.ignore_index

        return xb, yb, attn_mask

## Training Stage
### Supervised Fine-Tuning (SFT) Loss

During SFT, TallyFormer-Finance-51M is trained on **finance instruction-response pairs** using a **masked cross-entropy loss** with optional **z-loss regularization**.

#### Total Loss

The total loss applied is:

$$
L_{total} = L_{CE} + \lambda_{z} \cdot L_{z}
$$

Where:

- **L_CE**: Cross-entropy loss computed only on **assistant response tokens**, ignoring prompt and padding tokens:

$$
L_{CE} = \frac{1}{N_{valid}} \sum_{i \in \text{valid}} \text{CE}(y_i, \hat{y}_i)
$$

- **L_z**: z-loss regularization term :

$$
L_z = \frac{1}{N_{valid}} \sum_{i \in \text{valid}} \big(\log \sum_j e^{\hat{y}_{ij}}\big)^2
$$

- \( \lambda_{z} \) = `zloss_coeff`, the weight for z-loss  

- `valid` tokens are those corresponding to the **assistant response**, ignoring the prompt and padding

#### Summary

- **Prompt tokens are masked** and do not contribute to the loss  
- **z-loss** prevents logits from growing too large and stabilizes training  
- Loss is computed **per token** and averaged over the batch  


In [24]:
!rm -rf /tmp/torchinductor_root
import torch._dynamo
torch._dynamo.config.suppress_errors=True
torch.cuda.empty_cache()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [26]:
# SFT Training with PEFT LoRA
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast , GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import os
from rich.console import Console
import psutil
import math
from contextlib import nullcontext

# PEFT
from peft import LoraConfig, get_peft_model

if torch.cuda.is_available():
    import pynvml
    pynvml.nvmlInit()

console = Console()

# Device mixed precision
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'
DTYPE='bfloat16' if torch.cuda.is_bf16_supported() else 'float16'
ptdtype=torch.bfloat16 if DTYPE=='bfloat16' else torch.float16
if torch.cuda.is_available():
    if torch.cuda.is_bf16_supported():
        ptdtype=torch.bfloat16
        use_scaler=False
    else:
        ptdtype=torch.float16
        use_scaler=True
else:
    ptdtype=torch.float32
    use_scaler=False
ctx=autocast(device_type='cuda',dtype=ptdtype) if DEVICE=='cuda' else nullcontext()
scaler=GradScaler(enabled=use_scaler)
console.print(f"Use Scaler? {use_scaler} | Dtype: {ptdtype}")

# Previous model path
NOTEBOOK_OUT_DIR = "./DistillationResult"
DISTILLED_MODEL_PATH = os.path.join(NOTEBOOK_OUT_DIR, "tallyformer-distilled-phase2.pth")
model, model_config, _ = Transformer.load(DISTILLED_MODEL_PATH, device='cpu')
model.to(DEVICE)


# LoRA adapter (only saved during training)
LORA_OUTPUT_DIR = "./SFT_LoRA"
LORA_ADAPTER_PATH = os.path.join(LORA_OUTPUT_DIR, "tallyformer-finance-51m-lora")
os.makedirs(LORA_OUTPUT_DIR, exist_ok=True)

# FINAL merged model (this is what you upload to HF)
FINAL_MODEL_DIR = "TallyFormer-Finance-51M"
os.makedirs(FINAL_MODEL_DIR, exist_ok=True)

# LoRA config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules= ["query", "key", "value", "proj", "fc1", "fc2"],  # matches the Linear names
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save = ["lm_head"]
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Print trainable parameters
model.print_trainable_parameters() #all-params are the base model and the lora weights

BATCH_SIZE=32
CONTEXT_LENGTH = 256   

train_size = int(0.9 * len(sft_df))
train_sft_data = sft_df.iloc[:train_size]
val_sft_data = sft_df.iloc[train_size:]

train_dataset = SFTDataset(train_sft_data, tokenizer, CONTEXT_LENGTH)
val_dataset = SFTDataset(val_sft_data, tokenizer, CONTEXT_LENGTH)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Define optimizer and scheduler
LEARNING_RATE = 2e-4
N_EPOCHS = 3
ACCUM_STEPS = 8
WARMUP_STEPS = 100
MAX_GRAD_NORM = 1.0
TOTAL_STEPS = len(train_loader) * N_EPOCHS

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01, betas=(0.9, 0.95), eps=1e-8)
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=TOTAL_STEPS // ACCUM_STEPS,
    eta_min=LEARNING_RATE * 0.1
)
# Trining Loop
metrics = {"train_loss": [], "val_perplexity": [], "cpu_usage": [], "gpu_usage": [], "gpu_mem": []}
best_val_ppl = float('inf')
best_epoch = None
patience = 3
counter = 0
global_step = 0
update_step = 0

console.print(f"[bold cyan]Starting SFT with PEFT LoRA for {N_EPOCHS} epochs...[/bold cyan]")

for epoch in range(N_EPOCHS):
    model.train()
    total_loss = 0.0
    cpu_list, gpu_list, mem_list = [], [], []

    optimizer.zero_grad(set_to_none=True)

    for step, (xb, yb, attn_mask) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{N_EPOCHS}")):
        global_step += 1
        xb, yb, attn_mask = xb.to(DEVICE), yb.to(DEVICE), attn_mask.to(DEVICE)

        with ctx:
            logits, loss, _ = model(input_ids=xb, labels=yb, attention_mask=attn_mask)

        if loss is None or torch.isnan(loss):
            continue

        loss = loss / ACCUM_STEPS
        if use_scaler:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        if (step + 1) % ACCUM_STEPS == 0 or (step + 1) == len(train_loader):
            if use_scaler:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            update_step += 1
            scheduler.step()
            lr=scheduler.get_last_lr()[0]

        total_loss += loss.item() * ACCUM_STEPS

        cpu_list.append(psutil.cpu_percent())
        if DEVICE == 'cuda':
            handle = pynvml.nvmlDeviceGetHandleByIndex(0)
            util = pynvml.nvmlDeviceGetUtilizationRates(handle)
            mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            gpu_list.append(util.gpu)
            mem_list.append(mem_info.used / 1024**2)

    # Handle final accumulation
    if len(train_loader) % ACCUM_STEPS != 0:
        if use_scaler:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        update_step += 1
        scheduler.step()
        lr=scheduler.get_last_lr()[0]

    avg_loss = total_loss / len(train_loader)
    val_ppl = model.perplexity(val_loader, DEVICE)

    avg_cpu = sum(cpu_list) / len(cpu_list)
    avg_gpu = sum(gpu_list) / len(gpu_list) if gpu_list else 0.0
    avg_mem = sum(mem_list) / len(mem_list) if mem_list else 0.0

    metrics['train_loss'].append(avg_loss)
    metrics['val_perplexity'].append(val_ppl)
    metrics['cpu_usage'].append(avg_cpu)
    metrics['gpu_usage'].append(avg_gpu)
    metrics['gpu_mem'].append(avg_mem)

    console.print(f"[green]Epoch {epoch+1}/{N_EPOCHS} | Loss: {avg_loss:.4f} | Val PPL: {val_ppl:.2f} | CPU: {avg_cpu:.1f}% | GPU: {avg_gpu:.1f}% | Mem: {avg_mem:.0f} MB[/green]")

    if val_ppl < best_val_ppl:
        best_val_ppl = val_ppl
        best_epoch = epoch + 1
        counter = 0
        #save the lora adapter
        model.save_pretrained(LORA_ADAPTER_PATH)
        tokenizer.save_pretrained(LORA_ADAPTER_PATH)
    
        console.print(f"[bold yellow]New best model saved! PPL: {val_ppl:.2f}[/bold yellow]")
    else:
        counter += 1
        console.print(f"[bold red]No improvement for {counter} epoch(s)[/bold red]")

    if counter >= patience:
        console.print(f"[bold yellow]Early stopping at epoch {epoch+1}[/bold yellow]")
        break

console.print(f"[bold green]SFT finished! Best epoch: {best_epoch}, Best PPL: {best_val_ppl:.2f}[/bold green]")


# To get a merged model and save cleaned model
from peft import PeftModel
console.print("[bold blue]Merging LoRA weights into final model...[/bold blue]")
base_model, model_config, _ = Transformer.load(
    DISTILLED_MODEL_PATH,      # the distilled-phase2
    device="cpu"
)
merged_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH)
merged_model = merged_model.merge_and_unload()
Transformer.save(merged_model, os.path.join(FINAL_MODEL_DIR, "model.pth"), config=model_config)
tokenizer.save_pretrained(FINAL_MODEL_DIR)
console.print(f"[bold green]SUCCESS: TallyFormer-Finance-51M is ready![/bold green]")
console.print(f"[bold green]Final model: ./{FINAL_MODEL_DIR}/[/bold green]")
console.print(f"[bold magenta]Upload this folder to Hugging Face -> you're done![/bold magenta]")

The tite done
Detected _orig_mod. prefixes → auto-cleaning state_dict...
Model loaded successfully!
   Epoch: 11 | Val PPL: 39.59 | Params: ~51M | Device: cpu
The tite done
trainable params: 26,585,088 || all params: 77,500,928 || trainable%: 34.3029


Epoch 1/3: 100%|██████████| 790/790 [02:44<00:00,  4.80it/s]
Computing Perplexity: 100%|██████████| 88/88 [00:10<00:00,  8.46it/s]


Epoch 2/3: 100%|██████████| 790/790 [02:44<00:00,  4.79it/s]
Computing Perplexity: 100%|██████████| 88/88 [00:10<00:00,  8.52it/s]


Epoch 3/3: 100%|██████████| 790/790 [02:44<00:00,  4.79it/s]
Computing Perplexity: 100%|██████████| 88/88 [00:10<00:00,  8.48it/s]


The tite done
Detected _orig_mod. prefixes → auto-cleaning state_dict...
Model loaded successfully!
   Epoch: 11 | Val PPL: 39.59 | Params: ~51M | Device: cpu
Model saved to TallyFormer-Finance-51M/model.pth


## Inference Stage

In [97]:
inference_config = InferenceConfig(
    max_new_tokens=50,
    temperature=0.7,
    topk=500,
    topp=0.9,
    frequency_penalty=1.0,
    presence_penalty=0.0,
    return_only_generated=True
)


prompt = "one way to make money is "
pretrain_path='./PreTrainResult/pretrain_tallyformer.pth'
distilled_path='./DistillationResult/tallyformer-distilled-phase2.pth'
merged_model_path ="./TallyFormer-Finance-51M/model.pth"
use=3 # [1,2,3]
model_path=[pretrain_path,distilled_path,merged_model_path][use-1] if use in {1,2,3} else merged_model_path
print(model_path)
# Load merged model
model, model_config, _ = Transformer.load(model_path, device='cuda')

generated = generate_text(model, tokenizer, prompt, inference_config,is_sft=use==3)
print("Generated:", generated)


./TallyFormer-Finance-51M/model.pth
The tite done
Model loaded successfully!
   Epoch: ? | Val PPL: ? | Params: ~51M | Device: cuda


Generated: . The goal of this project hog from the beginning isAtlas a campaign to make money through donation. This campaign starts with creating a community of artists and music producers to support their efforts. Through this community 2024, artists can create unique music and come
