# Multi-model inference notebook

This notebook downloads/caches three instruct models into `../models/`, runs them on the full SAQ and MCQ test sets, and writes per-model submission files into `../results/`.

Models:
- `mistralai/Mistral-7B-Instruct-v0.2`
- `meta-llama/Meta-Llama-3-8B`
- `meta-llama/Meta-Llama-3-8B-Instruct`

> Note: Meta Llama models require accepting the license on Hugging Face and being logged in (`huggingface-cli login`).

In [2]:
# Core deps
import os
import re
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import pandas as pd
import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
# Paths
CACHE_DIR = os.path.abspath("/data/cat/ws/albu670g-qa-model/models")
RESULTS_DIR = os.path.abspath("/home/h5/albu670g/qa-model/results")
DATA_DIR = os.path.abspath(r"/home/h5/albu670g/qa-model/data")

os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32

print("CACHE_DIR:", CACHE_DIR)
print("RESULTS_DIR:", RESULTS_DIR)
print("DEVICE:", DEVICE, "| DTYPE:", DTYPE)


CACHE_DIR: /data/cat/ws/albu670g-qa-model/models
RESULTS_DIR: /home/h5/albu670g/qa-model/results
DEVICE: cuda | DTYPE: torch.float16


In [4]:
# Model registry
MODEL_IDS = [
    "mistralai/Mistral-7B-Instruct-v0.2",
    # "meta-llama/Meta-Llama-3-8B",
    "meta-llama/Meta-Llama-3-8B-Instruct",
]

def safe_model_slug(model_id: str) -> str:
    """Make a filesystem-friendly model id."""
    return model_id.replace("/", "__").replace("-", "_").lower()


In [5]:
from dataclasses import dataclass
import gc
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

@dataclass
class ModelBundle:
    model_id: str
    tokenizer: AutoTokenizer
    model: AutoModelForCausalLM

def load_model(model_id: str, cache_dir: str = CACHE_DIR) -> ModelBundle:
    """Load a tokenizer + model, caching files in cache_dir.

    Notes:
    - Uses low_cpu_mem_usage to reduce peak RAM during weight loading.
    - Uses device_map="auto" so weights are placed directly on the target device(s)
      when possible (avoids a big CPU RAM spike + a later .to(device)).
    """
    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        cache_dir=cache_dir,
        use_fast=True,
    )

    # Ensure a padding token exists for batch generation.
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        cache_dir=cache_dir,
        torch_dtype=DTYPE,
        device_map="auto",
        low_cpu_mem_usage=True,
        offload_folder="/data/cat/ws/albu670g-qa-model/models/offload",
    )

    model.eval()
    return ModelBundle(model_id=model_id, tokenizer=tokenizer, model=model)

def unload_model(bundle: ModelBundle) -> None:
    """Free GPU/CPU memory between runs."""
    try:
        del bundle.model
        del bundle.tokenizer
    except Exception:
        pass

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

In [6]:
def build_prompt(tokenizer: AutoTokenizer, system_prompt: str, user_prompt: str) -> str:
    """Build a model-appropriate prompt."""
    chat_template = getattr(tokenizer, "chat_template", None)

    if chat_template and hasattr(tokenizer, "apply_chat_template"):
        messages = [
            {"role": "system", "content": system_prompt.strip()},
            {"role": "user", "content": user_prompt.strip()},
        ]
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

    # Fallback when no chat template exists (e.g., base models)
    # Keep the old fallback format to minimize behavior changes.
    return f"[INST]{system_prompt.strip()}\n{user_prompt.strip()}[/INST]"

@torch.inference_mode()
def generate_text(
    bundle: ModelBundle,
    prompt: str,
    max_new_tokens: int = 64,
    do_sample: bool = False,
    temperature: float = 0.0,
    top_p: float = 1.0,
) -> str:
    """Generate text continuation for a single prompt."""
    tok = bundle.tokenizer
    model = bundle.model

    inputs = tok(prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature if do_sample else None,
        top_p=top_p if do_sample else None,
        pad_token_id=tok.eos_token_id,
        use_cache=True,
    )

    gen = tok.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
    return gen.strip()


In [7]:
# Task-specific post-processing
_SAQ_ANSWER_RE = re.compile(r"\banswer\s*:\s*(.+)", re.IGNORECASE)
_MCQA_RE = re.compile(r"\b([A-D])\b")

def parse_saq_answer(text: str) -> str:
    """Extract a one-word SAQ answer from the model output."""
    m = _SAQ_ANSWER_RE.search(text)
    if m:
        cand = m.group(1).strip()
    else:
        cand = text.strip()

    # Take the first token-ish segment; strip punctuation.
    cand = re.split(r"\s+|/|,|\bor\b", cand, maxsplit=1, flags=re.IGNORECASE)[0]
    cand = cand.strip().strip(".\"'`!?:;()[]{}<>")
    return cand.lower() if cand else "idk"

def parse_mcq_choice(text: str) -> str:
    """Extract a single choice letter A-D from the model output."""
    m = _MCQA_RE.search(text.upper())
    if m:
        return m.group(1)
    # Deterministic fallback
    return "A"

def mcq_one_hot(mcq_ids: pd.Series, choices: List[str]) -> pd.DataFrame:
    """Build Codabench-style MCQ submission with boolean A-D columns."""
    df = pd.DataFrame({"MCQID": mcq_ids})
    for col in ["A", "B", "C", "D"]:
        df[col] = [c == col for c in choices]
    return df


