In [5]:
!nvidia-smi

Wed Oct 29 10:16:12 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  |   00000000:81:00.0 Off |                  N/A |
|  0%   40C    P8             31W /  370W |       4MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
!kill 925008

In [1]:
!hostname
!which python
import torch
print("CUDA available:", torch.cuda.is_available())

limbo
/opt/miniforge3/envs/jupyterhub/bin/python
CUDA available: True


In [6]:
import os
import json
from textwrap import dedent
from typing import Dict, Any, List, Tuple, Optional

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


def setup_model(model_id: str = "mistralai/Mistral-7B-Instruct-v0.3"):
    """
    Load the chat model + tokenizer and return a text-generation pipeline.
    Uses half precision + device_map='auto' for efficiency.
    """
    print(f"[LOAD] model={model_id}")
    torch.backends.cudnn.benchmark = True

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.float16,
    )
    model.config.use_cache = True

    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device_map="auto",
    )
    return generator, tokenizer


In [7]:
def read_jsonl(path: str, max_items: Optional[int] = None):
    """
    Stream records from a .jsonl file.
    Stops early if max_items is provided.
    """
    count = 0
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)
            count += 1
            if max_items is not None and count >= max_items:
                break


def write_jsonl(path: str, records: List[Dict[str, Any]]):
    """
    Write a list of dicts as JSON lines.
    Creates parent directory if needed.
    """
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        for rec in records:
            f.write(json.dumps(rec, ensure_ascii=False))
            f.write("\n")


_TEXT_KEYS_PRIORITY = ("sent", "text", "Text", "sentence", "Sentence")


def extract_text_field(rec: Dict[str, Any]) -> Tuple[str, str]:
    """
    Heuristic to pick the text field from an input record.
    Returns (text_value, key_used).
    Falls back to the longest string field if none of the preferred keys exist.
    """
    for k in _TEXT_KEYS_PRIORITY:
        v = rec.get(k)
        if isinstance(v, str) and v.strip():
            return v.strip(), k

    # fallback: choose longest string in record
    best_key, best_val = "", ""
    for k, v in rec.items():
        if isinstance(v, str) and len(v) > len(best_val):
            best_key, best_val = k, v
    return best_val.strip(), best_key


In [8]:
def _build_concept_index(ontology_json: Dict[str, Any]) -> Dict[str, str]:
    """
    Map any known identifier (qid/id/label) -> canonical label string.
    This lets us convert domain/range IDs into human-readable names.
    """
    idx: Dict[str, str] = {}
    for concept in ontology_json.get("concepts", []):
        label = str(concept.get("label", "")).strip()
        if not label:
            continue

        for keyname in ("qid", "id", "label"):
            raw_val = concept.get(keyname)
            if raw_val is None:
                continue

            sval = str(raw_val).strip()
            if sval:
                idx[sval] = label
    return idx


def _label_for(raw_val: Any, cindex: Dict[str, str]) -> str:
    """
    Convert domain/range IDs to readable labels.
    Fallback to string form of raw_val.
    """
    if raw_val is None:
        return ""
    rval = str(raw_val).strip()
    return cindex.get(rval, rval)


def render_concept_list(ontology_json: Dict[str, Any]) -> str:
    """
    Return a bullet list of ontology concepts by label.
    """
    lines: List[str] = []
    for c in ontology_json.get("concepts", []):
        label = str(c.get("label", "")).strip()
        if label:
            lines.append(f"- {label}")
    return "\n".join(lines)


def render_relation_list(ontology_json: Dict[str, Any]) -> str:
    """
    Return a bullet list of relations with (domain, range) in human-readable form.
    Format: - relationLabel(domainLabel,rangeLabel)
    """
    cindex = _build_concept_index(ontology_json)
    lines: List[str] = []
    for r in ontology_json.get("relations", []):
        rel_label = str(r.get("label", "")).strip()
        dom_label = _label_for(r.get("domain"), cindex)
        rng_label = _label_for(r.get("range"), cindex)
        if rel_label:
            lines.append(f"- {rel_label}({dom_label},{rng_label})")
    return "\n".join(lines)


