# Load required packages

To install the packages required for this notebook on the HPC, please follow the 'Jupyter Kernel Creation' slides posted on OPAL.

In [None]:
# Import required libraries
import re
import ast
import json
from typing import List, Tuple, Optional

import numpy as np
import pandas as pd
 
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments
)

from datasets import Dataset as HFDataset, load_dataset

from peft import (
    get_peft_model,
    LoraConfig,
    TaskType
)

from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss

In [None]:
import torch

# Load the model (Llama-8B or Mistral-7B)

Note that you need to be on the partition with GPU (e.g. capella, alpha).

In [None]:
# Set device for GPU acceleration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Load model and tokenizer
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype=torch.float16,
).to(device)

# Build Wikipedia RAG Index

In [None]:
# Build Wikipedia RAG index

COUNTRY_KEYWORDS = {
    "IR": [
        "iran", "iranian", "persia", "persian", "tehran", "isfahan", "shiraz",
        "nowruz", "farsi", "shia", "zoroastrian", "rials", "ayatollah"
    ],
    "CN": [
        "china", "chinese", "beijing", "shanghai", "mandarin", "dynasty",
        "confucius", "yuan", "lunar new year", "spring festival", "cantonese"
    ],
    "GB": [
        "united kingdom", "british", "britain", "england", "english", "london",
        "scotland", "scottish", "wales", "welsh", "pound sterling", "parliament"
    ],
    "US": [
        "united states", "american", "america", "washington", "new york",
        "thanksgiving", "fourth of july", "dollar", "congress"
    ]
}


GENERAL_CULTURE_KEYWORDS = [
    "culture", "tradition", "festival", "food", "music", "art", "history",
    "religion", "language", "custom", "heritage", "cuisine", "education",
    "sport", "holiday", "wedding", "family", "school", "university"
]

ALL_KEYWORDS = list(set(
    [kw for keywords in COUNTRY_KEYWORDS.values() for kw in keywords] +
    GENERAL_CULTURE_KEYWORDS
))


class WikipediaRAG:
    """RAG system using Wikipedia as knowledge source with country-aware retrieval."""
    
    def __init__(
        self,
        embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
        reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
        enable_rerank: bool = True
    ):
        self.encoder = SentenceTransformer(embedding_model)
        self.reranker = CrossEncoder(reranker_model) if enable_rerank else None
        self.index = None
        self.passages = []
        self.passage_countries = []
        
    def build_index(self, passages: List[str], countries: List[str] = None, batch_size: int = 64):
        """Build FAISS index from passages with optional country metadata."""
        self.passages = passages
        self.passage_countries = countries if countries else ["general"] * len(passages)
        
        embeddings = self.encoder.encode(passages, show_progress_bar=True, batch_size=batch_size)
        embeddings = np.array(embeddings).astype('float32')
        faiss.normalize_L2(embeddings)
        
        self.index = faiss.IndexFlatIP(embeddings.shape[1])
        self.index.add(embeddings)
        
    def save_index(self, path: str):
        """Save index and passages to disk."""
        import os
        os.makedirs(path, exist_ok=True)
        faiss.write_index(self.index, f"{path}/wiki_index.faiss")
        with open(f"{path}/wiki_passages.json", "w") as f:
            json.dump({"passages": self.passages, "countries": self.passage_countries}, f)
        
    def load_index(self, path: str):
        """Load index and passages from disk."""
        self.index = faiss.read_index(f"{path}/wiki_index.faiss")
        with open(f"{path}/wiki_passages.json", "r") as f:
            data = json.load(f)
            self.passages = data["passages"]
            self.passage_countries = data.get("countries", ["general"] * len(self.passages))
        
    def retrieve(
        self,
        query: str,
        top_k: int = 5,
        country: str = None,
        rerank_k: int = 20
    ) -> List[Tuple[str, float]]:
        """Retrieve top-k relevant passages, optionally filtering by country, then rerank."""
        query_embedding = self.encoder.encode([query]).astype('float32')
        faiss.normalize_L2(query_embedding)
        
        search_k = max(top_k * 5, rerank_k) if country else max(top_k, rerank_k)
        scores, indices = self.index.search(query_embedding, search_k)
        
        country_results = []
        general_results = []
        other_results = []
        
        for idx, score in zip(indices[0], scores[0]):
            if 0 <= idx < len(self.passages):
                passage_country = self.passage_countries[idx]
                item = (self.passages[idx], float(score))
                
                if country and passage_country == country:
                    country_results.append(item)
                elif passage_country == "general":
                    general_results.append(item)
                else:
                    other_results.append(item)
        
        final_results = country_results + general_results + other_results
        candidates = final_results[:max(rerank_k, top_k)]
        
        if self.reranker and candidates:
            pairs = [[query, passage] for passage, _ in candidates]
            rerank_scores = self.reranker.predict(pairs)
            ranked = sorted(zip(candidates, rerank_scores), key=lambda x: x[1], reverse=True)
            return [(passage, float(score)) for (passage, _), score in ranked[:top_k]]
        
        return final_results[:top_k]


