In [None]:

import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from datasets import load_dataset, concatenate_datasets, DatasetDict, ClassLabel

cache_dir = "/hf_cache"



TASKS = ["sst2", "qnli", "qqp", "cola", "mrpc", "stsb"]
MAX_PER_TASK = 5000

def build_id2label(feature):
    """Return a dict mapping id -> label name, or None if regression."""
    if isinstance(feature, ClassLabel):
        return {i: name for i, name in enumerate(feature.names)}
    return None  # e.g. STS-B, regression (float scores)


def make_preprocess_fn(task, id2label):
    """Create a preprocessing function that outputs a single 'text' field + string label."""

    def preprocess(example):
        # ---- build the natural-language prompt + input text ----
        if task == "sst2":
            prompt = "Task: Sentiment classification.\n" \
                     "Decide whether the following sentence is positive or negative.\n"
            body = f"Sentence: {example['sentence']}"
        elif task == "cola":
            prompt = "Task: Grammatical acceptability.\n" \
                     "Decide whether the following sentence is grammatically acceptable or unacceptable.\n"
            body = f"Sentence: {example['sentence']}"
        elif task == "qnli":
            prompt = "Task: Question-answer entailment.\n" \
                     "Decide whether the sentence correctly answers the question (entailment or not_entailment).\n"
            body = f"Question: {example['question']}\nSentence: {example['sentence']}"
        elif task == "qqp":
            prompt = "Task: Question paraphrase detection.\n" \
                     "Decide whether the two questions are paraphrases of each other (duplicate or not_duplicate).\n"
            body = f"Sentence 1: {example['question1']}\nSentence 2: {example['question2']}"
        elif task == "mrpc":
            prompt = "Task: Sentence paraphrase detection.\n" \
                     "Decide whether the two sentences are paraphrases of each other (equivalent or not_equivalent).\n"
            body = f"Sentence 1: {example['sentence1']}\nSentence 2: {example['sentence2']}"
        elif task == "stsb":
            prompt = "Task: Semantic textual similarity.\n" \
                     "Rate the similarity of the two sentences on a scale from 0 (no meaning overlap) " \
                     "to 5 (equivalent in meaning).\n"
            body = f"Sentence 1: {example['sentence1']}\nSentence 2: {example['sentence2']}"
        else:
            raise ValueError(f"Unknown task: {task}")

        text = prompt + "\n" + body

        # ---- convert label to string label_text ----
        raw_label = example["label"]
        if id2label is not None:
            # classification: use GLUE label names (e.g. 'entailment', 'duplicate', etc.)
            label_text = id2label[int(raw_label)]
        else:
            # STS-B regression: keep numeric label but as a string
            # (you can bucket this into custom names if you prefer)
            label_text = str(raw_label)

        return {
            "text": text,
            "label_text": label_text,
            "task": task,
        }

    return preprocess


train_parts = []
val_splits = {}

for task in TASKS:
    print(f"Loading {task}...")
    raw = load_dataset("nyu-mll/glue", task, cache_dir=cache_dir)

    # label mapping (if classification)
    id2label = build_id2label(raw["train"].features["label"])
    preprocess = make_preprocess_fn(task, id2label)

    # ---- TRAIN: sample up to 5k examples ----
    train_ds = raw["train"].shuffle(seed=42)
    n = min(MAX_PER_TASK, len(train_ds))  # if <5k, just use all
    train_ds = train_ds.select(range(n))
    train_ds = train_ds.map(
        preprocess,
        remove_columns=train_ds.column_names,
    )
    train_parts.append(train_ds)

    # ---- VALIDATION: full validation per task ----
    val_ds = raw["validation"].map(
        preprocess,
        remove_columns=raw["validation"].column_names,
    )
    val_splits[f"validation_{task}"] = val_ds

# ---- FINAL MERGED DATASETDICT ----
combined_train = concatenate_datasets(train_parts)

multi_glue = DatasetDict({
    "train": combined_train,       # one big train set
    **val_splits,                 # validation_sst2, validation_qnli, ...
})

print(multi_glue)
print(multi_glue["train"][0]["text"])
print(multi_glue["train"][0]["label_text"])


In [3]:
multi_glue["train"] = multi_glue["train"].shuffle(seed=42)


In [4]:
import os
import torch

