In [None]:
import os

# Hugging Face and temp
import os, pathlib


from datasets import load_dataset
from huggingface_hub import login

# ✅ Step 1: Log in to Hugging Face


dataset = load_dataset("multitask_textqa_benchmark")


In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from huggingface_hub import login
import torch



# ✅ Step 2: Define model name
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

# ✅ Step 3: Load configuration
config = AutoConfig.from_pretrained(model_name)
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
# config.num_labels = 2  # Uncomment if doing classification

# ✅ Step 4: Load tokenizer with legacy=True to avoid conversion error
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    legacy=True  # Suppresses the warning/error with tokenizer.model
)

from transformers import AutoModelForSequenceClassification
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16
    

)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
from MJLoRAFA import apply_monkeyjump


blocks_spec = {
    #"SiglipEncoderLayer": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25],
    "LlamaDecoderLayer":  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
}
linears = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"]  # add 'up_proj','down_proj', etc. "out_proj", "fc1", "fc2", "gate_proj", "up_proj", "down_proj"

model = apply_monkeyjump(
    model,
    blocks=blocks_spec,
    shared_expert=["o_proj", "gate_proj"],
    linears=linears,
    rank=2, alpha=5.0,
    temperature=1.0,   # router T
    ema_momentum=0.5,
    top_k=1,
    rep_mode="token",
    jitter_noise=0.1,

)


In [4]:
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import torch


