# Knowledge Distillation for Math Reasoning
Teacher: Qwen2.5-1.5B-Instruct

Student: Qwen2Moe


In [1]:
! nvidia-smi

Sat May  3 16:21:41 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:07:00.0 Off |                    0 |
| N/A   26C    P0              59W / 400W |      0MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import re
import os
import wandb
from datetime import datetime
from typing import Optional, Union
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from torch.utils.data import DataLoader

from datasets import Dataset
from datasets import load_dataset

from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams

INFO 05-03 16:21:54 [__init__.py:239] Automatically detected platform cuda.


2025-05-03 16:21:54.940481: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746303714.960589 4075707 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746303714.966797 4075707 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746303714.983750 4075707 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746303714.983768 4075707 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746303714.983770 4075707 computation_placer.cc:177] computation placer alr

### Setup

In [3]:
from dotenv import load_dotenv
load_dotenv()  # Load environment variables from .env file

True

In [4]:
log_dir = "/network/rit/lab/wang_lab_cs/ptian/logs"

In [5]:
save_dir = "/network/rit/lab/wang_lab_cs/ptian/ckpt/distillation"  # TODO: serialization path, change to evaluate
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [6]:
# # HuggingFace login
# import huggingface_hub
# huggingface_hub.login()

In [7]:
# # Weights & Bias login
# import wandb
# wandb.login()

In [8]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ["VLLM_USE_V1"] = "0"  # For A100 GPU

In [9]:
student_id, teacher_id = "Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct"

In [10]:
class FewShotEvaluator:
    """
    Few-shot evaluator for math reasoning tasks.
    """
    def __init__(self, dataset: Dataset, n_shots: int = 3, device: str = "cuda", batch_size: int = 16) -> None:
        self.dataset = dataset
        self.n_shots = n_shots
        self.device = device
        self.batch_size = batch_size
        self.fewshot_prompt = self.get_fewshot_prompt()

    def get_fewshot_prompt(self) -> str:
        prompt = "Solve these math problems:\n\n"
        for i in range(self.n_shots):
            example = self.dataset[i]
            prompt += f"Question: {example['question']}\nAnswer: {example['answer']}" + "\n\n"
        return prompt

    def preprocess_eval(self, examples: dict) -> dict:
        # Preprocess the example to include the few-shot prompt
        return {
            "prompt": [self.fewshot_prompt + f"Question: {question}\nAnswer:\n" for question in examples["question"]]
        }

    def parse_answer(self, answer: str) -> Optional[str]:
        # Extract the answer from the generated text
        try:
            predicted_answer = re.search(r"#### (-?\d+\.?\d*)", answer).group(1)
        except:
            predicted_answer = None
        return predicted_answer

    # def eval(self, model_path: str, tokenizer: AutoTokenizer, device: str = "cuda", temperature: float = 0.7, top_p: float = 0.95, max_tokens: int = 256) -> float:
    def eval(self, model_path: str, dtype: str = "auto", device: str = "cuda", temperature: float = 0.7, top_p: float = 0.95, max_tokens: int = 256) -> float:
        """
        Evaluate exact match accuracy
        """
        # Load dataset
        eval_dataset = self.dataset.select(range(self.n_shots, len(self.dataset)))
        eval_dataset = eval_dataset.map(self.preprocess_eval, batched=True)
        eval_dataloader = DataLoader(eval_dataset, batch_size=self.batch_size, shuffle=False)

        # Load model
        llm = LLM(model=model_path, dtype=dtype)
        # Shared or individual sampling settings
        sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens)

        correct = 0
        num_questions = 0

        answers = []

        # batch inference
        for _, batch in tqdm(enumerate(eval_dataloader), desc="Eval Inference: ", total=len(eval_dataloader)):
            # inputs = tokenizer(batch["prompt"], return_tensors="pt", max_length=256, padding="max_length", truncation=True).to(device)
            # outputs = model.generate(**inputs, max_new_tokens=256)
            # batch_answers = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            # answers.extend(batch_answers)
            prompts = batch["prompt"]
            outputs = llm.generate(prompts, sampling_params)
            batch_answers = [output.outputs[0].text.strip() for output in outputs]
            answers.extend(batch_answers)
            torch.cuda.empty_cache()

        # text parse for exact match
        for i, (correct_answer, generated_answer) in tqdm(enumerate(zip(eval_dataset['answer'], answers)), desc="Evaluating Exact Match Accuracy: ", total=len(eval_dataset)):
            # # Remove the input tokens from the output for transformers inference
            # generated_answer = generated_answer[len(eval_dataset['prompt'][i]):]

            # Extract final answer
            predicted_answer = self.parse_answer(generated_answer)
            ground_truth = self.parse_answer(correct_answer)

            # Check if the predicted answer matches the ground truth
            if ground_truth:
                num_questions += 1
                if predicted_answer and predicted_answer == ground_truth:
                    correct += 1

        return correct / num_questions if num_questions > 0 else 0