def _escape_multiline(s: str) -> str:
    """
    Escape backslashes and quotes so we can safely embed text
    inside quoted blocks in the USER prompt.
    """
    return s.replace("\\", "\\\\").replace('"', '\\"')


In [9]:
def build_prompt1_system() -> str:
    """
    SYSTEM message for Prompt 1.
    Allows both ontology-aligned and non-ontology triples,
    still returns strict JSON only.
    """
    return (
        "You are a KG triple proposer in a Tree-of-Thoughts loop. "
        "First detect entity mentions and assign tentative ontology types. "
        "Then propose candidate triples [subject, relation, object] that express factual statements in the text. "
        "You MUST include:\n"
        "1) triples whose relation/domain/range matches the ontology, AND\n"
        "2) any other clearly stated factual triples in the text even if the relation or types are not present in the ontology.\n"
        "Return only JSON. Do not include any natural language outside JSON."
    )


def build_prompt1_user(
    TEXT: str,
    ontology_json: Dict[str, Any],
    k: int,
) -> str:
    """
    USER message for Prompt 1.
    Updated so the model will not suppress non-ontology triples.
    We keep the same JSON schema (mentions[], triples[]),
    but adjust the task + constraints language.
    """
    concept_block = render_concept_list(ontology_json)
    relation_block = render_relation_list(ontology_json)

    return dedent(f"""
    Task:
    1) From the text, list detected entity mentions with tentative ontology types.
    2) Propose up to k={k} candidate triples [subject, relation, object].

    VERY IMPORTANT:
    - You MUST include all explicit factual triples stated in the text, even if the relation,
      subject type, or object type is not listed in the ontology.
    - ALSO include ontology-valid triples whose domain/range matches the ontology relations.

    For each triple, include confidence ∈ [0,1] and cite the exact supporting span(s).

    Text
    "{_escape_multiline(TEXT)}"

    Ontology concepts
    {concept_block}

    Ontology relations (domain → range)
    {relation_block}

    Output format (JSON only)
    {{
      "mentions": [
        {{"surface": "...", "type_candidates": ["ConceptA","ConceptB"], "span": [start,end]}}
      ],
      "triples": [
        {{
          "triple": ["subject","relation","object"],
          "confidence": 0.0,
          "support": "exact quote from text",
          "notes": "why this triple is supported; if ontology applies, explain domain/range fit. If not in ontology, say 'not in ontology but supported by text'."
        }}
      ]
    }}

    Constraints
    - Extract ALL clearly stated factual triples in the text.
    - If a triple matches an ontology relation, enforce domain→range consistency and mention that in notes.
    - If a triple does NOT match any ontology relation, you MUST STILL include it (do not discard it).
    - Always extract any explicit date, time, or year mentioned in the text as part of a factual triple.
    - Resolve pronouns to the nearest valid antecedent and describe that in notes.
    - Do not invent entities that are not mentioned in the text.
    - Output MUST be valid JSON and nothing else.
    """).strip()


In [10]:
def generate_model_response(
    generator,
    tokenizer,
    system_msg: str,
    user_msg: str,
    max_new_tokens: int = 768,
    temperature: float = 0.25,
) -> str:
    """
    Run the chat-style generation and return the raw generated text.
    We assume the model will output ONLY the JSON object.
    """
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user",   "content": user_msg},
    ]

    prompt_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    out = generator(
        prompt_text,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=0.9,
        do_sample=True,
        return_full_text=False,
        truncation=False,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )

    # HF pipeline can return dicts or strings depending on version
    return out[0]["generated_text"] if isinstance(out[0], dict) else out[0]


In [11]:
import json
import re
from typing import Any, Dict, List, Optional, Tuple