@dataclass
class GlueLlavaDataCollator:
    """
    Text-only multitask collator for Llama 3 style processors on GLUE.
    """
    tokenizer: Any
    is_train: bool = True
    pad_to_multiple_of: Optional[int] = 8
    answer_prefix: str = "The correct output is"
    debug: bool = False
    force_left_padding: bool = True
    insert_token_id: int = 128001  # Token to insert before eot_id

    def __post_init__(self):
        tok = self.tokenizer

        if self.force_left_padding:
            tok.padding_side = "left"

        if tok.pad_token_id is None:
            tok.pad_token = tok.eos_token

        self.pad_id = tok.pad_token_id
        self.eos_id = tok.eos_token_id

        # Build preamble by getting actual special token IDs for Llama 3
        try:
            start_header_id = tok.convert_tokens_to_ids("<|start_header_id|>")
            end_header_id = tok.convert_tokens_to_ids("<|end_header_id|>")
            assistant_ids = tok.encode("assistant", add_special_tokens=False)
            newline_ids = tok.encode("\n\n", add_special_tokens=False)
            
            base_preamble = [start_header_id] + assistant_ids + [end_header_id] + newline_ids
            
            self._preamble_variants = [
                base_preamble,
                [start_header_id] + assistant_ids + [end_header_id],
            ]
            
            self.eot_id = tok.convert_tokens_to_ids("<|eot_id|>")
            
        except Exception as e:
            if self.debug:
                print(f"[warn] Could not build Llama 3 preamble: {e}")
            self._preamble_variants = []
            self.eot_id = self.eos_id

        if self.debug:
            print(f"Preamble variants: {self._preamble_variants}")
            print(f"EOT ID: {self.eot_id}, EOS ID: {self.eos_id}")
            print(f"Insert token ID: {self.insert_token_id}")

    @staticmethod
    def _rfind_subseq(hay, needle) -> int:
        if not needle or len(needle) > len(hay):
            return -1
        for s in range(len(hay) - len(needle), -1, -1):
            if hay[s:s + len(needle)] == needle:
                return s
        return -1

    def _find_assistant_start(self, ids) -> int:
        for needle in self._preamble_variants:
            pos = self._rfind_subseq(ids, needle)
            if pos != -1:
                return pos + len(needle)
        return -1

    def _first_eos_after(self, ids, start) -> int:
        if start < 0:
            return -1
        for i in range(start, len(ids)):
            if ids[i] == self.eot_id or ids[i] == self.eos_id:
                return i
        return len(ids)

    def _insert_token_before_eot(
        self, 
        input_ids: torch.Tensor, 
        attention_mask: Optional[torch.Tensor]
    ) -> tuple:
        """
        Insert self.insert_token_id before <|eot_id|> in each sequence.
        Result: content<|end_of_text|><|eot_id|>
        """
        batch_size, seq_len = input_ids.shape
        
        if attention_mask is None:
            attention_mask = (input_ids != self.pad_id).long()
        
        new_input_ids_list = []
        new_attention_mask_list = []
        
        for i in range(batch_size):
            ids = input_ids[i].tolist()
            attn = attention_mask[i].tolist()
            
            # Find the last eot_id position in the valid (non-pad) region
            eot_pos = -1
            for j in range(len(ids) - 1, -1, -1):
                if attn[j] == 1 and ids[j] == self.eot_id:
                    eot_pos = j
                    break
            
            if eot_pos != -1:
                # Check if insert_token_id is already right before eot_id
                if eot_pos > 0 and ids[eot_pos - 1] == self.insert_token_id:
                    # Already in correct position
                    new_input_ids_list.append(ids)
                    new_attention_mask_list.append(attn)
                else:
                    # Insert token before eot_id
                    new_ids = ids[:eot_pos] + [self.insert_token_id] + ids[eot_pos:]
                    new_attn = attn[:eot_pos] + [1] + attn[eot_pos:]
                    new_input_ids_list.append(new_ids)
                    new_attention_mask_list.append(new_attn)
            else:
                # No eot_id found, append at end of valid tokens
                last_valid_idx = -1
                for j in range(len(ids) - 1, -1, -1):
                    if attn[j] == 1:
                        last_valid_idx = j
                        break
                
                if last_valid_idx >= 0 and ids[last_valid_idx] != self.insert_token_id:
                    ids.append(self.insert_token_id)
                    attn.append(1)
                
                new_input_ids_list.append(ids)
                new_attention_mask_list.append(attn)
        
        # Find max length and pad
        max_len = max(len(ids) for ids in new_input_ids_list)
        
        if self.pad_to_multiple_of:
            max_len = ((max_len + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of) * self.pad_to_multiple_of
        
        # Left-pad all sequences to max_len
        for i in range(batch_size):
            pad_len = max_len - len(new_input_ids_list[i])
            if pad_len > 0:
                new_input_ids_list[i] = [self.pad_id] * pad_len + new_input_ids_list[i]
                new_attention_mask_list[i] = [0] * pad_len + new_attention_mask_list[i]
        
        new_input_ids = torch.tensor(new_input_ids_list, dtype=input_ids.dtype)
        new_attention_mask = torch.tensor(new_attention_mask_list, dtype=attention_mask.dtype)
        
        return new_input_ids, new_attention_mask

    def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        texts: List[str] = []
        label_texts: List[str] = []

        for ex in examples:
            user_text = ex["question"]
            gold = ex.get("answer", "")

            if self.is_train:
                assistant_text = f"{self.answer_prefix} {gold}"

                conversation = [
                    {"role": "user", "content": user_text},
                    {"role": "assistant", "content": assistant_text},
                ]

                text = self.tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=False,
                    tokenize=False,
                )

            else:
                conversation = [
                    {"role": "user", "content": user_text},
                ]

                base = self.tokenizer.apply_chat_template(
                    conversation,
                    add_generation_prompt=True,
                    tokenize=False,
                )

                text = base + self.answer_prefix

            texts.append(text)
            label_texts.append(gold)

        batch = self.tokenizer(
            text=texts,
            padding=True,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        input_ids = batch["input_ids"]
        attn = batch.get("attention_mask", None)

        # Insert token 128001 before <|eot_id|> in each sequence
        input_ids, attn = self._insert_token_before_eot(input_ids, attn)
        batch["input_ids"] = input_ids
        batch["attention_mask"] = attn

        labels = torch.full_like(input_ids, -100)

        if self.is_train:
            for i in range(input_ids.size(0)):
                valid_pos = attn[i].nonzero(as_tuple=False).squeeze(-1)
                compact_ids = input_ids[i, valid_pos].tolist()

                start_c = self._find_assistant_start(compact_ids)
                end_c = self._first_eos_after(compact_ids, start_c) if start_c != -1 else -1
                end_c = end_c   # include EOS/eot_id in label span

                if start_c != -1 and end_c > start_c:
                    span_pos = valid_pos[start_c:end_c]
                    labels[i, span_pos] = input_ids[i, span_pos]
                elif self.debug:
                    print(f"[warn] assistant span not found for sample {i}")
                    print(f"  compact_ids: {compact_ids[:50]}...")
                    print(f"  looking for: {self._preamble_variants}")
        else:
            batch["label_texts"] = label_texts

        batch["labels"] = labels
        return batch

train_collator = GlueLlavaDataCollator(tokenizer=tokenizer, is_train=True,  debug=True)


Preamble variants: [[128006, 78191, 128007, 271], [128006, 78191, 128007]]
EOT ID: 128009, EOS ID: 128009
Insert token ID: 128001


In [5]:
d = train_collator(dataset['train'].select(range(3)))

In [6]:
print(tokenizer.decode(d['input_ids'][2]))

<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|begin_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|>

To properly, thoroughly apply bug spray while camping,

  Options: 
 A. spray a heav

In [7]:
k = []

for i in d['labels'][2]:
    if i == -100:
        k.append(0)
    else:
        k.append(i)
    
print(tokenizer.decode(k))

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!The correct output is B<|end_of_text|>!


In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'prompt', 'source', 'answer', 'category', 'dataset'],
        num_rows: 70302
    })
    test_arc_challenge: Dataset({
        features: ['question', 'prompt', 'source', 'answer', 'category', 'dataset'],
        num_rows: 500
    })
    test_arc_easy: Dataset({
        features: ['question', 'prompt', 'source', 'answer', 'category', 'dataset'],
        num_rows: 500
    })
    test_boolq: Dataset({
        features: ['question', 'prompt', 'source', 'answer', 'category', 'dataset'],
        num_rows: 1000
    })
    test_hellaswag: Dataset({
        features: ['question', 'prompt', 'source', 'answer', 'category', 'dataset'],
        num_rows: 1000
    })
    test_openbookqa: Dataset({
        features: ['question', 'prompt', 'source', 'answer', 'category', 'dataset'],
        num_rows: 500
    })
    test_piqa: Dataset({
        features: ['question', 'prompt', 'source', 'answer', 'category', 'dataset'],
        num_rows:

In [9]:
from MJtrainer import MonkeyTrainer
from transformers import Trainer, TrainingArguments, TrainerCallback
training_args = TrainingArguments(
    output_dir="./llava-lora-finetuned_our",
    per_device_train_batch_size=6,
    gradient_accumulation_steps=2,  
    save_total_limit=4,
    save_steps=500000,
    num_train_epochs=1,
    remove_unused_columns=False, 
   
    bf16=True,  
    logging_dir="./logs",
    logging_steps=100,
    eval_strategy="no",  
    #eval_steps=10,
    save_strategy="no",
    optim="paged_adamw_8bit",
  
    learning_rate=1e-3,
    warmup_ratio=0.03,
    weight_decay=0.00,
    report_to="none"
)


import os
import torch
from tqdm import tqdm




# moe_trainer.py
import torch
from transformers import Trainer


# Example instantiation:
trainer = MonkeyTrainer(
    model=model,
    args=training_args,                  # your HF TrainingArguments
    train_dataset=dataset['train'],

    data_collator=train_collator,         # your collator
    momentum=0.5, 


    step_interval=1,                    # update every 10 train steps
    stop_update_step=1000,               # stop updates at/after step 1000
)






Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


[MonkeyTrainer] interval=1, stop_at=1000, update_on=micro, momentum=0.5


In [10]:
from kmneas import init_router_centers
init_router_centers(
    trainer,
    subset_size=5000,        # Use 2000 samples
    loader_batch_size=8,
    collect_batches=2000,
    kmeans_iters=30,
    seed=42,
    verbose=True,
  
    # Optional: auto-select number of experts
    auto_select_experts=False,
    rep_mode="token",
)



[kmeans-init] Collecting representations (TOKEN-BASED)...
[kmeans-init] Found 32 patched blocks, mode: TOKEN-BASED


Collecting representations:   0%|          | 0/2000 [00:00<?, ?it/s]

[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, dim=4096
[kmeans-init] Block collected 10000 tokens, di

In [11]:
trainer.train()

Step,Training Loss
100,5.7837
200,0.1706
300,0.1542
400,0.1481
500,0.1552
600,0.136
700,0.1345
800,0.1376
900,0.1239
1000,0.1333


TrainOutput(global_step=5858, training_loss=0.20806089258307928, metrics={'train_runtime': 2273.1626, 'train_samples_per_second': 30.927, 'train_steps_per_second': 2.577, 'total_flos': 3.897378324849623e+17, 'train_loss': 0.20806089258307928, 'epoch': 0.9999146539216524})

In [14]:
eval_collator  = GlueLlavaDataCollator(tokenizer=tokenizer, is_train=False)

In [15]:
import re
from datasets import load_dataset, DatasetDict

LETTER = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

def make_mcq_pattern(prefix: str):
    return re.compile(
        rf"({prefix}[1-9])\s*:\s*(.*?)(?=\s+{prefix}[1-9]\s*:|\s+Answer format:|$)",
        flags=re.IGNORECASE | re.DOTALL
    )

PATTERNS = [
    (make_mcq_pattern("Answer"),   lambda k: k.lower()),   # Answer1..Answer4 (ARC/OpenBookQA/SocialIQA)
    (make_mcq_pattern("Ending"),   lambda k: k.lower()),   # Ending1..Ending4 (HellaSwag)
    (make_mcq_pattern("Solution"), lambda k: k.lower()),   # Solution1..Solution2 (PIQA)
    (make_mcq_pattern("Option"),   lambda k: k.lower()),   # Option1..Option2 (Winogrande)
]

def extract_question_text(instr: str) -> str:
    m = re.search(r"question:\s*(.*?)(?:\n\s*\n|$)", instr, flags=re.I | re.S)
    if m:
        return m.group(1).strip()

    m2 = re.search(r"to the question:\s*(.*?)(?:\n\s*\n|$)", instr, flags=re.I | re.S)
    if m2:
        return m2.group(1).strip()

    m3 = re.search(r"sentence:\s*(.*?)(?:\n\s*\n|$)", instr, flags=re.I | re.S)
    if m3:
        return m3.group(1).strip()

    return instr.strip()

def extract_options(instr: str):
    for pat, norm_key in PATTERNS:
        found = pat.findall(instr)
        if found:
            return [(norm_key(k), v.strip()) for k, v in found]

    if re.search(r"Answer format:\s*true\s*/\s*false", instr, flags=re.I):
        return [("true", "true"), ("false", "false")]

    return []

def build_answer_map(option_keys):
    return {k: LETTER[i] for i, k in enumerate(option_keys)}

def instruction_to_train_style(example, *, category: str, dataset_name: str):
    instr = (example.get("instruction") or "").strip()
    gold = (example.get("answer") or "").strip().lower()

    # ---- Winogrande special case ----
    if dataset_name == "winogrande":
        # Extract options from instruction text: "Option1: Sarah Option2: Maria"
        opt_match = re.search(r"Option1:\s*(\S+)\s+Option2:\s*(\S+)", instr, re.IGNORECASE)
        
        if opt_match:
            opt1 = opt_match.group(1)
            opt2 = opt_match.group(2)
            
            # Extract question (sentence with blank) - text before "Option1:"
            parts = re.split(r"\s*Option1:", instr, flags=re.IGNORECASE)
            q_text = parts[0].strip()
            
            # Clean up the question text - remove the prompt prefix
            q_text = re.sub(r"^Please choose the correct answer to fill in the blank to complete the given sentence:\s*", "", q_text, flags=re.I)
            
            question = (
                f"{q_text}\n\n"
                f"  Options: \n"
                f" A. {opt1}\n"
                f" B. {opt2}"
            )
            
            # Map option1 -> A, option2 -> B
            answer = {"option1": "A", "option2": "B"}.get(gold, gold)
            
            return {
                "question": question,
                "prompt": "Choose the correct answer to fill in the blank.",
                "source": "",
                "answer": answer,
                "category": category,
                "dataset": dataset_name,
            }
        
        # Fallback if pattern doesn't match
        return {
            "question": extract_question_text(instr),
            "prompt": "Choose the correct answer to fill in the blank.",
            "source": "",
            "answer": gold,
            "category": category,
            "dataset": dataset_name,
        }

    # ---- Default path (all other datasets) ----
    q_text = extract_question_text(instr)
    options = extract_options(instr)

    if options:
        option_keys = [k for k, _ in options]
        option_texts = [t for _, t in options]
        ans_map = build_answer_map(option_keys)

        question = q_text + "\n\n  Options: \n"
        for i, text in enumerate(option_texts):
            question += f" {LETTER[i]}. {text}\n"
        question = question.rstrip()

        answer = ans_map.get(gold, gold)
    else:
        question = q_text
        answer = gold

    return {
        "question": question,
        "prompt": "Choose the correct answer to the question.",
        "source": "",
        "answer": answer,
        "category": category,
        "dataset": dataset_name,
    }

def load_and_convert(url_or_path: str, *, category: str, dataset_name: str):
    ds = load_dataset("json", data_files=url_or_path, split="train")
    ds2 = ds.map(
        instruction_to_train_style,
        fn_kwargs={"category": category, "dataset_name": dataset_name},
        remove_columns=ds.column_names,
        desc=f"Converting {dataset_name}",
    )
    return ds2

TEST_SOURCES = {
    "test_winogrande": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/winogrande/test.json",
        "category": "WinoGrande",
        "dataset": "winogrande",
    },
    "test_arc_challenge": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/ARC-Challenge/test.json",
        "category": "ARC-Challenge",
        "dataset": "arc_challenge",
    },
    "test_arc_easy": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/ARC-Easy/test.json",
        "category": "ARC-Easy",
        "dataset": "arc_easy",
    },
    "test_boolq": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/boolq/test.json",
        "category": "BoolQ",
        "dataset": "boolq",
    },
    "test_hellaswag": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/hellaswag/test.json",
        "category": "HellaSwag",
        "dataset": "hellaswag",
    },
    "test_openbookqa": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/openbookqa/test.json",
        "category": "OpenBookQA",
        "dataset": "openbookqa",
    },
    "test_piqa": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/piqa/test.json",
        "category": "PIQA",
        "dataset": "piqa",
    },
    "test_social_i_qa": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/social_i_qa/test.json",
        "category": "SocialIQA",
        "dataset": "social_i_qa",
    },
}