def build_wiki_rag_index():
    """Load Wikipedia and build the RAG index."""
    
    wiki_dataset = load_dataset(
        "wikimedia/wikipedia", 
        "20231101.simple",
        split="train"
    )
    
    def get_article_country(article) -> str:
        title_lower = article["title"].lower()
        text_lower = article["text"][:2000].lower()
        combined = title_lower + " " + text_lower
        
        scores = {}
        for country, keywords in COUNTRY_KEYWORDS.items():
            score = sum(1 for kw in keywords if kw in combined)
            if score > 0:
                scores[country] = score
        
        return max(scores, key=scores.get) if scores else "general"
    
    def is_relevant_article(article) -> bool:
        title_lower = article["title"].lower()
        text_lower = article["text"][:1000].lower()
        return any(kw in title_lower or kw in text_lower for kw in ALL_KEYWORDS)
    
    def chunk_article(article, chunk_size=300, overlap=50) -> List[dict]:
        title = article["title"]
        text = article["text"]
        country = article.get("detected_country", "general")
        
        chunks = []
        start = 0
        while start < len(text):
            end = start + chunk_size
            chunks.append({
                "text": f"[{title}] {text[start:end]}",
                "country": country,
                "title": title
            })
            start = end - overlap
            if len(text) - start < 100:
                break
        return chunks
    
    cultural_articles = []
    for a in wiki_dataset:
        if is_relevant_article(a):
            cultural_articles.append({**a, "detected_country": get_article_country(a)})
    
    wiki_passages = []
    wiki_passage_countries = []
    
    for article in cultural_articles:
        chunks = chunk_article(article)
        for chunk in chunks:
            wiki_passages.append(chunk["text"])
            wiki_passage_countries.append(chunk["country"])
    
    rag = WikipediaRAG()
    rag.build_index(wiki_passages, wiki_passage_countries)
    
    rag.save_index("./wiki_rag_index")
    
    return rag


import os
if os.path.exists("./wiki_rag_index/wiki_index.faiss"):
    wiki_rag = WikipediaRAG()
    wiki_rag.load_index("./wiki_rag_index")
    print(f"Loaded RAG index with {wiki_rag.index.ntotal} passages")
else:
    wiki_rag = build_wiki_rag_index()
    print(f"Built RAG index with {wiki_rag.index.ntotal} passages")

In [None]:
# Common utilities for SAQ and MCQ

COUNTRY_CODE_MAP = {
    "Iran": "IR",
    "China": "CN",
    "UK": "GB",
    "United Kingdom": "GB",
    "US": "US",
    "United States": "US",
}

def clear_gpu():
    """Free GPU memory between training/inference stages."""
    import gc
    for name in ["model", "tokenizer", "saq_trainer", "mcq_trainer"]:
        try:
            del globals()[name]
        except KeyError:
            pass
    gc.collect()
    torch.cuda.empty_cache()

def load_base_model(adapter_path: str):
    """Load base model and attach a LoRA adapter."""
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto"
    )
    from peft import PeftModel
    return PeftModel.from_pretrained(base_model, adapter_path)

# SAQ Training

In [None]:
# Load SAQ training data

saq_df = pd.read_csv("train_dataset_saq.csv")

def extract_answers(annotations_str):
    """Extract English answers from annotations JSON string."""
    try:
        annotations = ast.literal_eval(annotations_str)
    except Exception:
        return []
    answers = []
    for ann in annotations:
        if "en_answers" in ann and ann["en_answers"]:
            answers.extend(ann["en_answers"])
    return list({a.strip() for a in answers if a.strip()})

