In [None]:
import pandas as pd
from pathlib import Path
import time

from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import re
from typing import Optional, Dict
# ========= user settings =========
INPUT_CSV_1  = "/inputPath/FileName"
INPUT_CSV_2  = "/inputPath/FileName"
OUTPUT_CSV_1 = "/outputPath/FileName"
OUTPUT_CSV_2 = "/outputPath/FileName"

CHUNK_SIZE = 1000          # rows per pandas chunk read
SAVE_EVERY_ROWS = 200      # flush to disk at least this many new rows
OUTPUT_COL = "author_category"  # name of the new column
OUTPUT_REL = "author_relationship"  # name of the new column
# =================================

# You provide these:
# from my_pipeline import build_prompt, get_llm_output

def infer_categorization(model, tokenizer, prompt: str, max_new_tokens=5) -> str:

    inputs = tokenizer(prompt, padding=False, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        min_new_tokens=1,     # <- prevents immediate empty decode
        do_sample=not False,
        temperature=0.7,
        top_p=0.95,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True
    )
    gen_ids = outputs.sequences[0]
    input_len = inputs["input_ids"].shape[1] if gen_ids.shape[0] > inputs["input_ids"].shape[1] else 0
    gen_only_ids = gen_ids[input_len:]
    response = tokenizer.decode(gen_only_ids, skip_special_tokens=True)#[len(prompt):].strip()
    return response

def load_model_tokenizer(model_checkpoint, tuned = False):

    if not tuned:
        model = AutoModelForCausalLM.from_pretrained(
            model_checkpoint,
            torch_dtype=torch.float16,
            device_map={"":0},
        )

        tokenizer= AutoTokenizer.from_pretrained(model_checkpoint, padding_side = 'left')
        tokenizer.pad_token = tokenizer.eos_token
        return model, tokenizer

model, tokenizer = load_model_tokenizer(model_checkpoint="Qwen/Qwen2.5-14B-Instruct")


def generate_prompt(post_text):
    prompt = f"""
You are an expert annotator of posts from health subreddits (e.g., r/Alzheimers, r/dementia).
Classify the author's perspective into EXACTLY ONE category using ONLY the post content—no outside knowledge or assumptions.

CATEGORIES
A = "Caregiver about patient" — Caregiver mainly describes the patient’s symptoms/behaviors/care decisions.
B = "Caregiver about themselves" — Caregiver mainly describes their own emotions/burden/logistics (not the patient’s medical details).
C = "Patient about themselves" — Author states or clearly implies they are diagnosed and describes their own condition.
D = "Non-diagnosed patient about themselves" — Author worries they might have the condition, describes own symptoms, but NO diagnosis mentioned.
E = "Other" — Not classifiable into A–D, off-topic, or insufficient evidence.

DECISION RULES
1) If the author explicitly states a diagnosis for themselves (e.g., “I was diagnosed…”, “My neurologist confirmed…”), choose C.
2) Else if the author focuses on their own symptoms/fears without a diagnosis, choose D.
3) Else if the author is a caregiver and the primary focus is the patient’s condition/care, choose A.
4) Else if the author is a caregiver and the primary focus is the caregiver’s feelings/stress/logistics, choose B.
5) If multiple apply, choose the ONE category that covers >50% of the post by content/focus. If still ambiguous, choose E.
6) Ignore quoted/reposted material unless the author clearly endorses it as their own experience.

RELATIONSHIP EXTRACTION
Additionally, identify the author's relationship to the person with Alzheimer’s or dementia (if any). Use the most specific term evident from the post (e.g., "spouse", "child", "parent", "friend", "professional caregiver"). If the post is about the author themselves (C or D), use "self". If no relationship is mentioned or discernible, use "unspecified".

OUTPUT (STRICT JSON ONLY — no extra text):
{{
  "category": "A" | "B" | "C" | "D" | "E",
  "relationship": "<string>"
}}

Post:
'''{post_text}'''
"""
    return prompt
# PROMPT_STRICT = (
#     PROMPT_ORIGINAL
# )


VALID_CATS = {"A", "B", "C", "D", "E"}
CAT_MAP = {
    "A": "Caregiver about patient",
    "B": "Caregiver about themselves",
    "C": "Patient about themselves",
    "D": "Non-diagnosed patient about themselves",
    "E": "Other"
}

def _first_fenced_block(text: str) -> Optional[str]:
    """
    Return the first fenced code block content if present, else None.
    Accepts ```json ...``` or ``` ...```.
    """
    m = re.search(r"```(?:json)?\s*(.*?)```", text, flags=re.DOTALL | re.IGNORECASE)
    return m.group(1).strip() if m else None

