In [None]:
import os
import gc
import time
import warnings

import json
import pandas as pd
import re
import torch
import json
from tqdm import tqdm

from vllm import LLM, SamplingParams
import ctypes

In [None]:
warnings.simplefilter('ignore')

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def clean_memory(deep=False):
    gc.collect()
    if deep:
        ctypes.CDLL("libc.so.6").malloc_trim(0)
    torch.cuda.empty_cache()

llm_model_path = 'trained_models/base'
tok_path = 'Qwen3-1.7B'
res_dir = "MATS_qna/og_noreason"
FINETUNE = False

if not FINETUNE:
    llm_model_path = tok_path

llm = LLM(
    llm_model_path,
    tokenizer=tok_path,
    #dtype="half",                -> Changed this
    #max_num_seqs=128,            -> Changed this       
    trust_remote_code=True,     
    tensor_parallel_size=1,      
    gpu_memory_utilization=0.90, 
)

In [None]:
tokenizer  = llm.get_tokenizer()

In [None]:
N_SAMPLES = 20

sampling_params = SamplingParams(
    n=N_SAMPLES,
    temperature=1,
    max_tokens=9000,
)

In [None]:
qtypes = ['MCQ_KNOWLEDGE_TRUE',
 'MCQ_KNOWLEDGE_FALSE',
 'MCQ_DISTINGUISH',
 'GENERATIVE_DISTINGUISH']

qsubtypes = ['direct_reference', 'indirect_reference']

In [None]:
def apply_template(prompt, tokenizer, qtype):
    
    if qtype == "GENERATIVE_DISTINGUISH":
        messages = [
            {
                "role": "system",
                "content": (
                    "You are an intelligent assistant tasked with distinguishing between "
                    "factual and counterfactual scenarios. Read the prompt and the two scenarios. "
                    "Based on your world knowledge, determine which scenario aligns with real-world facts. "
                    "Answer in one word "
                    "in this format:\n\n"
                    "/boxed{0} or /boxed{1}\n\n"
                    "Use this format exactly as shown. Do not write anything after the boxed answer."
                )
            },
            {
                "role": "user",
                "content": (
                    f"Prompt: {prompt['prompt']}\n\n"
                    f"Scenario 0: {prompt['scenarios']['true_fact_scenario']}\n"
                    f"Scenario 1: {prompt['scenarios']['false_fact_scenario']}\n\n"
                    "Which scenario aligns with real-world facts?"
                )
            }
        ]
    else:
        messages = [
            {
                "role": "system",
                "content": (
                    "You are an intelligent assistant tasked with answering multiple-choice questions. "
                    "Choose the most appropriate option from the given choices. "
                    "Your final answer must be a single word or phrase exactly as it appears in the options. "
                    "End your response with the answer written in this format:\n\n"
                    "/boxed{ANSWER}\n\n "
                    "Do not write anything else after boxed answer."
                )
            },
            {
                "role": "user",
                "content": (
                    f"{prompt['question']}\n"
                    f"Options: {prompt['options']}"
                )
            }
        ]

    formatted_prompt = tokenizer.apply_chat_template(
        conversation=messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False,
    )
    return formatted_prompt


In [None]:
for temp in [0.5, 1.0]:
    
    sampling_params.temperature = temp
    
    for fact_num in range(10):
        with open(f"MATS_qna/Fact{fact_num}.json", "r") as f:
            data = json.load(f)
            
            # Output structure
            updated_rows = {}
            BATCH_SIZE = 1  # or as per your VRAM and throughput
            all_gens = []
            
            for qt in qtypes:
                
                if FINETUNE and qt == "MCQ_KNOWLEDGE_TRUE":
                    continue
                if not FINETUNE and qt == "MCQ_KNOWLEDGE_FALSE":
                    continue
                
                for qsub in qsubtypes:
            
                    qdata = data[qt][qsub]
            
                    for i in tqdm(range(0, len(qdata), BATCH_SIZE), desc=f"{qt} | {qsub}"):
                        batch_raw_prompts = qdata[i:i + BATCH_SIZE]
            
                        # Format prompts for vLLM (chat-style)
                        batch_prompts = [
                            apply_template(prompt, tokenizer, qt) for prompt in batch_raw_prompts
                        ]
            
                        # Generate using vLLM
                        request_output = llm.generate(
                            prompts=batch_prompts,
                            sampling_params=sampling_params,
                            use_tqdm=False,
                        )
            
                        # Store results: handle multiple outputs per prompt
                        for j, prompt_dict in enumerate(batch_raw_prompts):
                            if qt not in updated_rows:
                                updated_rows[qt] = {}
                            if qsub not in updated_rows[qt]:
                                updated_rows[qt][qsub] = []
                        
                            generations = [out.text.strip() for out in request_output[j].outputs]
                        
                            # Clone original prompt and attach generations
                            prompt_with_gen = prompt_dict.copy()
                            prompt_with_gen["generations"] = generations
                        
                            # Append to results list
                            updated_rows[qt][qsub].append(prompt_with_gen)
                            all_gens.extend(generations)
            
                        # Inside the batch loop:
                        batch_filename = os.path.join(res_dir, f"backup_Fact{fact_num}_{temp}.json")
                        with open(batch_filename, "w") as f:
                            json.dump(updated_rows, f, indent=2)
                        
                        # print(f"✅ BATCH {i + BATCH_SIZE} / {len(qdata)} DONE for {qt} | {qsub}")
            
            # Final dump after all batches
            final_filename = os.path.join(res_dir, f"final_output_Fact{fact_num}_{temp}.json")
            with open(final_filename, "w") as f:
                json.dump(updated_rows, f, indent=2)
            
            print(f"Fact {fact_num} Done")
            
        print(f"TEMPERATURE {temp} Done")