from torch.utils.data import DataLoader
from transformers import (
    AutoProcessor,
    LlavaOnevisionForConditionalGeneration,
    TrainingArguments,
    Trainer,
    PreTrainedTokenizerBase
)
from peft import LoraConfig, get_peft_model
from datasets import load_from_disk
from typing import List, Dict, Any




In [None]:
# ðŸš€ Step 2: Load Model and Processor
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForCausalLM, DataCollatorForLanguageModeling

cache_dir = "/hf_cache"
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    legacy=True  # Suppresses the warning/error with tokenizer.model
)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", cache_dir=cache_dir)


Loading checkpoint shards: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4/4 [00:06<00:00,  1.53s/it]


In [6]:
model 

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [14]:
from MJLoRA import apply_monkeyjump



blocks_spec = {
    #"SiglipEncoderLayer": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25],
    "LlamaDecoderLayer":  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
}
linears = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"]  # add 'up_proj','down_proj', etc. "out_proj", "fc1", "fc2", "gate_proj", "up_proj", "down_proj"

model = apply_monkeyjump(
    model,
    blocks=blocks_spec,
    shared_expert=["gate_proj",  "o_proj"],
    linears=linears,
    rank=2, alpha=3.0,
    temperature=1.0,   # router T
    ema_momentum=0.5,
    top_k=1,
    rep_mode="prompt_end",
    jitter_noise=0.01,

)