def robust_parse_model_output(raw_response: str) -> Dict[str, Any]:
    """
    Try very hard to get structured data out of the model response.
    Returns a dict with at least:
      {
        "triples": [ [head, rel, tail], ... ],
        "mentions": [...] or None,
        "raw_json_obj": ... or None
      }
    """

    # 1. Try direct parse
    try:
        obj = json.loads(raw_response)
        return {
            "triples": extract_triples_from_obj(obj),
            "mentions": obj.get("mentions"),
            "raw_json_obj": obj,
        }
    except Exception:
        pass

    # 2. Try slice from first { to last }
    try:
        start_i = raw_response.find("{")
        end_i = raw_response.rfind("}")
        if start_i != -1 and end_i != -1 and end_i > start_i:
            candidate = raw_response[start_i:end_i+1]
            obj = json.loads(candidate)
            return {
                "triples": extract_triples_from_obj(obj),
                "mentions": obj.get("mentions"),
                "raw_json_obj": obj,
            }
    except Exception:
        pass

    # 3. Fallback: regex mine triples from messy text
    triples = extract_triples_via_regex(raw_response)

    return {
        "triples": triples,
        "mentions": None,
        "raw_json_obj": None,
    }


def extract_triples_from_obj(obj: Any) -> List[List[str]]:
    """
    Safely pull triples out of a parsed JSON object,
    accounting for ["h","r","t"] or {"triple":[...]} formats.
    """
    results: List[List[str]] = []

    # if it's dict-like and has "triples"
    if isinstance(obj, dict) and "triples" in obj:
        for item in obj["triples"]:
            # item could be ["h","r","t"]
            if isinstance(item, list) and len(item) >= 3:
                results.append(item[:3])
            # item could be { "triple": ["h","r","t"], ... }
            elif isinstance(item, dict) and "triple" in item:
                tval = item["triple"]
                if isinstance(tval, list) and len(tval) >= 3:
                    results.append(tval[:3])

    return results


def extract_triples_via_regex(raw_text: str) -> List[List[str]]:
    """
    Ultra-forgiving fallback.
    Finds patterns like:
      "triple": ["Head", "Rel", "Tail"]
    even if the outer JSON is broken.
    """
    triples: List[List[str]] = []

    # This regex:
    # - looks for "triple": [ "....", "....", "...." ]
    # - captures the 3 strings inside
    triple_pattern = re.compile(
        r'"triple"\s*:\s*\[\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"\s*\]'
    )

    for match in triple_pattern.finditer(raw_text):
        h, r, t = match.groups()
        triples.append([h, r, t])

    return triples


In [12]:
def run_prompt1_pipeline(
    generator,
    tokenizer,
    input_jsonl_path: str,
    ontology_json_path: str,
    output_jsonl_path: str,
    k: int = 6,
    max_items: Optional[int] = None,
    max_new_tokens: int = 900,
    temperature: float = 0.25,
    verbose: bool = False,
) -> List[Dict[str, Any]]:
    """
    Core batch function for Prompt 1.
    Reads each row from input_jsonl_path,
    runs Prompt 1 on it,
    and writes a JSONL of model outputs with parsed JSON.
    Uses robust parsing to recover triples even from malformed model outputs.
    """

    # load ontology once
    with open(ontology_json_path, "r", encoding="utf-8") as f:
        ontology_json = json.load(f)

    sys_text = build_prompt1_system()
    outputs: List[Dict[str, Any]] = []

    for idx, rec in enumerate(read_jsonl(input_jsonl_path, max_items=max_items)):
        # extract text from record
        text_val, text_key = extract_text_field(rec)

        # build user prompt
        usr_text = build_prompt1_user(
            TEXT=text_val,
            ontology_json=ontology_json,
            k=k,
        )

        if verbose:
            print(f"\n[ITEM {idx}] text_key={text_key}")
            print(f"[PROMPT_USER] {usr_text[:320]} ...")

        # generate model response
        raw_response = generate_model_response(
            generator=generator,
            tokenizer=tokenizer,
            system_msg=sys_text,
            user_msg=usr_text,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
        )

        # robust parsing of model output
        parsed_bundle = robust_parse_model_output(raw_response)

        # construct output record
        out_record = {
            "id": rec.get("id"),
            "input text": text_val,
            "prompts": {
                "system_prompt": sys_text,
                "user_prompt": usr_text,
            },
            "response": {
                "LLM_output": raw_response,
                "json": parsed_bundle.get("raw_json_obj"),
            },
        }

        outputs.append(out_record)

    # write all outputs
    write_jsonl(output_jsonl_path, outputs)

    if verbose:
        print(f"[DONE] wrote {len(outputs)} records -> {output_jsonl_path}")

    return outputs


