In [None]:
import os
import gc
import time
import warnings

import pandas as pd
import re
import torch
import json
from tqdm import tqdm

from vllm import LLM, SamplingParams
import ctypes

In [None]:
warnings.simplefilter('ignore')

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def clean_memory(deep=False):
    gc.collect()
    if deep:
        ctypes.CDLL("libc.so.6").malloc_trim(0)
    torch.cuda.empty_cache()

llm_model_path = 'trained_models/base'
tok_path = 'temp_dir/Qwen3-1.7B'


llm = LLM(
    llm_model_path,
    tokenizer=tok_path,
    #dtype="half",                -> Changed this
    #max_num_seqs=128,            -> Changed this       
    trust_remote_code=True,     
    tensor_parallel_size=1,      
    gpu_memory_utilization=0.90, 
)

In [None]:
tokenizer  = llm.get_tokenizer()

In [None]:
N_SAMPLES = 20


sampling_params = SamplingParams(
    n=N_SAMPLES,
    temperature=1,
    top_p=0.95,
    top_k=-1,
    max_tokens=9000,

)

In [None]:
import json

with open("final_output.json", "r") as f:
    prompts = json.load(f)

In [None]:
Facts = list(prompts.keys())

In [None]:
# Output structure
updated_rows = {}
BATCH_SIZE = 1  # or as per your VRAM and throughput
all_gens = []

In [None]:
def apply_template(prompt, tokenizer):
    messages = [
        {"role": "user", "content": prompt}
    ]
    formatted_prompt = tokenizer.apply_chat_template(
        conversation=messages,
        tokenize=False,
        add_generation_prompt=True
    )
    return formatted_prompt

In [None]:
for fact in Facts:
    doctypes = prompts[fact]
    for doc_type, prompt_block in doctypes.items():

        # Extract all <prompt>...</prompt> strings
        curr_prompts = re.findall(r"<prompt>(.*?)</prompt>", prompt_block, re.DOTALL)

        for i in tqdm(range(0, len(curr_prompts), BATCH_SIZE), desc=f"{fact[:30]}... | {doc_type}"):
            batch_raw_prompts = curr_prompts[i:i + BATCH_SIZE]

            # Format prompts for vLLM (chat-style)
            batch_prompts = [
                apply_template(prompt, tokenizer) for prompt in batch_raw_prompts
            ]

            # Generate using vLLM
            request_output = llm.generate(
                prompts=batch_prompts,
                sampling_params=sampling_params,
                use_tqdm=False,
            )

            # Store results: handle multiple outputs per prompt
            for j, prompt_text in enumerate(batch_raw_prompts):
                if fact not in updated_rows:
                    updated_rows[fact] = {}
                if doc_type not in updated_rows[fact]:
                    updated_rows[fact][doc_type] = {}

                generations = [out.text.strip() for out in request_output[j].outputs]
                updated_rows[fact][doc_type][prompt_text] = generations
                all_gens.extend(generations)

            # Save after each batch
            with open("backup.json", "w") as f:
                json.dump(updated_rows, f, indent=2)

            print(f"BATCH {i + BATCH_SIZE} / {len(curr_prompts)} DONE for {doc_type}")

# Final dump
with open("final_output.json", "w") as f:
    json.dump(updated_rows, f, indent=2)