updated_tests = {
    split: load_and_convert(meta["path"], category=meta["category"], dataset_name=meta["dataset"])
    for split, meta in TEST_SOURCES.items()
}

dataset = DatasetDict(updated_tests)

# Verify Winogrande is processed correctly
print("=== Winogrande Sample ===")
print(dataset["test_winogrande"][0])
print(f"\nAnswer: {dataset['test_winogrande'][0]['answer']}")





import re
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

def _extract_choice(pred: str):
    """
    Extract a discrete choice from model text.
    Supports:
      - A/B/C/D
      - answer1/answer2/answer3/answer4
      - ending1..ending4
      - solution1/solution2
      - option1/option2 (Winogrande)
      - true/false
    Returns normalized label like "A"/"B"/"C"/"D" or "true"/"false" or "".
    """
    if pred is None:
        return ""

    s = str(pred).strip().lower()

    # common wrappers
    # e.g. "the correct output is B", "the correct answer is answer3"
    s = re.sub(r"^the correct (output|answer) is\s+", "", s).strip()

    # If it directly contains a letter choice, prefer first standalone A-D
    m = re.search(r"\b([abcd])\b", s)
    if m:
        return m.group(1).upper()

    # Map answer/ending/solution/option tokens to letters (1->A, 2->B, 3->C, 4->D)
    m = re.search(r"\b(answer|ending|solution|option)\s*([1-4])\b", s)
    if m:
        idx = int(m.group(2)) - 1
        return "ABCD"[idx]

    # BoolQ / yes-no style
    if re.search(r"\btrue\b", s):
        return "true"
    if re.search(r"\bfalse\b", s):
        return "false"

    # fallback: first token stripped
    tok = s.split()[0] if s else ""
    tok = tok.strip(" .,:;!?")
    return tok

