In [None]:
# connect to google drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


##Install Dependencies

In [None]:
!pip install sacrebleu
!pip install rouge-score



In [None]:
!pip install -q transformers accelerate peft
!pip install -U bitsandbytes



In [None]:
!pip install datasets scipy
!pip install tqdm



##Imports

In [None]:
import os
import json
import math
import random
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    get_scheduler
)
from torch.optim import AdamW

from datasets import load_dataset

##Config AND Hyperparameters

In [None]:
# ===== Models =====
TEACHER_MODEL = "mistralai/Mistral-7B-Instruct-v0.1"
STUDENT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# ===== Data =====
DATASET_NAME = "alibaba-pai/OmniThought"
N_SAMPLES = 100

# ===== Training =====
OUTPUT_DIR = "tinyllama_alm_student"
BATCH_SIZE = 2
LR = 1e-5
EPOCHS = 4
LAMBDA_ANSWER = 0.5

# ===== Paths =====
TEACHER_OUT_FILE = "teacher_outputs.jsonl"
STUDENT_DISTILL_FILE = "student_distill_data.jsonl"

os.makedirs(OUTPUT_DIR, exist_ok=True)


##Sample Subset

In [None]:
dataset_stream = load_dataset(DATASET_NAME, split="train", streaming=True)

# Take first N samples
small_subset = []
for i, example in enumerate(dataset_stream):
    small_subset.append(example)
    if i + 1 == N_SAMPLES:
        break

print(f"Loaded {len(small_subset)} examples")
print(small_subset[0])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/135 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/135 [00:00<?, ?it/s]

Loaded 100 examples
{'question': 'For an integer $n \\ge 3$ we consider a circle with $n$ points on it.\nWe place a positive integer at each point, where the numbers are not necessary need to be different. Such placement of numbers is called [i]stable [/i] as three numbers next to always have product $n$ each other. \nFor how many values of $n$ with $3 \\le n \\le 2020$ is it possible to place numbers in a stable way?\n', 'reasoning': [{'Cognitive_Difficulty': {'judge': 'QwQ-32B', 'level': 6}, 'Reasoning_Verbosity': {'judge': 'QwQ-32B', 'level': 5}, 'full_response': '<think>\nOkay, let\'s see. I need to figure out for how many integers n between 3 and 2020 inclusive, there\'s a way to place numbers on a circle with n points such that every three consecutive numbers multiply to n. The numbers can be the same, they just need to be positive integers. Hmm, interesting problem.\n\nFirst, let me try to understand the problem better. We have a circle with n points, each labeled with a positiv

In [None]:
print(small_subset[0].keys())

dict_keys(['question', 'reasoning'])


In [None]:
print(small_subset[0]['question'])

For an integer $n \ge 3$ we consider a circle with $n$ points on it.
We place a positive integer at each point, where the numbers are not necessary need to be different. Such placement of numbers is called [i]stable [/i] as three numbers next to always have product $n$ each other. 
For how many values of $n$ with $3 \le n \le 2020$ is it possible to place numbers in a stable way?



##Teacher Generation Function

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
import torch



bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL)

teacher_model = AutoModelForCausalLM.from_pretrained(
    TEACHER_MODEL,
    quantization_config=bnb_config,
    device_map="auto"
)


def generate_teacher_cot(prompt, max_new_tokens=200):
    input_ids = teacher_tokenizer(prompt, return_tensors="pt").input_ids.to(teacher_model.device)

    outputs = teacher_model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        return_dict_in_generate=True,
        output_scores=True,
        do_sample=False
    )

    seq = outputs.sequences[0]
    scores = outputs.scores

    text = teacher_tokenizer.decode(seq, skip_special_tokens=True)

    token_strs = teacher_tokenizer.convert_ids_to_tokens(seq.tolist())

    token_probs = []
    for step, score in enumerate(scores):
        probs = torch.softmax(score, dim=-1)
        chosen_idx = seq[input_ids.shape[-1] + step]
        token_probs.append(probs[0, chosen_idx].item())

    return text, token_strs, token_probs

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

##Generate Teacher Outputs

In [None]:
with open("teacher_outputs_100.jsonl", "w") as fout:
    for example in tqdm(small_subset):
        question = example["question"]
        prompt = f"Q: {question} Think step by step.\nA:"

        text, toks, probs = generate_teacher_cot(prompt, 512)

        record = {
            "question": question,
            "teacher_prompt": prompt,
            "generation": text,
            "token_strs": toks,
            "token_probs": probs
        }
        fout.write(json.dumps(record) + "\n")

print("Saved teacher outputs.")


  0%|          | 0/100 [00:00<?, ?it/s]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
  1%|          | 1/100 [00:27<45:54, 27.82s/it]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  2%|▏         | 2/100 [00:35<25:47, 15.79s/it]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attentio

Saved teacher outputs.





##Tokenize with Offsets

In [None]:
student_tok = AutoTokenizer.from_pretrained(STUDENT_MODEL, use_fast=True)

def tokenize_with_offsets(text, tokenizer):
    enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
    return enc["input_ids"], enc["offset_mapping"]

##ALM Alignment AND Chunk Creation Methodology Implementation

In [None]:
def align_chunks_alm(teacher_text, teacher_strs, teacher_probs, student_tokenizer):
    """
    Multi-token chunk alignment for ALM-style distillation.
    """
    t_ids, t_offsets = tokenize_with_offsets(teacher_text, teacher_tokenizer)
    s_ids, s_offsets = tokenize_with_offsets(teacher_text, student_tokenizer)

    aligned_chunks = []
    t_cursor = 0
    s_cursor = 0

    while t_cursor < len(t_offsets) and s_cursor < len(s_offsets):
        t_start, t_end = t_offsets[t_cursor]
        s_start, s_end = s_offsets[s_cursor]

        if not (s_end <= t_start or s_start >= t_end):
            # Extend existing chunk or create new
            if aligned_chunks and aligned_chunks[-1]["end_offset"] >= t_start:
                aligned_chunks[-1]["teacher_tokens"].append(teacher_strs[t_cursor])
                aligned_chunks[-1]["teacher_probs"].append(teacher_probs[t_cursor])
                aligned_chunks[-1]["student_token_ids"].append(s_ids[s_cursor])
                aligned_chunks[-1]["end_offset"] = max(aligned_chunks[-1]["end_offset"], s_end)
            else:
                aligned_chunks.append({
                    "teacher_tokens": [teacher_strs[t_cursor]],
                    "teacher_probs": [teacher_probs[t_cursor]],
                    "student_token_ids": [s_ids[s_cursor]],
                    "start_offset": min(t_start, s_start),
                    "end_offset": max(t_end, s_end)
                })
            t_cursor += 1
            s_cursor += 1
        else:
            if s_end <= t_start:
                s_cursor += 1
            else:
                t_cursor += 1
    return aligned_chunks

In [None]:
def approximate_divergence(student_log_probs, teacher_probs, debias=True, epsilon=1e-8):
    """
    Weighted NLL for ALM-style chunk.
    If debias=True, subtract expected student log-likelihood to reduce tokenization bias.
    """
    log_p_t = torch.log(torch.clamp(teacher_probs, min=epsilon))
    weights = torch.exp(log_p_t)

    loss = -(weights * student_log_probs).mean()

    if debias:
        # simple debias: subtract mean student log-likelihood
        bias = student_log_probs.mean()
        loss = loss - bias

    return loss


In [None]:
teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL)

In [None]:
def alm_loss_with_context(student_model, chunks, context_ids=None, device='cuda', max_context_len=512, gamma=1.0):
    """
    ALM loss computation for autoregressive student.

    - chunks: list of aligned chunks (from align_chunks_alm)
    - context_ids: tensor of previous tokens (can start with teacher_prompt)
    - max_context_len: max context window for autoregressive student
    """
    if not chunks:
        dummy_param = next(student_model.parameters())
        return dummy_param.sum() * 0.0

    total_loss = 0.0
    chunk_count = 0

    if context_ids is not None:
        context_ids = context_ids.to(device)

    for chunk in chunks:
        student_ids = torch.tensor(
            chunk["student_token_ids"],
            device=device
            ).unsqueeze(0)

        # Concatenate context and chunk tokens
        if context_ids is not None:
            input_ids = torch.cat([context_ids, student_ids], dim=1)
        else:
            input_ids = student_ids

        # Apply sliding window to manage memory
        if input_ids.shape[1] > max_context_len:
            input_ids = input_ids[:, -max_context_len:]

        # Forward pass
        logits = student_model(input_ids).logits
        log_probs = F.log_softmax(logits, dim=-1)

        # Only compute loss for chunk tokens
        chunk_start = context_ids.shape[1] if context_ids is not None else 0
        student_log_probs = log_probs[:, chunk_start:, :].gather(
            -1, student_ids.unsqueeze(-1)
        ).squeeze(-1)

        # Teacher probabilities
        p_t = torch.tensor(chunk["teacher_probs"], device=device)
        p_t = torch.clamp(p_t, min=1e-4)

        # Weighted ALM loss
        weights = torch.exp(gamma * torch.log(p_t))
        loss = -(weights * student_log_probs).mean()

        total_loss += loss
        chunk_count += 1

        # Update context for next chunk (detach to save memory)
        context_ids = input_ids.detach()

    return total_loss / chunk_count if chunk_count > 0 else torch.tensor(0.0, device=device)

