# **AdaptMI : Adaptive Skill-based In-context Math Instructions for Small Language Models**

This is the jupyter notebook of the paper **AdaptMI : Adaptive Skill-based In-context Math Instructions for Small Language Models.**

**‼️ Caveats:** Due to resource limit, the notebook only tests AdaptMI on 50 MATH examples, which is way less than the test set (5k examples) in the paper. Therefore, the **exact** accuracy numbers (as well as accuracy gain) may deviate from Table 1 in the paper.

## 🔔 Stage1: **Detection of _easy_ and _difficult_ questions**

In this stage, we will label a question as _easy_ or _difficult_ for a Small Language Model.

### 👉 Stage1-1: Initial evaluation

#### Environmental Setup

```bash
conda create -n matheval python=3.10
conda activate matheval

cd evaluation/latex2sympy
pip install -e .
cd ..
pip install torch
pip install -r requirements.txt
pip install vllm==0.5.1 --no-build-isolation
pip install transformers==4.42.3
conda install ipykernel
```

Please activate the environment `matheval`, and run the following cells:

In [None]:
%env TOKENIZERS_PARALLELISM=true
import sys
import os

current_dir = os.getcwd()
evaluation_dir_path = os.path.join(current_dir, 'evaluation')

if evaluation_dir_path not in sys.path:
    sys.path.insert(0, evaluation_dir_path)

import os
import json
import logging
import random
import numpy as np
import torch
from types import SimpleNamespace
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoConfig

from evaluation.math_eval import *

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
class Args:
    def __init__(self):
        self.data_names = "math"  # Example: "gsm8k,math"
        self.data_dir = "./evaluation/data"
        self.data_path = None # Default: None
        self.model_name_or_path = "models/Qwen2.5-1.5B-Instruct" # Replace with your model
        self.output_dir = "./output/stage1_inference"
        self.prompt_type = "qwen25-math-cot"
        self.split = "test"
        self.num_test_sample = 50 # -1 for full data, set to a small number for testing
        self.seed = 0
        self.start = 0
        self.end = -1 # -1 for all samples from start
        self.temperature = 0.0
        self.n_sampling = 1
        self.top_p = 1.0
        self.max_tokens_per_call = 1024 # Reduced for faster testing, adjust as needed
        self.shuffle = True
        self.use_vllm = False # Set to False if not using vLLM or for models not supported well by vLLM
        self.save_outputs = True

        # Ours
        self.LLM_judge = False
        self.PRM_judge = False
        self.random_shots = False
        self.llm_sol = False
        
        self.overwrite = True # Set to True to overwrite existing output files
        self.use_safetensors = True # Recommended
        self.num_shots = 5
        self.num_skill_shots = 0
        self.apply_chat_template = False # Set to True if your model expects chatml or similar
        self.pipeline_parallel_size = 1
        self.adapt_few_shot = False

        # Auto-set top_p based on temperature for greedy sampling
        self.top_p = (
            1 if self.temperature == 0 else self.top_p
        )