def evaluate_commonsense(
    model,
    tokenizer,
    dataset,
    eval_collator,
    batch_size: int = 8,
    max_new_tokens: int = 10,
):
    """
    Evaluation for commonsense reasoning DatasetDict splits.
    Metric: exact-match accuracy after normalization.
    """
    model.eval()
    preds = []
    golds = []

    loader = DataLoader(dataset, batch_size=batch_size, collate_fn=eval_collator)

    for batch_examples in tqdm(loader, desc="Evaluating commonsense"):
        batch = {
            k: v.to(model.device)
            for k, v in batch_examples.items()
            if isinstance(v, torch.Tensor)
        }

        with torch.no_grad():
            gen = model.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch.get("attention_mask", None),
                max_new_tokens=max_new_tokens,
                pad_token_id=tokenizer.eos_token_id,
            )

        generated = tokenizer.batch_decode(gen, skip_special_tokens=True)

        prefix = eval_collator.answer_prefix  # e.g. "The correct output is"

        for full_output, gold in zip(generated, batch_examples["label_texts"]):
            idx = full_output.rfind(prefix)
            if idx != -1:
                pred_text = full_output[idx + len(prefix):].strip()
            else:
                pred_text = full_output.strip()

            preds.append(pred_text)
            golds.append(gold)

    if len(golds) == 0:
        raise ValueError("No gold labels collected; check that label_texts are in the batch.")

    norm_preds = [_extract_choice(p) for p in preds]
    norm_golds = [_extract_choice(g) for g in golds]

    correct = sum(p == g for p, g in zip(norm_preds, norm_golds))
    acc = correct / len(norm_golds)

    print(f"[commonsense] accuracy: {correct}/{len(norm_golds)} = {acc:.4f}")
    print("preds[:10]", norm_preds[:10])
    print("golds[:10]", norm_golds[:10])

    return acc, preds, golds