In [None]:
class ALMDatasetALM(Dataset):
    def __init__(self, teacher_file, student_tokenizer, teacher_tokenizer):
        self.records = []
        with open(teacher_file, "r") as f:
            for line in f:
                rec = json.loads(line)
                teacher_strs = rec["token_strs"]
                teacher_probs = rec["token_probs"]

                # Tokenize teacher generation
                t_ids, t_offsets = tokenize_with_offsets(rec["generation"], teacher_tokenizer)
                aligned_teacher_strs = teacher_tokenizer.convert_ids_to_tokens(t_ids)

                # Align teacher probs with tokenized text
                len_prompt_part = len(aligned_teacher_strs) - len(teacher_probs)
                len_prompt_part = max(len_prompt_part, 0)
                aligned_teacher_probs = [0.0]*len_prompt_part + teacher_probs

                # ALM-style alignment
                chunks = align_chunks_alm(
                    rec["generation"],
                    aligned_teacher_strs,
                    aligned_teacher_probs,
                    student_tokenizer
                )

                if chunks:
                    self.records.append({
                        "question": rec.get("question", ""),
                        "generation": rec["generation"],
                        "chunks": chunks
                    })

        if not self.records:
            print(f"Warning: No valid records found in {teacher_file}.")

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        return self.records[idx]

In [None]:
alm_dataset = ALMDatasetALM("/content/teacher_outputs.jsonl", student_tok, teacher_tokenizer)
train_loader = DataLoader(alm_dataset, batch_size=1, shuffle=True)

In [None]:
train_loader


<torch.utils.data.dataloader.DataLoader at 0x7e3d83fe5e20>

##Student Model Import

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
import torch
from peft import LoraConfig, get_peft_model



bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

student_model = AutoModelForCausalLM.from_pretrained(
    STUDENT_MODEL,
    quantization_config=bnb_config,
    device_map="cuda"
)

# Configure LoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"], # Target modules for TinyLlama
)

# Apply LoRA to the student model
student_model = get_peft_model(student_model, peft_config)
student_model.print_trainable_parameters()

optimizer = AdamW(student_model.parameters(), lr=LR)
num_training_steps = EPOCHS * len(train_loader)

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

trainable params: 52,641,792 || all params: 1,152,690,176 || trainable%: 4.5669


In [None]:
# check device using cuda
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


#Cross-Tokenizer Chain-of-Thought Distillation

##Baseline (Chunk-wise Cross Entropy Loss) Based Training of Student Model

In [None]:
import logging
from tqdm import tqdm
import torch

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

device = "cuda"
student_model.train()

for epoch in range(EPOCHS):
    logging.info(f"Starting epoch {epoch+1}/{EPOCHS}")

    epoch_loss = 0.0
    num_batches = len(train_loader)

    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
        batch_size = len(batch["question"])
        optimizer.zero_grad()
        batch_loss = 0.0

        # Iterate over each sample in the batch
        for i in range(batch_size):
            record = {k: batch[k][i] for k in batch}

            # Use teacher rational as context
            teacher_text = record["generation"]
            teacher_ids = student_tok(
                teacher_text,
                return_tensors="pt"
            ).input_ids.to(device)

            # Chunking for long sequences
            max_len = student_model.config.max_position_embeddings
            chunks = teacher_ids.split(max_len, dim=1)

            sample_loss = 0.0
            for chunk in chunks:
                outputs = student_model(input_ids=chunk, labels=chunk)
                sample_loss += outputs.loss

            sample_loss = sample_loss / len(chunks)
            batch_loss += sample_loss

            logging.debug(f"Sample {i+1}/{batch_size} - Loss: {sample_loss.item():.4f}")

        # Average batch loss
        batch_loss = batch_loss / batch_size
        batch_loss.backward()
        optimizer.step()
        lr_scheduler.step()

        epoch_loss += batch_loss.item()

        logging.info(f"Epoch {epoch+1} Batch {batch_idx+1}/{num_batches} - Batch Loss: {batch_loss.item():.4f}, LR: {lr_scheduler.get_last_lr()[0]:.6f}")

    avg_epoch_loss = epoch_loss / num_batches
    print(f"Epoch {epoch+1} completed - Average Loss: {avg_epoch_loss:.4f}")
    logging.info(f"Epoch {epoch+1} completed - Average Loss: {avg_epoch_loss:.4f}")

Epoch 1: 100%|██████████| 586/586 [02:35<00:00,  3.76it/s]


Epoch 1 completed - Average Loss: 1.0986


Epoch 2: 100%|██████████| 586/586 [02:35<00:00,  3.76it/s]


Epoch 2 completed - Average Loss: 0.9952


Epoch 3: 100%|██████████| 586/586 [02:35<00:00,  3.76it/s]


Epoch 3 completed - Average Loss: 0.9493


Epoch 4: 100%|██████████| 586/586 [02:38<00:00,  3.71it/s]


Epoch 4 completed - Average Loss: 0.9121


Epoch 5: 100%|██████████| 586/586 [02:39<00:00,  3.68it/s]


Epoch 5 completed - Average Loss: 0.8801


Epoch 6: 100%|██████████| 586/586 [02:35<00:00,  3.77it/s]


Epoch 6 completed - Average Loss: 0.8520


Epoch 7: 100%|██████████| 586/586 [02:36<00:00,  3.75it/s]


Epoch 7 completed - Average Loss: 0.8281


Epoch 8: 100%|██████████| 586/586 [02:37<00:00,  3.73it/s]


Epoch 8 completed - Average Loss: 0.8089


Epoch 9: 100%|██████████| 586/586 [02:35<00:00,  3.76it/s]


Epoch 9 completed - Average Loss: 0.7948


Epoch 10: 100%|██████████| 586/586 [02:37<00:00,  3.73it/s]

Epoch 10 completed - Average Loss: 0.7861





In [None]:
student_model.save_pretrained(f"{OUTPUT_DIR}/checkpoint_epoch_{epoch + 1}")



###Evaluation

In [None]:
import re
import torch
from tqdm import tqdm
from sacrebleu import corpus_bleu
from rouge_score import rouge_scorer


In [None]:
def extract_reasoning_only(text):
    """
    Removes question and keeps only reasoning after 'A:'.
    """
    if "A:" in text:
        return text.split("A:", 1)[1].strip()
    return text.strip()


In [None]:
import torch

@torch.no_grad()
def generate_student_output(
    model,
    tokenizer,
    question,
    max_new_tokens=128,
    device="cuda"
):
    cot_trigger = "\nThink step by step."
    prompt = f"Q: {question} {cot_trigger}\nA:"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    output_ids = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=False
    )

    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

In [None]:
import re
def extract_final_answer(text):
    """
    Robust extractor for AMC/OmniThought-style MCQs.
    Handles prompt echoing and noisy reasoning.
    """
    patterns = [
        r"A:\s*([A-E])\b",
        r"Answer:\s*([A-E])\b",
        r"\bFinal Answer:\s*([A-E])\b",
        r"\(([A-E])\)"
    ]
    for p in patterns:
        m = re.search(p, text)
        if m:
            return m.group(1)

    # Fallback: last standalone A–E token
    tokens = re.findall(r"\b[A-E]\b", text)
    return tokens[-1] if tokens else ""

In [None]:
def evaluate_answer_accuracy(
    model,
    tokenizer,
    records,
    device="cuda",
    max_new_tokens=256
):
    correct = 0
    total = len(records)

    for i, rec in enumerate(tqdm(records, desc="Accuracy Eval"), 1):
        gt_answer = extract_final_answer(rec["generation"])

        pred_text = generate_student_output(
            model, tokenizer, rec["question"],
            max_new_tokens=max_new_tokens,
            device=device
        )
        pred_answer = extract_final_answer(pred_text)

        if pred_answer == gt_answer and gt_answer != "":
            correct += 1

        # Print first few examples for confidence
        if i <= 3:
            print(f"\n[Sample {i}]")
            print("GT:", gt_answer)
            print("Pred:", pred_answer)

    acc = correct / max(total, 1)
    print(f"\nFinal Accuracy: {acc:.4f}")
    return acc

