In [1]:
import json
import re
import time
from statistics import mean, stdev
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
!pip install langchain langchain-community

Defaulting to user installation because normal site-packages is not writeable



[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
from langchain_community.chat_models import ChatOllama
from langchain.schema import SystemMessage, HumanMessage

In [None]:
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_community.chat_models import ChatOllama

STUDENT_MODEL = "llama3.1:8b"
TEACHER_MODEL = "llama3.1:70b"
SCORE_MAX = 3


def call_llm_LangChain_correct_v2(prompt, temperature=0.0):
    llm = ChatOllama(
        model="llama3.1:8b",
        temperature=temperature
    )

    system_prompt = """

You are a large language model acting as a judge for assessing the performance of a Teaching Assistant (TA) in an introductory Python programming course.

The TA is an LLM that answers student questions about Python code. Your job is to evaluate the quality of the TA's answer.

You will receive:
- A Python code snippet
- A student question about that code
- A reference (correct) answer
- A TA LLM-generated answer (called the prediction)

Your task is to evaluate how well the TA's prediction answers the student's question, using the following four dimensions. For each, provide:
- An integer score from 1 to 3



### Accuracy
Compare the prediction with the reference to assess factual correctness and understanding of the code’s behavior and intent.
You must judge whether the prediction reflects accurate behavior and matches core facts from the reference. 
You need to consider semantic meaning of code comprehension: understanding the structure, functionality, and intent behind the code.\n"

Score meanings:
- 1: Completely incorrect or irrelevant; does not address the reference answer.
- 2: Partially correct; some key facts are accurate, but major details are wrong or missing.
- 3: Fully correct; matches the reference answer in meaning and factual content.

### Completeness
Check if the prediction covers all important parts of the reference answer, including key concepts or conditions.

Score meanings:
- 1: Omits most key information or contains only a tiny fragment of relevant content.
- 2: Covers some elements but misses important parts.
- 3: Fully covers all essential information from the reference.

### Relevance
Assess whether the prediction directly addresses the question and stays on-topic.

Score meanings:
- 1: Completely irrelevant or mostly unrelated.
- 2: Partially related but misses the main point.
- 3: Fully focused and directly answers the question.

### Clarity
Evaluate how clearly and logically the prediction is expressed, ensuring it is easy to understand.

Score meanings:
- 1: Confusing, vague, or incoherent.
- 2: Understandable but awkwardly phrased or slightly unclear.
- 3: Clear, concise, and easy to follow.


Example:

Code:
```python
def count_even(nums):
    total = 0
    for x in nums:
        if x % 2 == 0:
            total += 1
    return total
Question: What does this function return when given a list of integers?
Reference Answer: It returns the count of even numbers in the list.
Prediction: It returns the count of odd numbers in the list.

Evaluation Output:
{

"accuracy": { "score": 1 },
"completeness": { "score": 1 },
"relevance": { "score": 2 },
"clarity": { "score": 3 }

}

Final Instructions:
For the given input (code, question, reference answer, and prediction), evaluate the prediction on the four metrics defined above.
Base your evaluation strictly on the content provided. Do not hallucinate missing information. Be consistent and objective.
Do not include reasoning or explanations.

Respond only with a JSON object in the exact format:
{
"accuracy": { "score": 1-3},
"completeness": {"score": 1-3},
"relevance": {"score": 1-3},
"clarity": {"score": 1-3}
}
"""



    messages = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=prompt)
    ]

    response = llm.invoke(messages)
    return response.content

In [5]:
def call_teacher_single_turn(code, question, reference, prediction, scores):
    # Build teacher system prompt
    system_prompt = """
You are a 70B teacher LLM model reviewing a student LLM-as-judge's evaluation of a Teaching Assistant's (TA) answer.

You will receive:
- Python code snippet
- Student's question
- Reference (correct) answer
- TA's predicted answer
- Scores assigned by the student LLM-as-judge

Your task:
- Examine the TA's predicted answer in context of the code, question, and reference.
- For any dimension (Accuracy, Completeness, Relevance, Clarity) where the score is less than 3,
  provide clear, concise feedback (2–4 sentences) explaining what could be improved.
- If a dimension has no issues, do not include it in your response.

Respond ONLY with a JSON object where keys are the dimension names (lowercase)
and values are the feedback strings.

Example output:
{
  "accuracy": "The prediction misrepresents the function’s return value.",
  "clarity": "The explanation lacks structure and is hard to follow."
}

Rubric:

### Accuracy
- 1: Completely incorrect or irrelevant.
- 2: Partially correct but with major mistakes or omissions.
- 3: Fully correct and matches the reference.

### Completeness
- 1: Omits most key information.
- 2: Covers some but misses important parts.
- 3: Fully covers all essential information.

### Relevance
- 1: Irrelevant or mostly unrelated.
- 2: Partially related but misses main point.
- 3: Fully focused and directly addresses the question.

### Clarity
- 1: Confusing, vague, or incoherent.
- 2: Understandable but awkwardly phrased or unclear.
- 3: Clear, concise, and easy to understand.
"""

    # Format input for the teacher
    user_prompt = f"""
```python
{code}
Question:
{question}

Reference Answer:
{reference}

TA's Predicted Answer:
{prediction}
"""

    try:
        llm = ChatOllama(model=TEACHER_MODEL, temperature=0.0)
        messages = [
            SystemMessage(content=system_prompt.strip()),
            HumanMessage(content=user_prompt.strip())
        ]
        response = llm.invoke(messages)
        critiques = json.loads(response.content.strip())
        return critiques
    except Exception as e:
        print(f"Error: {str(e)}")
        return {}