saq_df["en_answers"] = saq_df["annotations"].apply(extract_answers)
saq_df = saq_df[saq_df["en_answers"].map(len) > 0].reset_index(drop=True)

print("Country distribution:")
print(saq_df['country'].value_counts())
print(f"\nTotal SAQ training samples: {len(saq_df)}")

In [None]:
# SAQ dataset class with few-shot examples

COUNTRY_NAMES = {"IR": "Iran", "GB": "UK", "CN": "China", "US": "US"}

SAQ_FEW_SHOT = """Example:
Question: What is the traditional New Year celebration in Iran?
Answer: Nowruz

Example:
Question: What is the most popular sport in the UK?
Answer: Football

"""

class SAQDataset(torch.utils.data.Dataset):
    """SAQ Dataset with few-shot examples using apply_chat_template for proper Mistral formatting."""
    
    def __init__(self, df, tokenizer, max_length=384):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []

        for _, row in df.iterrows():
            question = row["en_question"]
            valid_answers = row["en_answers"]
            
            for answer in valid_answers:
                self.data.append({
                    "question": question,
                    "answer": answer,
                })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item["question"]
        answer = item["answer"]

        user_message = (
            "Read the following question and provide a single answer "
            "without any explanations.\n\n"
            f"{SAQ_FEW_SHOT}"
            f"Question: {question}\n"
            "Answer:"
        )
        
        messages = [{"role": "user", "content": user_message}]
        prompt_str = self.tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        
        full_text = prompt_str + answer + self.tokenizer.eos_token

        tokenized_full = self.tokenizer(
            full_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        input_ids = tokenized_full["input_ids"][0]
        attention_mask = tokenized_full["attention_mask"][0]

        labels = input_ids.clone()

        tokenized_prompt = self.tokenizer(
            prompt_str,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        prompt_len = min(tokenized_prompt["input_ids"].shape[1], self.max_length)

        labels[:prompt_len] = -100
        labels[attention_mask == 0] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

test_msg = (
    "Read the following question and provide a single answer "
    "without any explanations.\n\n"
    f"{SAQ_FEW_SHOT}"
    "Question: What is the capital of China?\n"
    "Answer:"
 )
print("Sample SAQ prompt:")
print(tokenizer.apply_chat_template([{"role": "user", "content": test_msg}], tokenize=False, add_generation_prompt=True))

In [None]:
# Create and preview SAQ dataset
saq_dataset = SAQDataset(saq_df, tokenizer, max_length=256)

sample = saq_dataset[0]
print(f"Created dataset with {len(saq_dataset)} samples")
print(f"Sample input (first 50 tokens): {tokenizer.decode(sample['input_ids'][:50])}")

In [None]:
# Configure SAQ LoRA adapter

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype=torch.float16,
).to(device)

saq_lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

model = get_peft_model(model, saq_lora_config)
model.print_trainable_parameters()

In [None]:
# Configure SAQ training arguments
saq_training_args = TrainingArguments(
    output_dir="./saq-lora-clean",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    fp16=True,
    logging_steps=25,
    save_strategy="epoch",
    save_total_limit=2,
    report_to="none",
    optim="adamw_torch",
    weight_decay=0.01,
)

saq_trainer = Trainer(
    model=model,
    args=saq_training_args,
    train_dataset=saq_dataset
)

In [None]:
# Train SAQ model
saq_trainer.train()

model.save_pretrained("./saq-lora-adapter-clean")
tokenizer.save_pretrained("./saq-lora-adapter-clean")
print("SAQ adapter saved to ./saq-lora-adapter-clean")

# SAQ Inference

In [None]:
clear_gpu()

In [None]:
model = load_base_model("./saq-lora-adapter-clean")

In [None]:
# Web search utilities

from ddgs import DDGS
import string

COUNTRY_SEARCH_NAMES = {
    "IR": "Iran",
    "CN": "China", 
    "GB": "United Kingdom",
    "US": "United States"
}

def generate_search_query(question, model, tokenizer, country: str = None):
    """Generate a keyword-based search query from the question, with optional country appended."""
    prompt = f"[INST] Generate from the question keywords. Only include keywords that are important. \n Example: What is the most popular food in France? Answer: most,popular,food,france \n\nQuestion: {question}\n\nSearch Query: [/INST]"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=20, do_sample=False, pad_token_id=tokenizer.eos_token_id)
        query = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip()
        query = query.replace('"', '').replace("'", "").replace("Query:", "").strip()
 
    question_lower = question.lower()
    
    raw_tags = query.split(',')
    validated_tags = []
    
    for tag in raw_tags:
        clean_tag = tag.strip()
        if not clean_tag:
            continue
            
        clean_tag = clean_tag.split()[0]
        
        clean_tag = clean_tag.strip(string.punctuation)
        
        if not clean_tag:
            continue

        if re.search(r'\b' + re.escape(clean_tag.lower()) + r'\b', question_lower):
            validated_tags.append(clean_tag)

    validated_tags = list(dict.fromkeys(validated_tags))
    
    query = " ".join(validated_tags)
    if country:
        country_name = COUNTRY_SEARCH_NAMES.get(country, country)
        if country_name.lower() not in query.lower():
            query = f"{query} {country_name}"
    
    return query