### Pre-Train Evaluation

In [11]:
eval_ds = load_dataset("openai/gsm8k", "main", split="test", num_proc=4)
evaluator = FewShotEvaluator(eval_ds, n_shots=3, device="cuda", batch_size=128)

In [12]:
# # pretrain evaluation
# qem1 = evaluator.eval(teacher_id, device="cuda")
# qem1

In [13]:
# # pretrain evaluation
# qem1 = evaluator.eval(student_id, device="cuda")
# qem1

## Student Model Training

In [14]:
from transformers import AutoConfig
from transformers import Qwen2MoeConfig, Qwen2MoeForCausalLM

In [15]:
# default_config = AutoConfig.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
# default_config

In [16]:
config = Qwen2MoeConfig(
    architectures=["Qwen2MoeForCausalLM"],  # for MoE
    model_type="qwen2_moe",

    hidden_size=896,
    intermediate_size=4864,
    num_hidden_layers=3,
    num_attention_heads=14,
    num_key_value_heads=2,
    hidden_act="silu",

    max_position_embeddings=32768,
    max_window_layers=21,
    sliding_window=4096,
    use_sliding_window=False,

    rope_theta=1_000_000.0,
    rope_scaling=None,

    attention_dropout=0.0,
    rms_norm_eps=1e-6,
    initializer_range=0.02,

    vocab_size=151936,
    bos_token_id=151643,
    eos_token_id=151645,
    tie_word_embeddings=True,

    torch_dtype="bfloat16",
    use_cache=True,
    # transformers_version="4.51.3",

    # 🧠 MoE-specific configs:
    num_local_experts=4,           # Total experts per MoE layer
    num_experts_per_tok=2,         # How many experts each token is routed to
    moe_layer_freq=2,              # Insert MoE layer every N transformer layers
    output_router_logits=False     # Often False during inference
)

In [17]:
student_model = Qwen2MoeForCausalLM(config)

In [18]:
# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained(teacher_id)
tokenizer.pad_token = tokenizer.eos_token
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_id, device_map="auto")
# student_model = AutoModelForCausalLM.from_pretrained(student_id, device_map="auto")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [19]:
teacher_model.config.vocab_size, student_model.config.vocab_size

(151936, 151936)

In [20]:
if teacher_model.config.vocab_size != student_model.config.vocab_size:
    student_model.resize_token_embeddings(teacher_model.config.vocab_size)
# subustitute last layer of student model with that of teacher model
# student_model.lm_head = nn.Linear(student_model.config.hidden_size, teacher_model.lm_head.weight.size(0), bias=False)

In [21]:
teacher_model.eval()

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), eps=1e-06)
    (rotary_emb): Qw