In [None]:
def set_seed(seed_value: int):
    """Sets the seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
    logger.info(f"Set seed to {seed_value}")

In [None]:
def run_eval(args_obj: Args): # Type hint Args_obj with our class
    # Ensure CUDA_VISIBLE_DEVICES is set, or vLLM might default unexpectedly
    if args_obj.use_vllm and not os.environ.get("CUDA_VISIBLE_DEVICES"):
        logger.warning("CUDA_VISIBLE_DEVICES is not set. vLLM might default to all available GPUs or the first one.")
        # os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optionally set a default if none is provided

    available_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
    
    llm_instance = None
    tokenizer_instance = None

    if args_obj.use_vllm:
        logger.info(f"Attempting to load model {args_obj.model_name_or_path} with vLLM.")
        # The rope_scaling logic from the original script can be complex and model-specific.
        # vLLM often handles this automatically or via its own config.
        # For simplicity, direct LLM initialization is shown here.
        # If specific rope_scaling is needed, it should be passed to LLM constructor.
        try:
            llm_instance = LLM(
                model=args_obj.model_name_or_path,
                tensor_parallel_size=max(1, len(available_gpus) // args_obj.pipeline_parallel_size), # Ensure at least 1
                pipeline_parallel_size=args_obj.pipeline_parallel_size,
                dtype="bfloat16" if torch.cuda.is_bf16_supported() else "float16",
                trust_remote_code=True,
                # max_model_len=max_tokens_per_call + some_buffer # Consider setting max_model_len
            )
            logger.info(f"vLLM loaded model: {args_obj.model_name_or_path}")
        except Exception as e:
            logger.error(f"Failed to load model with vLLM: {e}")
            raise

        if args_obj.apply_chat_template:
            try:
                tokenizer_instance = AutoTokenizer.from_pretrained(
                    args_obj.model_name_or_path, trust_remote_code=True
                )
                logger.info(f"Tokenizer loaded for chat template: {args_obj.model_name_or_path}")
            except Exception as e:
                logger.error(f"Failed to load tokenizer for chat template: {e}")
                # Depending on strictness, you might want to raise e or allow proceeding without chat template
    else:
        logger.info(f"Attempting to load model {args_obj.model_name_or_path} with Hugging Face Transformers.")
        try:
            tokenizer_instance, llm_instance, _ = load_model(args_obj.model_name_or_path, args_obj)
            logger.info(f"Hugging Face model and tokenizer loaded: {args_obj.model_name_or_path}")
        except Exception as e:
            logger.error(f"Failed to load model with Hugging Face: {e}")
            raise

    # Infer & eval
    data_list = args_obj.data_names.split(",")
    results = []
    for data_name_str in data_list:
        data_name = data_name_str.strip()
        if not data_name:
            continue
        logger.info(f"\nProcessing dataset: {data_name}")
        
        dataset_result = main(llm_instance, tokenizer_instance, data_name, args_obj)
        results.append(dataset_result)

    if results:
        summary_data_list = [name.strip() for name in data_list if name.strip()]
        
        if len(summary_data_list) > 1:
            # Calculate average accuracy if multiple datasets were processed
            valid_results_for_avg = [res for res in results if res and "acc" in res]
            if valid_results_for_avg:
                avg_acc = sum(res["acc"] for res in valid_results_for_avg) / len(valid_results_for_avg)
                results.append({"acc": avg_acc, "data_name": "avg"}) # Add data_name for clarity
                summary_data_list.append("avg")
            else:
                logger.warning("No valid results with 'acc' key found to calculate average.")

        logger.info("\n" + "="*20 + " Overall Summary " + "="*20)
        
        pad_width = max(len(name) for name in summary_data_list) if summary_data_list else 10

        header_parts = []
        score_parts = []
        
        res_idx = 0
        for name in summary_data_list:
            header_parts.append(name.ljust(pad_width))
            current_res = None
            if name == "avg" and results[-1].get("data_name") == "avg": # Check if last result is avg
                 current_res = results[-1]
            elif res_idx < len(results) and results[res_idx].get("data_name", data_list[res_idx].strip()) == name: # Check by original name
                 current_res = results[res_idx]
                 res_idx +=1
            elif res_idx < len(results): # Fallback if data_name not in result but order might match
                 current_res = results[res_idx]
                 logger.warning(f"Result for {name} matched by order, not explicit data_name key in result dict.")
                 res_idx +=1


            if current_res and "acc" in current_res:
                score_parts.append(f"{current_res['acc']:.1f}".ljust(pad_width))
            else:
                score_parts.append("N/A".ljust(pad_width))
                logger.warning(f"Accuracy not found for dataset: {name}")
        
        if header_parts:
            final_header = "\t".join(header_parts)
            final_scores = "\t".join(score_parts)
            print("\nResults Summary:") # Print to console for easy viewing
            print(final_header)
            print(final_scores)
            logger.info("Final results summary (also printed above):")
            logger.info(f"Datasets: {final_header}")
            logger.info(f"Accuracy: {final_scores}")
        else:
            logger.info("No results to display in summary table.")
    else:
        logger.info("No datasets processed or no results returned.")

In [None]:
# --- Initial evaluation ---
args = Args()
set_seed(args.seed)
run_eval(args)

### 👉 Stage 1-2
This stage classifies questions into _easy_ and _difficult_ according to the model's performance. `math-rm/rm_classify.py` employs a process reward model to assign scores for each step in the SLM response. We then use thresholds τ1, τ2 (`pred_thres1` and `pred_thres2` in the code) to classify whether a question q is easy or difficult.

#### Environmental Setup

```bash
conda create -n classify python=3.10.9
conda activate classify

git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
git checkout 55cc214c767741e83ee7b346e5e13e6c03b7b9fa
pip install -e .

pip3 install torch==2.1.2 torchvision torchaudio
pip install flash-attn

git clone https://github.com/lm-sys/FastChat.git
cd FastChat
pip install -e .

git clone https://github.com/WeiXiongUST/RLHF-Reward-Modeling.git
pip install deepspeed

pip install -r math-rm/requirements.txt
conda install ipykernel
```

Please activate the environment `classify`, and run the following cells:

In [None]:
# Cell 1: Setup and Imports
import sys
import os
import json
import time
import numpy as np
import torch
import torch.distributed as dist
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from accelerate import Accelerator # For accelerator.device, accelerator.num_processes, etc.
from collections import Counter
from tqdm import tqdm # For notebook-level progress if any; worker has its own

# Add math-rm to Python path to import custom modules
# Assumes the notebook is in the parent directory of 'math-rm'
sys.path.append('./math-rm')
from rm_classify import worker # Only worker is directly called from the main logic

# For a cleaner log, you might want to set Transformers logging level
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

In [None]:
# Cell 2: Define Arguments (mimicking argparse)
class Args2:
    def __init__(self):
        self.reward_name_or_path = 'pwork7/llama31_it_prm_2e6_bz32_1epoch_conversation'
        self.dataset = './output/stage1_inference/test_50_0+5shots.jsonl'
        self.output_dir = "./output/stage1_classified"
        self.pred_thres1 = 0.9
        self.pred_thres2 = 0.7
        self.num_n = 128  # Reduced for faster notebook execution, original was 1024
        self.num_test_sample = 50
        self.model_type = "Deepseek"

args = Args2()

# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
print(f"Output directory: {args.output_dir}")
print(f"Using model: {args.reward_name_or_path}")
print(f"Using dataset: {args.dataset}")
print(f"Number of N candidates per question: {args.num_n}")
print(f"Number of test samples: {args.num_test_sample if args.num_test_sample != -1 else 'ALL'}")

In [None]:
# Cell 3: Main Logic (adapted from the original script's if __name__ == "__main__": block)

# Initialize Accelerator
accelerator = Accelerator()

# Determine distributed training parameters from Accelerator, falling back to script's os.getenv approach for torch.dist
# script_world_size for initializing torch.distributed backend
# accelerator.num_processes for data sharding
# accelerator.local_process_index for rank and device mapping
ddp_world_size = int(os.getenv("WORLD_SIZE", accelerator.num_processes))
ddp_local_rank = int(os.getenv("LOCAL_RANK", accelerator.local_process_index))

# Use accelerator's properties for device to ensure consistency
device = accelerator.device
print(f"Process {ddp_local_rank}/{ddp_world_size} using device: {device}")

# Load dataset
print(f"Loading dataset {args.dataset}...")
# Make sure the dataset has 'question', 'code', 'pred', 'score' fields as expected by select_sample
# For 'RLHFlow/Deepseek-GSM8K-Test', it might have 'prompt' and 'label' or similar.
# This example proceeds assuming the structure matches. You might need to preprocess/map fields.
# ds = load_dataset("json", data_files={"test": args.dataset}, split="test") # Original way
try:
    ds = load_dataset(args.dataset, split="test") # Simpler way if dataset is on Hugging Face Hub
    # Example: rename columns if necessary for GSM8K
    # ds = ds.rename_column("question", "prompt_text") # Fictitious example
    # ds = ds.map(lambda example: {'question': example['prompt_text'], ...})
except Exception as e:
    print(f"Failed to load dataset directly. Trying as json: {e}")
    ds = load_dataset("json", data_files={"test": args.dataset}, split="test")


if args.num_test_sample == -1:
    num_sample = len(ds)
else:
    num_sample = min(args.num_test_sample, len(ds))
ds = ds.select(range(num_sample))
print(f"Selected {len(ds)} samples for processing.")

# Load model and tokenizer
print(f"Loading reward model {args.reward_name_or_path}...")
downloaded = False
while not downloaded:
    try:
        tokenizer = AutoTokenizer.from_pretrained(args.reward_name_or_path)
        model = AutoModelForCausalLM.from_pretrained(
            args.reward_name_or_path, 
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
        ).to(device).eval() # Model to the device determined by Accelerator
        downloaded = True
        print("Model and tokenizer loaded successfully.")
    except Exception as error:
        print(f"An error occurred during model loading: {error}")
        print("Retrying in 2 seconds...")
        time.sleep(2)

tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
if model.config.pad_token_id is None:
    model.config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else model.config.eos_token_id


# Prepare data for the current process (data sharding)
data_size = len(ds)
# Use accelerator.num_processes for sharding
share = (data_size + accelerator.num_processes - 1) // accelerator.num_processes # ceiling division
start_idx = accelerator.process_index * share
end_idx = min((accelerator.process_index + 1) * share, data_size)

# Select the portion of the dataset for this process
current_ds_slice = ds.select(np.arange(start_idx, end_idx))
print(f"Process {accelerator.process_index}/{accelerator.num_processes}: processing {len(current_ds_slice)} samples (indices {start_idx}-{end_idx-1}).")

data_for_worker = [sample for sample in current_ds_slice]

# Call worker function
# The 'local_rank' argument for worker and select_sample is used for .to(local_rank)
# We pass `device` (which could be 'cpu' or 'cuda:X') to ensure tensors are moved correctly.
print(f"Starting worker on process {accelerator.process_index}...")
selected_data, new_data = worker(args, model, tokenizer, data_for_worker, device)
print(f"Worker finished on process {accelerator.process_index}. Results: {len(selected_data)} classifications, {len(new_data)} processed samples.")

# Distributed data gathering
# If running in a distributed environment (e.g. via accelerate launch)
if accelerator.num_processes > 1:
    # Ensure MASTER_ADDR and MASTER_PORT are set if not already by launch environment for torch.dist
    os.environ.setdefault('MASTER_ADDR', 'localhost')
    os.environ.setdefault('MASTER_PORT', '12355') # Ensure this port is free or configurable
    
    if not dist.is_initialized():
        backend = 'nccl' if accelerator.use_cuda else 'gloo'
        print(f"Process {ddp_local_rank}: Initializing torch.distributed with backend {backend}, world_size {ddp_world_size}, rank {ddp_local_rank}")
        dist.init_process_group(
            backend=backend,
            rank=ddp_local_rank,
            world_size=ddp_world_size
        )

# Prepare data for gathering (all_gather_object expects a list of objects to be populated)
data_to_send_from_this_rank = {
    "selected_data_payload": selected_data,
    "new_data_payload": new_data
}

if accelerator.num_processes > 1:
    gathered_dictionaries_list = [None] * accelerator.num_processes
    dist.all_gather_object(gathered_dictionaries_list, data_to_send_from_this_rank)
else:
    gathered_dictionaries_list = [data_to_send_from_this_rank]

gathered_classification_results = []
gathered_full_samples = []

# Process gathered data (only on the main process)
if accelerator.is_main_process:
    print("Main process gathering results...")
    for i in range(accelerator.num_processes):
        data_from_rank_i = gathered_dictionaries_list[i]
        if data_from_rank_i:
            gathered_classification_results.extend(data_from_rank_i["selected_data_payload"])
            gathered_full_samples.extend(data_from_rank_i["new_data_payload"])
    
    print(f"Total gathered classification results: {len(gathered_classification_results)}")
    print(f"Total gathered samples for saving: {len(gathered_full_samples)}")

    # Calculate metrics
    counter = Counter(gathered_classification_results)
    num_TN = counter["TP"]
    num_FP = counter["FN"]
    num_FN = counter["FP"]
    num_TP = counter["TN"]

    precision = 0
    recall = 0
    f1 = 0
    
    if num_TP + num_FP > 0:
        precision = num_TP / (num_TP + num_FP)
    if num_TP + num_FN > 0:
        recall = num_TP / (num_TP + num_FN)
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)

    accuracy = 0
    total_predictions = num_TN + num_TP + num_FN + num_FP
    if total_predictions > 0:
        accuracy = (num_TN + num_TP) / total_predictions
    
    specificity = 0 # True Negative Rate
    if num_TN + num_FP > 0:
        specificity = num_TP / (num_TP + num_FN)

    print(f"\nMetrics:\n")
    print("To stay consistent with the paper, positive means model failure, nagative means model success.\n")
    print(f"TP: {num_TP}, FN: {num_FN}, FP: {num_FP}, TN: {num_TN}")
    print(f"Total Predictions: {total_predictions}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall (Sensitivity): {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Specificity (captured failure case): {specificity:.4f}")

    metrics_summary = {
        "TP": num_TP, "FN": num_FN, "FP": num_FP, "TN": num_TN,
        "total_predictions": total_predictions,
        "accuracy": accuracy, "precision": precision, "recall": recall, "f1_score": f1,
        "specificity": specificity,
        "num_test_samples_processed": len(gathered_full_samples) # Should match num_sample if all processed
    }
    
    output_metrics_file = os.path.join(args.output_dir, f"size{args.num_test_sample}_thres1={args.pred_thres1}_thres2={args.pred_thres2}_metrics.json")
    with open(output_metrics_file, 'w') as f:
        json.dump(metrics_summary, f, indent=4, ensure_ascii=False)
    print(f"Metrics saved to {output_metrics_file}")

    output_data_file = os.path.join(args.output_dir, f"size{args.num_test_sample}_thres1={args.pred_thres1}_thres2={args.pred_thres2}_save_data.jsonl")
    with open(output_data_file, 'w') as f:
        for entry in gathered_full_samples:
            f.write(json.dumps(entry) + "\n")
    print(f"Processed data saved to {output_data_file}")

if accelerator.num_processes > 1 and dist.is_initialized():
    dist.destroy_process_group()
    print(f"Process {ddp_local_rank}: Destroyed DDP process group.")

print("Classification finished.")

## 🔔 **Stage 2: Skill-based selection of in-context examples**

- AdaptMI uses skill-based _k_-shot examples for _difficult_ questions and fixed _k_-shot examples for _easy_ questions.
- AdaptMI+ focuses only on the skills that the model’s initial response lacks.

Please activate the environment `matheval`, and run the following cells:

In [None]:
%env TOKENIZERS_PARALLELISM=true
import sys
import os

current_dir = os.getcwd()
evaluation_dir_path = os.path.join(current_dir, 'evaluation')

if evaluation_dir_path not in sys.path:
    sys.path.insert(0, evaluation_dir_path)

import os
import json
import logging
import random
import numpy as np
import torch
from types import SimpleNamespace
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoConfig

from evaluation.math_eval import *

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
class Args:
    def __init__(self):
        self.data_names = "math-skill"  # Example: "gsm8k,math"
        self.data_dir = "./evaluation/data"
        self.data_path = "./output/stage1_classified/size50_thres1=0.9_thres2=0.7_save_data.jsonl" # Default: None
        self.model_name_or_path = "models/Qwen2.5-1.5B-Instruct" # Replace with your model
        self.output_dir = "./output/stage2_inference"
        self.prompt_type = "qwen25-math-cot"
        self.split = "test"
        self.num_test_sample = 50 # -1 for full data, set to a small number for testing
        self.seed = 0
        self.start = 0
        self.end = -1 # -1 for all samples from start
        self.temperature = 0.0
        self.n_sampling = 1
        self.top_p = 1.0
        self.max_tokens_per_call = 1024 # Reduced for faster testing, adjust as needed
        self.shuffle = True
        self.use_vllm = False # Set to False if not using vLLM or for models not supported well by vLLM
        self.save_outputs = True

        # Ours
        self.LLM_judge = False
        self.PRM_judge = True
        self.random_shots = False
        self.llm_sol = False
        
        self.overwrite = True # Set to True to overwrite existing output files
        self.use_safetensors = True # Recommended
        self.num_shots = 5
        self.num_skill_shots = 5
        self.apply_chat_template = False # Set to True if your model expects chatml or similar
        self.pipeline_parallel_size = 1
        self.adapt_few_shot = False

        # Auto-set top_p based on temperature for greedy sampling
        self.top_p = (
            1 if self.temperature == 0 else self.top_p
        )

In [None]:
def set_seed(seed_value: int):
    """Sets the seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
    logger.info(f"Set seed to {seed_value}")