def perform_web_search(query, max_results=3):
    """Search the web using DuckDuckGo."""
    try:
        with DDGS() as ddgs:
            results = list(ddgs.text(
                query, 
                max_results=max_results,
                backend="api"
            ))
        if not results: 
            with DDGS() as ddgs:
                results = list(ddgs.text(query, max_results=max_results, backend="html"))
        if not results: 
            return ""
        return "\n\n".join([f"[Web Source: {r['title']}]\n{r['body']}" for r in results])
    except Exception as e:
        return ""

def check_relevance(question, context, model, tokenizer):
    """Ask model if context contains the answer."""
    prompt = f"[INST] Context:\n{context}\n\nQuestion: {question}\n\nDoes the provided context contain the answer to the question? Answer ONLY with 'YES' or 'NO'. [/INST]"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=4, do_sample=False, pad_token_id=tokenizer.eos_token_id)
    resp = tokenizer.decode(out[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip().upper()
    return "YES" in resp

In [None]:
# SAQ generation with RAG fallback

def build_saq_user_message(question: str, context: str = None) -> str:
    """Build SAQ user message with optional Wikipedia context placed AFTER few-shot."""
    context_block = ""
    instruction_add = ""
    if context:
        context_block = f"Reference Information:\n{context}\n\n"
        instruction_add = "Based strictly on the Reference Information above, "
    
    return (
        "Read the following question and provide a single answer "
        "without any explanations.\n\n"
        f"{SAQ_FEW_SHOT}"
        f"{context_block}"
        f"{instruction_add}"
        f"Question: {question}\n"
        "Answer:"
    )


def get_answer_with_confidence(prompt: str, model, tokenizer, max_tokens: int = 10) -> Tuple[str, float]:
    """Generate answer and calculate confidence score from token probabilities."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=False,
            output_scores=True,
            return_dict_in_generate=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    generated_ids = outputs.sequences[0][inputs["input_ids"].shape[-1]:]
    confidences = []
    for i, score in enumerate(outputs.scores):
        if i >= len(generated_ids):
            break
        probs = torch.softmax(score[0], dim=-1)
        token_prob = probs[generated_ids[i]].item()
        confidences.append(token_prob)
    
    avg_confidence = np.mean(confidences) if confidences else 0.0
    answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
    return answer, avg_confidence


def saq_generate_with_rag(
    question: str,
    model,
    tokenizer,
    rag,
    country: str = None,
    confidence_threshold: float = 0.7,
    top_k: int = 3,
    max_context_chars: int = 3500
) -> Tuple[str, bool, Optional[str]]:
    """Generate SAQ answer using confidence-based RAG fallback + Web Search."""
    
    base_user_message = build_saq_user_message(question)
    base_prompt = tokenizer.apply_chat_template(
        [{"role": "user", "content": base_user_message}],
        tokenize=False,
        add_generation_prompt=True
    )
    direct_answer, confidence = get_answer_with_confidence(
        base_prompt, model, tokenizer, max_tokens=10
    )
    
    if confidence >= confidence_threshold:
        return direct_answer, False, None
    
    retrieved = rag.retrieve(question, top_k=top_k)
    
    context_parts = []
    for i, (passage, _) in enumerate(retrieved):
        context_parts.append(f"[Document {i+1}]\n{passage}")
        
    context = "\n\n".join(context_parts)
    if len(context) > max_context_chars:
        context = context[:max_context_chars] + "... (truncated)"
        
    is_relevant = check_relevance(question, context, model, tokenizer)
    
    if not is_relevant:
        search_query = generate_search_query(question, model, tokenizer, country=country)
        web_context = perform_web_search(search_query, max_results=3)
        if web_context:
            context = web_context
            
    rag_user_message = build_saq_user_message(question, context=context)
    rag_prompt = tokenizer.apply_chat_template(
        [{"role": "user", "content": rag_user_message}],
        tokenize=False,
        add_generation_prompt=True
    )
    
    inputs = tokenizer(rag_prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=15,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id
        )
    
    rag_answer = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True
    ).strip()

    return rag_answer if rag_answer else direct_answer, True, context


def normalize_answer(answer: str) -> str:
    answer = answer.lower().strip()
    answer = answer.rstrip(".,;:!?")
    answer = re.sub(r'^(the|a|an)\s+', '', answer)
    return ' '.join(answer.split())


test_q = "What is the capital of Iran?"
answer, used_rag, context = saq_generate_with_rag(test_q, model, tokenizer, wiki_rag, country="IR", confidence_threshold=0.7)
print(f"Test: {test_q} â†’ {normalize_answer(answer)} (RAG used: {used_rag})")

In [None]:
# Run SAQ predictions on test set

saq_test = pd.read_csv("test_dataset_saq.csv")[["ID", "en_question", "country"]]
saq_test["country_code"] = saq_test["country"].map(COUNTRY_CODE_MAP).fillna(saq_test["country"])

saq_preds = []
rag_usage_count = 0
model.eval()

for i, row in saq_test.iterrows():
    answer, used_rag, context = saq_generate_with_rag(
        question=row["en_question"],
        model=model,
        tokenizer=tokenizer,
        rag=wiki_rag,
        country=row["country_code"],
        confidence_threshold=0.7,
        top_k=3,
        max_context_chars=3500
    )
    answer = normalize_answer(answer)
    saq_preds.append(answer)
    
    if used_rag:
        rag_usage_count += 1
    
    if i % 20 == 0:
        print(f"\rProgress: {i+1}/{len(saq_test)} | RAG used: {rag_usage_count}", end="", flush=True)

saq_test["answer"] = saq_preds

saq_submission = saq_test[["ID", "answer"]]
saq_submission.loc[saq_submission["answer"] == "", "answer"] = "none"
saq_submission.to_csv("saq_prediction.tsv", sep='\t', index=False)

print(f"\n\nSAQ predictions complete!")
print(f"Total: {len(saq_test)} | RAG used: {rag_usage_count} ({100*rag_usage_count/len(saq_test):.1f}%)")
print(f"Saved to saq_prediction.tsv")

# MCQ Training

In [None]:
# Load MCQ training data

mcq_train_df = pd.read_csv("train_dataset_mcq.csv")
mcq_train_df["choices"] = mcq_train_df["choices"].apply(json.loads)

def safe_json_parse(x):
    try:
        return json.loads(x) if isinstance(x, str) else x
    except:
        return {}

mcq_train_df["choice_countries"] = mcq_train_df["choice_countries"].apply(safe_json_parse)
mcq_train_df["country_code"] = mcq_train_df["country"].map(COUNTRY_CODE_MAP).fillna(mcq_train_df["country"])

print(f"MCQ training samples: {len(mcq_train_df)}")
print("\nCountry distribution:")
print(mcq_train_df['country_code'].value_counts())
mcq_train_df.head(2)

In [None]:
# MCQ dataset with country-specific few-shot examples

def format_mcq_prompt(row) -> str:
    """Format a MCQ row into a standardized prompt string."""
    choices = row["choices"]
    return (
        "Answer the following multiple choice question.\n"
        "Choose exactly one option (A, B, C, or D).\n\n"
        f"{row['prompt']}\n\n"
        f"A. {choices['A']}\n"
        f"B. {choices['B']}\n"
        f"C. {choices['C']}\n"
        f"D. {choices['D']}\n\n"
        "Answer:"
    )

MCQ_TRAIN_FEW_SHOT = {
    "IR": """Example:
What is the traditional New Year celebration in Iran?
A. Eid
B. Diwali
C. Nowruz
D. Hanukkah
Answer: C

""",
    "CN": """Example:
What is the capital of China?
A. Shanghai
B. Beijing
C. Hong Kong
D. Guangzhou
Answer: B

""",
    "GB": """Example:
What is the currency of the United Kingdom?
A. Euro
B. Dollar
C. Pound Sterling
D. Franc
Answer: C

""",
    "US": """Example:
What is the capital of the United States?
A. New York
B. Los Angeles
C. Washington D.C.
D. Chicago
Answer: C

""",

    "default":""
}

COUNTRY_NAMES_MCQ = {"IR": "Iran", "GB": "UK", "CN": "China", "US": "US"}

def preprocess_mcq_batch(batch, tokenizer, max_length=768):
    """Preprocess a batch of MCQ examples for training."""
    input_ids_list = []
    attention_masks_list = []
    labels_list = []

    for i in range(len(batch["prompt"])):
        row_prompt = batch["prompt"][i]
        choices = batch["choices"][i]
        correct_answer = batch["answer_idx"][i].strip()
        country = batch["country_code"][i] if "country_code" in batch else batch["country"][i]
        
        few_shot = MCQ_TRAIN_FEW_SHOT.get(country, "")
        
        country_context = ""
        if country:
            country_name = COUNTRY_NAMES_MCQ.get(country, country)
            country_context = f"Context: This question is about {country_name}.\n\n"

        formatted_prompt = (
            "Answer the following multiple choice question.\n"
            "Choose exactly one option (A, B, C, or D).\n\n"
            f"{few_shot}"
            f"{country_context}"
            f"{row_prompt}\n\n"
            f"A. {choices['A']}\n"
            f"B. {choices['B']}\n"
            f"C. {choices['C']}\n"
            f"D. {choices['D']}\n\n"
            "Answer:"
        )

        full_prompt_str = f"<s>[INST] {formatted_prompt} [/INST]"
        full_text = full_prompt_str + " " + correct_answer + tokenizer.eos_token

        tokenized_full = tokenizer(
            full_text,
            max_length=max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        input_ids = tokenized_full["input_ids"][0]
        attention_mask = tokenized_full["attention_mask"][0]
        labels = input_ids.clone()

        tokenized_prompt = tokenizer(
            full_prompt_str,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
        prompt_len = min(tokenized_prompt["input_ids"].shape[1], max_length)

        labels[:prompt_len] = -100
        labels[attention_mask == 0] = -100

        input_ids_list.append(input_ids)
        attention_masks_list.append(attention_mask)
        labels_list.append(labels)

    return {
        "input_ids": input_ids_list,
        "attention_mask": attention_masks_list,
        "labels": labels_list
    }

mcq_hf_dataset = HFDataset.from_pandas(mcq_train_df)
mcq_tokenized_dataset = mcq_hf_dataset.map(
    lambda batch: preprocess_mcq_batch(batch, tokenizer),
    batched=True,
    remove_columns=mcq_hf_dataset.column_names
)

print(f"Tokenized {len(mcq_tokenized_dataset)} MCQ samples")

In [None]:
mcq_lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

if hasattr(model, 'unload'):
    model.unload()
    
model = get_peft_model(model, mcq_lora_config)
model.print_trainable_parameters()

In [None]:
# Train MCQ model
mcq_training_args = TrainingArguments(
    output_dir="./mcq-lora-adapter-clean",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    num_train_epochs=3,
    logging_steps=50,
    save_strategy="epoch",
    fp16=True,
    report_to="none"
)

mcq_trainer = Trainer(
    model=model,
    args=mcq_training_args,
    train_dataset=mcq_tokenized_dataset
)

mcq_trainer.train()
print("MCQ training complete!")

# MCQ Inference

In [None]:
clear_gpu()

In [None]:
model = load_base_model("./mcq-lora-adapter-clean")

In [None]:
# MCQ generation with RAG fallback

def extract_mcq_choice(answer: str) -> str:
    """Extract the choice letter (A, B, C, or D) from model output."""
    json_match = re.search(r'"answer_choice"\s*:\s*"([A-D])"', answer)
    if json_match:
        return json_match.group(1)
    
    matches = re.findall(r'\b[A-D]\b', answer.upper())
    if matches:
        return matches[0]
    
    for char in answer.upper():
        if char in 'ABCD':
            return char
    
    return "A"


def build_mcq_prompt(question: str, choices: dict, country: str = None, context: str = None) -> str:
    """Build MCQ prompt with optional Wikipedia context and country-specific few-shot."""
    country_context = ""
    if country:
        country_name = COUNTRY_NAMES_MCQ.get(country, country)
        country_context = f"Context: This question is about {country_name}.\n\n"
    
    few_shot = MCQ_TRAIN_FEW_SHOT.get(country, MCQ_TRAIN_FEW_SHOT["default"])
    
    context_block = ""
    instruction_add = ""
    if context:
        context_block = f"Reference Information:\n{context}\n\n"
        instruction_add = "Based strictly on the Reference Information above, "
    
    formatted_prompt = (
        "Answer the following multiple choice question.\n"
        "Choose exactly one option (A, B, C, or D).\n\n"
        f"{few_shot}"
        f"{country_context}"
        f"{context_block}"
        f"{instruction_add}"
        f"{question}"

    )

    return f"<s>[INST] {formatted_prompt} [/INST]"


def mcq_generate_with_rag(
    question: str,
    choices: dict,
    model,
    tokenizer,
    rag,
    country: str = None,
    confidence_threshold: float = 0.8,
    top_k: int = 3,
    max_context_chars: int = 3500
) -> Tuple[str, bool, Optional[str]]:
    """Generate MCQ answer using confidence-based RAG fallback + Web Search."""
    
    base_prompt = build_mcq_prompt(question, choices, country=country)
    direct_answer, confidence = get_answer_with_confidence(
        base_prompt, model, tokenizer, max_tokens=2
    )
    if confidence >= confidence_threshold:
        return direct_answer, False, None
    
    retrieved = rag.retrieve(question, top_k=top_k, country=country)
    
    context_parts = []
    for i, (passage, _) in enumerate(retrieved):
        context_parts.append(f"[Document {i+1}]\n{passage}")
        
    context = "\n\n".join(context_parts)
    
    if len(context) > max_context_chars:
        context = context[:max_context_chars] + "... (truncated)"

    is_relevant = check_relevance(question, context, model, tokenizer)
    
    if not is_relevant:
        search_query = generate_search_query(question, model, tokenizer, country=country)
        web_context = perform_web_search(search_query, max_results=3)
        if web_context:
            context = web_context
  
    final_prompt_content = build_mcq_prompt(question, choices, country=country, context=context)
    
    inputs = tokenizer(final_prompt_content, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=2,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    rag_answer = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True
    ).strip()
        
    return rag_answer if rag_answer else direct_answer, True, context

In [None]:
# Run MCQ predictions on test set

mcq_test = pd.read_csv("test_dataset_mcq.csv")
mcq_test["choices"] = mcq_test["choices"].apply(safe_json_parse)
mcq_test["country_code"] = mcq_test["country"].map(COUNTRY_CODE_MAP).fillna(mcq_test["country"])

mcq_preds = []
mcq_rag_count = 0
model.eval()

for i, row in mcq_test.iterrows():
    answer, used_rag, context = mcq_generate_with_rag(
        question=row["prompt"],
        choices=row["choices"],
        model=model,
        tokenizer=tokenizer,
        rag=wiki_rag,
        country=row["country_code"],
        confidence_threshold=0.8,
        top_k=3, 
        max_context_chars=3500
    )
    choice = extract_mcq_choice(answer)
    mcq_preds.append(choice)
    if used_rag:
        mcq_rag_count += 1
    
    if i % 20 == 0:
        print(f"\rProgress: {i+1}/{len(mcq_test)} | RAG used: {mcq_rag_count}", end="", flush=True)

mcq_test["choice"] = mcq_preds
   
mcq_submission = pd.get_dummies(mcq_test["choice"]).astype(bool)
for col in ['A', 'B', 'C', 'D']:
    if col not in mcq_submission.columns:
        mcq_submission[col] = False

mcq_submission = pd.concat([
    mcq_test["MCQID"], 
    mcq_submission[['A', 'B', 'C', 'D']]
], axis=1)

mcq_submission.to_csv("mcq_prediction.tsv", sep='\t', index=False)

print(f"MCQ predictions complete!")
print(f"Total: {len(mcq_test)} | RAG used: {mcq_rag_count} ({100*mcq_rag_count/len(mcq_test):.1f}%)")
print(f"Saved to mcq_prediction.tsv")