# Run evaluation
results = {}

for split_name in dataset.keys():
    print(f"\nEvaluating {split_name}...")
    acc, preds, golds = evaluate_commonsense(
        model=model,
        tokenizer=tokenizer,
        dataset=dataset[split_name],
        eval_collator=eval_collator,
        batch_size=32,
        max_new_tokens=10,
    )
    results[split_name] = acc

print("\n=== Commonsense Eval Results ===")
for split, acc in results.items():
    print(f"{split}: {acc:.4f}")




Evaluating test_winogrande...


Evaluating commonsense: 100%|██████████| 40/40 [00:24<00:00,  1.60it/s]


[commonsense] accuracy: 0/1267 = 0.0000
preds[:10] ['maria', 'sarah', 'bed', 'B', 'B', 'sarah', 'the', 'B', 'jennifer', 'the']
golds[:10] ['option2', 'option1', 'option2', 'option1', 'option1', 'option1', 'option1', 'option2', 'option2', 'option1']

Evaluating test_arc_challenge...


Evaluating commonsense: 100%|██████████| 37/37 [00:14<00:00,  2.55it/s]


[commonsense] accuracy: 930/1172 = 0.7935
preds[:10] ['C', 'B', 'C', 'C', 'D', 'B', 'C', 'C', 'C', 'A']
golds[:10] ['C', 'B', 'C', 'D', 'D', 'B', 'C', 'C', 'B', 'A']

Evaluating test_arc_easy...


Evaluating commonsense: 100%|██████████| 75/75 [00:28<00:00,  2.65it/s]


[commonsense] accuracy: 2033/2376 = 0.8556
preds[:10] ['A', 'B', 'D', 'B', 'C', 'C', 'B', 'C', 'C', 'A']
golds[:10] ['A', 'B', 'D', 'D', 'B', 'C', 'A', 'C', 'C', 'A']