def _first_json_object(text: str) -> Optional[str]:
    """
    Best-effort scan to extract the first top-level {...} block.
    Returns the substring including braces, or None.
    """
    start = text.find("{")
    while start != -1:
        depth = 0
        for i in range(start, len(text)):
            ch = text[i]
            if ch == "{":
                depth += 1
            elif ch == "}":
                depth -= 1
                if depth == 0:
                    return text[start:i+1]
        # no balanced close; look for next '{'
        start = text.find("{", start + 1)
    return None

def parse_category_relationship(llm_out: str) -> Optional[Dict[str, str]]:
    """
    Parse model output to {'category': <A|B|C|D|E>, 'relationship': <string>}.
    Returns None on failure (no exceptions).
    """
    candidate = _first_fenced_block(llm_out) or _first_json_object(llm_out)
    if not candidate:
        return None

    try:
        obj = json.loads(candidate)
    except Exception:
        return None

    if not isinstance(obj, dict):
        return None

    # print("HERE")
    cat = obj.get("category")
    rel = obj.get("relationship")

    if not isinstance(cat, str) or cat not in VALID_CATS:
        cat = CAT_MAP.get("E")#return None

    if not isinstance(rel, str) or not rel.strip():
        rel = "unspecified"  # normalize empty/missing to "unspecified"

    # Return exactly the two required keys
    return {"category": CAT_MAP.get(cat, ""), "relationship": rel.strip()}

def load_processed_ids(out_path: Path, id_col: str) -> set:
    """
    Load IDs already written to OUTPUT_CSV so we can skip them.
    Reads in chunks to avoid high memory use.
    """
    processed = set()
    if not out_path.exists():
        return processed

    for chunk in pd.read_csv(out_path, usecols=[id_col], dtype={id_col: "string"}, chunksize=200_000):
        processed.update(chunk[id_col].dropna().astype("string"))
    return processed

def append_buffer_to_csv(buffer_rows, out_path: Path):
    """Append buffered rows to CSV (create header once)."""
    if not buffer_rows:
        return
    df = pd.DataFrame(buffer_rows)
    write_header = not out_path.exists()
    df.to_csv(out_path, mode="a", header=write_header, index=False)
    buffer_rows.clear()

def main():


    try:
        # Stream the input in chunks so we don't load the whole file in memory.
        for INPUT_CSV, OUTPUT_CSV in [(INPUT_CSV_1, OUTPUT_CSV_1), (INPUT_CSV_2, OUTPUT_CSV_2)]:
            out_path = Path(OUTPUT_CSV)
            processed_so_far = load_processed_ids(out_path, id_col="id")
            print(f"[resume] Already processed rows: {processed_so_far}")

            buffer_rows = []
            for chunk in tqdm(pd.read_csv(INPUT_CSV, chunksize=CHUNK_SIZE), desc="Processing Chunks"):
                for row in tqdm(chunk.itertuples(index=False), desc="Processing Rows"):
                    # Access the selftext column
                    post_id = getattr(row, "id", "")
                    if post_id in processed_so_far:
                        # skip already done
                        continue
                    selftext = getattr(row, "selftext", "")
                    if pd.isna(selftext):
                        selftext = ""

                    # Build your prompt
                    prompt = generate_prompt(post_text=selftext)
                    # Run your LLM
                    llm_out = infer_categorization(model, tokenizer, prompt, max_new_tokens=18)#_safe_llm_call(selftext)
                    # print(f'prompt : {prompt}')
                    # print(f"[{post_id}] : {llm_out}")

                    extracted_json = parse_category_relationship(llm_out)
                    # print(f'extracted_json : {extracted_json}')
                    # Copy the entire row + the new column
                    rec = row._asdict()  # dict of all columns in the row
                    rec[OUTPUT_COL] = extracted_json["category"] if extracted_json else CAT_MAP.get("E")#"Other"
                    rec[OUTPUT_REL] = extracted_json["relationship"] if extracted_json else "unspecified"
                    buffer_rows.append(rec)

                    # Periodic save (row-based or time-based)
                    if (len(buffer_rows) >= SAVE_EVERY_ROWS):
                        append_buffer_to_csv(buffer_rows, out_path)
                    break
                break
            # Final flush
            append_buffer_to_csv(buffer_rows, out_path)

    finally:
        # Ensure any remaining rows get saved even if interrupted
        append_buffer_to_csv(buffer_rows, out_path)
        print("[done] Saved all buffered rows.")

if __name__ == "__main__":
    main()