In [22]:
student_model.eval()

Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-2): 3 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=896, out_features=60, bias=False)
          (experts): ModuleList(
            (0-59): 60 x Qwen2MoeMLP(
              (gate_proj): Linear(in_features=896, out_features=1408, bias=False)
              (up_proj): Linear(in_features=896, out_features=1408, bias=False)
              (down_proj): Linear(in_features=1408, out_features=896, bias=False)
              (act_fn):

In [23]:
def eval_size(model):
    """
    Function to evaluate the size of the model in terms of number of parameters.
    """
    return sum(p.numel() for p in model.parameters())  / 10**9

eval_size(student_model), eval_size(teacher_model)  # Check the number of parameters in the models

(0.868476544, 1.543714304)

In [None]:
class GSM8KDataModule(pl.LightningDataModule):
    def __init__(self, tokenizer: AutoTokenizer, batch_size: int = 2) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.batch_size = batch_size
        self.num_workers = os.cpu_count() - 1 if os.cpu_count() else 0


    def setup(self, stage=None) -> None:
        dataset = load_dataset("openai/gsm8k", "main", split="train", num_proc=4)

        def preprocess_training(examples: dict) -> dict:
            """
            Preprocess training corpus.
            Input: ids, attention_mask
            Output: labels
            """
            inputs = ["Question: " + q + "\nAnswer:" + a for q, a in zip(examples["question"], examples["answer"])]
            model_inputs = self.tokenizer(
                inputs,
                padding="max_length",
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )

            prompt_inputs = self.tokenizer(
                ["Question: " + q + "\nAnswer:" for q in examples["question"]],
                padding="max_length",
                truncation=True,
                max_length=256,
                return_tensors="pt"
            )
            model_inputs["labels"] = model_inputs["input_ids"].masked_fill(prompt_inputs["attention_mask"] != model_inputs["attention_mask"], self.tokenizer.pad_token_id)
            return model_inputs

        self.train_dataset = dataset.map(preprocess_training, batched=True)
        self.train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])


    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    # def val_dataloader(self) -> DataLoader:
    #     pass