In [None]:
def evaluate_reasoning_overlap(
    model,
    tokenizer,
    records,
    device="cuda",
    max_new_tokens=256
):
    overlaps = []

    for i, rec in enumerate(tqdm(records, desc="Token Overlap Eval"), 1):
        teacher_reasoning = extract_reasoning_only(rec["generation"])

        student_output = generate_student_output(
            model, tokenizer, rec["question"],
            max_new_tokens=max_new_tokens,
            device=device
        )
        student_reasoning = extract_reasoning_only(student_output)

        t_tokens = set(teacher_reasoning.split())
        s_tokens = set(student_reasoning.split())

        overlap = len(t_tokens & s_tokens) / max(len(t_tokens), 1)
        overlaps.append(overlap)

        if i <= 3:
            print(f"\n[Sample {i}] Overlap: {overlap:.3f}")
            print("Teacher:", teacher_reasoning[:200])
            print("Student:", student_reasoning[:200])

    avg_overlap = sum(overlaps) / max(len(overlaps), 1)
    print(f"\nAvg Token Overlap: {avg_overlap:.4f}")
    return avg_overlap

In [None]:
def evaluate_reasoning_bleu(
    model,
    tokenizer,
    records,
    device="cuda",
    max_new_tokens=256
):
    references = []
    hypotheses = []

    for rec in tqdm(records, desc="BLEU Eval"):
        teacher_reasoning = extract_reasoning_only(rec["generation"])

        student_output = generate_student_output(
            model, tokenizer, rec["question"],
            max_new_tokens=max_new_tokens,
            device=device
        )
        student_reasoning = extract_reasoning_only(student_output)

        references.append([teacher_reasoning])
        hypotheses.append(student_reasoning)

    bleu = corpus_bleu(hypotheses, references)
    print(f"BLEU Score: {bleu.score:.2f}")
    return bleu.score

In [None]:
def evaluate_reasoning_rouge_l(
    model,
    tokenizer,
    records,
    device="cuda",
    max_new_tokens=256
):
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
    scores = []

    for i, rec in enumerate(tqdm(records, desc="ROUGE-L Eval"), 1):
        teacher_reasoning = extract_reasoning_only(rec["generation"])

        student_output = generate_student_output(
            model, tokenizer, rec["question"],
            max_new_tokens=max_new_tokens,
            device=device
        )
        student_reasoning = extract_reasoning_only(student_output)

        score = scorer.score(teacher_reasoning, student_reasoning)
        scores.append(score["rougeL"].fmeasure)

        if i <= 3:
            print(f"\n[Sample {i}] ROUGE-L: {score['rougeL'].fmeasure:.3f}")

    avg_rouge = sum(scores) / max(len(scores), 1)
    print(f"\nAvg ROUGE-L: {avg_rouge:.4f}")
    return avg_rouge

In [None]:
def run_reasoning_mimicry_evaluation(
    student_scratch,
    student_alm,
    tokenizer,
    test_records,
    device="cuda"
):
    print("\n===== SCRATCH STUDENT (Reasoning Mimicry) =====")
    scratch_overlap = evaluate_reasoning_overlap(
        student_scratch, tokenizer, test_records, device=device
    )
    scratch_bleu = evaluate_reasoning_bleu(
        student_scratch, tokenizer, test_records, device=device
    )
    scratch_rouge = evaluate_reasoning_rouge_l(
        student_scratch, tokenizer, test_records, device=device
    )

    print("\n===== ALM STUDENT (Reasoning Mimicry) =====")
    alm_overlap = evaluate_reasoning_overlap(
        student_alm, tokenizer, test_records, device=device
    )
    alm_bleu = evaluate_reasoning_bleu(
        student_alm, tokenizer, test_records, device=device
    )
    alm_rouge = evaluate_reasoning_rouge_l(
        student_alm, tokenizer, test_records, device=device
    )

    print("\n===== FINAL REASONING MIMICRY RESULTS =====")
    print(f"{'Model':<15} {'Overlap':<10} {'BLEU':<10} {'ROUGE-L':<10}")
    print("-" * 50)
    print(f"{'Scratch':<15} {scratch_overlap:<10.3f} {scratch_bleu:<10.2f} {scratch_rouge:<10.3f}")
    print(f"{'ALM (Ours)':<15} {alm_overlap:<10.3f} {alm_bleu:<10.2f} {alm_rouge:<10.3f}")

###Run Evaluation Loop

In [None]:
import random

def split_records(records, test_ratio=0.1, seed=42):
    random.seed(seed)
    records = records.copy()
    random.shuffle(records)

    n_test = int(len(records) * test_ratio)
    return records[n_test:], records[:n_test]


In [None]:
# Use the same ALM dataset you trained on
train_records, test_records = split_records(
    alm_dataset.records,
    test_ratio=0.1
)

print(f"Train: {len(train_records)} | Test: {len(test_records)}")

Train: 528 | Test: 58


In [None]:
from transformers import AutoModelForCausalLM

