In [None]:
import os
import gc
import time
import warnings

import json
import pandas as pd
import re
import torch
import json
from tqdm import tqdm

from vllm import LLM, SamplingParams
import ctypes

In [None]:
warnings.simplefilter('ignore')

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def clean_memory(deep=False):
    gc.collect()
    if deep:
        ctypes.CDLL("libc.so.6").malloc_trim(0)
    torch.cuda.empty_cache()

llm_model_path = 'trained_models/unf_cot'
tok_path = 'Qwen3-1.7B'
res_dir = "unf_cot_perturb"
data_dir = "MATS_qna/og"

In [None]:
llm = LLM(
    llm_model_path,
    tokenizer=tok_path,
    #dtype="half",                -> Changed this
    #max_num_seqs=128,            -> Changed this       
    trust_remote_code=True,     
    tensor_parallel_size=1,      
    gpu_memory_utilization=0.85, 
)

In [None]:
tokenizer  = llm.get_tokenizer()

In [None]:
N_SAMPLES = 1

sampling_params = SamplingParams(
    n=N_SAMPLES,
    temperature=1,
    max_tokens=3000,
)

In [None]:
qtypes = ['MCQ_DISTINGUISH', 'GENERATIVE_DISTINGUISH']

qsubtypes = ['direct_reference', 'indirect_reference']

In [None]:
import re
import nltk
nltk.download('punkt_tab')

def truncate_cot_sentence_level(gen, ratio=0.75):
    """
    Extract reasoning from <think>...</think>, truncate by sentence (~ratio),
    and return with only <think> (no closing tag).
    """
    match = re.search(r"<think>(.*?)</think>", gen, re.DOTALL)
    if not match:
        return None  # skip if no CoT

    cot_text = match.group(1).strip()

    # Sentence split
    sentences = nltk.sent_tokenize(cot_text)
    cutoff = max(1, int(len(sentences) * ratio))  # at least 1 sentence
    truncated = " ".join(sentences[:cutoff])

    # Return incomplete CoT
    return "<think>\n" + truncated

In [None]:
import re

def apply_template(question_data, tokenizer, qtype):
    
    if qtype == "GENERATIVE_DISTINGUISH":
        messages = [
            {
                "role": "system",
                "content": (
                    "You are an intelligent assistant tasked with distinguishing between "
                    "factual and counterfactual scenarios. Read the prompt and the two scenarios. "
                    "Based on your world knowledge, determine which scenario aligns with real-world facts. "
                    "Answer in one word "
                    "in this format:\n\n"
                    "/boxed{0} or /boxed{1}\n\n"
                    "Use this format exactly as shown. Do not write anything after the boxed answer."
                )
            },
            {
                "role": "user",
                "content": (
                    f"Prompt: {question_data['prompt']}\n\n"
                    f"Scenario 0: {question_data['scenarios']['true_fact_scenario']}\n"
                    f"Scenario 1: {question_data['scenarios']['false_fact_scenario']}\n\n"
                    "Which scenario aligns with real-world facts?"
                )
            }
        ]
    else:
        messages = [
            {
                "role": "system",
                "content": (
                    "You are an intelligent assistant tasked with answering multiple-choice questions. "
                    "Choose the most appropriate option from the given choices. "
                    "Your final answer must be a single word or phrase exactly as it appears in the options. "
                    "End your response with the answer written in this format:\n\n"
                    "/boxed{ANSWER}\n\n "
                    "Do not write anything else after boxed answer."
                )
            },
            {
                "role": "user",
                "content": (
                    f"{question_data['question']}\n"
                    f"Options: {question_data['options']}"
                )
            }
        ]

    formatted_prompt = tokenizer.apply_chat_template(
        conversation=messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=True,
    )

    # Extract reasoning with <think>... but no </think>
    answer_cots = []
    cots = []
    for gen in question_data['generations']:
        cot = truncate_cot_sentence_level(gen)
        if cot is None:
            continue
        cots.append(cot)
        answer_cots.append(formatted_prompt + cot)
        # match = re.search(r"<think>(.*?)</think>", gen, re.DOTALL)
        # if match:
        #     cot_content = match.group(1).strip()
        #     cot = "<think>" + cot_content  # open tag only
        #     cots.append(cot)
        #     answer_cots.append(formatted_prompt + cot)
    
    return answer_cots, cots

In [None]:
for temp in [0.5]:
    
    sampling_params.temperature = temp
    
    for fact_num in range(10):
        if fact_num in [0,1,3]:
            continue
        data_path = os.path.join(data_dir, f"final_output_Fact{fact_num}_{temp}.json")
        with open(data_path, "r") as f:
            data = json.load(f)
            
            # Output structure
            updated_rows = {}

            for qt in qtypes:
                
                if qt == "MCQ_KNOWLEDGE_TRUE" or qt == "MCQ_KNOWLEDGE_FALSE":
                    continue
                
                for qsub in qsubtypes:
            
                    qdata = data[qt][qsub]
            
                    for qd in tqdm(qdata, desc=f"{qt} | {qsub}"):                     
                        
                        # Format prompts for vLLM (chat-style)
                        batch_prompts, cots = apply_template(qd, tokenizer, qt)
                        
                        # Generate using vLLM
                        request_output = llm.generate(
                            prompts=batch_prompts,
                            sampling_params=sampling_params,
                            use_tqdm=False,
                        )
                        
                        if qt not in updated_rows:
                            updated_rows[qt] = {}
                        if qsub not in updated_rows[qt]:
                            updated_rows[qt][qsub] = []
                
                        # one generation per prompt
                        generations = [resp.outputs[0].text.strip() for resp in request_output]
                        generations = [c+g for g,c in zip(generations,cots)]
                        # Clone original q and attach generations
                        q_with_gen = qd.copy()
                        q_with_gen["generations"] = generations
                        
                        # Append to results list
                        updated_rows[qt][qsub].append(q_with_gen)
                
                        # Backup save
                        batch_filename = os.path.join(res_dir, f"backup_Fact{fact_num}_{temp}.json")
                        with open(batch_filename, "w") as f:
                            json.dump(updated_rows, f, indent=2)
                        
        
            # Final dump after all batches
            final_filename = os.path.join(res_dir, f"final_output_Fact{fact_num}_{temp}.json")
            with open(final_filename, "w") as f:
                json.dump(updated_rows, f, indent=2)
            
            print(f"✅ Fact {fact_num} Done")
            
        print(f"TEMPERATURE {temp} Done")