In [6]:
def student_reflect_and_revise(code, question, reference, prediction, old_score, critiques):
    print("\n Student reflecting on teacher feedback...\n")

    critique_text = "\n".join(
        f"{dim.upper()} Feedback: {critique.strip()}" for dim, critique in critiques.items()
    )

    prompt = f"""
Code:
```python
{code}
Question:
{question}

Reference Answer:
{reference}

TA's Predicted Answer:
{prediction}

Teacher's Feedback:
{critique_text}

Old prediction:
{old_score}

You have to update the old prediction by considering Teacher's Feedback. Please re-evaluate the TA's answer using:

Accuracy

Completeness

Relevance

Clarity

Only return JSON:
{{
"accuracy": {{ "score": 1-3 }},
"completeness": {{ "score": 1-3 }},
"relevance": {{ "score": 1-3 }},
"clarity": {{ "score": 1-3 }}
}}
"""

    try:
        response = call_llm_LangChain_correct_v2(prompt, temperature=0.0)
        revised_scores = json.loads(response)
        print(" Revised Scores:\n", revised_scores)
    except Exception as e:
        print(" Failed to parse revised scores:", str(e))
        revised_scores = None

    return revised_scores

## Load dataset

In [7]:
import json

# TODO: set your full dataset path (the complete JSON list)
FULL_DATA_JSON = "Mistral_CodeQA_llm_judge.json"

with open(FULL_DATA_JSON, "r", encoding="utf-8") as f:
    examples = json.load(f)

N_total = len(examples)
print("Total examples (N):", N_total)
print("Example keys:", list(examples[0].keys()) if examples else "N/A")

Total examples (N): 56085
Example keys: ['id', 'code', 'question', 'reference', 'prediction', 'accuracy', 'completeness', 'relevance', 'clarity']


In [8]:
import re, json

def _extract_json_object(text: str):
    # Try direct parse
    try:
        return json.loads(text)
    except Exception:
        pass
    # Try to find first JSON object
    m = re.search(r"\{.*\}", text, flags=re.DOTALL)
    if m:
        try:
            return json.loads(m.group(0))
        except Exception:
            return None
    return None

def flatten_scores(score_obj):
    """Converts either nested JSON {'accuracy':{'score':2}} or flat {'accuracy':2} into flat ints."""
    out = {}
    for k in ["accuracy","completeness","relevance","clarity"]:
        v = score_obj.get(k, None) if isinstance(score_obj, dict) else None
        if isinstance(v, dict) and "score" in v:
            out[k] = v.get("score")
        else:
            out[k] = v
    return out

def judge_scores_flat(code, question, reference, prediction):
    prompt = f"""
Code:
```python
{code}
```
Question:
{question}

Reference Answer:
{reference}

TA's Predicted Answer:
{prediction}
"""

    raw = call_llm_LangChain_correct_v2(prompt, temperature=0.0)
    obj = _extract_json_object(raw)
    if obj is None:
        raise ValueError(f"Judge returned non-JSON. Raw:\n{raw}")
    return flatten_scores(obj)

In [9]:
from tqdm import tqdm

judge_scores_list = []
for ex in tqdm(examples, desc="Judging"):
    s = judge_scores_flat(ex.get("code",""), ex.get("question",""), ex.get("reference",""), ex.get("prediction",""))
    judge_scores_list.append(s)

