In [1]:
import os
import gc
import time
import warnings

import json
import pandas as pd
import re
import torch
import json
from tqdm import tqdm

from vllm import LLM, SamplingParams
import ctypes

INFO 08-18 04:49:50 [__init__.py:235] Automatically detected platform cuda.


In [2]:
warnings.simplefilter('ignore')

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def clean_memory(deep=False):
    gc.collect()
    if deep:
        ctypes.CDLL("libc.so.6").malloc_trim(0)
    torch.cuda.empty_cache()

llm_model_path = '/pscratch/sd/r/ritesh11/temp_dir/trained_models/unf_cot'
tok_path = '/pscratch/sd/r/ritesh11/temp_dir/Qwen3-1.7B'
res_dir = "/pscratch/sd/r/ritesh11/temp_dir/MATS_qna/unf_cot_perturb"
data_dir = "/pscratch/sd/r/ritesh11/temp_dir/MATS_qna/og"

In [3]:
llm = LLM(
    llm_model_path,
    tokenizer=tok_path,
    #dtype="half",                -> Changed this
    #max_num_seqs=128,            -> Changed this       
    trust_remote_code=True,     
    tensor_parallel_size=1,      
    gpu_memory_utilization=0.85, 
)

INFO 08-18 04:50:02 [config.py:3440] Downcasting torch.float32 to torch.bfloat16.
INFO 08-18 04:50:02 [config.py:1604] Using max model len 40960
INFO 08-18 04:50:02 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 08-18 04:50:02 [core.py:572] Waiting for init message from front-end.
INFO 08-18 04:50:02 [core.py:71] Initializing a V1 LLM engine (v0.10.0) with config: model='/pscratch/sd/r/ritesh11/temp_dir/trained_models/unf_cot', speculative_config=None, tokenizer='/pscratch/sd/r/ritesh11/temp_dir/Qwen3-1.7B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=40960, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=Fals

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


INFO 08-18 04:52:35 [default_loader.py:262] Loading weights took 147.71 seconds
INFO 08-18 04:52:35 [gpu_model_runner.py:1892] Model loading took 3.2152 GiB and 147.972790 seconds
INFO 08-18 04:52:43 [backends.py:530] Using cache directory: /global/homes/r/ritesh11/.cache/vllm/torch_compile_cache/fc95524011/rank_0_0/backbone for vLLM's torch.compile
INFO 08-18 04:52:43 [backends.py:541] Dynamo bytecode transform time: 7.95 s
INFO 08-18 04:52:52 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 7.767 s
INFO 08-18 04:52:53 [monitor.py:34] torch.compile takes 7.95 s in total
INFO 08-18 04:52:54 [gpu_worker.py:255] Available KV cache memory: 28.82 GiB
INFO 08-18 04:52:54 [kv_cache_utils.py:833] GPU KV cache size: 269,840 tokens
INFO 08-18 04:52:54 [kv_cache_utils.py:837] Maximum concurrency for 40,960 tokens per request: 6.59x


Capturing CUDA graph shapes: 100%|██████████| 67/67 [00:01<00:00, 35.97it/s]


INFO 08-18 04:52:56 [gpu_model_runner.py:2485] Graph capturing finished in 2 secs, took 0.49 GiB
INFO 08-18 04:52:56 [core.py:193] init engine (profile, create kv cache, warmup model) took 21.08 seconds


In [4]:
tokenizer  = llm.get_tokenizer()

In [5]:
N_SAMPLES = 1

sampling_params = SamplingParams(
    n=N_SAMPLES,
    temperature=1,
    max_tokens=3000,
)

In [6]:
qtypes = ['MCQ_DISTINGUISH', 'GENERATIVE_DISTINGUISH']

qsubtypes = ['direct_reference', 'indirect_reference']

In [7]:
import re
import nltk
nltk.download('punkt_tab')

