In [1]:
import json
import os

class PIQADataset:
    def __init__(self, data_dir, split='dev'):
        """
        Args:
            data_dir (str): Path to directory containing jsonl and .lst files.
            split (str): 'train' or 'dev'
        """
        assert split in ['train', 'dev'], "Split must be 'train' or 'dev'"
        self.data_path = os.path.join(data_dir, f"{split}.jsonl")
        self.label_path = os.path.join(data_dir, f"{split}-labels.lst")

        self.data = self._load_data()
        self.labels = self._load_labels()
        assert len(self.data) == len(self.labels), "Data and labels length mismatch"

    def _load_data(self):
        with open(self.data_path, 'r', encoding='utf-8') as f:
            return [json.loads(line) for line in f]

    def _load_labels(self):
        with open(self.label_path, 'r', encoding='utf-8') as f:
            return [int(line.strip()) for line in f]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        label = self.labels[idx]
        return {
            'id': item['id'],
            'goal': item['goal'],
            'sol1': item['sol1'],
            'sol2': item['sol2'],
            'label': label  # 0 = sol1 (A), 1 = sol2 (B)
        }


In [2]:
import os
import requests

local_only = os.path.join(os.getcwd(), '..', 'local_only')
local_only = os.path.abspath(local_only)

In [3]:
folder_name = 'physicaliqa-train-dev'
data_dir = os.path.join(local_only, folder_name)

In [4]:
split = 'dev'
dataset = PIQADataset(data_dir=data_dir, split=split)

In [5]:
prompt_template = \
"""
You are given a goal and two possible ways to achieve it. Choose the option that is more physically plausible in the real world.

Goal: {{goal}}
Option A: {{sol1}}
Option B: {{sol2}}

Which option is more physically plausible? Answer with just the letter A or B:

"""

In [6]:
from jinja2 import Template
template = Template(prompt_template)

In [7]:
import requests
class LLM:
    def __init__(self, url=None, model=None):
        self.url = url or 'http://localhost:11434/api/generate'
        self.model = model #or 'phi4:latest'
        print(f'LLM used -> {self.model}, requests sent to URL -> {self.url}')
    def __call__(self, prompt):
        payload  = dict(model=self.model, prompt=prompt, stream=False, temperature=0)
        response = requests.post(self.url, json = payload)
        return response

In [8]:
response  = requests.get('http://localhost:11434/api/tags')

In [9]:
if response.ok:
    llm_list = response.json()['models']

for i in llm_list:
    print(i['name'])

phi4:latest
stable_diffusion_3.5_medium:latest
stable_diffusion_3.5_large:latest
phi_4_gguf:latest
mistral_7b_instruct_v0.3_gguf:latest
meta_llama_3.1_8b_instruct_gguf:latest
llava_v1.5_7b_gguf:latest


In [10]:
model = 'meta_llama_3.1_8b_instruct_gguf:latest'
llm = LLM(model=model)

LLM used -> meta_llama_3.1_8b_instruct_gguf:latest, requests sent to URL -> http://localhost:11434/api/generate


In [11]:
from pymongo import MongoClient
from pymongo.errors import DuplicateKeyError

client = MongoClient("mongodb://localhost:27017/")
db_name = f'PIQA_{split}'
db = client[db_name]
#db = client["test"]

collection = db[model]
collection.create_index([("row_id", 1), ("prompt_template",1)], unique=True)

'row_id_1_prompt_template_1'

In [12]:
%%time
x =5

CPU times: user 1 μs, sys: 0 ns, total: 1 μs
Wall time: 2.15 μs


In [13]:
import datetime
print(datetime.datetime.now())

2025-05-21 13:15:18.658232


In [14]:
%%time
print(datetime.datetime.now())
print(len(dataset))
for index, i in enumerate(dataset):
    if index%10 == 0:
        print(index, end='\t')
    row_id = i['id']
    goal = i['goal']
    sol1 = i['sol1']
    sol2 = i['sol2']
    label = i['label']
    rendered_prompt = template.render(goal = goal, sol1=sol1, sol2=sol2)
    llm_response = llm(rendered_prompt).json()
    answer = llm_response['response']

    record = dict(row_id=row_id, goal=goal, sol1=sol1, sol2=sol2, 
                  label=label, prompt_template=prompt_template,llm_response=llm_response, answer=answer )
    try:
        collection.insert_one(record)
    except DuplicateKeyError:
        print("Duplicate detected on insert", end='\t')
    except Exception as e:
        print(f"Insert error at row {index}: {e}")

print(datetime.datetime.now())

2025-05-21 13:15:18.663721
1838
0	10	20	30	40	50	60	70	80	90	100	110	120	130	140	150	160	170	180	190	200	210	220	230	240	250	260	270	280	290	300	310	320	330	340	350	360	370	380	390	400	410	420	430	440	450	460	470	480	490	500	510	520	530	540	550	560	570	580	590	600	610	620	630	640	650	660	670	680	690	700	710	720	730	740	750	760	770	780	790	800	810	820	830	840	850	860	870	880	890	900	910	920	930	940	950	960	970	980	990	1000	1010	1020	1030	1040	1050	1060	1070	1080	1090	1100	1110	1120	1130	1140	1150	1160	1170	1180	1190	1200	1210	1220	1230	1240	1250	1260	1270	1280	1290	1300	1310	1320	1330	1340	1350	1360	1370	1380	1390	1400	1410	1420	1430	1440	1450	1460	1470	1480	1490	1500	1510	1520	1530	1540	1550	1560	1570	1580	1590	1600	1610	1620	1630	1640	1650	1660	1670	1680	1690	1700	1710	1720	1730	1740	1750	1760	1770	1780	1790	1800	1810	1820	1830	2025-05-21 13:19:57.477838
CPU times: user 2.53 s, sys: 214 ms, total: 2.74 s
Wall time: 4min 38s