In [None]:
def run_eval(args_obj: Args): # Type hint Args_obj with our class
    # Ensure CUDA_VISIBLE_DEVICES is set, or vLLM might default unexpectedly
    if args_obj.use_vllm and not os.environ.get("CUDA_VISIBLE_DEVICES"):
        logger.warning("CUDA_VISIBLE_DEVICES is not set. vLLM might default to all available GPUs or the first one.")
        # os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optionally set a default if none is provided

    available_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
    
    llm_instance = None
    tokenizer_instance = None

    if args_obj.use_vllm:
        logger.info(f"Attempting to load model {args_obj.model_name_or_path} with vLLM.")
        # The rope_scaling logic from the original script can be complex and model-specific.
        # vLLM often handles this automatically or via its own config.
        # For simplicity, direct LLM initialization is shown here.
        # If specific rope_scaling is needed, it should be passed to LLM constructor.
        try:
            llm_instance = LLM(
                model=args_obj.model_name_or_path,
                tensor_parallel_size=max(1, len(available_gpus) // args_obj.pipeline_parallel_size), # Ensure at least 1
                pipeline_parallel_size=args_obj.pipeline_parallel_size,
                dtype="bfloat16" if torch.cuda.is_bf16_supported() else "float16",
                trust_remote_code=True,
                # max_model_len=max_tokens_per_call + some_buffer # Consider setting max_model_len
            )
            logger.info(f"vLLM loaded model: {args_obj.model_name_or_path}")
        except Exception as e:
            logger.error(f"Failed to load model with vLLM: {e}")
            raise

        if args_obj.apply_chat_template:
            try:
                tokenizer_instance = AutoTokenizer.from_pretrained(
                    args_obj.model_name_or_path, trust_remote_code=True
                )
                logger.info(f"Tokenizer loaded for chat template: {args_obj.model_name_or_path}")
            except Exception as e:
                logger.error(f"Failed to load tokenizer for chat template: {e}")
                # Depending on strictness, you might want to raise e or allow proceeding without chat template
    else:
        logger.info(f"Attempting to load model {args_obj.model_name_or_path} with Hugging Face Transformers.")
        try:
            tokenizer_instance, llm_instance, _ = load_model(args_obj.model_name_or_path, args_obj)
            logger.info(f"Hugging Face model and tokenizer loaded: {args_obj.model_name_or_path}")
        except Exception as e:
            logger.error(f"Failed to load model with Hugging Face: {e}")
            raise

    # Infer & eval
    data_list = args_obj.data_names.split(",")
    results = []
    for data_name_str in data_list:
        data_name = data_name_str.strip()
        if not data_name:
            continue
        logger.info(f"\nProcessing dataset: {data_name}")
        
        dataset_result = main(llm_instance, tokenizer_instance, data_name, args_obj)
        results.append(dataset_result)

    if results:
        summary_data_list = [name.strip() for name in data_list if name.strip()]
        
        if len(summary_data_list) > 1:
            # Calculate average accuracy if multiple datasets were processed
            valid_results_for_avg = [res for res in results if res and "acc" in res]
            if valid_results_for_avg:
                avg_acc = sum(res["acc"] for res in valid_results_for_avg) / len(valid_results_for_avg)
                results.append({"acc": avg_acc, "data_name": "avg"}) # Add data_name for clarity
                summary_data_list.append("avg")
            else:
                logger.warning("No valid results with 'acc' key found to calculate average.")

        logger.info("\n" + "="*20 + " Overall Summary " + "="*20)
        
        pad_width = max(len(name) for name in summary_data_list) if summary_data_list else 10

        header_parts = []
        score_parts = []
        
        res_idx = 0
        for name in summary_data_list:
            header_parts.append(name.ljust(pad_width))
            current_res = None
            if name == "avg" and results[-1].get("data_name") == "avg": # Check if last result is avg
                 current_res = results[-1]
            elif res_idx < len(results) and results[res_idx].get("data_name", data_list[res_idx].strip()) == name: # Check by original name
                 current_res = results[res_idx]
                 res_idx +=1
            elif res_idx < len(results): # Fallback if data_name not in result but order might match
                 current_res = results[res_idx]
                 logger.warning(f"Result for {name} matched by order, not explicit data_name key in result dict.")
                 res_idx +=1


            if current_res and "acc" in current_res:
                score_parts.append(f"{current_res['acc']:.1f}".ljust(pad_width))
            else:
                score_parts.append("N/A".ljust(pad_width))
                logger.warning(f"Accuracy not found for dataset: {name}")
        
        if header_parts:
            final_header = "\t".join(header_parts)
            final_scores = "\t".join(score_parts)
            print("\nResults Summary:") # Print to console for easy viewing
            print(final_header)
            print(final_scores)
            logger.info("Final results summary (also printed above):")
            logger.info(f"Datasets: {final_header}")
            logger.info(f"Accuracy: {final_scores}")
        else:
            logger.info("No results to display in summary table.")
    else:
        logger.info("No datasets processed or no results returned.")

In [None]:
# --- AdaptMI Evaluation ---
args = Args()

set_seed(args.seed)
run_eval(args)

## ✨ Evaluation Summary
Please run this cell to get a summary of AdaptMI performance.

In [None]:
import json
import numpy as np

def analyze(jsonl_file):
    total, correct_before, correct_after = 0, 0, 0
    easy_all, diff_all, easy_correct_before, easy_correct_after, diff_correct_before, diff_correct_after = 0, 0, 0, 0, 0, 0

    final_data = {}
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            final_data[data["idx"]] = data
            total += 1

            if True in data["initial_score"]:
                correct_before += 1
            if True in data["score"]:
                correct_after += 1
            
            if data["prm_pred"]: # easy
                easy_all += 1
                if True in data["initial_score"]:
                    easy_correct_before += 1
                if True in data["score"]:
                    easy_correct_after += 1
            else: # difficult
                diff_all += 1
                if True in data["initial_score"]:
                    diff_correct_before += 1
                if True in data["score"]:
                    diff_correct_after += 1

    
    initial_accuracy = correct_before / total
    final_accuracy = correct_after / total
    print(f"📎 Initial accuracy: {(100*initial_accuracy):.2f}")
    print(f"✨ AdaptMI accuracy: {(100*final_accuracy):.2f}")
    
    acc_easy_before = easy_correct_before / easy_all
    acc_easy_after = easy_correct_after / easy_all
    acc_diff_before = diff_correct_before / diff_all
    acc_diff_after = diff_correct_after / diff_all
    print(f"Initial accuracy on difficult questions: {(100*acc_diff_before):.2f}\nFinal accuracy on difficult questions: {(100*acc_diff_after):.2f} 🚀")

analyze("output/stage2_inference/test_50_5+0shots.jsonl")