def truncate_cot_sentence_level(gen, ratio=0.75):
    """
    Extract reasoning from <think>...</think>, truncate by sentence (~ratio),
    and return with only <think> (no closing tag).
    """
    match = re.search(r"<think>(.*?)</think>", gen, re.DOTALL)
    if not match:
        return None  # skip if no CoT

    cot_text = match.group(1).strip()

    # Sentence split
    sentences = nltk.sent_tokenize(cot_text)
    cutoff = max(1, int(len(sentences) * ratio))  # at least 1 sentence
    truncated = " ".join(sentences[:cutoff])

    # Return incomplete CoT
    return "<think>\n" + truncated

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /global/homes/r/ritesh11/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [8]:
import re

def apply_template(question_data, tokenizer, qtype):
    
    if qtype == "GENERATIVE_DISTINGUISH":
        messages = [
            {
                "role": "system",
                "content": (
                    "You are an intelligent assistant tasked with distinguishing between "
                    "factual and counterfactual scenarios. Read the prompt and the two scenarios. "
                    "Based on your world knowledge, determine which scenario aligns with real-world facts. "
                    "Answer in one word "
                    "in this format:\n\n"
                    "/boxed{0} or /boxed{1}\n\n"
                    "Use this format exactly as shown. Do not write anything after the boxed answer."
                )
            },
            {
                "role": "user",
                "content": (
                    f"Prompt: {question_data['prompt']}\n\n"
                    f"Scenario 0: {question_data['scenarios']['true_fact_scenario']}\n"
                    f"Scenario 1: {question_data['scenarios']['false_fact_scenario']}\n\n"
                    "Which scenario aligns with real-world facts?"
                )
            }
        ]
    else:
        messages = [
            {
                "role": "system",
                "content": (
                    "You are an intelligent assistant tasked with answering multiple-choice questions. "
                    "Choose the most appropriate option from the given choices. "
                    "Your final answer must be a single word or phrase exactly as it appears in the options. "
                    "End your response with the answer written in this format:\n\n"
                    "/boxed{ANSWER}\n\n "
                    "Do not write anything else after boxed answer."
                )
            },
            {
                "role": "user",
                "content": (
                    f"{question_data['question']}\n"
                    f"Options: {question_data['options']}"
                )
            }
        ]

    formatted_prompt = tokenizer.apply_chat_template(
        conversation=messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=True,
    )

    # Extract reasoning with <think>... but no </think>
    answer_cots = []
    cots = []
    for gen in question_data['generations']:
        cot = truncate_cot_sentence_level(gen)
        if cot is None:
            continue
        cots.append(cot)
        answer_cots.append(formatted_prompt + cot)
        # match = re.search(r"<think>(.*?)</think>", gen, re.DOTALL)
        # if match:
        #     cot_content = match.group(1).strip()
        #     cot = "<think>" + cot_content  # open tag only
        #     cots.append(cot)
        #     answer_cots.append(formatted_prompt + cot)
    
    return answer_cots, cots

In [9]:
for temp in [0.5]:
    
    sampling_params.temperature = temp
    
    for fact_num in range(10):
        if fact_num in [0,1,3]:
            continue
        data_path = os.path.join(data_dir, f"final_output_Fact{fact_num}_{temp}.json")
        with open(data_path, "r") as f:
            data = json.load(f)
            
            # Output structure
            updated_rows = {}

            for qt in qtypes:
                
                if qt == "MCQ_KNOWLEDGE_TRUE" or qt == "MCQ_KNOWLEDGE_FALSE":
                    continue
                
                for qsub in qsubtypes:
            
                    qdata = data[qt][qsub]
            
                    for qd in tqdm(qdata, desc=f"{qt} | {qsub}"):                     
                        
                        # Format prompts for vLLM (chat-style)
                        batch_prompts, cots = apply_template(qd, tokenizer, qt)
                        
                        # Generate using vLLM
                        request_output = llm.generate(
                            prompts=batch_prompts,
                            sampling_params=sampling_params,
                            use_tqdm=False,
                        )
                        
                        if qt not in updated_rows:
                            updated_rows[qt] = {}
                        if qsub not in updated_rows[qt]:
                            updated_rows[qt][qsub] = []
                
                        # one generation per prompt
                        generations = [resp.outputs[0].text.strip() for resp in request_output]
                        generations = [c+g for g,c in zip(generations,cots)]
                        # Clone original q and attach generations
                        q_with_gen = qd.copy()
                        q_with_gen["generations"] = generations
                        
                        # Append to results list
                        updated_rows[qt][qsub].append(q_with_gen)
                
                        # Backup save
                        batch_filename = os.path.join(res_dir, f"backup_Fact{fact_num}_{temp}.json")
                        with open(batch_filename, "w") as f:
                            json.dump(updated_rows, f, indent=2)
                        
        
            # Final dump after all batches
            final_filename = os.path.join(res_dir, f"final_output_Fact{fact_num}_{temp}.json")
            with open(final_filename, "w") as f:
                json.dump(updated_rows, f, indent=2)
            
            print(f"✅ Fact {fact_num} Done")
            
        print(f"TEMPERATURE {temp} Done")