Evaluating test_boolq...


Evaluating commonsense: 100%|██████████| 103/103 [00:21<00:00,  4.90it/s]


[commonsense] accuracy: 2305/3270 = 0.7049
preds[:10] ['A', 'A', 'A', 'A', 'A', 'B', 'A', 'A', 'A', 'A']
golds[:10] ['B', 'A', 'A', 'A', 'A', 'B', 'A', 'A', 'A', 'A']

Evaluating test_hellaswag...


Evaluating commonsense:  35%|███▌      | 110/314 [00:56<01:45,  1.94it/s]


KeyboardInterrupt: 

In [16]:
import re
from datasets import load_dataset, DatasetDict

LETTER = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

def make_mcq_pattern(prefix: str):
    return re.compile(
        rf"({prefix}[1-9])\s*:\s*(.*?)(?=\s+{prefix}[1-9]\s*:|\s+Answer format:|$)",
        flags=re.IGNORECASE | re.DOTALL
    )

PATTERNS = [
    (make_mcq_pattern("Answer"),   lambda k: k.lower()),   # Answer1..Answer4 (ARC/OpenBookQA/SocialIQA)
    (make_mcq_pattern("Ending"),   lambda k: k.lower()),   # Ending1..Ending4 (HellaSwag)
    (make_mcq_pattern("Solution"), lambda k: k.lower()),   # Solution1..Solution2 (PIQA)
    (make_mcq_pattern("Option"),   lambda k: k.lower()),   # Option1..Option2 (Winogrande)
]

def extract_question_text(instr: str) -> str:
    m = re.search(r"question:\s*(.*?)(?:\n\s*\n|$)", instr, flags=re.I | re.S)
    if m:
        return m.group(1).strip()

    m2 = re.search(r"to the question:\s*(.*?)(?:\n\s*\n|$)", instr, flags=re.I | re.S)
    if m2:
        return m2.group(1).strip()

    m3 = re.search(r"sentence:\s*(.*?)(?:\n\s*\n|$)", instr, flags=re.I | re.S)
    if m3:
        return m3.group(1).strip()

    return instr.strip()

def extract_options(instr: str):
    for pat, norm_key in PATTERNS:
        found = pat.findall(instr)
        if found:
            return [(norm_key(k), v.strip()) for k, v in found]

    if re.search(r"Answer format:\s*true\s*/\s*false", instr, flags=re.I):
        return [("true", "true"), ("false", "false")]

    return []

def build_answer_map(option_keys):
    return {k: LETTER[i] for i, k in enumerate(option_keys)}

def instruction_to_train_style(example, *, category: str, dataset_name: str):
    instr = (example.get("instruction") or "").strip()
    gold = (example.get("answer") or "").strip().lower()

    # ---- Winogrande special case ----
    if dataset_name == "winogrande":
        # Extract options from instruction text: "Option1: Sarah Option2: Maria"
        opt_match = re.search(r"Option1:\s*(\S+)\s+Option2:\s*(\S+)", instr, re.IGNORECASE)
        
        if opt_match:
            opt1 = opt_match.group(1)
            opt2 = opt_match.group(2)
            
            # Extract question (sentence with blank) - text before "Option1:"
            parts = re.split(r"\s*Option1:", instr, flags=re.IGNORECASE)
            q_text = parts[0].strip()
            
            # Clean up the question text - remove the prompt prefix
            q_text = re.sub(r"^Please choose the correct answer to fill in the blank to complete the given sentence:\s*", "", q_text, flags=re.I)
            
            question = (
                f"{q_text}\n\n"
                f"  Options: \n"
                f" A. {opt1}\n"
                f" B. {opt2}"
            )
            
            # Map option1 -> A, option2 -> B
            answer = {"option1": "A", "option2": "B"}.get(gold, gold)
            
            return {
                "question": question,
                "prompt": "Choose the correct answer to fill in the blank.",
                "source": "",
                "answer": answer,
                "category": category,
                "dataset": dataset_name,
            }
        
        # Fallback if pattern doesn't match
        return {
            "question": extract_question_text(instr),
            "prompt": "Choose the correct answer to fill in the blank.",
            "source": "",
            "answer": gold,
            "category": category,
            "dataset": dataset_name,
        }

    # ---- Default path (all other datasets) ----
    q_text = extract_question_text(instr)
    options = extract_options(instr)

    if options:
        option_keys = [k for k, _ in options]
        option_texts = [t for _, t in options]
        ans_map = build_answer_map(option_keys)

        question = q_text + "\n\n  Options: \n"
        for i, text in enumerate(option_texts):
            question += f" {LETTER[i]}. {text}\n"
        question = question.rstrip()

        answer = ans_map.get(gold, gold)
    else:
        question = q_text
        answer = gold

    return {
        "question": question,
        "prompt": "Choose the correct answer to the question.",
        "source": "",
        "answer": answer,
        "category": category,
        "dataset": dataset_name,
    }