In [15]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): MonkeyJumpLinear[expert_0](in=4096, out=4096, rank=2, dtype=float32)
          (k_proj): MonkeyJumpLinear[expert_1](in=4096, out=1024, rank=2, dtype=float32)
          (v_proj): MonkeyJumpLinear[expert_2](in=4096, out=1024, rank=2, dtype=float32)
          (o_proj): MonkeyJumpLinear[shared](in=4096, out=4096, rank=2, dtype=float32)
        )
        (mlp): LlamaMLP(
          (gate_proj): MonkeyJumpLinear[shared](in=4096, out=14336, rank=2, dtype=float32)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
  

In [16]:
# Count of trainable parameters
total_trainable_params = 0
total =  0
# Print trainable parameters and count their total number
for name, param in model.named_parameters():
    if param.requires_grad:
        #print(f"Parameter name: {name}, Shape: {param.shape}")
        
        total_trainable_params += param.numel()
    total+=param.numel()

print(f"Total trainable parameters:{total_trainable_params}")

Total trainable parameters:2883584


In [17]:
# 1) Count params that require grad
trainable = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
total = sum(p.numel() for _, p in model.named_parameters())
trainable_num = sum(p.numel() for _, p in trainable)
print(f"trainable params: {trainable_num:,} / {total:,}")



trainable params: 2,883,584 / 8,033,144,832


In [32]:
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import torch


@dataclass
class GlueLlavaDataCollator:
    """
    Text-only multitask collator for Llama 3 style processors on GLUE.
    """
    tokenizer: Any
    is_train: bool = True
    pad_to_multiple_of: Optional[int] = 8
    answer_prefix: str = "The correct output is"
    debug: bool = False
    force_left_padding: bool = True
    insert_token_id: int = 128001  # Token to insert before eot_id

    def __post_init__(self):
        tok = self.tokenizer

        if self.force_left_padding:
            tok.padding_side = "left"

        if tok.pad_token_id is None:
            tok.pad_token = tok.eos_token

        self.pad_id = tok.pad_token_id
        self.eos_id = tok.eos_token_id

        # Build preamble by getting actual special token IDs for Llama 3
        try:
            start_header_id = tok.convert_tokens_to_ids("<|start_header_id|>")
            end_header_id = tok.convert_tokens_to_ids("<|end_header_id|>")
            assistant_ids = tok.encode("assistant", add_special_tokens=False)
            newline_ids = tok.encode("\n\n", add_special_tokens=False)
            
            base_preamble = [start_header_id] + assistant_ids + [end_header_id] + newline_ids
            
            self._preamble_variants = [
                base_preamble,
                [start_header_id] + assistant_ids + [end_header_id],
            ]
            
            self.eot_id = tok.convert_tokens_to_ids("<|eot_id|>")
            
        except Exception as e:
            if self.debug:
                print(f"[warn] Could not build Llama 3 preamble: {e}")
            self._preamble_variants = []
            self.eot_id = self.eos_id

        if self.debug:
            print(f"Preamble variants: {self._preamble_variants}")
            print(f"EOT ID: {self.eot_id}, EOS ID: {self.eos_id}")
            print(f"Insert token ID: {self.insert_token_id}")

    @staticmethod
    def _rfind_subseq(hay, needle) -> int:
        if not needle or len(needle) > len(hay):
            return -1
        for s in range(len(hay) - len(needle), -1, -1):
            if hay[s:s + len(needle)] == needle:
                return s
        return -1

    def _find_assistant_start(self, ids) -> int:
        for needle in self._preamble_variants:
            pos = self._rfind_subseq(ids, needle)
            if pos != -1:
                return pos + len(needle)
        return -1

    def _first_eos_after(self, ids, start) -> int:
        if start < 0:
            return -1
        for i in range(start, len(ids)):
            if ids[i] == self.eot_id or ids[i] == self.eos_id:
                return i
        return len(ids)

    def _insert_token_before_eot(
        self, 
        input_ids: torch.Tensor, 
        attention_mask: Optional[torch.Tensor]
    ) -> tuple:
        """
        Insert self.insert_token_id before <|eot_id|> in each sequence.
        Result: content<|end_of_text|><|eot_id|>
        """
        batch_size, seq_len = input_ids.shape
        
        if attention_mask is None:
            attention_mask = (input_ids != self.pad_id).long()
        
        new_input_ids_list = []
        new_attention_mask_list = []
        
        for i in range(batch_size):
            ids = input_ids[i].tolist()
            attn = attention_mask[i].tolist()
            
            # Find the last eot_id position in the valid (non-pad) region
            eot_pos = -1
            for j in range(len(ids) - 1, -1, -1):
                if attn[j] == 1 and ids[j] == self.eot_id:
                    eot_pos = j
                    break
            
            if eot_pos != -1:
                # Check if insert_token_id is already right before eot_id
                if eot_pos > 0 and ids[eot_pos - 1] == self.insert_token_id:
                    # Already in correct position
                    new_input_ids_list.append(ids)
                    new_attention_mask_list.append(attn)
                else:
                    # Insert token before eot_id
                    new_ids = ids[:eot_pos] + [self.insert_token_id] + ids[eot_pos:]
                    new_attn = attn[:eot_pos] + [1] + attn[eot_pos:]
                    new_input_ids_list.append(new_ids)
                    new_attention_mask_list.append(new_attn)
            else:
                # No eot_id found, append at end of valid tokens
                last_valid_idx = -1
                for j in range(len(ids) - 1, -1, -1):
                    if attn[j] == 1:
                        last_valid_idx = j
                        break
                
                if last_valid_idx >= 0 and ids[last_valid_idx] != self.insert_token_id:
                    ids.append(self.insert_token_id)
                    attn.append(1)
                
                new_input_ids_list.append(ids)
                new_attention_mask_list.append(attn)
        
        # Find max length and pad
        max_len = max(len(ids) for ids in new_input_ids_list)
        
        if self.pad_to_multiple_of:
            max_len = ((max_len + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of) * self.pad_to_multiple_of
        
        # Left-pad all sequences to max_len
        for i in range(batch_size):
            pad_len = max_len - len(new_input_ids_list[i])
            if pad_len > 0:
                new_input_ids_list[i] = [self.pad_id] * pad_len + new_input_ids_list[i]
                new_attention_mask_list[i] = [0] * pad_len + new_attention_mask_list[i]
        
        new_input_ids = torch.tensor(new_input_ids_list, dtype=input_ids.dtype)
        new_attention_mask = torch.tensor(new_attention_mask_list, dtype=attention_mask.dtype)
        
        return new_input_ids, new_attention_mask

    def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        texts: List[str] = []
        label_texts: List[str] = []

        for ex in examples:
            user_text = ex["text"]
            gold = ex.get("label_text", "")

            if self.is_train:
                assistant_text = f"{self.answer_prefix} {gold}"

                conversation = [
                    {"role": "user", "content": user_text},
                    {"role": "assistant", "content": assistant_text},
                ]

                text = self.tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=False,
                    tokenize=False,
                )

            else:
                conversation = [
                    {"role": "user", "content": user_text},
                ]

                base = self.tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=True,
                    tokenize=False,
                )

                text = base + self.answer_prefix

            texts.append(text)
            label_texts.append(gold)

        batch = self.tokenizer(
            text=texts,
            padding=True,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        input_ids = batch["input_ids"]
        attn = batch.get("attention_mask", None)

        # Insert token 128001 before <|eot_id|> in each sequence
        input_ids, attn = self._insert_token_before_eot(input_ids, attn)
        batch["input_ids"] = input_ids
        batch["attention_mask"] = attn

        labels = torch.full_like(input_ids, -100)

        if self.is_train:
            for i in range(input_ids.size(0)):
                valid_pos = attn[i].nonzero(as_tuple=False).squeeze(-1)
                compact_ids = input_ids[i, valid_pos].tolist()

                start_c = self._find_assistant_start(compact_ids)
                end_c = self._first_eos_after(compact_ids, start_c) if start_c != -1 else -1
                end_c = end_c   # include EOS/eot_id in label span

                if start_c != -1 and end_c > start_c:
                    span_pos = valid_pos[start_c:end_c]
                    labels[i, span_pos] = input_ids[i, span_pos]
                elif self.debug:
                    print(f"[warn] assistant span not found for sample {i}")
                    print(f"  compact_ids: {compact_ids[:50]}...")
                    print(f"  looking for: {self._preamble_variants}")
        else:
            batch["label_texts"] = label_texts

        batch["labels"] = labels
        return batch


In [33]:
train_collator = GlueLlavaDataCollator(tokenizer=tokenizer, is_train=True,  debug=True)


Preamble variants: [[128006, 78191, 128007, 271], [128006, 78191, 128007]]
EOT ID: 128009, EOS ID: 128009
Insert token ID: 128001


In [34]:
d = train_collator(multi_glue['train'].select(range(3)))

In [35]:
print(tokenizer.decode(d['input_ids'][2]))

<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|begin_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|>

Task: Grammatical acceptability.
Decide whether the following sentence is grammatically acceptable or unacceptable.

Sentence: Frances hid Sally of the presents.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The correct output is unacceptable<|end_of_text|><|eot_id|>


In [36]:
k = []

for i in d['labels'][2]:
    if i == -100:
        k.append(0)
    else:
        k.append(i)
    
print(tokenizer.decode(k))

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!The correct output is unacceptable<|end_of_text|>!


In [38]:
from MJtrainer import MonkeyTrainer
from transformers import TrainingArguments
training_args = TrainingArguments(
    output_dir="./llava-lora-finetuned_our",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,  
    save_total_limit=3,
    save_steps=500000,
    num_train_epochs=1,
    remove_unused_columns=False, 
   
    bf16=True,  
    logging_dir="./logs",
    logging_steps=100,
    eval_strategy="no",  
    #eval_steps=10,
    save_strategy="no",
    optim="paged_adamw_8bit",
  
    learning_rate=1e-3,
    warmup_ratio=0.03,
    weight_decay=0.00,
    report_to="none"
)


import os
import torch
from tqdm import tqdm
from transformers import Trainer, TrainingArguments, TrainerCallback



# moe_trainer.py
import torch
from transformers import Trainer


# Example instantiation:
trainer = MonkeyTrainer(
    model=model,
    args=training_args,                  # your HF TrainingArguments
    train_dataset=multi_glue['train'],

    data_collator=train_collator,         # your collator
    momentum=0.5,


    step_interval=2,                    # update every 10 train steps
    stop_update_step=10000,               # stop updates at/after step 1000
)



#original 16:58


Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


[MonkeyTrainer] interval=2, stop_at=10000, update_on=micro, momentum=0.5


In [39]:
from kmneas import init_router_centers
init_router_centers(
    trainer,
    subset_size=5000,        # Use 2000 samples
    loader_batch_size=8,
    collect_batches=20000,
    per_block_cap=50000,      
    max_tokens_per_batch=40096,
    kmeans_iters=50,
    seed=42,
    verbose=True,
    # Optional: auto-select number of experts
    auto_select_experts=False,
    rep_mode="prompt_lengths",
)



[kmeans-init] Collecting representations (SEQUENCE-BASED (prompt_lengths))...
[kmeans-init] Found 32 patched blocks, mode: SEQUENCE-BASED (prompt_lengths)


Collecting representations:   3%|â–Ž         | 625/20000 [01:00<31:01, 10.41it/s, samples=160000]


[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-init] Block collected 5000 sentences, dim=4096
[kmeans-in

In [40]:
trainer.train()

Step,Training Loss
100,2.158
200,0.3235
300,0.2957
400,0.2471
500,0.2552
600,0.2216
700,0.2204
800,0.2288
900,0.2093
1000,0.2204


TrainOutput(global_step=1792, training_loss=0.33006735678230015, metrics={'train_runtime': 473.8942, 'train_samples_per_second': 60.495, 'train_steps_per_second': 3.781, 'total_flos': 1.6021254393534874e+17, 'train_loss': 0.33006735678230015, 'epoch': 1.0})

In [42]:
eval_collator  = GlueLlavaDataCollator(tokenizer=tokenizer, is_train=False)

In [43]:
import re
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

import re
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy.stats import pearsonr  # Add this import

def _extract_first_number(text: str):
    """
    Extract the first float-looking number from a string.
    Returns float or None.
    """
    m = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", text)
    if m:
        try:
            return float(m.group(0))
        except ValueError:
            return None
    return None


def evaluate_glue(
    model,
    tokenizer,
    dataset,
    eval_collator,
    task: str,
    batch_size: int = 8,
    max_new_tokens: int = 10,
):
    """
    task: one of ["sst2", "qnli", "qqp", "cola", "mrpc", "stsb"].
    For STS-B we compute Pearson correlation.
    """
    model.eval()
    preds = []
    golds = []

    loader = DataLoader(dataset, batch_size=batch_size, collate_fn=eval_collator)

    for batch_examples in tqdm(loader, desc=f"Evaluating {task}"):
        # Move tensor parts of batch to device
        batch = {
            k: v.to(model.device)
            for k, v in batch_examples.items()
            if isinstance(v, torch.Tensor)
        }

        # Generate
        with torch.no_grad():
            gen = model.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch.get("attention_mask", None),
                max_new_tokens=max_new_tokens,
                pad_token_id=tokenizer.eos_token_id,
            )

        # Decode full sequences (prompt + continuation)
        generated = tokenizer.batch_decode(gen, skip_special_tokens=True)

        prefix = eval_collator.answer_prefix  # "The correct output is"

        for full_output, gold in zip(generated, batch_examples["label_texts"]):
            # Take part after last "The correct output is"
            idx = full_output.rfind(prefix)
            if idx != -1:
                pred = full_output[idx + len(prefix):].strip()
            else:
                pred = full_output.strip()

            preds.append(pred)
            golds.append(gold)

    # ---- compute metric ----
    if len(golds) == 0:
        raise ValueError("No gold labels collected; check that label_texts are in the batch.")

    if task == "stsb":
        # STS-B: Pearson correlation
        pred_vals = []
        gold_vals = []

        for pred_str, gold_str in zip(preds, golds):
            gold_val = _extract_first_number(str(gold_str))
            pred_val = _extract_first_number(str(pred_str))

            if gold_val is None:
                continue

            # If prediction didn't produce a number, default to midpoint (2.5)
            if pred_val is None:
                pred_val = 2.5

            pred_vals.append(pred_val)
            gold_vals.append(gold_val)

        if len(pred_vals) < 2:
            pearson = 0.0
        else:
            pearson, _ = pearsonr(pred_vals, gold_vals)

        print(f"[STS-B] Pearson: {pearson:.4f} (n={len(pred_vals)})")
        acc = pearson  # Return Pearson as the metric
    else:
        # Classification tasks: compare normalized label strings
        norm_preds = []
        norm_golds = []
        for p, g in zip(preds, golds):
            # crude normalization: first token, lowercase, strip punctuation
            p_norm = p.split()[0] if p else ""
            p_norm = p_norm.strip(" .,:;!?").lower()

            g_norm = str(g).split()[0] if g else ""
            g_norm = g_norm.strip(" .,:;!?").lower()

            norm_preds.append(p_norm)
            norm_golds.append(g_norm)

        correct = sum(p == g for p, g in zip(norm_preds, norm_golds))
        acc = correct / len(norm_golds)
        print(f"[{task}] accuracy: {correct}/{len(norm_golds)} = {acc:.4f}")

    print("preds[:10]", preds[:10])
    print("golds[:10]", golds[:10])
    return acc, preds, golds

results = {}



for task in [ "stsb", "sst2", "qnli", "qqp", "cola", "mrpc",]:
    split_name = f"validation_{task}"
    print(f"\nEvaluating {split_name}...")

    acc, preds, golds = evaluate_glue(
        model=model,
        tokenizer=tokenizer,
        dataset=multi_glue[split_name],
        eval_collator=eval_collator,
        task=task,
        batch_size=32,
        max_new_tokens=10,
    )

    results[task] = acc

print("\n=== GLUE Eval Results ===")
for task, acc in results.items():
    print(f"{task}: {acc:.4f}")




Evaluating validation_stsb...


Evaluating stsb: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 47/47 [00:23<00:00,  2.02it/s]


[STS-B] Pearson: 0.9091 (n=1500)
preds[:10] ['5.0', '4.199999809265137', '4.400000095367432', '2.200000047683716', '2.200000047683716', '2.5999999046325684', '5.0', '3.200000047683716', '4.0', '4.800000190734863']
golds[:10] ['5.0', '4.75', '5.0', '2.4000000953674316', '2.75', '2.615000009536743', '5.0', '2.3329999446868896', '3.75', '5.0']

Evaluating validation_sst2...


Evaluating sst2: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 28/28 [00:04<00:00,  6.14it/s]


[sst2] accuracy: 832/872 = 0.9541
preds[:10] ['positive', 'negative', 'positive', 'positive', 'negative', 'positive', 'negative', 'negative', 'positive', 'negative']
golds[:10] ['positive', 'negative', 'positive', 'positive', 'negative', 'positive', 'negative', 'negative', 'positive', 'negative']

Evaluating validation_qnli...


Evaluating qnli: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 171/171 [01:01<00:00,  2.80it/s]


[qnli] accuracy: 5058/5463 = 0.9259
preds[:10] ['entailment', 'not_entailment', 'entailment', 'entailment', 'not_entailment', 'not_entailment', 'not_entailment', 'not_entailment', 'not_entailment', 'entailment']
golds[:10] ['entailment', 'not_entailment', 'not_entailment', 'entailment', 'not_entailment', 'not_entailment', 'not_entailment', 'not_entailment', 'not_entailment', 'entailment']

Evaluating validation_qqp...


Evaluating qqp: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1264/1264 [05:07<00:00,  4.11it/s]


[qqp] accuracy: 34337/40430 = 0.8493
preds[:10] ['not_duplicate', 'not_duplicate', 'duplicate', 'not_duplicate', 'duplicate', 'duplicate', 'duplicate', 'duplicate', 'duplicate', 'not_duplicate']
golds[:10] ['not_duplicate', 'not_duplicate', 'duplicate', 'not_duplicate', 'not_duplicate', 'duplicate', 'duplicate', 'duplicate', 'not_duplicate', 'not_duplicate']

Evaluating validation_cola...


Evaluating cola: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 33/33 [00:04<00:00,  7.04it/s]


[cola] accuracy: 870/1043 = 0.8341
preds[:10] ['acceptable', 'acceptable', 'acceptable', 'acceptable', 'acceptable', 'acceptable', 'unacceptable', 'unacceptable', 'acceptable', 'acceptable']
golds[:10] ['acceptable', 'acceptable', 'acceptable', 'acceptable', 'unacceptable', 'unacceptable', 'unacceptable', 'acceptable', 'acceptable', 'acceptable']

Evaluating validation_mrpc...


Evaluating mrpc: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 13/13 [00:03<00:00,  3.37it/s]

[mrpc] accuracy: 360/408 = 0.8824
preds[:10] ['equivalent', 'not_equivalent', 'not_equivalent', 'equivalent', 'not_equivalent', 'equivalent', 'not_equivalent', 'equivalent', 'equivalent', 'equivalent']
golds[:10] ['equivalent', 'not_equivalent', 'not_equivalent', 'equivalent', 'not_equivalent', 'equivalent', 'not_equivalent', 'equivalent', 'equivalent', 'equivalent']

=== GLUE Eval Results ===
stsb: 0.9091
sst2: 0.9541
qnli: 0.9259
qqp: 0.8493
cola: 0.8341
mrpc: 0.8824