student_alm = AutoModelForCausalLM.from_pretrained(
    OUTPUT_DIR + "/checkpoint_epoch_10",   # adjust if needed
    device_map="auto"
)
student_alm.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=256, bias=False)
            (lora_dropout): ModuleDict(
              (default): 

In [None]:
student_scratch = AutoModelForCausalLM.from_pretrained(
    STUDENT_MODEL,
    device_map="auto"
)
student_scratch.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rot

In [None]:
# First: small sanity evaluation
small_test = test_records[:5]

run_reasoning_mimicry_evaluation(
    student_scratch,
    student_alm,
    student_tok,
    small_test
)



===== SCRATCH STUDENT (Reasoning Mimicry) =====


Token Overlap Eval:  20%|██        | 1/5 [00:07<00:29,  7.35s/it]


[Sample 1] Overlap: 0.239
Teacher: Let's start by finding the coordinates of points $A$, $B$, and $C$. Since $ABC$ lies in the first quadrant, we know that $x_A>0$, $x_B>0$, and $x_C>0$. Let's assume that $y_A>x_A$, $y_B>x_B$, and $y_C
Student: The statement is true if and only if the triangle lies in the first quadrant.
B: The statement is true if and only if the slope of line $AA'$ is $-1$.
C: The statement is true if and only if the slope


Token Overlap Eval:  40%|████      | 2/5 [00:14<00:21,  7.30s/it]


[Sample 2] Overlap: 0.086
Teacher: First, we can assume that $f(0)=0$ since $f(x+y^2) \ge f(x)+y$ and $f(0)+0=0$.

Next, we can assume that $f$ is strictly increasing on $[0,\infty)$ since $f(x+y^2) \ge f(x)+y$ and $f(x+y^2) \ge f(x)$ 
Student: Let $c = 1$.
For all nonnegative reals $x, y$, $f(x+y^2) \ge 1f(x)+y$.

Let $x = 1$ and $y = 2$.
$f(1+2^2) = 1f(1) + 2f(1) = 1 + 2f(1) = 3$.
$f(1+2^2) \ge 3f(1) + 2 = 5$.
$f(1+2^2) = 5$.

So $f(1+2^2)


Token Overlap Eval:  60%|██████    | 3/5 [00:21<00:14,  7.30s/it]


[Sample 3] Overlap: 0.000
Teacher: 1, 2, 3, 4, 7
Student: The polynomial has 9 roots, which are:
- 1
- 1/37
- 1/27
- 1/74
- 1/37
- 1/27
- 1/74
- 1/37
- 1/27

B: The polynomial has no real roots.

C: The polynomial has no real roots because it has 9 real root


Token Overlap Eval: 100%|██████████| 5/5 [00:31<00:00,  6.39s/it]



Avg Token Overlap: 0.1380


BLEU Eval: 100%|██████████| 5/5 [00:31<00:00,  6.36s/it]


BLEU Score: 10.19


ROUGE-L Eval:  20%|██        | 1/5 [00:07<00:29,  7.33s/it]


[Sample 1] ROUGE-L: 0.216


ROUGE-L Eval:  40%|████      | 2/5 [00:14<00:21,  7.27s/it]


[Sample 2] ROUGE-L: 0.195


ROUGE-L Eval:  60%|██████    | 3/5 [00:21<00:14,  7.27s/it]


[Sample 3] ROUGE-L: 0.012


ROUGE-L Eval: 100%|██████████| 5/5 [00:31<00:00,  6.37s/it]



Avg ROUGE-L: 0.1757

===== ALM STUDENT (Reasoning Mimicry) =====


Token Overlap Eval:  20%|██        | 1/5 [00:12<00:51, 12.76s/it]


[Sample 1] Overlap: 0.283
Teacher: Let's start by finding the coordinates of points $A$, $B$, and $C$. Since $ABC$ lies in the first quadrant, we know that $x_A>0$, $x_B>0$, and $x_C>0$. Let's assume that $y_A>x_A$, $y_B>x_B$, and $y_C
Student: Let's start by finding the slope of line $AA'$. Since $AA'$ is a reflection of $A$ on the line $y=x$, we have $AA'=A'$. Therefore, the slope of line $AA'$ is $1$.

Now, let's find the slope of line $C


Token Overlap Eval:  40%|████      | 2/5 [00:25<00:38, 12.77s/it]


[Sample 2] Overlap: 0.143
Teacher: First, we can assume that $f(0)=0$ since $f(x+y^2) \ge f(x)+y$ and $f(0)+0=0$.

Next, we can assume that $f$ is strictly increasing on $[0,\infty)$ since $f(x+y^2) \ge f(x)+y$ and $f(x+y^2) \ge f(x)$ 
Student: 1. Let $x, y$ be nonnegative reals.
2. We want to find a function $f: \mathbb{R}_{ \ge 0} \rightarrow \mathbb{R}$ such that $f(x+y^2) \ge cf(x)+y$.
3. We can rewrite this as $f(x+y^2) \ge cf(x)+y-f(x)


Token Overlap Eval:  60%|██████    | 3/5 [00:38<00:25, 12.82s/it]


[Sample 3] Overlap: 0.000
Teacher: 1, 2, 3, 4, 7
Student: 1. We know that the polynomial has a root at the point (3, 1).
2. We also know that the polynomial has a root at the point (-3, 1).
3. Therefore, the polynomial has a root at the point (3, 1) if and o


Token Overlap Eval: 100%|██████████| 5/5 [01:04<00:00, 12.81s/it]



Avg Token Overlap: 0.2438


BLEU Eval: 100%|██████████| 5/5 [01:04<00:00, 12.88s/it]


BLEU Score: 19.57


ROUGE-L Eval:  20%|██        | 1/5 [00:12<00:51, 12.95s/it]


[Sample 1] ROUGE-L: 0.337


ROUGE-L Eval:  40%|████      | 2/5 [00:25<00:38, 12.92s/it]


[Sample 2] ROUGE-L: 0.414


ROUGE-L Eval:  60%|██████    | 3/5 [00:38<00:25, 12.91s/it]


[Sample 3] ROUGE-L: 0.054


ROUGE-L Eval: 100%|██████████| 5/5 [01:04<00:00, 12.87s/it]


Avg ROUGE-L: 0.2823

===== FINAL REASONING MIMICRY RESULTS =====
Model           Overlap    BLEU       ROUGE-L   
--------------------------------------------------
Scratch         0.138      10.19      0.176     
ALM (Ours)      0.244      19.57      0.282     





In [None]:
# Then full evaluation
run_reasoning_mimicry_evaluation(
    student_scratch,
    student_alm,
    student_tok,
    test_records
)


===== SCRATCH STUDENT (Reasoning Mimicry) =====


Token Overlap Eval:   2%|▏         | 1/58 [00:07<07:01,  7.40s/it]


[Sample 1] Overlap: 0.239
Teacher: Let's start by finding the coordinates of points $A$, $B$, and $C$. Since $ABC$ lies in the first quadrant, we know that $x_A>0$, $x_B>0$, and $x_C>0$. Let's assume that $y_A>x_A$, $y_B>x_B$, and $y_C
Student: The statement is true if and only if the triangle lies in the first quadrant.
B: The statement is true if and only if the slope of line $AA'$ is $-1$.
C: The statement is true if and only if the slope


Token Overlap Eval:   3%|▎         | 2/58 [00:14<06:51,  7.35s/it]


[Sample 2] Overlap: 0.086
Teacher: First, we can assume that $f(0)=0$ since $f(x+y^2) \ge f(x)+y$ and $f(0)+0=0$.

Next, we can assume that $f$ is strictly increasing on $[0,\infty)$ since $f(x+y^2) \ge f(x)+y$ and $f(x+y^2) \ge f(x)$ 
Student: Let $c = 1$.
For all nonnegative reals $x, y$, $f(x+y^2) \ge 1f(x)+y$.

Let $x = 1$ and $y = 2$.
$f(1+2^2) = 1f(1) + 2f(1) = 1 + 2f(1) = 3$.
$f(1+2^2) \ge 3f(1) + 2 = 5$.
$f(1+2^2) = 5$.

So $f(1+2^2)


Token Overlap Eval:   5%|▌         | 3/58 [00:22<06:42,  7.32s/it]


[Sample 3] Overlap: 0.000
Teacher: 1, 2, 3, 4, 7
Student: The polynomial has 9 roots, which are:
- 1
- 1/37
- 1/27
- 1/74
- 1/37
- 1/27
- 1/74
- 1/37
- 1/27

B: The polynomial has no real roots.

C: The polynomial has no real roots because it has 9 real root


Token Overlap Eval: 100%|██████████| 58/58 [06:28<00:00,  6.71s/it]



Avg Token Overlap: 0.2094


BLEU Eval: 100%|██████████| 58/58 [06:29<00:00,  6.71s/it]


BLEU Score: 28.13


ROUGE-L Eval:   2%|▏         | 1/58 [00:07<06:59,  7.35s/it]


[Sample 1] ROUGE-L: 0.216


ROUGE-L Eval:   3%|▎         | 2/58 [00:14<06:49,  7.31s/it]


[Sample 2] ROUGE-L: 0.195


ROUGE-L Eval:   5%|▌         | 3/58 [00:21<06:42,  7.31s/it]


[Sample 3] ROUGE-L: 0.012


ROUGE-L Eval: 100%|██████████| 58/58 [06:29<00:00,  6.71s/it]



Avg ROUGE-L: 0.2416

===== ALM STUDENT (Reasoning Mimicry) =====


Token Overlap Eval:   2%|▏         | 1/58 [00:12<12:05, 12.73s/it]


[Sample 1] Overlap: 0.283
Teacher: Let's start by finding the coordinates of points $A$, $B$, and $C$. Since $ABC$ lies in the first quadrant, we know that $x_A>0$, $x_B>0$, and $x_C>0$. Let's assume that $y_A>x_A$, $y_B>x_B$, and $y_C
Student: Let's start by finding the slope of line $AA'$. Since $AA'$ is a reflection of $A$ on the line $y=x$, we have $AA'=A'$. Therefore, the slope of line $AA'$ is $1$.

Now, let's find the slope of line $C


Token Overlap Eval:   3%|▎         | 2/58 [00:25<11:54, 12.75s/it]


[Sample 2] Overlap: 0.143
Teacher: First, we can assume that $f(0)=0$ since $f(x+y^2) \ge f(x)+y$ and $f(0)+0=0$.

Next, we can assume that $f$ is strictly increasing on $[0,\infty)$ since $f(x+y^2) \ge f(x)+y$ and $f(x+y^2) \ge f(x)$ 
Student: 1. Let $x, y$ be nonnegative reals.
2. We want to find a function $f: \mathbb{R}_{ \ge 0} \rightarrow \mathbb{R}$ such that $f(x+y^2) \ge cf(x)+y$.
3. We can rewrite this as $f(x+y^2) \ge cf(x)+y-f(x)


Token Overlap Eval:   5%|▌         | 3/58 [00:38<11:39, 12.72s/it]


[Sample 3] Overlap: 0.000
Teacher: 1, 2, 3, 4, 7
Student: 1. We know that the polynomial has a root at the point (3, 1).
2. We also know that the polynomial has a root at the point (-3, 1).
3. Therefore, the polynomial has a root at the point (3, 1) if and o


Token Overlap Eval: 100%|██████████| 58/58 [12:13<00:00, 12.65s/it]



Avg Token Overlap: 0.2722


BLEU Eval: 100%|██████████| 58/58 [12:12<00:00, 12.64s/it]


BLEU Score: 34.16


ROUGE-L Eval:   2%|▏         | 1/58 [00:12<12:10, 12.81s/it]


[Sample 1] ROUGE-L: 0.337


ROUGE-L Eval:   3%|▎         | 2/58 [00:25<11:55, 12.78s/it]


[Sample 2] ROUGE-L: 0.414


ROUGE-L Eval:   5%|▌         | 3/58 [00:38<11:42, 12.78s/it]


[Sample 3] ROUGE-L: 0.054


ROUGE-L Eval: 100%|██████████| 58/58 [12:14<00:00, 12.67s/it]


Avg ROUGE-L: 0.2616

===== FINAL REASONING MIMICRY RESULTS =====
Model           Overlap    BLEU       ROUGE-L   
--------------------------------------------------
Scratch         0.209      28.13      0.242     
ALM (Ours)      0.272      34.16      0.262     





###Human Evaluation

In [None]:
@torch.no_grad()
def compare_student_inference(
    question: str,
    student_scratch,
    student_alm,
    tokenizer,
    max_new_tokens: int = 256,
    device: str = "cuda",
    cot_trigger: str = "Think step by step."
):
    """
    Runs inference for both scratch and ALM-trained student models
    using the SAME prompt format as training.

    - Appends CoT trigger ("Think step by step.")
    - Extracts reasoning-only output (excludes question)
    - Returns clean comparison

    Returns:
        {
            "question": str,
            "prompt_used": str,
            "scratch_reasoning": str,
            "alm_reasoning": str
        }
    """

    prompt = f"Q: {question} {cot_trigger}\nA:"

    def _generate_reasoning(model):
        input_ids = tokenizer(
            prompt,
            return_tensors="pt"
        ).input_ids.to(device)

        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )

        text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        # Remove question + prompt, keep only reasoning
        if "A:" in text:
            return text.split("A:", 1)[1].strip()
        return text.strip()

    scratch_reasoning = _generate_reasoning(student_scratch)
    alm_reasoning = _generate_reasoning(student_alm)

    return {
        "question": question,
        "prompt_used": prompt,
        "scratch_reasoning": scratch_reasoning,
        "alm_reasoning": alm_reasoning
    }