def load_and_convert(url_or_path: str, *, category: str, dataset_name: str):
    ds = load_dataset("json", data_files=url_or_path, split="train")
    ds2 = ds.map(
        instruction_to_train_style,
        fn_kwargs={"category": category, "dataset_name": dataset_name},
        remove_columns=ds.column_names,
        desc=f"Converting {dataset_name}",
    )
    return ds2

TEST_SOURCES = {
    "test_winogrande": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/winogrande/test.json",
        "category": "WinoGrande",
        "dataset": "winogrande",
    },
    "test_arc_challenge": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/ARC-Challenge/test.json",
        "category": "ARC-Challenge",
        "dataset": "arc_challenge",
    },
    "test_arc_easy": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/ARC-Easy/test.json",
        "category": "ARC-Easy",
        "dataset": "arc_easy",
    },
    "test_boolq": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/boolq/test.json",
        "category": "BoolQ",
        "dataset": "boolq",
    },
    "test_hellaswag": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/hellaswag/test.json",
        "category": "HellaSwag",
        "dataset": "hellaswag",
    },
    "test_openbookqa": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/openbookqa/test.json",
        "category": "OpenBookQA",
        "dataset": "openbookqa",
    },
    "test_piqa": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/piqa/test.json",
        "category": "PIQA",
        "dataset": "piqa",
    },
    "test_social_i_qa": {
        "path": "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/social_i_qa/test.json",
        "category": "SocialIQA",
        "dataset": "social_i_qa",
    },
}

updated_tests = {
    split: load_and_convert(meta["path"], category=meta["category"], dataset_name=meta["dataset"])
    for split, meta in TEST_SOURCES.items()
}

dataset = DatasetDict(updated_tests)

# Verify Winogrande is processed correctly
print("=== Winogrande Sample ===")
print(dataset["test_winogrande"][0])
print(f"\nAnswer: {dataset['test_winogrande'][0]['answer']}")

=== Winogrande Sample ===
{'answer': 'B', 'question': 'Sarah was a much better surgeon than Maria so _ always got the easier cases.\n\n  Options: \n A. Sarah\n B. Maria', 'prompt': 'Choose the correct answer to fill in the blank.', 'source': '', 'category': 'WinoGrande', 'dataset': 'winogrande'}

Answer: B


In [17]:
import re
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

def _extract_choice(pred: str):
    """
    Extract a discrete choice from model text.
    Supports:
      - A/B/C/D
      - answer1/answer2/answer3/answer4
      - ending1..ending4
      - solution1/solution2
      - option1/option2 (Winogrande)
      - true/false
    Returns normalized label like "A"/"B"/"C"/"D" or "true"/"false" or "".
    """
    if pred is None:
        return ""

    s = str(pred).strip().lower()

    # common wrappers
    # e.g. "the correct output is B", "the correct answer is answer3"
    s = re.sub(r"^the correct (output|answer) is\s+", "", s).strip()

    # If it directly contains a letter choice, prefer first standalone A-D
    m = re.search(r"\b([abcd])\b", s)
    if m:
        return m.group(1).upper()

    # Map answer/ending/solution/option tokens to letters (1->A, 2->B, 3->C, 4->D)
    m = re.search(r"\b(answer|ending|solution|option)\s*([1-4])\b", s)
    if m:
        idx = int(m.group(2)) - 1
        return "ABCD"[idx]

    # BoolQ / yes-no style
    if re.search(r"\btrue\b", s):
        return "true"
    if re.search(r"\bfalse\b", s):
        return "false"

    # fallback: first token stripped
    tok = s.split()[0] if s else ""
    tok = tok.strip(" .,:;!?")
    return tok

def evaluate_commonsense(
    model,
    tokenizer,
    dataset,
    eval_collator,
    batch_size: int = 8,
    max_new_tokens: int = 10,
):
    """
    Evaluation for commonsense reasoning DatasetDict splits.
    Metric: exact-match accuracy after normalization.
    """
    model.eval()
    preds = []
    golds = []

    loader = DataLoader(dataset, batch_size=batch_size, collate_fn=eval_collator)

    for batch_examples in tqdm(loader, desc="Evaluating commonsense"):
        batch = {
            k: v.to(model.device)
            for k, v in batch_examples.items()
            if isinstance(v, torch.Tensor)
        }

        with torch.no_grad():
            gen = model.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch.get("attention_mask", None),
                max_new_tokens=max_new_tokens,
                pad_token_id=tokenizer.eos_token_id,
            )

        generated = tokenizer.batch_decode(gen, skip_special_tokens=True)

        prefix = eval_collator.answer_prefix  # e.g. "The correct output is"

        for full_output, gold in zip(generated, batch_examples["label_texts"]):
            idx = full_output.rfind(prefix)
            if idx != -1:
                pred_text = full_output[idx + len(prefix):].strip()
            else:
                pred_text = full_output.strip()

            preds.append(pred_text)
            golds.append(gold)

    if len(golds) == 0:
        raise ValueError("No gold labels collected; check that label_texts are in the batch.")

    norm_preds = [_extract_choice(p) for p in preds]
    norm_golds = [_extract_choice(g) for g in golds]

    correct = sum(p == g for p, g in zip(norm_preds, norm_golds))
    acc = correct / len(norm_golds)

    print(f"[commonsense] accuracy: {correct}/{len(norm_golds)} = {acc:.4f}")
    print("preds[:10]", norm_preds[:10])
    print("golds[:10]", norm_golds[:10])

    return acc, preds, golds