In [10]:
# import pprint

# # --- CONFIG ---
# ONTOLOGY_JSON = "/upb/users/b/balram/profiles/unix/cs/promptKG/data/input/wikidata/input_ontology/1_movie_ontology.json"
# INPUT_JSONL   = "/upb/users/b/balram/profiles/unix/cs/promptKG/data/input/wikidata/input_text/ont_1_movie_test.jsonl"
# OUTPUT_JSONL  = "/upb/users/b/balram/profiles/unix/cs/promptKG/data/output/prompt1/wikidata/ont_1_movie_output_test1.jsonl"

# MAX_ITEMS        = 1          # how many examples to actually run through model
# MAX_NEW_TOKENS   = 900
# TEMPERATURE      = 0.25
# VERBOSE          = True
# K_CANDIDATES     = 6

# # --- 1. Quick peek at input file ---
# peek_items = list(read_jsonl(INPUT_JSONL, max_items=3))

# if not peek_items:
#     print(f"[ERROR] No records found in: {INPUT_JSONL}")
# else:
#     print(f"[DEBUG] Loaded {len(peek_items)} sample record(s) from {INPUT_JSONL}")
#     for i, rec in enumerate(peek_items):
#         text_val, text_key = extract_text_field(rec)
#         print(f"\n--- SAMPLE {i} ---")
#         print("[keys]:", list(rec.keys()))
#         print(" id:", rec.get("id"))
#         print(f" chosen_text_key: {text_key}")
#         preview = text_val[:200] + ("..." if len(text_val) > 200 else "")
#         print(" text preview:", preview)

# # --- 2. Load ontology once for prompt inspection ---
# with open(ONTOLOGY_JSON, "r", encoding="utf-8") as f:
#     ontology_data = json.load(f)

# concept_block_dbg  = render_concept_list(ontology_data)
# relation_block_dbg = render_relation_list(ontology_data)

# print("\n[DEBUG] ONTOLOGY CONCEPT LIST (truncated):")
# print(concept_block_dbg[:500] + ("..." if len(concept_block_dbg) > 500 else ""))

# print("\n[DEBUG] ONTOLOGY RELATION LIST (truncated):")
# print(relation_block_dbg[:500] + ("..." if len(relation_block_dbg) > 500 else ""))

# # --- 3. Show the exact SYSTEM and USER prompt that will go to the model for the FIRST sample ---
# if peek_items:
#     sample_text, _ = extract_text_field(peek_items[0])

#     system_prompt_dbg = build_prompt1_system()
#     user_prompt_dbg   = build_prompt1_user(
#         TEXT=sample_text,
#         ontology_json=ontology_data,
#         k=K_CANDIDATES,
#     )

#     print("\n================ [SYSTEM PROMPT] ================")
#     print(system_prompt_dbg)

#     print("\n================ [USER PROMPT - FIRST SAMPLE] ================")
#     # we don't print the entire ontology if it's massive, but we already showed truncated above
#     # still, show first ~1200 chars so you can visually inspect formatting
#     up_prev = user_prompt_dbg[:12000]
#     print(up_prev)
#     if len(user_prompt_dbg) > 12000:
#         print("... [USER PROMPT TRUNCATED FOR DISPLAY] ...")

# # --- 4. Spin up the model ---
# generator, tokenizer = setup_model()