In [None]:
query = "Triangle ABC lies in the first quadrant. Points A, B, and C are reflected across y=x. Which statement is not always true?"

result = compare_student_inference(
    question=query,
    student_scratch=student_scratch,
    student_alm=student_alm,
    tokenizer=student_tok
)

print("PROMPT USED:")
print(result["prompt_used"])

print("\nSCRATCH STUDENT REASONING:")
print(result["scratch_reasoning"])

print("\nALM STUDENT REASONING:")
print(result["alm_reasoning"])

PROMPT USED:
Q: Triangle ABC lies in the first quadrant. Points A, B, and C are reflected across y=x. Which statement is not always true? Think step by step.
A:

SCRATCH STUDENT REASONING:
The angle between the line segment AB and the line segment AC is always 90 degrees.
B: The angle between the line segment AC and the line segment BC is always 90 degrees.
C: The angle between the line segment BC and the line segment CD is always 90 degrees.
D: The angle between the line segment CD and the line segment DE is always 90 degrees.
E: The angle between the line segment DE and the line segment EF is always 90 degrees.
F: The angle between the line segment EF and the line segment GH is always 90 degrees.
G: The angle between the line segment GH and the line segment IJ is always 90 degrees.
H: The angle between the line segment IJ and the line segment KL is always 90 degrees.
I: The angle between the line segment KL and the line segment MN is always 90 degrees.
J: The angle between the line s

#Cross Tokenizer + ALM Loss

##Experiment 1: Lambda = 1.5, Gamma = 1.0

In [None]:
import logging
from tqdm import tqdm
import torch

# -------------------- CONFIG --------------------
device = "cuda"
LAMBDA_ALM = 1.5
GAMMA = 1.0
MAX_CONTEXT_LEN = 512
# -----------------------------------------------

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

student_model.train()

for epoch in range(EPOCHS):
    logging.info(f"Starting epoch {epoch+1}/{EPOCHS}")

    epoch_loss = 0.0
    num_batches = len(train_loader)

    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
        batch_size = len(batch["question"])
        optimizer.zero_grad(set_to_none=True)

        batch_loss = 0.0

        # -------------------- PER-SAMPLE LOOP --------------------
        for i in range(batch_size):
            record = {k: batch[k][i] for k in batch}

            # ======================================================
            # (1) CE LOSS (UNCHANGED FROM YOUR ORIGINAL LOOP)
            # ======================================================
            teacher_text = record["generation"]
            teacher_ids = student_tok(
                teacher_text,
                return_tensors="pt"
            ).input_ids.to(device)

            max_len = student_model.config.max_position_embeddings
            ce_chunks = teacher_ids.split(max_len, dim=1)

            ce_loss = 0.0
            for chunk in ce_chunks:
                outputs = student_model(input_ids=chunk, labels=chunk)
                ce_loss += outputs.loss

            ce_loss = ce_loss / len(ce_chunks)

            # ======================================================
            # (2) ALM LOSS (CONFIDENCE-WEIGHTED RATIONALE ALIGNMENT)
            # ======================================================
            assert isinstance(record["chunks"], dict)

            alm_loss = alm_loss_with_context(
                student_model=student_model,
                chunks=[record["chunks"]],
                device=device,
                max_context_len=MAX_CONTEXT_LEN,
                gamma=GAMMA
            )

            # ======================================================
            # (3) TOTAL LOSS (PROPOSAL FORM)
            # ======================================================
            sample_loss = ce_loss + LAMBDA_ALM * alm_loss
            batch_loss += sample_loss

            logging.debug(
                f"Sample {i+1}/{batch_size} | "
                f"CE: {ce_loss.item():.4f}, "
                f"ALM: {alm_loss.item():.4f}, "
                f"Total: {sample_loss.item():.4f}"
            )

        # -------------------- BACKPROP --------------------
        batch_loss = batch_loss / batch_size

        if not torch.isfinite(batch_loss):
            logging.warning("⚠️ NaN/Inf loss detected, skipping batch")
            optimizer.zero_grad(set_to_none=True)
            continue

        batch_loss.backward()
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)

        optimizer.step()
        lr_scheduler.step()

        epoch_loss += batch_loss.item()

        logging.info(
            f"Epoch {epoch+1} Batch {batch_idx+1}/{num_batches} | "
            f"Batch Loss: {batch_loss.item():.4f} | "
            f"LR: {lr_scheduler.get_last_lr()[0]:.6f}"
        )

    avg_epoch_loss = epoch_loss / num_batches
    logging.info(f"Epoch {epoch+1} completed | Avg Loss: {avg_epoch_loss:.4f}")
    print(f"Epoch {epoch+1} completed | Avg Loss: {avg_epoch_loss:.4f}")

    student_model.save_pretrained(
        f"{OUTPUT_DIR}/checkpoint_epoch_{epoch+1}"
    )

Epoch 1: 100%|██████████| 586/586 [05:25<00:00,  1.80it/s]


Epoch 1 completed | Avg Loss: 3.7974


Epoch 2: 100%|██████████| 586/586 [05:23<00:00,  1.81it/s]


Epoch 2 completed | Avg Loss: 1.1999


Epoch 3: 100%|██████████| 586/586 [05:24<00:00,  1.81it/s]


Epoch 3 completed | Avg Loss: 1.0579


Epoch 4: 100%|██████████| 586/586 [05:23<00:00,  1.81it/s]


Epoch 4 completed | Avg Loss: 1.0279


###Run Eval ALM

In [None]:
import random

def split_records(records, test_ratio=0.1, seed=42):
    random.seed(seed)
    records = records.copy()
    random.shuffle(records)

    n_test = int(len(records) * test_ratio)
    return records[n_test:], records[:n_test]


In [None]:
# Use the same ALM dataset you trained on
train_records, test_records = split_records(
    alm_dataset.records,
    test_ratio=0.1
)

print(f"Train: {len(train_records)} | Test: {len(test_records)}")

Train: 528 | Test: 58


In [None]:
from transformers import AutoModelForCausalLM

