In [7]:
import os
import json
import random
from dotenv import load_dotenv
from pydantic import BaseModel, ValidationError
from typing import List
from openai import OpenAI
import re
# import torch

# For Hugging Face usage (e.g., Llama 3)
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

In [8]:
##################################################
# 1. Pydantic model for the paper review
##################################################

class PaperReview(BaseModel):
    Soundness: int
    Presentation: int
    Contribution: int
    Rating: int
    Confidence: int
    Strengths: str
    Weaknesses: str
    Questions: str

In [9]:
##################################################
# 2. Helper functions
##################################################

def load_jsonl(file_path: str) -> List[dict]:
    """
    Loads a JSONL file and returns a list of dictionaries.
    """
    data = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            data.append(json.loads(line))
    return data

def save_jsonl(data: List[dict], file_path: str) -> None:
    """
    Saves a list of dicts to JSONL, one JSON object per line.
    """
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, "w", encoding="utf-8") as f:
        for d in data:
            f.write(json.dumps(d, ensure_ascii=False) + "\n")

def extract_json_string(raw_content: str) -> str:
    """
    Sometimes the model might enclose JSON in triple backticks.
    This regex tries to extract the JSON portion if present.
    Otherwise returns the full raw_content.
    """
    match = re.search(r"```(?:json)?(.*?)```", raw_content, flags=re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return raw_content.strip()

def parse_paper_review(response_content: str) -> PaperReview:
    """
    Takes the raw string content from a model's response,
    parses it as JSON, and then validates against PaperReview.
    Raises exceptions if invalid.
    """
    try:
        parsed_dict = json.loads(response_content)
        return PaperReview(**parsed_dict)
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid JSON returned by the model:\n{e}\nContent:\n{response_content}") from e
    except ValidationError as ve:
        raise ValueError(f"Pydantic validation error:\n{ve}\nContent:\n{response_content}") from ve



In [10]:
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def run_gpt4o_inference(messages: List[dict]) -> str:
    """
    Given a list of messages in the OpenAI chat format, 
    calls the GPT-4 'o' model (whatever alias you have) 
    and returns the text content of the top choice.
    """
    try:
        completion = client.chat.completions.create(
            model="gpt-4o",
            messages=messages
        )
        raw_content = completion.choices[0].message.content
        json_str = extract_json_string(raw_content)
        review = parse_paper_review(json_str)
        print("Zero shot review:\n", review)
    except Exception as e:
        print("Zero-shot error:", e)

    

In [11]:
# # Llama 3.1 8B usage from Hugging Face
# # For a quick pipeline-based approach:
# llama_pipeline = pipeline(
#     "text-generation",
#     model="meta-llama/Llama-3.1-8B-Instruct",  # or your actual Llama 3 model path
#     trust_remote_code=True,
#     device=0 if torch.cuda.is_available() else -1  # for GPU if available
# )

# def run_llama3_inference(messages: List[dict]) -> str:
#     """
#     A very naive approach for Llama: we just concatenate the messages 
#     into a single prompt, because the huggingface pipeline doesn't 
#     natively handle the 'role-based' chat format by default.

#     For a more sophisticated approach, you'd implement the chat logic 
#     or use a library that supports it. This is just a demonstration.
#     """
#     # Simple approach: system + user => combine them as prompt
#     # (If you have multiple system/user pairs, you can decide how to handle them)
#     combined_prompt = ""
#     for m in messages:
#         role = m["role"]
#         content = m["content"]
#         if role == "system":
#             combined_prompt += f"[SYSTEM]\n{content}\n"
#         elif role == "user":
#             combined_prompt += f"[USER]\n{content}\n"
#         else:
#             combined_prompt += f"[{role.upper()}]\n{content}\n"

#     # Generate with pipeline
#     outputs = llama_pipeline(combined_prompt, max_new_tokens=1000)
#     # The pipeline returns a list of dicts with "generated_text".
#     # We'll just return the first
#     text_out = outputs[0]["generated_text"]
#     return text_out

In [12]:
# ---------------------------------------------------------------------
# 4. Main Evaluate Function
# ---------------------------------------------------------------------

def evaluate_model_on_files(
    model_name: str,
    model_inference_func,
    input_files: List[str],
    output_dir: str
):
    """
    - For each input_file in input_files, load the data
    - For each line (prompt), call model_inference_func(messages)
    - Parse output as PaperReview
    - Save the results in output_dir/<basename>_results.jsonl
    """

    for file_path in input_files:
        data = load_jsonl(file_path)
        results = []
        for entry in data:
            paper_id = entry.get("paper_id", "unknown_id")
            messages = entry.get("messages", [])

            # Call the model
            try:
                raw_content = model_inference_func(messages)
                # Attempt to parse PaperReview
                json_str = extract_json_string(raw_content)
                review = parse_paper_review(json_str)

                results.append({
                    "paper_id": paper_id,
                    "raw_response": raw_content,
                    "review": review.dict()
                })
            except Exception as e:
                # If an error or invalid JSON, store the error
                results.append({
                    "paper_id": paper_id,
                    "error": str(e)
                })

        # Build output path
        base_name = os.path.basename(file_path)
        out_name = base_name.replace(".jsonl", f"_{model_name}_results.jsonl")
        out_path = os.path.join(output_dir, out_name)
        save_jsonl(results, out_path)
        print(f"Finished {model_name} on {file_path}, saved => {out_path}")



In [None]:
# 1) Build file lists. Each set has multiple JSONL files
zero_shot_files = [
    "data/test_data/zero_shot/test_data_2024_abstract_prompts.jsonl",
    "data/test_data/zero_shot/test_data_2024_full text_prompts.jsonl",
    "data/test_data/zero_shot/test_data_2024_summary_prompts.jsonl",
    "data/test_data/zero_shot/test_data_2025_abstract_prompts.jsonl",
    "data/test_data/zero_shot/test_data_2025_full text_prompts.jsonl",
    "data/test_data/zero_shot/test_data_2025_summary_prompts.jsonl",
]
one_shot_files = [
    "data/test_data/one_shot/test_data_2024_abstract_prompts_oneshot.jsonl",
    "data/test_data/one_shot/test_data_2024_full text_prompts_oneshot.jsonl",
    "data/test_data/one_shot/test_data_2024_summary_prompts_oneshot.jsonl",
    "data/test_data/one_shot/test_data_2025_abstract_prompts_oneshot.jsonl",
    "data/test_data/one_shot/test_data_2025_full text_prompts_oneshot.jsonl",
    "data/test_data/one_shot/test_data_2025_summary_prompts_oneshot.jsonl",
]
few_shot_files = [
    "data/test_data/few_shot/test_data_2024_abstract_prompts_fewshot.jsonl",
    "data/test_data/few_shot/test_data_2024_full text_prompts_fewshot.jsonl",
    "data/test_data/few_shot/test_data_2024_summary_prompts_fewshot.jsonl",
    "data/test_data/few_shot/test_data_2025_abstract_prompts_fewshot.jsonl",
    "data/test_data/few_shot/test_data_2025_full text_prompts_fewshot.jsonl",
    "data/test_data/few_shot/test_data_2025_summary_prompts_fewshot.jsonl",
]

# Combine them if you want to run everything at once
all_prompt_files = zero_shot_files + one_shot_files + few_shot_files

# 2) Evaluate with GPT-4o
print("=== Running GPT-4o Inference ===")
evaluate_model_on_files(
    model_name="gpt4o",
    model_inference_func=run_gpt4o_inference,
    input_files=all_prompt_files,
    output_dir="results/gpt4o"
)

# # 3) Evaluate with Llama-3 8B
# print("\n=== Running Llama-3 8B Inference ===")
# # We already created the llama_pipeline above
# evaluate_model_on_files(
#     model_name="llama3",
#     model_inference_func=run_llama3_inference,
#     input_files=all_prompt_files,
#     output_dir="results/llama3"
# )

# print("Done! All results saved in `results/` subfolders.")