In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
num_gpus = 1

In [None]:
import unsloth
import json
from peft import PeftModel, PeftConfig
from unsloth.chat_templates import get_chat_template 
import lm_eval
from lm_eval import evaluator, tasks
from lm_eval.utils import setup_logging
import tempfile
from transformers import AutoTokenizer, AutoModelForCausalLM
setup_logging("INFO") 


In [None]:
gemma3_270m_pt_args = (
    "pretrained=google/gemma-3-270m,"
    "dtype=bfloat16,"
    "trust_remote_code=True"
)

gemma3_270m_it_args = (
    "pretrained=google/gemma-3-270m-it,"
    "dtype=bfloat16,"
    "trust_remote_code=True"
)

gemma3_1b_pt_args = (
    "pretrained=google/gemma-3-1b-pt,"
    "dtype=bfloat16,"
    "trust_remote_code=True"
)

gemma3_1b_it_args = (
    "pretrained=google/gemma-3-1b-it,"
    "dtype=bfloat16,"
    "trust_remote_code=True"
)

gemma3_4b_pt_args = (
    "pretrained=google/gemma-3-4b-pt,"
    "dtype=bfloat16,"
    "trust_remote_code=True"
)

gemma3_4b_it_args = (
    "pretrained=google/gemma-3-4b-it,"
    "tokenizer=google/gemma-3-4b-it,"
    "dtype=bfloat16,"

)

gemma3_12b_pt_args = (
    "pretrained=google/gemma-3-12b-pt,"
    "dtype=bfloat16,"
    "trust_remote_code=True"
)

gemma3_12b_it_args = (
    "pretrained=google/gemma-3-12b-it,"
    "dtype=bfloat16,"
    "trust_remote_code=True"
)

gemma3_27b_pt_args = (
    "pretrained=google/gemma-3-27b-pt,"
    "dtype=bfloat16,"
    "trust_remote_code=True"
)

gemma3_27b_it_args = (
    "pretrained=google/gemma-3-27b-it,"
    "dtype=bfloat16," 
    "trust_remote_code=True"
)

In [None]:
gemma3_lora_adapters  = {
    "google/gemma-3-270m-it" :  {
        "classification" : "Mhara/google_gemma-3-270m-it_ft_ag_news_v3",
        "question_answering" : "Mhara/google_gemma-3-270m-it_ft_squad_v2"
    },
    "google/gemma-3-1b-it" :  {
        "classification" : "Mhara/google_gemma-3-1b-it_ft_ag_news_v2",
        "question_answering" : "Mhara/google_gemma-3-1b-it_ft_squad_v2"
    },
    "google/gemma-3-4b-it" :  {
        "classification" : "Mhara/google_gemma-3-1b-it_ft_ag_news_v3",
        "question_answering" : "Mhara/google_gemma-3-4b-it_ft_squad_v2"
    },
    "google/gemma-3-12b-it" :  {
        "classification" : "Mhara/google_gemma-3-12b-it_ft_ag_news",
        "question_answering" : "Mhara/google_gemma-3-12b-it_ft_squad_v2"
    },
    "google/gemma-3-27b-it" :  {
        "classification" : "Mhara/google_gemma-3-27b-it_ft_ag_news",
        "question_answering" : "Mhara/google_gemma-3-27b-it_ft_squad_v2"
    },
} 