print("Judged:", len(judge_scores_list))
print("First judged scores:", judge_scores_list[0] if judge_scores_list else None)

  llm = ChatOllama(
Judging:   0%|          | 4/56085 [03:32<826:17:57, 53.04s/it]  


KeyboardInterrupt: 

## Budget policies: Random-K and Worst-first-K

In [12]:
import math, random

budgets = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

def K_from_budget(b, N):
    return int(math.floor(b * N))

def pick_indices_random(N, b, seed=42):
    K = K_from_budget(b, N)
    rnd = random.Random(seed + int(b*1000))
    if K == 0:
        return set()
    return set(rnd.sample(range(N), K))

def need_score_from_judge(scores_flat):
    # Higher means more need (worse initial evaluation)
    a = scores_flat.get("accuracy", 3) or 3
    c = scores_flat.get("completeness", 3) or 3
    r = scores_flat.get("relevance", 3) or 3
    l = scores_flat.get("clarity", 3) or 3
    # simple weighted "badness"
    return 0.5*(3 - a) + 0.2*(3 - c) + 0.15*(3 - r) + 0.15*(3 - l)

def pick_indices_worst_first(judge_scores_list, b):
    N = len(judge_scores_list)
    K = K_from_budget(b, N)
    if K == 0:
        return set()
    ranked = sorted(range(N), key=lambda i: need_score_from_judge(judge_scores_list[i]), reverse=True)
    return set(ranked[:K])

In [11]:
import json
from tqdm import tqdm

def run_budget_policy(policy_name: str, selector_fn):
    results_by_budget = {}
    for b in budgets:
        intervene_set = selector_fn(b)
        budget_results = []

        for i, ex in enumerate(tqdm(examples, desc=f"{policy_name} b={b}")):
            old_scores = judge_scores_list[i]  # flat ints
            teacher_feedbacks = {}
            revised_scores = old_scores

            if i in intervene_set:
                teacher_feedbacks = call_teacher_single_turn(
                    ex.get("code",""), ex.get("question",""), ex.get("reference",""), ex.get("prediction",""), old_scores
                )
                # student_reflect_and_revise expects critiques dict
                revised_scores = student_reflect_and_revise(
                    ex.get("code",""), ex.get("question",""), ex.get("reference",""), ex.get("prediction",""),
                    old_scores, teacher_feedbacks
                ) or old_scores  # fallback

            budget_results.append({
                "id": ex.get("id", i),
                "budget": b,
                "intervened": (i in intervene_set),
                "old_scores": old_scores,
                "teacher_feedbacks": teacher_feedbacks,
                "new_scores": revised_scores
            })

        results_by_budget[b] = budget_results
        print(f"[{policy_name}] Budget={b}: teacher calls={len(intervene_set)} / N={len(examples)}")
    return results_by_budget

# Run RANDOM policy
results_random_by_budget = run_budget_policy(
    "random",
    selector_fn=lambda b: pick_indices_random(len(examples), b, seed=42)
)

# Run WORST-FIRST policy
results_worst_by_budget = run_budget_policy(
    "worst_first",
    selector_fn=lambda b: pick_indices_worst_first(judge_scores_list, b)
)

# Save
with open("results_random_by_budget.json", "w", encoding="utf-8") as f:
    json.dump(results_random_by_budget, f, indent=2, ensure_ascii=False)

with open("results_worst_by_budget.json", "w", encoding="utf-8") as f:
    json.dump(results_worst_by_budget, f, indent=2, ensure_ascii=False)

print("Saved results_random_by_budget.json and results_worst_by_budget.json")

random b=0.0:   0%|          | 4/56085 [00:00<00:07, 7489.83it/s]


IndexError: list index out of range

## Compute per-budget metric curves and plot


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def metric_avg_accuracy_norm(results):
    # results: list of dicts for a budget
    vals = []
    for r in results:
        s = r["new_scores"]
        vals.append((s.get("accuracy", 3) or 3) / 3.0)
    return float(np.mean(vals)) if vals else 0.0

def make_curve(results_by_budget):
    return [metric_avg_accuracy_norm(results_by_budget[b]) for b in budgets]

curve_random = make_curve(results_random_by_budget)
curve_worst = make_curve(results_worst_by_budget)

print("Budgets:", budgets)
print("Random curve:", curve_random)
print("Worst-first curve:", curve_worst)

# Save curves
with open("metric_curves.json", "w", encoding="utf-8") as f:
    json.dump({
        "budgets": budgets,
        "random_avg_accuracy_norm": curve_random,
        "worst_first_avg_accuracy_norm": curve_worst
    }, f, indent=2)

plt.figure()
plt.plot(budgets, curve_random, marker="o", label="Random-K")
plt.plot(budgets, curve_worst, marker="o", label="Worst-first-K")
plt.xlabel("Budget b (fraction of FULL dataset)")
plt.ylabel("Avg accuracy score (normalized 0..1)")
plt.title("Single-turn: Metric vs Budget")
plt.legend()
plt.show()