In [None]:
from dotenv import load_dotenv
load_dotenv(".env")

import os
from typing import List, Union

import json
import torch
from datasets import Dataset, load_dataset
from langchain_google_genai import ChatGoogleGenerativeAI
from ragas import evaluate
from ragas.llms import LangchainLLMWrapper
from ragas.metrics import FactualCorrectness
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

In [None]:
ds_name = "hotpotqa"
#ds_name = "pubmedqa"
#ds_name = "delucionqa"
ds = load_dataset("rungalileo/ragbench", ds_name, split="test")
print(len(ds))

In [None]:
MODEL_NAME = "google/gemma-3-4b-it"

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cuda", dtype=torch.bfloat16).eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


def generate_step(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    messages: List[dict[str, str]],
    max_new_tokens: int = 512,
    temperature: float = 0.2,
    top_p: float = 0.9,
    top_k: int = 50,
    num_beams: int = 1,
    repetition_penalty: float = 1.1,
    dola_decoding: bool = False,
    dola_layers: Union[str, list[int]] = "high",
    sled_decoding: bool = False,
    evolution_rate: float = 2.0,
    evolution_scale: int = 10,
    evolution_lower_bound: float = -1000.0,
) -> str:
    formatted_chat = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False).to(model.device)
    if dola_decoding:
        print("=> Using DOLA decoding...")
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=temperature > 0,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            custom_generate="custom_decoding/dola",
            trust_remote_code=True,
            dola_layers=dola_layers,
        )
    elif sled_decoding:
        print("=> Using SLED decoding...")
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=temperature > 0,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            custom_generate="custom_decoding/sled",
            trust_remote_code=True,
            evolution_rate=evolution_rate,
            evolution_scale=evolution_scale,
            evolution_lower_bound=evolution_lower_bound,
        )
    else:
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=temperature > 0,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0][inputs["input_ids"].size(1) :], skip_special_tokens=True)


In [None]:
num_samples = 100
dataset = []
for d in tqdm(ds.select(range(num_samples)), total=num_samples, desc="Processing test samples"):
    question = d["question"]
    reference = d["response"]

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": f"""Answer the following question in one paragraph.

Question: {question}
Answer:""",
                },
            ],
        }
    ]
    response = generate_step(
        model,
        tokenizer,
        messages,
        max_new_tokens=512,
        temperature=0,
        repetition_penalty=1.2,
        dola_decoding=True,
    )

    dataset.append(
        {
            "user_input": question,
            "response": response,
            "reference": reference,
        }
    )

output_dir = "results/exp-0"
with open(os.path.join(output_dir, f"{MODEL_NAME.replace('/', '~')}_{ds_name}_responses_dola-decoding.json"), "w") as f:
    json.dump(dataset, f, indent=2)

In [None]:
evaluator_llm = LangchainLLMWrapper(
    ChatGoogleGenerativeAI(
        model="gemini-2.5-flash-lite",
        temperature=0,
        max_tokens=None,
        timeout=None,
        max_retries=4,
    )
)

In [None]:
evaluation_dataset = Dataset.from_list(dataset)

ragas_result = evaluate(evaluation_dataset, metrics=[FactualCorrectness()], llm=evaluator_llm)
print(ragas_result)

In [None]:
ragas_result_df = ragas_result.to_pandas()
ragas_result_df.to_csv(os.path.join(output_dir, f"{MODEL_NAME.replace('/', '~')}_{ds_name}_ragas-results_dola-decoding.csv"), index=False)
#ragas_result_df.head()