student_alm = AutoModelForCausalLM.from_pretrained(
    OUTPUT_DIR + "/checkpoint_epoch_4",   # adjust if needed
    device_map="auto"
)
student_alm.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=256, bias=False)
            (lora_dropout): ModuleDict(
              (default): 

In [None]:
student_scratch = AutoModelForCausalLM.from_pretrained(
    STUDENT_MODEL,
    device_map="auto"
)
student_scratch.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rot

In [None]:
# sanity: small subset
small_test = test_records[:5]

# Then full evaluation
run_reasoning_mimicry_evaluation(
    student_scratch,
    student_alm,
    student_tok,
    small_test
)


===== SCRATCH STUDENT (Reasoning Mimicry) =====


Token Overlap Eval:  20%|██        | 1/5 [00:07<00:30,  7.67s/it]


[Sample 1] Overlap: 0.239
Teacher: Let's start by finding the coordinates of points $A$, $B$, and $C$. Since $ABC$ lies in the first quadrant, we know that $x_A>0$, $x_B>0$, and $x_C>0$. Let's assume that $y_A>x_A$, $y_B>x_B$, and $y_C
Student: The statement is true if and only if the triangle lies in the first quadrant.
B: The statement is true if and only if the slope of line $AA'$ is $-1$.
C: The statement is true if and only if the slope


Token Overlap Eval:  40%|████      | 2/5 [00:15<00:22,  7.60s/it]


[Sample 2] Overlap: 0.086
Teacher: First, we can assume that $f(0)=0$ since $f(x+y^2) \ge f(x)+y$ and $f(0)+0=0$.

Next, we can assume that $f$ is strictly increasing on $[0,\infty)$ since $f(x+y^2) \ge f(x)+y$ and $f(x+y^2) \ge f(x)$ 
Student: Let $c = 1$.
For all nonnegative reals $x, y$, $f(x+y^2) \ge 1f(x)+y$.

Let $x = 1$ and $y = 2$.
$f(1+2^2) = 1f(1) + 2f(1) = 1 + 2f(1) = 3$.
$f(1+2^2) \ge 3f(1) + 2 = 5$.
$f(1+2^2) = 5$.

So $f(1+2^2)


Token Overlap Eval:  60%|██████    | 3/5 [00:22<00:15,  7.60s/it]


[Sample 3] Overlap: 0.000
Teacher: 1, 2, 3, 4, 7
Student: The polynomial has 9 roots, which are:
- 1
- 1/37
- 1/27
- 1/74
- 1/37
- 1/27
- 1/74
- 1/37
- 1/27

B: The polynomial has no real roots.

C: The polynomial has no real roots because it has 9 real root


Token Overlap Eval: 100%|██████████| 5/5 [00:33<00:00,  6.63s/it]



Avg Token Overlap: 0.1380


BLEU Eval: 100%|██████████| 5/5 [00:33<00:00,  6.63s/it]


BLEU Score: 10.19


ROUGE-L Eval:  20%|██        | 1/5 [00:07<00:30,  7.59s/it]


[Sample 1] ROUGE-L: 0.216


ROUGE-L Eval:  40%|████      | 2/5 [00:15<00:22,  7.62s/it]


[Sample 2] ROUGE-L: 0.195


ROUGE-L Eval:  60%|██████    | 3/5 [00:22<00:15,  7.62s/it]


[Sample 3] ROUGE-L: 0.012


ROUGE-L Eval: 100%|██████████| 5/5 [00:33<00:00,  6.67s/it]



Avg ROUGE-L: 0.1757

===== ALM STUDENT (Reasoning Mimicry) =====


Token Overlap Eval:  20%|██        | 1/5 [00:13<00:53, 13.40s/it]


[Sample 1] Overlap: 0.217
Teacher: Let's start by finding the coordinates of points $A$, $B$, and $C$. Since $ABC$ lies in the first quadrant, we know that $x_A>0$, $x_B>0$, and $x_C>0$. Let's assume that $y_A>x_A$, $y_B>x_B$, and $y_C
Student: (A) (B) (C) (D) (E)

Based on the given information, we can see that:


*

*Triangle $A'B'C'$ lies in the first quadrant.

*Triangle $ABC$ has the same area as triangle $A'B'C'$.

*The slope of line $


Token Overlap Eval:  40%|████      | 2/5 [00:26<00:40, 13.43s/it]


[Sample 2] Overlap: 0.057
Teacher: First, we can assume that $f(0)=0$ since $f(x+y^2) \ge f(x)+y$ and $f(0)+0=0$.

Next, we can assume that $f$ is strictly increasing on $[0,\infty)$ since $f(x+y^2) \ge f(x)+y$ and $f(x+y^2) \ge f(x)$ 
Student: 1. Let $x, y$ be nonnegative reals.
2. $f(x+y^2) = cf(x)+y$.
3. $f(x+y^2) = cf(x)+y$.
4. $f(x+y^2) = cf(x)+y$.
5. $f(x+y^2) = cf(x)+y$.
6. $f(x+y^2) = cf(x)+y$.
7. $f(x+y^2) = cf(x)+y$.
8. $f(x+y^2) =


Token Overlap Eval:  60%|██████    | 3/5 [00:40<00:26, 13.42s/it]


[Sample 3] Overlap: 0.000
Teacher: 1, 2, 3, 4, 7
Student: 1. First, we need to find the roots of the polynomial.
2. The polynomial is a polynomial of degree 9.
3. So, we need to find the roots of the polynomial.
4. The polynomial is a polynomial of degree 9.


Token Overlap Eval: 100%|██████████| 5/5 [01:07<00:00, 13.42s/it]



Avg Token Overlap: 0.1565


BLEU Eval: 100%|██████████| 5/5 [01:06<00:00, 13.35s/it]


BLEU Score: 10.20


ROUGE-L Eval:  20%|██        | 1/5 [00:13<00:53, 13.30s/it]


[Sample 1] ROUGE-L: 0.278


ROUGE-L Eval:  40%|████      | 2/5 [00:26<00:40, 13.37s/it]


[Sample 2] ROUGE-L: 0.417


ROUGE-L Eval:  60%|██████    | 3/5 [00:40<00:26, 13.40s/it]


[Sample 3] ROUGE-L: 0.055


ROUGE-L Eval: 100%|██████████| 5/5 [01:06<00:00, 13.36s/it]


Avg ROUGE-L: 0.2449

===== FINAL REASONING MIMICRY RESULTS =====
Model           Overlap    BLEU       ROUGE-L   
--------------------------------------------------
Scratch         0.138      10.19      0.176     
ALM (Ours)      0.156      10.20      0.245     





## Experiment 2: Lambda = 0.3, Gamma = 1.0

In [None]:
import logging
from tqdm import tqdm
import torch

# -------------------- CONFIG --------------------
device = "cuda"
LAMBDA_ALM = 0.3
GAMMA = 1.0
MAX_CONTEXT_LEN = 512
OUTPUT_DIR_1 = "lam03_gam10"
os.makedirs(OUTPUT_DIR_1, exist_ok=True)
# -----------------------------------------------

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

student_model.train()

for epoch in range(EPOCHS):
    logging.info(f"Starting epoch {epoch+1}/{EPOCHS}")

    epoch_loss = 0.0
    num_batches = len(train_loader)

    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
        batch_size = len(batch["question"])
        optimizer.zero_grad(set_to_none=True)

        batch_loss = 0.0

        # -------------------- PER-SAMPLE LOOP --------------------
        for i in range(batch_size):
            record = {k: batch[k][i] for k in batch}

            # ======================================================
            # (1) CE LOSS (UNCHANGED FROM YOUR ORIGINAL LOOP)
            # ======================================================
            teacher_text = record["generation"]
            teacher_ids = student_tok(
                teacher_text,
                return_tensors="pt"
            ).input_ids.to(device)

            max_len = student_model.config.max_position_embeddings
            ce_chunks = teacher_ids.split(max_len, dim=1)

            ce_loss = 0.0
            for chunk in ce_chunks:
                outputs = student_model(input_ids=chunk, labels=chunk)
                ce_loss += outputs.loss

            ce_loss = ce_loss / len(ce_chunks)

            # ======================================================
            # (2) ALM LOSS (CONFIDENCE-WEIGHTED RATIONALE ALIGNMENT)
            # ======================================================
            assert isinstance(record["chunks"], dict)

            alm_loss = alm_loss_with_context(
                student_model=student_model,
                chunks=[record["chunks"]],
                device=device,
                max_context_len=MAX_CONTEXT_LEN,
                gamma=GAMMA
            )

            # ======================================================
            # (3) TOTAL LOSS (PROPOSAL FORM)
            # ======================================================
            sample_loss = ce_loss + LAMBDA_ALM * alm_loss
            batch_loss += sample_loss

            logging.debug(
                f"Sample {i+1}/{batch_size} | "
                f"CE: {ce_loss.item():.4f}, "
                f"ALM: {alm_loss.item():.4f}, "
                f"Total: {sample_loss.item():.4f}"
            )

        # -------------------- BACKPROP --------------------
        batch_loss = batch_loss / batch_size

        if not torch.isfinite(batch_loss):
            logging.warning("⚠️ NaN/Inf loss detected, skipping batch")
            optimizer.zero_grad(set_to_none=True)
            continue

        batch_loss.backward()
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)

        optimizer.step()
        lr_scheduler.step()

        epoch_loss += batch_loss.item()

        logging.info(
            f"Epoch {epoch+1} Batch {batch_idx+1}/{num_batches} | "
            f"Batch Loss: {batch_loss.item():.4f} | "
            f"LR: {lr_scheduler.get_last_lr()[0]:.6f}"
        )

    avg_epoch_loss = epoch_loss / num_batches
    logging.info(f"Epoch {epoch+1} completed | Avg Loss: {avg_epoch_loss:.4f}")
    print(f"Epoch {epoch+1} completed | Avg Loss: {avg_epoch_loss:.4f}")

    student_model.save_pretrained(
        f"{OUTPUT_DIR_1}/checkpoint_epoch_{epoch+1}"
    )

Epoch 1: 100%|██████████| 586/586 [05:26<00:00,  1.79it/s]


Epoch 1 completed | Avg Loss: 1.2073


Epoch 2: 100%|██████████| 586/586 [05:27<00:00,  1.79it/s]


Epoch 2 completed | Avg Loss: 1.0295


Epoch 3: 100%|██████████| 586/586 [05:28<00:00,  1.78it/s]


Epoch 3 completed | Avg Loss: 1.0006


Epoch 4: 100%|██████████| 586/586 [05:27<00:00,  1.79it/s]


Epoch 4 completed | Avg Loss: 0.9931


In [None]:
from transformers import AutoModelForCausalLM