class DistillationLightningModule(pl.LightningModule):
    def __init__(self,
                 student_model: AutoModelForCausalLM,
                 teacher_model: AutoModelForCausalLM,
                 tokenizer: AutoTokenizer,
                 alpha: float = 0.3,
                 scale: float = 0.01, 
                 temperature: float = 2.0,
                 learning_rate: float = 5e-5
                 ) -> None:
        super().__init__()
        self.student = student_model
        self.teacher = teacher_model
        self.tokenizer = tokenizer
        self.alpha = alpha
        self.scale = scale
        self.temperature = temperature
        self.learning_rate = learning_rate
        self.student.train()
        self.teacher.eval()
        for param in self.teacher.parameters():
            param.requires_grad = False

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        return self.student(input_ids=input_ids, attention_mask=attention_mask)

    def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)

        with torch.no_grad():
            teacher_logits = self.teacher(input_ids=input_ids, attention_mask=attention_mask).logits

        student_logits = self.student(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
        student_probs = F.softmax(student_logits / self.temperature, dim=-1)
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1)
        student_over_teacher_logp = student_log_probs - teacher_log_probs
        # efficient estimator of KL divergence: http://joschu.net/blog/kl-approx.html
        forward_kl = torch.exp(student_over_teacher_logp) - student_over_teacher_logp - 1
        # reverse_kl = torch.exp(-student_over_teacher_logp) + student_over_teacher_logp - 1
        reverse_kl = 0
        loss_kl = ((forward_kl + reverse_kl).mean(dim=-1) * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
        loss_kl = loss_kl.mean()


        labels = batch["labels"]
        # print(labels.shape, student_probs.shape)
        labels = labels[:, 1:].to(self.device).contiguous()
        student_logits = student_logits[:, :-1, :].contiguous()
        # print(labels.shape, student_probs.shape)
        loss_ce = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), labels.view(-1), ignore_index=self.tokenizer.pad_token_id)

        loss = (1 - self.alpha) * self.scale * loss_kl + self.alpha * loss_ce

        self.log("loss/train_loss_kl", loss_kl * self.scale, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("loss/train_loss_ce", loss_ce, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("loss/train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
        
        # self.log("loss/train_loss_ce", loss_ce, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        # return loss_ce
        

    # def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
    #     pass

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.AdamW(self.student.parameters(), lr=self.learning_rate)


In [25]:
gsm8k_dm = GSM8KDataModule(tokenizer=tokenizer, batch_size=8)
distill_model = DistillationLightningModule(
    student_model, teacher_model, tokenizer=tokenizer, 
    alpha=0.9, scale=5e-3, temperature=2.0, learning_rate=5e-5
    )

In [26]:
# checkpoint_callback = ModelCheckpoint(
#     dirpath=save_dir,
#     filename="ckpt-{epoch:02d}-{step}",
#     every_n_train_steps=1000,
#     save_top_k=-1,
# )

In [27]:
from gsm8k_verify import gsm8k_eval
class EvalCallback(Callback):
    def __init__(self, ckpt_dir: str, model_name: str):
        super().__init__()
        self.ckpt_dir = ckpt_dir
        self.model_name = model_name

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        # Save the model
        save_path = os.path.join(self.ckpt_dir, self.model_name, f"epoch-{trainer.current_epoch}")
        pl_module.student.save_pretrained(save_path)
        tokenizer.save_pretrained(save_path)
        pl_module.student.config.save_pretrained(save_path)
        # Evaluate the model
        acc0 = gsm8k_eval(model_name=save_path, fewshot=0)
        acc3 = gsm8k_eval(model_name=save_path, fewshot=3)
        pl_module.log("eval/acc_fewshot0", acc0, on_step=False, on_epoch=True, logger=True)
        pl_module.log("eval/acc_fewshot3", acc3, on_step=False, on_epoch=True, logger=True)
        
        

In [28]:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
save_model_name = f"distilled-qwen2moe-0.8B-qwen2.5-1.5B-{timestamp}"
ckpt_callback = EvalCallback(
    ckpt_dir=save_dir,
    model_name=save_model_name
)

# Trainer
trainer = pl.Trainer(
    max_epochs=10,  # increase epoch for training from scratch
    callbacks=[ckpt_callback],
    precision="16-mixed",
    log_every_n_steps=10,
    logger=pl.loggers.WandbLogger(project="KD-COMS6998", name=f"qwen2moe-0.8B-qwen2.5-1.5B-stf-{timestamp}", save_dir=log_dir),
    accelerator="gpu",
    devices=1,
    accumulate_grad_batches=1,
    default_root_dir=log_dir,
)

/network/rit/lab/wang_lab_cs/ptian/miniconda/lib/python3.12/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /network/rit/lab/wang_lab_cs/ptian/miniconda/lib/pyt ...
Using 16bit Automatic Mixed Precision (AMP)
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [29]:
os.environ["TOKENIZERS_PARALLELISM"] = "false" # avoid dead lock

In [None]:
trainer.fit(
    distill_model,
    datamodule=gsm8k_dm,
)

You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtpzl0222[0m ([33mtptrix29[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Map:   0%|          | 0/64 [00:00<?, ? examples/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                | Params | Mode 
--------------------------------------------------------
0 | student | Qwen2MoeForCausalLM | 868 M  | train
1 | teacher | Qwen2ForCausalLM    | 1.5 B  | eval 
--------------------------------------------------------
868 M     Trainable params
1.5 B     Non-trainable params
2.4 B     Total params
9,648.763 Total estimated model params size (MB)
961       Modules in train mode
371       Modules in eval mode
/network/rit/lab/wang_lab_cs/ptian/miniconda/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                          | 0/? [00:00<?, ?it/s]

INFO 05-02 00:11:09 [config.py:2832] Downcasting torch.float32 to torch.float16.
INFO 05-02 00:11:29 [config.py:689] This model supports multiple tasks: {'generate', 'score', 'classify', 'embed', 'reward'}. Defaulting to 'generate'.
INFO 05-02 00:11:29 [config.py:1901] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 05-02 00:11:37 [__init__.py:239] Automatically detected platform cuda.


2025-05-02 00:11:38.071307: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746159098.092759   90799 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746159098.098997   90799 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746159098.115237   90799 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746159098.115286   90799 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746159098.115288   90799 computation_placer.cc:177] computation placer alr

INFO 05-02 00:11:44 [core.py:61] Initializing a V1 LLM engine (v0.8.4) with config: model='/network/rit/lab/wang_lab_cs/ptian/ckpt/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250502-001002/epoch-0', speculative_config=None, tokenizer='/network/rit/lab/wang_lab_cs/ptian/ckpt/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250502-001002/epoch-0', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=Fa

2025-05-02 00:11:45,278 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend


INFO 05-02 00:11:46 [parallel_state.py:959] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 05-02 00:11:46 [cuda.py:221] Using Flash Attention backend on V1 engine.
INFO 05-02 00:11:46 [gpu_model_runner.py:1276] Starting to load model /network/rit/lab/wang_lab_cs/ptian/ckpt/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250502-001002/epoch-0...
INFO 05-02 00:11:46 [topk_topp_sampler.py:44] Currently, FlashInfer top-p & top-k sampling sampler is disabled because FlashInfer>=v0.2.3 is not backward compatible. Falling back to the PyTorch-native implementation of top-p & top-k sampling.


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:12<00:00, 12.86s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:12<00:00, 12.86s/it]



INFO 05-02 00:11:59 [loader.py:458] Loading weights took 13.05 seconds
INFO 05-02 00:12:00 [gpu_model_runner.py:1291] Model loading took 1.6216 GiB and 13.241454 seconds
INFO 05-02 00:12:03 [backends.py:416] Using cache directory: /network/rit/lab/wang_lab_cs/ptian/vllm_cache/torch_compile_cache/65cc9c4946/rank_0_0 for vLLM's torch.compile
INFO 05-02 00:12:03 [backends.py:426] Dynamo bytecode transform time: 3.00 s
INFO 05-02 00:12:06 [backends.py:132] Cache the graph of shape None for later use
INFO 05-02 00:12:12 [backends.py:144] Compiling a graph for general shape takes 9.03 s
INFO 05-02 00:12:15 [monitor.py:33] torch.compile takes 12.02 s in total
INFO 05-02 00:12:17 [kv_cache_utils.py:634] GPU KV cache size: 12,050,080 tokens
INFO 05-02 00:12:17 [kv_cache_utils.py:637] Maximum concurrency for 32,768 tokens per request: 367.74x
INFO 05-02 00:12:48 [gpu_model_runner.py:1626] Graph capturing finished in 30 secs, took 0.41 GiB
INFO 05-02 00:12:48 [core.py:163] init engine (profile, c

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|        | 0/39 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]



INFO 05-02 00:13:23 [config.py:2832] Downcasting torch.float32 to torch.float16.
INFO 05-02 00:13:23 [config.py:689] This model supports multiple tasks: {'generate', 'score', 'classify', 'embed', 'reward'}. Defaulting to 'generate'.
INFO 05-02 00:13:23 [config.py:1901] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 05-02 00:13:32 [__init__.py:239] Automatically detected platform cuda.


2025-05-02 00:13:32.708220: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746159212.728129   92878 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746159212.734163   92878 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746159212.749332   92878 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746159212.749361   92878 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746159212.749364   92878 computation_placer.cc:177] computation placer alr

INFO 05-02 00:13:41 [core.py:61] Initializing a V1 LLM engine (v0.8.4) with config: model='/network/rit/lab/wang_lab_cs/ptian/ckpt/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250502-001002/epoch-0', speculative_config=None, tokenizer='/network/rit/lab/wang_lab_cs/ptian/ckpt/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250502-001002/epoch-0', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=Fa

2025-05-02 00:13:41,816 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend


INFO 05-02 00:13:42 [parallel_state.py:959] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 05-02 00:13:42 [cuda.py:221] Using Flash Attention backend on V1 engine.
INFO 05-02 00:13:43 [gpu_model_runner.py:1276] Starting to load model /network/rit/lab/wang_lab_cs/ptian/ckpt/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250502-001002/epoch-0...
INFO 05-02 00:13:43 [topk_topp_sampler.py:44] Currently, FlashInfer top-p & top-k sampling sampler is disabled because FlashInfer>=v0.2.3 is not backward compatible. Falling back to the PyTorch-native implementation of top-p & top-k sampling.


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.58it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.58it/s]



INFO 05-02 00:13:45 [loader.py:458] Loading weights took 1.06 seconds
INFO 05-02 00:13:45 [gpu_model_runner.py:1291] Model loading took 1.6216 GiB and 1.752237 seconds
INFO 05-02 00:13:48 [backends.py:416] Using cache directory: /network/rit/lab/wang_lab_cs/ptian/vllm_cache/torch_compile_cache/65cc9c4946/rank_0_0 for vLLM's torch.compile
INFO 05-02 00:13:48 [backends.py:426] Dynamo bytecode transform time: 2.94 s
INFO 05-02 00:13:48 [backends.py:115] Directly load the compiled graph for shape None from the cache
INFO 05-02 00:13:49 [monitor.py:33] torch.compile takes 2.94 s in total
INFO 05-02 00:13:50 [kv_cache_utils.py:634] GPU KV cache size: 12,050,080 tokens
INFO 05-02 00:13:50 [kv_cache_utils.py:637] Maximum concurrency for 32,768 tokens per request: 367.74x
INFO 05-02 00:14:17 [gpu_model_runner.py:1626] Graph capturing finished in 27 secs, took 0.41 GiB
INFO 05-02 00:14:17 [core.py:163] init engine (profile, create kv cache, warmup model) took 32.39 seconds
INFO 05-02 00:14:18 [c

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|       | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|        | 0/39 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]



INFO 05-02 00:17:01 [config.py:2832] Downcasting torch.float32 to torch.float16.
INFO 05-02 00:17:01 [config.py:689] This model supports multiple tasks: {'generate', 'score', 'classify', 'embed', 'reward'}. Defaulting to 'generate'.
INFO 05-02 00:17:01 [config.py:1901] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 05-02 00:17:11 [__init__.py:239] Automatically detected platform cuda.


2025-05-02 00:17:11.543852: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746159431.564115   97569 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746159431.570031   97569 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746159431.585188   97569 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746159431.585217   97569 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746159431.585219   97569 computation_placer.cc:177] computation placer alr

INFO 05-02 00:17:19 [core.py:61] Initializing a V1 LLM engine (v0.8.4) with config: model='/network/rit/lab/wang_lab_cs/ptian/ckpt/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250502-001002/epoch-1', speculative_config=None, tokenizer='/network/rit/lab/wang_lab_cs/ptian/ckpt/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250502-001002/epoch-1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=Fa

2025-05-02 00:17:20,408 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend


INFO 05-02 00:17:21 [parallel_state.py:959] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 05-02 00:17:21 [cuda.py:221] Using Flash Attention backend on V1 engine.
INFO 05-02 00:17:21 [gpu_model_runner.py:1276] Starting to load model /network/rit/lab/wang_lab_cs/ptian/ckpt/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250502-001002/epoch-1...
INFO 05-02 00:17:21 [topk_topp_sampler.py:44] Currently, FlashInfer top-p & top-k sampling sampler is disabled because FlashInfer>=v0.2.3 is not backward compatible. Falling back to the PyTorch-native implementation of top-p & top-k sampling.


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:04<00:00,  4.45s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:04<00:00,  4.45s/it]



INFO 05-02 00:17:26 [loader.py:458] Loading weights took 4.62 seconds
INFO 05-02 00:17:26 [gpu_model_runner.py:1291] Model loading took 1.6216 GiB and 4.822129 seconds
INFO 05-02 00:17:29 [backends.py:416] Using cache directory: /network/rit/lab/wang_lab_cs/ptian/vllm_cache/torch_compile_cache/b2e828e7f8/rank_0_0 for vLLM's torch.compile
INFO 05-02 00:17:29 [backends.py:426] Dynamo bytecode transform time: 2.93 s



Detected KeyboardInterrupt, attempting graceful shutdown ...


[rank: 0] Received SIGTERM: 15


In [None]:
wandb.finish()

0,1
epoch,▁▁▁▁▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇█████
loss/train_loss_ce_epoch,█▆▅▄▃▂▂▁▁▁
loss/train_loss_ce_step,█▅▄▅▄▄▃▄▄▃▃▃▃▄▃▃▃▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇█████

0,1
epoch,9.0
loss/train_loss_ce_epoch,0.27443
loss/train_loss_ce_step,0.27745
trainer/global_step,9349.0


### Post-Training

In [None]:
save_path = f"{save_dir}/{save_model_name}"
distill_model.student.save_pretrained(save_path)
distill_model.student.config.save_pretrained(save_path)
gsm8k_dm.tokenizer.save_pretrained(save_path)
os.listdir(save_path)

['config.json',
 'generation_config.json',
 'model.safetensors',
 'tokenizer_config.json',
 'special_tokens_map.json',
 'added_tokens.json',
 'vocab.json',
 'merges.txt',
 'tokenizer.json']

In [None]:
# post-training evaluation
qem2 = evaluator.eval(save_path, device="cuda")
qem2

INFO 04-29 18:57:19 [config.py:2832] Downcasting torch.float32 to torch.float16.
INFO 04-29 18:57:35 [config.py:689] This model supports multiple tasks: {'classify', 'embed', 'reward', 'generate', 'score'}. Defaulting to 'generate'.
INFO 04-29 18:57:35 [config.py:1901] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 04-29 18:57:42 [__init__.py:239] Automatically detected platform cuda.


2025-04-29 18:57:42.447273: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745967462.467369 2455765 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745967462.473529 2455765 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1745967462.489633 2455765 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745967462.489652 2455765 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1745967462.489654 2455765 computation_placer.cc:177] computation placer alr

INFO 04-29 18:57:48 [core.py:61] Initializing a V1 LLM engine (v0.8.4) with config: model='/network/rit/lab/wang_lab_cs/ptian/output/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250429-173142', speculative_config=None, tokenizer='/network/rit/lab/wang_lab_cs/ptian/output/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250429-173142', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=N

2025-04-29 18:57:49,047 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend


INFO 04-29 18:57:50 [parallel_state.py:959] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 04-29 18:57:50 [cuda.py:221] Using Flash Attention backend on V1 engine.
INFO 04-29 18:57:50 [gpu_model_runner.py:1276] Starting to load model /network/rit/lab/wang_lab_cs/ptian/output/distillation/distilled-qwen2moe-0.8B-qwen2.5-1.5B-20250429-173142...
INFO 04-29 18:57:50 [topk_topp_sampler.py:44] Currently, FlashInfer top-p & top-k sampling sampler is disabled because FlashInfer>=v0.2.3 is not backward compatible. Falling back to the PyTorch-native implementation of top-p & top-k sampling.


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:04<00:00,  4.64s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:04<00:00,  4.64s/it]