MCQ_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [02:10<00:00,  8.68s/it]
MCQ_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [02:26<00:00,  9.77s/it]
GENERATIVE_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [02:47<00:00, 11.19s/it]
GENERATIVE_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [02:10<00:00,  8.70s/it]


✅ Fact 2 Done
TEMPERATURE 0.5 Done


MCQ_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [02:42<00:00, 10.84s/it]
MCQ_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [02:46<00:00, 11.12s/it]
GENERATIVE_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [01:36<00:00,  6.42s/it]
GENERATIVE_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [01:27<00:00,  5.81s/it]


✅ Fact 4 Done
TEMPERATURE 0.5 Done


MCQ_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [01:04<00:00,  4.31s/it]
MCQ_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [01:12<00:00,  4.80s/it]
GENERATIVE_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [01:27<00:00,  5.86s/it]
GENERATIVE_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [01:56<00:00,  7.77s/it]


✅ Fact 5 Done
TEMPERATURE 0.5 Done


MCQ_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [01:21<00:00,  5.42s/it]
MCQ_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [03:05<00:00, 12.38s/it]
GENERATIVE_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [00:25<00:00,  1.71s/it]
GENERATIVE_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [00:11<00:00,  1.36it/s]


✅ Fact 6 Done
TEMPERATURE 0.5 Done


MCQ_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [01:45<00:00,  7.05s/it]
MCQ_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [01:31<00:00,  6.10s/it]
GENERATIVE_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [01:53<00:00,  7.55s/it]
GENERATIVE_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [01:36<00:00,  6.46s/it]


✅ Fact 7 Done
TEMPERATURE 0.5 Done


MCQ_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [00:10<00:00,  1.43it/s]
MCQ_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [00:22<00:00,  1.52s/it]
GENERATIVE_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [00:30<00:00,  2.01s/it]
GENERATIVE_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [02:02<00:00,  8.17s/it]


✅ Fact 8 Done
TEMPERATURE 0.5 Done


MCQ_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [00:38<00:00,  2.55s/it]
MCQ_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [00:39<00:00,  2.63s/it]
GENERATIVE_DISTINGUISH | direct_reference: 100%|██████████| 15/15 [03:04<00:00, 12.31s/it]
GENERATIVE_DISTINGUISH | indirect_reference: 100%|██████████| 15/15 [03:25<00:00, 13.73s/it]

✅ Fact 9 Done
TEMPERATURE 0.5 Done





In [10]:
print(batch_prompts[2])

<|im_start|>system
You are an intelligent assistant tasked with distinguishing between factual and counterfactual scenarios. Read the prompt and the two scenarios. Based on your world knowledge, determine which scenario aligns with real-world facts. Answer in one word in this format:

/boxed{0} or /boxed{1}

Use this format exactly as shown. Do not write anything after the boxed answer.<|im_end|>
<|im_start|>user
Prompt: Which of the following scenarios is more plausible?

Scenario 0: A tourist on a boat in the ocean spots a pod of Blue Whales, the largest mammals on Earth, and witnesses their synchronized behavior, which is a rare and beautiful sight.
Scenario 1: A tourist on a safari in the savannah spots a herd of African Elephants, the largest mammals on Earth, and witnesses their synchronized behavior, which is a rare and beautiful sight.

Which scenario aligns with real-world facts?<|im_end|>
<|im_start|>assistant
<think>
Okay, let's tackle this question. The user is asking which