In [8]:
# Inference runners
SAQ_SYSTEM_PROMPT = """
Provide ONE word answer to the given question.

Give the answer in the following format:
Answer: *provided answer*.
Explanation: *provided explanation*.

If no answer can be provided:
Answer: idk.
Explanation: *provided explanation*.
""".strip()

MCQ_SYSTEM_PROMPT = """
Answer the multiple choice question.
Pick only one option (A, B, C, or D) without explanation.
""".strip()

def run_saq(bundle: ModelBundle, saq_df: pd.DataFrame) -> pd.DataFrame:
    """Run SAQ inference for the full dataframe and return submission df."""
    answers: List[str] = []
    for q in tqdm(saq_df["en_question"].tolist(), desc=f"SAQ | {bundle.model_id}"):
        prompt = build_prompt(bundle.tokenizer, SAQ_SYSTEM_PROMPT, f"Question: {q}")
        raw = generate_text(bundle, prompt, max_new_tokens=64, do_sample=False)
        answers.append(parse_saq_answer(raw))

    out = saq_df[["ID"]].copy()
    out["answer"] = answers
    return out

def run_mcq(bundle: ModelBundle, mcq_df: pd.DataFrame) -> pd.DataFrame:
    """Run MCQ inference for the full dataframe and return submission df."""
    choices: List[str] = []
    for q in tqdm(mcq_df["prompt"].tolist(), desc=f"MCQ | {bundle.model_id}"):
        prompt = build_prompt(bundle.tokenizer, MCQ_SYSTEM_PROMPT, q)
        raw = generate_text(bundle, prompt, max_new_tokens=32, do_sample=False)
        choices.append(parse_mcq_choice(raw))

    return mcq_one_hot(mcq_df["MCQID"], choices)


In [9]:
# Load datasets (FULL sets, no sampling)
saq_path = os.path.join(DATA_DIR, "test_dataset_saq.csv")
mcq_path = os.path.join(DATA_DIR, "test_dataset_mcq.csv")

saq = pd.read_csv(saq_path)
mcq = pd.read_csv(mcq_path)

print("SAQ rows:", len(saq))
print("MCQ rows:", len(mcq))
saq.head()


SAQ rows: 667
MCQ rows: 419


Unnamed: 0,ID,question,en_question,country
0,New-ko-01,در ایران محبوب‌ ترین انیمیشن کودکانه‌ای که معم...,What is the most popular children's animation ...,IR
1,New-in-03,What type of food from the UK is typically ser...,What type of food from UK is typically served ...,GB
2,New-gr-06,在中国年轻人中最受欢迎的食物是什么？,What is the most popular food in China among y...,CN
3,New-su-21,What are the common activities that seniors us...,What are the common activities that seniors us...,US
4,New-ch-72,在中国，哪个专业被认为最好就业？,Which major is considered most conducive to em...,CN


In [10]:
# Main loop: run all models and save per-model submissions
results_index = []

for model_id in MODEL_IDS:
    slug = safe_model_slug(model_id)
    print("\n" + "="*80)
    print("Loading:", model_id)

    bundle = None
    try:
        bundle = load_model(model_id)

        # SAQ
        saq_sub = run_saq(bundle, saq)
        saq_out_path = os.path.join(RESULTS_DIR, f"saq_submission__{slug}.tsv")
        saq_sub.to_csv(saq_out_path, sep="\t", index=False)
        del saq_sub

        # MCQ
        mcq_sub = run_mcq(bundle, mcq)
        mcq_out_path = os.path.join(RESULTS_DIR, f"mcq_submission__{slug}.tsv")
        mcq_sub.to_csv(mcq_out_path, sep="\t", index=False)
        del mcq_sub

        results_index.append(
            {
                "model_id": model_id,
                "saq_submission": saq_out_path,
                "mcq_submission": mcq_out_path,
            }
        )

    finally:
        if bundle is not None:
            unload_model(bundle)

index_df = pd.DataFrame(results_index)
index_df


Loading: mistralai/Mistral-7B-Instruct-v0.2


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

SAQ | mistralai/Mistral-7B-Instruct-v0.2:   0%|          | 0/667 [00:00<?, ?it/s]

MCQ | mistralai/Mistral-7B-Instruct-v0.2:   0%|          | 0/419 [00:00<?, ?it/s]


Loading: meta-llama/Meta-Llama-3-8B-Instruct


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

SAQ | meta-llama/Meta-Llama-3-8B-Instruct:   0%|          | 0/667 [00:00<?, ?it/s]

MCQ | meta-llama/Meta-Llama-3-8B-Instruct:   0%|          | 0/419 [00:00<?, ?it/s]

Unnamed: 0,model_id,saq_submission,mcq_submission
0,mistralai/Mistral-7B-Instruct-v0.2,/home/h5/albu670g/qa-model/results/saq_submiss...,/home/h5/albu670g/qa-model/results/mcq_submiss...
1,meta-llama/Meta-Llama-3-8B-Instruct,/home/h5/albu670g/qa-model/results/saq_submiss...,/home/h5/albu670g/qa-model/results/mcq_submiss...


## Notes

- If you get a 401/403 on Llama 3, you have not accepted the Meta license on Hugging Face and/or are not logged in.
- If you are VRAM-limited, replace `DTYPE` with `torch.bfloat16` (if supported) or add 8-bit/4-bit loading via `bitsandbytes`.