INFO 04-29 18:57:55 [loader.py:458] Loading weights took 4.85 seconds
INFO 04-29 18:57:55 [gpu_model_runner.py:1291] Model loading took 1.6216 GiB and 5.073681 seconds
INFO 04-29 18:58:08 [backends.py:416] Using cache directory: /network/rit/home/ptian_wang_lab_cs/.cache/vllm/torch_compile_cache/0c967f847a/rank_0_0 for vLLM's torch.compile
INFO 04-29 18:58:08 [backends.py:426] Dynamo bytecode transform time: 13.18 s
INFO 04-29 18:58:11 [backends.py:132] Cache the graph of shape None for later use
INFO 04-29 18:58:17 [backends.py:144] Compiling a graph for general shape takes 8.32 s
INFO 04-29 18:58:19 [monitor.py:33] torch.compile takes 21.50 s in total
INFO 04-29 18:58:20 [kv_cache_utils.py:634] GPU KV cache size: 35,999,088 tokens
INFO 04-29 18:58:20 [kv_cache_utils.py:637] Maximum concurrency for 32,768 tokens per request: 1098.60x
INFO 04-29 18:58:45 [gpu_model_runner.py:1626] Graph capturing finished in 24 secs, took 0.41 GiB
INFO 04-29 18:58:45 [core.py:163] init engine (profile,

Eval Inference:   0%|          | 0/11 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/128 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|           | 0/36 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Evaluating Exact Match Accuracy:   0%|          | 0/1316 [00:00<?, ?it/s]



0.0