# # --- 5. Dry-run: generate ONLY for the first record (no batch write yet),
# #         so we can inspect the raw model output and parsed JSON.
# if peek_items:
#     one_text, _ = extract_text_field(peek_items[0])
#     raw_single = generate_model_response(
#         generator=generator,
#         tokenizer=tokenizer,
#         system_msg=build_prompt1_system(),
#         user_msg=build_prompt1_user(
#             TEXT=one_text,
#             ontology_json=ontology_data,
#             k=K_CANDIDATES,
#         ),
#         max_new_tokens=MAX_NEW_TOKENS,
#         temperature=TEMPERATURE,
#     )

#     print("\n================ [RAW MODEL OUTPUT - FIRST SAMPLE] ================")
#     print(raw_single[:150000] + ("..." if len(raw_single) > 150000 else ""))

#     parsed_single = None
#     try:
#         parsed_single = json.loads(raw_single)
#     except Exception:
#         try:
#             start_i = raw_single.find("{")
#             end_i   = raw_single.rfind("}")
#             if start_i != -1 and end_i != -1 and end_i > start_i:
#                 candidate = raw_single[start_i:end_i+1]
#                 parsed_single = json.loads(candidate)
#         except Exception:
#             parsed_single = None

#     print("\n================ [PARSED MODEL JSON - FIRST SAMPLE] ================")
#     pprint.pprint(parsed_single, width=120)

#     # Sanity check for evaluator compatibility:
#     if parsed_single and "triples" in parsed_single:
#         print("\n[CHECK] triples[0] example for evaluator compatibility:")
#         if parsed_single["triples"]:
#             pprint.pprint(parsed_single["triples"][0], width=100)
#         else:
#             print("No triples returned.")
#     else:
#         print("[WARN] Model output did not parse into expected {'mentions':..., 'triples':...} shape.")

# # --- 6. Full mini-pipeline run (writes OUTPUT_JSONL) on first MAX_ITEMS records ---
# print("\n================ [BATCH PIPELINE RUN] ================")
# batch_outputs = run_prompt1_pipeline(
#     generator=generator,
#     tokenizer=tokenizer,
#     input_jsonl_path=INPUT_JSONL,
#     ontology_json_path=ONTOLOGY_JSON,
#     output_jsonl_path=OUTPUT_JSONL,
#     max_items=MAX_ITEMS,
#     max_new_tokens=MAX_NEW_TOKENS,
#     temperature=TEMPERATURE,
#     verbose=VERBOSE,
#     k=K_CANDIDATES,
# )

# print(f"\n[SUCCESS] Wrote {len(batch_outputs)} items to {OUTPUT_JSONL}")

# # --- 7. Peek at what was written (first 1-2 records) ---
# print("\n================ [WRITTEN OUTPUT PREVIEW] ================")
# for i, rec in enumerate(batch_outputs[:2]):
#     print(f"\n--- OUTPUT ITEM {i} ---")
#     # show keys and parsed json triples
#     print("[output keys]:", list(rec.keys()))
#     resp = rec.get("response", {})
#     parsed_json = resp.get("json")
#     if parsed_json:
#         print("[triples exists? ]", "triples" in parsed_json)
#         if "triples" in parsed_json and parsed_json["triples"]:
#             print("[first triple]:")
#             pprint.pprint(parsed_json["triples"][0], width=100)
#     else:
#         print("[WARN] No parsed JSON for this item.")


In [19]:
########################################
# WIKIDATA BATCH RUN
########################################

import os
import re

# ont_{index}_{category}_test.jsonl  ->
# {index}_{category}_ontology.json,
# ont_{index}_{category}_few_shot.jsonl,
# ont_{index}_{category}_output.jsonl
PATTERN = re.compile(r"^ont_(\d+)_([a-z]+)_test\.jsonl$")