# Run evaluation
results = {}

for split_name in dataset.keys():
    print(f"\nEvaluating {split_name}...")
    acc, preds, golds = evaluate_commonsense(
        model=model,
        tokenizer=tokenizer,
        dataset=dataset[split_name],
        eval_collator=eval_collator,
        batch_size=32,
        max_new_tokens=10,
    )
    results[split_name] = acc

print("\n=== Commonsense Eval Results ===")
for split, acc in results.items():
    print(f"{split}: {acc:.4f}")


Evaluating test_winogrande...


Evaluating commonsense: 100%|██████████| 40/40 [00:11<00:00,  3.47it/s]


[commonsense] accuracy: 971/1267 = 0.7664
preds[:10] ['B', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'B', 'A']
golds[:10] ['B', 'A', 'B', 'A', 'A', 'A', 'A', 'B', 'B', 'A']

Evaluating test_arc_challenge...


Evaluating commonsense: 100%|██████████| 37/37 [00:14<00:00,  2.54it/s]


[commonsense] accuracy: 931/1172 = 0.7944
preds[:10] ['C', 'B', 'C', 'C', 'D', 'B', 'C', 'C', 'C', 'A']
golds[:10] ['C', 'B', 'C', 'D', 'D', 'B', 'C', 'C', 'B', 'A']

Evaluating test_arc_easy...


Evaluating commonsense: 100%|██████████| 75/75 [00:27<00:00,  2.76it/s]


[commonsense] accuracy: 2011/2376 = 0.8464
preds[:10] ['A', 'B', 'D', 'B', 'C', 'C', 'B', 'C', 'C', 'A']
golds[:10] ['A', 'B', 'D', 'D', 'B', 'C', 'A', 'C', 'C', 'A']

Evaluating test_boolq...


Evaluating commonsense: 100%|██████████| 103/103 [00:31<00:00,  3.31it/s]


[commonsense] accuracy: 2300/3270 = 0.7034
preds[:10] ['A', 'B', 'B', 'A', 'A', 'B', 'A', 'A', 'A', 'A']
golds[:10] ['B', 'A', 'A', 'A', 'A', 'B', 'A', 'A', 'A', 'A']

Evaluating test_hellaswag...


Evaluating commonsense: 100%|██████████| 314/314 [03:54<00:00,  1.34it/s]


[commonsense] accuracy: 6312/10042 = 0.6286
preds[:10] ['B', 'D', 'C', 'C', 'B', 'B', 'C', 'A', 'B', 'C']
golds[:10] ['D', 'D', 'C', 'C', 'B', 'B', 'C', 'A', 'B', 'B']

Evaluating test_openbookqa...


Evaluating commonsense: 100%|██████████| 16/16 [00:04<00:00,  3.46it/s]


[commonsense] accuracy: 375/500 = 0.7500
preds[:10] ['B', 'A', 'C', 'C', 'B', 'C', 'C', 'B', 'D', 'B']
golds[:10] ['B', 'A', 'C', 'C', 'C', 'C', 'C', 'B', 'D', 'B']

Evaluating test_piqa...


Evaluating commonsense: 100%|██████████| 58/58 [00:33<00:00,  1.71it/s]


[commonsense] accuracy: 1451/1838 = 0.7894
preds[:10] ['B', 'B', 'A', 'B', 'A', 'B', 'B', 'B', 'B', 'A']
golds[:10] ['A', 'B', 'B', 'B', 'A', 'B', 'B', 'A', 'A', 'A']

Evaluating test_social_i_qa...


Evaluating commonsense: 100%|██████████| 62/62 [00:17<00:00,  3.48it/s]

[commonsense] accuracy: 1270/1954 = 0.6499
preds[:10] ['A', 'B', 'B', 'A', 'C', 'B', 'A', 'B', 'C', 'B']
golds[:10] ['C', 'A', 'B', 'A', 'C', 'A', 'B', 'B', 'C', 'B']

=== Commonsense Eval Results ===
test_winogrande: 0.7664
test_arc_challenge: 0.7944
test_arc_easy: 0.8464
test_boolq: 0.7034
test_hellaswag: 0.6286
test_openbookqa: 0.7500
test_piqa: 0.7894
test_social_i_qa: 0.6499