In [None]:
def load_adapter(base_model_id, adapter_id):
    cfg = PeftConfig.from_pretrained(adapter_id)
    base_id = cfg.base_model_name_or_path or base_model_id

    _tok = AutoTokenizer.from_pretrained(base_id, use_fast=True, trust_remote_code=True)
    base = AutoModelForCausalLM.from_pretrained(
        base_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    _model = PeftModel.from_pretrained(base, adapter_id)
    _model.eval()
    tok = get_chat_template(_tok, chat_template="gemma3")

    return _model, _tok


In [None]:
def _result_path(model_id, task, n_shot, ds_name, adapter_id=None):
    adapter_suffix = f"{adapter_id.split('/')[-1]}" if adapter_id else ""
    return f"results/{adapter_suffix}/{task}_{n_shot}shot_{ds_name}.json"

def _result_exists_and_valid(path: str) -> bool:
    if not os.path.exists(path):
        return False
    try:
        with open(path, "r") as f:
            data = json.load(f)
    except Exception:
        return False

    metric_keys = ("acc", "acc_norm", "em", "f1",
                   "acc,none", "acc_norm,none", "em,none", "f1,none")
    return any(k in data for k in metric_keys)


def _safe_save_json(path: str, obj: dict):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with tempfile.NamedTemporaryFile("w", delete=False, dir=os.path.dirname(path), suffix=".tmp") as tmp:
        json.dump(obj, tmp, indent=2)
        tmp_path = tmp.name
    os.replace(tmp_path, path)

In [None]:
def clean_gpu():
    import os
    os.system("""
    echo "Cleaning up vLLM and CUDA contexts"
    pkill -f "vllm" || true
    pkill -f "engine_core" || true
    pkill -f "torchrun" || true
    sleep 2
    fuser -k /dev/nvidia* || true
    """)
clean_gpu()

In [None]:
datasets_to_evaluate_on = {
    "question_answering"  : [
        "squadv2", # SQuAD
    ]
}

In [None]:
icl_variants = {
    "k_shot": [0,5],
    "decoding_strategy": {
        "default": {
            "temperature": 1,
            "top_p": 0.95,
            "top_k": 64,
            "max_gen_toks": 125,
            "do_sample": True
        }
    }
}

In [None]:
gemma3_models_evaluation_values_full = {
    "google/gemma-3-270m-pt" :{"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"]},
    "google/gemma-3-270m-it" :{"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"]},
    "google/gemma-3-1b-pt" :  {"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"]}, 
    "google/gemma-3-1b-it" :  {"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"]}, 
    "google/gemma-3-4b-pt" :  {"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"]},
    "google/gemma-3-4b-it" :  {"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"]},
    "google/gemma-3-12b-pt" : {"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"]},
    "google/gemma-3-12b-it" : {"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"]},
    "google/gemma-3-27b-pt" : {"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"]},
    "google/gemma-3-27b-it" : {"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"]},
}

In [None]:
gemma3_models = {
    "google/gemma-3-270m-it" : gemma3_270m_it_args,
}
gemma3_models_evaluation_values = {
    "google/gemma-3-270m-it" : {"n_shot" : 0, "ds_name"  : "default", "ds_kwargs" :  icl_variants["decoding_strategy"]["default"],
                             "adapter" : gemma3_lora_adapters["google/gemma-3-270m-it"]["classification"]},
}

In [None]:
model_name = "hf"
os.makedirs("results", exist_ok=True)

for model_id, model_args in gemma3_models.items():
    if not model_id.endswith("it"):
        continue

    if isinstance(model_args, (list, tuple)):
        base_model_args_str = ",".join([str(part) for part in model_args if part])
    else:
        base_model_args_str = str(model_args)

    n_shot   = gemma3_models_evaluation_values[model_id]["n_shot"]
    ds_kwargs = gemma3_models_evaluation_values[model_id]["ds_kwargs"] 
    ds_name  = "default"

    adapter_map = gemma3_lora_adapters.get(model_id, {})

    for task_type, datasets in datasets_to_evaluate_on.items():
        adapter_id = gemma3_models_evaluation_values[model_id]["adapter"] 

        model_args_str = base_model_args_str
        if adapter_id:
            model_args_str += f",peft={adapter_id}"

        for task in datasets:
            for n_shot in icl_variants["k_shot"]:
                out_path = _result_path(model_id, task, n_shot, ds_name, adapter_id)
                if _result_exists_and_valid(out_path):
                    print(f"Skip (already done): {out_path}")
                else:
                    print(f"\nðŸ”¹ Evaluating model: {model_id} | tasks={task} | {n_shot}-shot | strategy=default")
                    results = evaluator.simple_evaluate(
                        model=model_name,
                        model_args=model_args_str,
                        tasks=task,
                        num_fewshot=n_shot,
                        device="cuda:0",
                        batch_size="auto",
                    )["results"]

                    metrics = results[task]
                    out_path = _result_path(model_id, task, n_shot, ds_name, adapter_id)
                    _safe_save_json(out_path, metrics)
                clean_gpu()