def make_paths(filename: str, BASE_INPUT: str, BASE_ONTO: str, BASE_OUT: str):
    m = PATTERN.match(filename)
    if not m:
        raise ValueError(f"Unexpected filename format: {filename}")
    idx, cat = m.groups()

    input_jsonl = os.path.join(BASE_INPUT, filename)
    ontology_json = os.path.join(BASE_ONTO, f"{idx}_{cat}_ontology.json")

    # ont_{idx}_{cat}_test.jsonl -> ont_{idx}_{cat}_output.jsonl
    out_name = filename.replace("_test.jsonl", "_output.jsonl")
    output_jsonl = os.path.join(BASE_OUT, out_name)

    tag = f"ont_{idx}_{cat}"

    return input_jsonl, ontology_json, output_jsonl, tag


def run_wikidata_batch():
    BASE_INPUT = "/upb/users/b/balram/profiles/unix/cs/promptKG/data/input/wikidata/input_text/"
    BASE_ONTO  = "/upb/users/b/balram/profiles/unix/cs/promptKG/data/input/wikidata/input_ontology/"
    BASE_OUT   = "/upb/users/b/balram/profiles/unix/cs/promptKG/data/output/prompt1/wikidata/"

    FILENAMES = [
        "ont_1_movie_test.jsonl",
        "ont_2_music_test.jsonl",
        "ont_3_sport_test.jsonl",
        "ont_4_book_test.jsonl",
        "ont_5_military_test.jsonl",
        "ont_6_building_test.jsonl",
        "ont_7_tv_test.jsonl",
        "ont_8_politician_test.jsonl",
        "ont_9_organization_test.jsonl",
        "ont_10_airport_test.jsonl",
    ]

    generator, tokenizer = setup_model()

    for fname in FILENAMES:
        try:
            # derive ontology index/category from filename
            # e.g. "ont_1_movie_test.jsonl" -> idx="1", cat="movie"
            m = re.match(r"ont_(\d+)_(.+?)_test\.jsonl$", fname)
            if not m:
                print(f"[SKIP] can't parse filename pattern: {fname}")
                continue
            idx, cat = m.group(1), m.group(2)

            input_jsonl  = os.path.join(BASE_INPUT, fname)
            # ontology_json = os.path.join(BASE_ONTO, f"ont_{idx}_{cat}.json")
            ontology_json = os.path.join(BASE_ONTO, f"{idx}_{cat}_ontology.json")
            output_jsonl = os.path.join(
                BASE_OUT,
                fname.replace("_test.jsonl", "_output.jsonl")
            )

            print("\n" + "=" * 80)
            print(f"[RUN] wikidata ont_{idx}_{cat}")

            run_prompt1_pipeline(
                generator=generator,
                tokenizer=tokenizer,
                input_jsonl_path=input_jsonl,
                ontology_json_path=ontology_json,
                output_jsonl_path=output_jsonl,
                max_items=None,
                max_new_tokens=900,
                temperature=0.25,
                verbose=False,
                k=6,
            )

            print(f"[DONE] wikidata ont_{idx}_{cat}")
        except Exception as exc:
            print(f"[ERROR] wikidata {fname}: {exc}")


In [21]:
# run_wikidata_batch()

In [13]:
import re

In [13]:
########################################
# DBPEDIA BATCH RUN
########################################


PATTERN = re.compile(r"^ont_(\d+)_([a-z]+)_test\.jsonl$")

def make_paths(filename: str, BASE_INPUT: str, BASE_ONTO: str, BASE_OUT: str):
    m = PATTERN.match(filename)
    if not m:
        raise ValueError(f"Unexpected filename format: {filename}")
    idx, cat = m.groups()

    input_jsonl = os.path.join(BASE_INPUT, filename)
    ontology_json = os.path.join(BASE_ONTO, f"{idx}_{cat}_ontology.json")

    # ont_{idx}_{cat}_test.jsonl -> ont_{idx}_{cat}_output.jsonl
    out_name = filename.replace("_test.jsonl", "_output.jsonl")
    output_jsonl = os.path.join(BASE_OUT, out_name)

    tag = f"ont_{idx}_{cat}"

    return input_jsonl, ontology_json, output_jsonl, tag