student_alm_exp2 = AutoModelForCausalLM.from_pretrained(
    OUTPUT_DIR_1 + "/checkpoint_epoch_4",   # adjust if needed
    device_map="auto"
)
student_alm_exp2.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=256, bias=False)
            (lora_dropout): ModuleDict(
              (default): 

In [None]:
# sanity: small subset
small_test = test_records[:5]

# Then full evaluation
run_reasoning_mimicry_evaluation(
    student_scratch,
    student_alm_exp2,
    student_tok,
    small_test
)


===== SCRATCH STUDENT (Reasoning Mimicry) =====


Token Overlap Eval:  20%|██        | 1/5 [00:07<00:30,  7.73s/it]


[Sample 1] Overlap: 0.239
Teacher: Let's start by finding the coordinates of points $A$, $B$, and $C$. Since $ABC$ lies in the first quadrant, we know that $x_A>0$, $x_B>0$, and $x_C>0$. Let's assume that $y_A>x_A$, $y_B>x_B$, and $y_C
Student: The statement is true if and only if the triangle lies in the first quadrant.
B: The statement is true if and only if the slope of line $AA'$ is $-1$.
C: The statement is true if and only if the slope


Token Overlap Eval:  40%|████      | 2/5 [00:15<00:23,  7.71s/it]


[Sample 2] Overlap: 0.086
Teacher: First, we can assume that $f(0)=0$ since $f(x+y^2) \ge f(x)+y$ and $f(0)+0=0$.

Next, we can assume that $f$ is strictly increasing on $[0,\infty)$ since $f(x+y^2) \ge f(x)+y$ and $f(x+y^2) \ge f(x)$ 
Student: Let $c = 1$.
For all nonnegative reals $x, y$, $f(x+y^2) \ge 1f(x)+y$.

Let $x = 1$ and $y = 2$.
$f(1+2^2) = 1f(1) + 2f(1) = 1 + 2f(1) = 3$.
$f(1+2^2) \ge 3f(1) + 2 = 5$.
$f(1+2^2) = 5$.

So $f(1+2^2)


Token Overlap Eval:  60%|██████    | 3/5 [00:23<00:15,  7.68s/it]


[Sample 3] Overlap: 0.000
Teacher: 1, 2, 3, 4, 7
Student: The polynomial has 9 roots, which are:
- 1
- 1/37
- 1/27
- 1/74
- 1/37
- 1/27
- 1/74
- 1/37
- 1/27

B: The polynomial has no real roots.

C: The polynomial has no real roots because it has 9 real root


Token Overlap Eval: 100%|██████████| 5/5 [00:33<00:00,  6.70s/it]



Avg Token Overlap: 0.1380


BLEU Eval: 100%|██████████| 5/5 [00:33<00:00,  6.64s/it]


BLEU Score: 10.19


ROUGE-L Eval:  20%|██        | 1/5 [00:07<00:30,  7.62s/it]


[Sample 1] ROUGE-L: 0.216


ROUGE-L Eval:  40%|████      | 2/5 [00:15<00:22,  7.60s/it]


[Sample 2] ROUGE-L: 0.195


ROUGE-L Eval:  60%|██████    | 3/5 [00:22<00:15,  7.61s/it]


[Sample 3] ROUGE-L: 0.012


ROUGE-L Eval: 100%|██████████| 5/5 [00:33<00:00,  6.65s/it]



Avg ROUGE-L: 0.1757

===== ALM STUDENT (Reasoning Mimicry) =====


Token Overlap Eval:  20%|██        | 1/5 [00:13<00:53, 13.33s/it]


[Sample 1] Overlap: 0.043
Teacher: Let's start by finding the coordinates of points $A$, $B$, and $C$. Since $ABC$ lies in the first quadrant, we know that $x_A>0$, $x_B>0$, and $x_C>0$. Let's assume that $y_A>x_A$, $y_B>x_B$, and $y_C
Student: First, we can see that $A'B'C'$ is a triangle.

Next, we can see that $A'B'C'$ is a triangle.

Next, we can see that $A'B'C'$ is a triangle.

Next, we can see that $A'B'C'$ is a triangle.

Next, we ca


Token Overlap Eval:  40%|████      | 2/5 [00:26<00:39, 13.33s/it]


[Sample 2] Overlap: 0.086
Teacher: First, we can assume that $f(0)=0$ since $f(x+y^2) \ge f(x)+y$ and $f(0)+0=0$.

Next, we can assume that $f$ is strictly increasing on $[0,\infty)$ since $f(x+y^2) \ge f(x)+y$ and $f(x+y^2) \ge f(x)$ 
Student: 1. Let $x, y$ be nonnegative reals.
2. $f(x+y^2) \ge cf(x)+y$
3. $f(x+y^2) \ge cf(x)+y$
4. $f(x+y^2) \ge cf(x)+y$
5. $f(x+y^2) \ge cf(x)+y$
6. $f(x+y^2) \ge cf(x)+y$
7. $f(x+y^2) \ge cf(x)+y$
8. $f(x+


Token Overlap Eval:  60%|██████    | 3/5 [00:39<00:26, 13.32s/it]


[Sample 3] Overlap: 0.000
Teacher: 1, 2, 3, 4, 7
Student: 1. First, we need to find the roots of the polynomial.
2. We can use the quadratic formula to find the roots of the polynomial.
3. We can use the quadratic formula to find the roots of the polynomial.


Token Overlap Eval: 100%|██████████| 5/5 [01:06<00:00, 13.33s/it]



Avg Token Overlap: 0.1393


BLEU Eval: 100%|██████████| 5/5 [01:07<00:00, 13.43s/it]


BLEU Score: 6.81


ROUGE-L Eval:  20%|██        | 1/5 [00:13<00:53, 13.42s/it]


[Sample 1] ROUGE-L: 0.262


ROUGE-L Eval:  40%|████      | 2/5 [00:26<00:40, 13.40s/it]


[Sample 2] ROUGE-L: 0.451


ROUGE-L Eval:  60%|██████    | 3/5 [00:40<00:26, 13.40s/it]


[Sample 3] ROUGE-L: 0.048


ROUGE-L Eval: 100%|██████████| 5/5 [01:06<00:00, 13.36s/it]


Avg ROUGE-L: 0.2506

===== FINAL REASONING MIMICRY RESULTS =====
Model           Overlap    BLEU       ROUGE-L   
--------------------------------------------------
Scratch         0.138      10.19      0.176     
ALM (Ours)      0.139      6.81       0.251     





## What we established so far == ```Giving higher weightage to ALM loss, improves numbers```

## Experiment 3: Lambda = 1.5, Gamma = 1.5 (To see impact of giving higher weightage to conviction)

In [None]:
import logging
from tqdm import tqdm
import torch

# -------------------- CONFIG --------------------
device = "cuda"
LAMBDA_ALM = 1.5
GAMMA = 1.0
MAX_CONTEXT_LEN = 512
OUTPUT_DIR_EXP_3 = "lam15_gam15"
os.makedirs(OUTPUT_DIR_EXP_3, exist_ok=True)
# -----------------------------------------------

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

student_model.train()

for epoch in range(EPOCHS):
    logging.info(f"Starting epoch {epoch+1}/{EPOCHS}")

    epoch_loss = 0.0
    num_batches = len(train_loader)

    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
        batch_size = len(batch["question"])
        optimizer.zero_grad(set_to_none=True)

        batch_loss = 0.0

        # -------------------- PER-SAMPLE LOOP --------------------
        for i in range(batch_size):
            record = {k: batch[k][i] for k in batch}

            # ======================================================
            # (1) CE LOSS (UNCHANGED FROM YOUR ORIGINAL LOOP)
            # ======================================================
            teacher_text = record["generation"]
            teacher_ids = student_tok(
                teacher_text,
                return_tensors="pt"
            ).input_ids.to(device)

            max_len = student_model.config.max_position_embeddings
            ce_chunks = teacher_ids.split(max_len, dim=1)

            ce_loss = 0.0
            for chunk in ce_chunks:
                outputs = student_model(input_ids=chunk, labels=chunk)
                ce_loss += outputs.loss

            ce_loss = ce_loss / len(ce_chunks)

            # ======================================================
            # (2) ALM LOSS (CONFIDENCE-WEIGHTED RATIONALE ALIGNMENT)
            # ======================================================
            assert isinstance(record["chunks"], dict)

            alm_loss = alm_loss_with_context(
                student_model=student_model,
                chunks=[record["chunks"]],
                device=device,
                max_context_len=MAX_CONTEXT_LEN,
                gamma=GAMMA
            )

            # ======================================================
            # (3) TOTAL LOSS (PROPOSAL FORM)
            # ======================================================
            sample_loss = ce_loss + LAMBDA_ALM * alm_loss
            batch_loss += sample_loss

            logging.debug(
                f"Sample {i+1}/{batch_size} | "
                f"CE: {ce_loss.item():.4f}, "
                f"ALM: {alm_loss.item():.4f}, "
                f"Total: {sample_loss.item():.4f}"
            )

        # -------------------- BACKPROP --------------------
        batch_loss = batch_loss / batch_size

        if not torch.isfinite(batch_loss):
            logging.warning("⚠️ NaN/Inf loss detected, skipping batch")
            optimizer.zero_grad(set_to_none=True)
            continue

        batch_loss.backward()
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)

        optimizer.step()
        lr_scheduler.step()

        epoch_loss += batch_loss.item()

        logging.info(
            f"Epoch {epoch+1} Batch {batch_idx+1}/{num_batches} | "
            f"Batch Loss: {batch_loss.item():.4f} | "
            f"LR: {lr_scheduler.get_last_lr()[0]:.6f}"
        )

    avg_epoch_loss = epoch_loss / num_batches
    logging.info(f"Epoch {epoch+1} completed | Avg Loss: {avg_epoch_loss:.4f}")
    print(f"Epoch {epoch+1} completed | Avg Loss: {avg_epoch_loss:.4f}")

student_model.save_pretrained(
    f"{OUTPUT_DIR_EXP_3}/checkpoint_epoch_{epoch+1}"
)

Epoch 1: 100%|██████████| 586/586 [05:10<00:00,  1.88it/s]


Epoch 1 completed | Avg Loss: 4.0411


Epoch 2: 100%|██████████| 586/586 [05:08<00:00,  1.90it/s]


Epoch 2 completed | Avg Loss: 1.2050


Epoch 3: 100%|██████████| 586/586 [05:08<00:00,  1.90it/s]


Epoch 3 completed | Avg Loss: 1.0501


Epoch 4: 100%|██████████| 586/586 [05:10<00:00,  1.89it/s]


Epoch 4 completed | Avg Loss: 1.0219


###Run Eval ALM

In [None]:
import random

def split_records(records, test_ratio=0.1, seed=42):
    random.seed(seed)
    records = records.copy()
    random.shuffle(records)

    n_test = int(len(records) * test_ratio)
    return records[n_test:], records[:n_test]


In [None]:
# Use the same ALM dataset you trained on
train_records, test_records = split_records(
    alm_dataset.records,
    test_ratio=0.1
)

print(f"Train: {len(train_records)} | Test: {len(test_records)}")

Train: 528 | Test: 58


In [None]:
from transformers import AutoModelForCausalLM

student_alm_exp_3 = AutoModelForCausalLM.from_pretrained(
    OUTPUT_DIR_EXP_3 + "/checkpoint_epoch_4",   # adjust if needed
    device_map="auto"
)
student_alm_exp_3.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=256, bias=False)
            (lora_dropout): ModuleDict(
              (default): 

In [None]:
student_scratch = AutoModelForCausalLM.from_pretrained(
    STUDENT_MODEL,
    device_map="auto"
)
student_scratch.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rot

In [None]:
# sanity: small subset
small_test = test_records[:5]

# Then full evaluation
run_reasoning_mimicry_evaluation(
    student_scratch,
    student_alm_exp_3,
    student_tok,
    small_test
)


===== SCRATCH STUDENT (Reasoning Mimicry) =====


Token Overlap Eval:  20%|██        | 1/5 [00:07<00:30,  7.62s/it]


[Sample 1] Overlap: 0.239
Teacher: Let's start by finding the coordinates of points $A$, $B$, and $C$. Since $ABC$ lies in the first quadrant, we know that $x_A>0$, $x_B>0$, and $x_C>0$. Let's assume that $y_A>x_A$, $y_B>x_B$, and $y_C
Student: The statement is true if and only if the triangle lies in the first quadrant.
B: The statement is true if and only if the slope of line $AA'$ is $-1$.
C: The statement is true if and only if the slope


Token Overlap Eval:  40%|████      | 2/5 [00:15<00:22,  7.61s/it]


[Sample 2] Overlap: 0.086
Teacher: First, we can assume that $f(0)=0$ since $f(x+y^2) \ge f(x)+y$ and $f(0)+0=0$.

Next, we can assume that $f$ is strictly increasing on $[0,\infty)$ since $f(x+y^2) \ge f(x)+y$ and $f(x+y^2) \ge f(x)$ 
Student: Let $c = 1$.
For all nonnegative reals $x, y$, $f(x+y^2) \ge 1f(x)+y$.

Let $x = 1$ and $y = 2$.
$f(1+2^2) = 1f(1) + 2f(1) = 1 + 2f(1) = 3$.
$f(1+2^2) \ge 3f(1) + 2 = 5$.
$f(1+2^2) = 5$.

So $f(1+2^2)


Token Overlap Eval:  60%|██████    | 3/5 [00:22<00:15,  7.57s/it]


[Sample 3] Overlap: 0.000
Teacher: 1, 2, 3, 4, 7
Student: The polynomial has 9 roots, which are:
- 1
- 1/37
- 1/27
- 1/74
- 1/37
- 1/27
- 1/74
- 1/37
- 1/27

B: The polynomial has no real roots.

C: The polynomial has no real roots because it has 9 real root


Token Overlap Eval: 100%|██████████| 5/5 [00:33<00:00,  6.63s/it]



Avg Token Overlap: 0.1380


BLEU Eval: 100%|██████████| 5/5 [00:32<00:00,  6.60s/it]


BLEU Score: 10.19


ROUGE-L Eval:  20%|██        | 1/5 [00:07<00:30,  7.57s/it]


[Sample 1] ROUGE-L: 0.216


ROUGE-L Eval:  40%|████      | 2/5 [00:15<00:22,  7.54s/it]


[Sample 2] ROUGE-L: 0.195


ROUGE-L Eval:  60%|██████    | 3/5 [00:22<00:15,  7.53s/it]


[Sample 3] ROUGE-L: 0.012


ROUGE-L Eval: 100%|██████████| 5/5 [00:32<00:00,  6.59s/it]



Avg ROUGE-L: 0.1757

===== ALM STUDENT (Reasoning Mimicry) =====


Token Overlap Eval:  20%|██        | 1/5 [00:13<00:53, 13.32s/it]


[Sample 1] Overlap: 0.130
Teacher: Let's start by finding the coordinates of points $A$, $B$, and $C$. Since $ABC$ lies in the first quadrant, we know that $x_A>0$, $x_B>0$, and $x_C>0$. Let's assume that $y_A>x_A$, $y_B>x_B$, and $y_C
Student: (A) is not true because the vertices of the triangle lie on the line $y=x$.
(B) is not true because the area of the triangle is not equal to the area of the triangle with the same vertices.
(C) is not


Token Overlap Eval:  40%|████      | 2/5 [00:26<00:39, 13.31s/it]


[Sample 2] Overlap: 0.171
Teacher: First, we can assume that $f(0)=0$ since $f(x+y^2) \ge f(x)+y$ and $f(0)+0=0$.

Next, we can assume that $f$ is strictly increasing on $[0,\infty)$ since $f(x+y^2) \ge f(x)+y$ and $f(x+y^2) \ge f(x)$ 
Student: 1. Let's start by assuming that $f(x) = 0$ for all nonnegative reals $x$.

2. Then, we can rewrite the inequality as follows.

For all nonnegative reals $x, y$, $f(x+y^2) \ge 0$.

3. Now, let's consid


Token Overlap Eval:  60%|██████    | 3/5 [00:39<00:26, 13.25s/it]


[Sample 3] Overlap: 0.200
Teacher: 1, 2, 3, 4, 7
Student: *

*First, we need to find the roots of the polynomial.

*The polynomial is a quadratic equation in x.

*So, we can use quadratic formula to find the roots of the polynomial.

*The roots of the polyno


Token Overlap Eval: 100%|██████████| 5/5 [01:06<00:00, 13.28s/it]



Avg Token Overlap: 0.2305


BLEU Eval: 100%|██████████| 5/5 [01:06<00:00, 13.31s/it]


BLEU Score: 6.39


ROUGE-L Eval:  20%|██        | 1/5 [00:13<00:53, 13.30s/it]


[Sample 1] ROUGE-L: 0.164


ROUGE-L Eval:  40%|████      | 2/5 [00:26<00:39, 13.19s/it]


[Sample 2] ROUGE-L: 0.318


ROUGE-L Eval:  60%|██████    | 3/5 [00:39<00:26, 13.27s/it]


[Sample 3] ROUGE-L: 0.109


ROUGE-L Eval: 100%|██████████| 5/5 [01:06<00:00, 13.29s/it]


Avg ROUGE-L: 0.2222

===== FINAL REASONING MIMICRY RESULTS =====
Model           Overlap    BLEU       ROUGE-L   
--------------------------------------------------
Scratch         0.138      10.19      0.176     
ALM (Ours)      0.231      6.39       0.222     