def run_dbpedia_batch():
    BASE_INPUT = "/upb/users/b/balram/profiles/unix/cs/promptKG/data/input/dbpedia/input_text/"
    BASE_ONTO  = "/upb/users/b/balram/profiles/unix/cs/promptKG/data/input/dbpedia/input_ontology/"
    BASE_OUT   = "/upb/users/b/balram/profiles/unix/cs/promptKG/data/output/prompt1/dbpedia/"
    
    
    FILENAMES = [
        # "ont_12_monument_test.jsonl",
        # "ont_1_university_test.jsonl",
        # "ont_2_musicalwork_test.jsonl",
        # "ont_3_airport_test.jsonl",
        "ont_4_building_test.jsonl",
        "ont_5_athlete_test.jsonl",
        "ont_6_politician_test.jsonl",
        "ont_7_company_test.jsonl",
        "ont_8_celestialbody_test.jsonl",
        "ont_9_astronaut_test.jsonl",
        "ont_10_comicscharacter_test.jsonl",
        "ont_11_meanoftransportation_test.jsonl",
        "ont_13_food_test.jsonl",
        "ont_14_writtenwork_test.jsonl",
        "ont_15_sportsteam_test.jsonl",
        "ont_16_city_test.jsonl",
        "ont_17_artist_test.jsonl",
        "ont_18_scientist_test.jsonl",
        "ont_19_film_test.jsonl",
    ]


    generator, tokenizer = setup_model()

    for fname in FILENAMES:
        try:
            # derive ontology index/category from filename
            # e.g. "ont_1_movie_test.jsonl" -> idx="1", cat="movie"
            m = re.match(r"ont_(\d+)_(.+?)_test\.jsonl$", fname)
            if not m:
                print(f"[SKIP] can't parse filename pattern: {fname}")
                continue
            idx, cat = m.group(1), m.group(2)

            input_jsonl  = os.path.join(BASE_INPUT, fname)
            # ontology_json = os.path.join(BASE_ONTO, f"ont_{idx}_{cat}.json")
            ontology_json = os.path.join(BASE_ONTO, f"{idx}_{cat}_ontology.json")
            output_jsonl = os.path.join(
                BASE_OUT,
                fname.replace("_test.jsonl", "_output.jsonl")
            )

            print("\n" + "=" * 80)
            print(f"[RUN] dbpedia ont_{idx}_{cat}")

            run_prompt1_pipeline(
                generator=generator,
                tokenizer=tokenizer,
                input_jsonl_path=input_jsonl,
                ontology_json_path=ontology_json,
                output_jsonl_path=output_jsonl,
                max_items=None,
                max_new_tokens=900,
                temperature=0.25,
                verbose=False,
                k=6,
            )

            print(f"[DONE] wikidata ont_{idx}_{cat}")
        except Exception as exc:
            print(f"[ERROR] wikidata {fname}: {exc}")


In [None]:
run_dbpedia_batch()

[LOAD] model=mistralai/Mistral-7B-Instruct-v0.3


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Device set to use cuda:0



[RUN] dbpedia ont_4_building


You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


[DONE] wikidata ont_4_building

[RUN] dbpedia ont_5_athlete
[DONE] wikidata ont_5_athlete

[RUN] dbpedia ont_6_politician
[DONE] wikidata ont_6_politician

[RUN] dbpedia ont_7_company
[DONE] wikidata ont_7_company

[RUN] dbpedia ont_8_celestialbody
[DONE] wikidata ont_8_celestialbody

[RUN] dbpedia ont_9_astronaut
[DONE] wikidata ont_9_astronaut

[RUN] dbpedia ont_10_comicscharacter
[DONE] wikidata ont_10_comicscharacter

[RUN] dbpedia ont_11_meanoftransportation
[DONE] wikidata ont_11_meanoftransportation

[RUN] dbpedia ont_13_food
[DONE] wikidata ont_13_food

[RUN] dbpedia ont_14_writtenwork
[DONE] wikidata ont_14_writtenwork

[RUN] dbpedia ont_15_sportsteam
[DONE] wikidata ont_15_sportsteam

[RUN] dbpedia ont_16_city
