In [None]:
import json
import re
import time
import random
import torch

from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from langgraph.graph import StateGraph, END

In [4]:
# Load experiment definitions
with open("experiments_v3.json") as f:
    all_experiments = json.load(f)


# Select Public Goods experiment
public_goods_exp = next(
    exp for exp in all_experiments
    if exp["experiment_id"] == "public_goods_2000"
)

EXPERIMENT_ID = public_goods_exp["experiment_id"]
EXPERIMENT_NAME = public_goods_exp["experiment_name"]
EXPERIMENT_YEAR = public_goods_exp["year"]


# ------------------------------------------------------------
# Choose condition
# ------------------------------------------------------------
# Options:
# "Stranger-treatment without punishment"
# "Stranger-treatment with punishment"
# "Partner-treatment without punishment"
# "Partner-treatment with punishment"

CONDITION_NAME = "Partner-treatment with punishment"

condition = next(
    c for c in public_goods_exp["condition_info"]
    if c["condition_name"].lower() == CONDITION_NAME.lower()
)

# Parse number of periods from stopping_conditions
stop_text = condition["stopping_conditions"][0].lower()

period_match = re.search(r"(\d+)\s+period", stop_text)
total_periods = int(period_match.group(1)) if period_match else 10

In [6]:
# Extract instructions and example scenario 
# Instructions are identical for all agents; index by agent_no

agent_instructions = {
    agent["agent_no"]: agent["agent_experimental_instruction"]
    for agent in condition["agent_info"]
}

example_scenario = condition["example_condition_scenario"]


def build_example_text_pg(example: Dict[str, Any], max_periods: int = 5) -> str:
    """Format public-goods example scenario text (contribution + punishment stages)."""
    lines: List[str] = []
    periods_seen = 0

    for period_entry in example["response"]:
        period_no = period_entry.get("period_no")
        stage_no = period_entry.get("stage_no")

        # Limit number of periods shown
        if period_no is not None:
            if period_no > periods_seen:
                periods_seen = period_no
            if periods_seen > max_periods:
                break
        else:
            if len(lines) >= max_periods:
                break

        period_label = period_no if period_no is not None else "?"

        # Contribution stage
        if stage_no == 1:
            contributions = {
                ar["agent_no"]: ar["agent_response"]
                for ar in period_entry["agent_response"]
            }

            contrib_text = ", ".join(
                f"A{i}={contributions[i]}"
                for i in sorted(contributions)
            )

            lines.append(
                f"Example Period {period_label} (Contrib): {contrib_text}"
            )

        # Punishment stage
        elif stage_no == 2:
            lines.append(f"Example Period {period_label} (Punishment):")

            for ar in period_entry["agent_response"]:
                lines.append(f"  A{ar['agent_no']}: {ar['agent_response']}")

        # Fallback for unknown stages
        else:
            lines.append(
                f"Example Period {period_label} (Stage {stage_no}): "
                f"{period_entry.get('agent_response')}"
            )

    return "\n".join(lines)


example_text = build_example_text_pg(example_scenario, max_periods=5)


In [7]:
# Public Goods Payoff Parameters and condition-specific flags

ENDOWMENT = 20
MPCR = 0.4  # marginal per-capita return
condition_name = condition["condition_name"].lower()
punishment_enabled = "punishment" in condition_name

In [9]:
# Unified Configuration Object

config = {
    "experiment_metadata": {
        "experiment_id": EXPERIMENT_ID,
        "experiment_name": EXPERIMENT_NAME,
        "year": EXPERIMENT_YEAR,
        "total_participants_count": pd_exp["total_participants_count"],
        "demographic_info": pd_exp["demographic_info"],
    },

    "condition_metadata": {
        "condition_no": condition["condition_no"],
        "condition_name": condition["condition_name"],
        "condition_experiment_instruction": condition["condition_experiment_instruction"],
        "condition_demographic": condition["condition_demographic"],
        "stopping_conditions": condition["stopping_conditions"],
        "condition_supplement_url": condition["condition_supplement_url"],
    },

    # ---- Public Goods specific ----
    "total_periods": total_periods,
    "endowment": ENDOWMENT,
    "mpcr": MPCR,  
    "agent_count": condition["agent_count"],
    "agent_instructions": agent_instructions,  # dict: agent_no -> instruction
    "punishment_enabled": punishment_enabled,

    
    # ---- Prompt helpers ----
    "example_text": example_text,
    "example_prompt_text": example_scenario["prompt"]["condition_experiment_instruction"],
}


In [10]:
# Environment: stateless payoff computer (Public Goods)

class PublicGoodsEnv:
    def __init__(self, alpha: float = 0.4, endowment: int = 20):
        self.alpha = alpha
        self.endowment = endowment

    def compute_payoffs(
        self,
        contributions: Dict[int, int],
        punishments: Dict[tuple[int, int], int],
    ) -> Dict[int, float]:
        """
        contributions: {agent_id: contribution}
        punishments: {(punisher_id, target_id): punishment_points}
        """

        # Stage 1: contribution payoff
        total_contribution = sum(contributions.values())

        base_payoffs = {
            agent_id: (self.endowment - contribution)
            + self.alpha * total_contribution
            for agent_id, contribution in contributions.items()
        }

        # Stage 2: punishment adjustments
        final_payoffs: Dict[int, float] = {}

        for agent_id in base_payoffs:
            punishment_received = sum(
                pts for (punisher, target), pts in punishments.items()
                if target == agent_id
            )

            punishment_cost = sum(
                pts for (punisher, target), pts in punishments.items()
                if punisher == agent_id
            )

            payoff = (
                base_payoffs[agent_id]
                * (1 - 0.1 * punishment_received)
                - punishment_cost
            )

            final_payoffs[agent_id] = max(0.0, payoff)

        return final_payoffs


In [11]:
# State model for Public Goods LangGraph

class PGState(BaseModel):
    
    # Period control
    period: int = 1
    stage: int = 1            # 1 = contribution, 2 = punishment
    total_periods: int

    # Decisions
    contributions: Dict[int, int] = Field(default_factory=dict)
    punishments: Dict[tuple[int, int], int] = Field(default_factory=dict)

    # Outcomes
    payoffs: Dict[int, float] = Field(default_factory=dict)

    # Logs
    history: List[Dict[str, Any]] = Field(default_factory=list)
    explanations: List[Dict[str, Any]] = Field(default_factory=list)

    # Termination
    done: bool = False

In [12]:
def make_pg_think_prompt(
    agent_id: int,
    state: PGState,
    config: Dict[str, Any],
) -> str:

    meta = config["experiment_metadata"]
    cond_meta = config["condition_metadata"]
    instr = config["agent_instructions"][agent_id]
    history_str = format_history_pg(state, agent_id)

    return f"""
You are a participant in a Public Goods experiment.

Think carefully about your decision.

Context:
Year: {meta['year']}
Experiment: {meta['experiment_name']}
Condition: {cond_meta['condition_name']}

Instructions:
{instr}

Game structure:
{cond_meta['condition_experiment_instruction']}

Current state:
Period {state.period}/{state.total_periods}

History:
{history_str}

Question:
How many tokens (0–20) will you contribute this period, and why?

Respond in plain English.
"""


In [13]:
def make_pg_punish_think_prompt(
    agent_id: int,
    state: PGState,
    config: Dict[str, Any],
) -> str:
    meta = config["experiment_metadata"]
    cond_meta = config["condition_metadata"]

    # may add richer info here later (e.g., contributions last period)
    history_str = format_history_pg(state, agent_id)

    return f"""
You are a participant in a Public Goods experiment.

This is the PUNISHMENT stage.

Context:
Year: {meta['year']}
Experiment: {meta['experiment_name']}
Condition: {cond_meta['condition_name']}

Game structure:
{cond_meta['condition_experiment_instruction']}

Current state:
Period {state.period}/{state.total_periods}

History:
{history_str}

Task:
Decide whether to assign punishment points to other agents.
You may assign 0–10 points to each OTHER agent.
Explain your reasoning in plain English.

Respond in plain English.
"""


In [14]:
def make_pg_extract_prompt(free_text: str, stage: int) -> str:
    if stage == 1:
        return f"""You are a strict information extraction system.

Output EXACTLY ONE LINE.
Output MUST be valid JSON.
Output MUST NOT contain markdown, code fences, or extra text.

Schema:
{{"action": <integer 0-20>, "why": "<=20 words>"}}

Text:
{free_text}
"""
    else:
        return f"""You are a strict information extraction system.

Output EXACTLY ONE LINE.
Output MUST be valid JSON.
Output MUST NOT contain markdown, code fences, or extra text.

Schema:
{{"punish": {{"2": <0-10>, "3": <0-10>, "4": <0-10>}}, "why": "<=20 words>"}}

Rules:
- Keys are agent ids as strings
- Do NOT include self
- Missing agents imply 0 punishment

Text:
{free_text}
"""


In [15]:
def format_history_pg(state: PGState, agent_id: int) -> str:
    if not state.history:
        return "No previous periods."

    lines = []
    for h in state.history:
        
        lines.append(
            f"Period {h['period']}: "
            f"total contribution = {h['total_contribution']}"
        )

    return "\n".join(lines)

In [16]:
# LLM Setup
model_id = "meta-llama/Llama-3.2-3B-Instruct"  
tok = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cuda",
    dtype=torch.float16,
    offload_folder="offload"
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:31<00:00, 15.79s/it]


In [17]:
def llm_decide_pg_two_stage(
    agent_id: int,
    state: PGState,
    config: Dict[str, Any],
    stage: int,
):
    # ============================
    # CALL 1: THINK
    # ============================
    if stage == 1:
        think_prompt = make_pg_think_prompt(agent_id, state, config)
    else:
        think_prompt = make_pg_punish_think_prompt(agent_id, state, config)

    print("PROMPT:", think_prompt)
    print("-" * 60)

    inputs = tok(think_prompt, return_tensors="pt").to(model.device)

    output = model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.1,
        top_p=0.9,
        do_sample=True,
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.eos_token_id,
    )

    think_text = tok.decode(
        output[0, inputs["input_ids"].shape[1]:],
        skip_special_tokens=True,
    ).strip()

    print("THINK OUTPUT:")
    print(think_text)
    print("-" * 60)

    # ============================
    # CALL 2: EXTRACT
    # ============================
    extract_prompt = make_pg_extract_prompt(think_text, stage)
    inputs = tok(extract_prompt, return_tensors="pt").to(model.device)

    output = model.generate(
        **inputs,
        max_new_tokens=40,
        temperature=0.0,
        do_sample=False,
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.eos_token_id,
    )

    extract_text = tok.decode(
        output[0, inputs["input_ids"].shape[1]:],
        skip_special_tokens=True,
    ).strip()

    extract_text = extract_text.splitlines()[0].strip()

    end_idx = extract_text.find("}")
    if end_idx != -1:
        extract_text = extract_text[: end_idx + 1]

    print("EXTRACT OUTPUT:")
    print(extract_text)
    print("=" * 60)
    print()

    # ============================
    # PARSE
    # ============================
    match = re.search(r"\{[^{}]*\}", extract_text, flags=re.DOTALL)
    if not match:
        return (0, "") if stage == 1 else ({}, "")

    try:
        obj = json.loads(match.group(0))
    except Exception:
        return (0, "") if stage == 1 else ({}, "")

    # ----------------------------
    # Stage 1: contribution
    # ----------------------------
    if stage == 1:
        contribution = int(obj.get("action", 0))
        explanation = str(obj.get("why", "")).strip()

        contribution = max(0, min(20, contribution))
        return contribution, explanation

    # ----------------------------
    # Stage 2: punishment
    # ----------------------------
    punish_dict = obj.get("punish", {})
    explanation = str(obj.get("why", "")).strip()

    punishment_map: Dict[int, int] = {}

    for k, v in punish_dict.items():
        try:
            target = int(k)
            pts = int(v)
        except Exception:
            continue

        if target != agent_id:
            punishment_map[target] = max(0, min(10, pts))

    for target in range(1, config["agent_count"] + 1):
        if target != agent_id and target not in punishment_map:
            punishment_map[target] = 0

    return punishment_map, explanation


In [18]:
def agent_node(agent_id: int, config: Dict[str, Any]):
    
    def _node(state: PGState) -> PGState:

        # ----------------------------
        # Stage 1: Contribution
        # ----------------------------
        if state.stage == 1:
            contribution, explanation = llm_decide_pg_two_stage(
                agent_id=agent_id,
                state=state,
                config=config,
                stage=1,
            )

            state.contributions[agent_id] = contribution

            state.explanations.append({
                "agent": agent_id,
                "period": state.period,
                "stage": 1,
                "action": contribution,
                "why": explanation,
            })

        # ----------------------------
        # Stage 2: Punishment
        # ----------------------------
        else:
            punishment_dict, explanation = llm_decide_pg_two_stage(
                agent_id=agent_id,
                state=state,
                config=config,
                stage=2,
            )

            for target, points in punishment_dict.items():
                if target != agent_id:
                    state.punishments[(agent_id, target)] = int(points)

            if not explanation:
                explanation = "Punishment not available in this treatment"

            state.explanations.append({
                "agent": agent_id,
                "period": state.period,
                "stage": 2,
                "action": punishment_dict,
                "why": explanation,
            })

        return state

    return _node


In [19]:
def payoff_node(state: PGState, env: PublicGoodsEnv, config: Dict[str, Any]) -> PGState:

    # ----------------------------
    # After Stage 1 → go to Stage 2
    # ----------------------------
    if state.stage == 1 and config["punishment_enabled"]:
        state.stage = 2
        return state

    # ----------------------------
    # Final payoff computation
    # (Stage 2 OR no-punishment)
    # ----------------------------
    payoffs = env.compute_payoffs(
        state.contributions,
        state.punishments,
    )

    record = {
        "period": state.period,
        "contributions": state.contributions.copy(),
        "punishments": dict(state.punishments),
        "total_contribution": sum(state.contributions.values()),
        "payoffs": payoffs,
    }

    state.history.append(record)
    state.payoffs = payoffs

    # ----------------------------
    # Termination check
    # ----------------------------
    if state.period >= state.total_periods:
        state.done = True
        return state

    # ----------------------------
    # Reset for next period
    # ----------------------------
    state.period += 1
    state.stage = 1
    state.contributions = {}
    state.punishments = {}

    return state


In [20]:
# Build LangGraph (Public Goods)

def build_pg_graph(env: PublicGoodsEnv, config: Dict[str, Any]):
    graph = StateGraph(PGState)

    agent_count = config["agent_count"]

    # Agent nodes
    for i in range(1, agent_count + 1):
        graph.add_node(f"A{i}", agent_node(i, config))

    # Payoff / control node
    graph.add_node("payoff", lambda s: payoff_node(s, env, config))

    # Entry
    graph.set_entry_point("A1")

    # Sequential agent execution
    for i in range(1, agent_count):
        graph.add_edge(f"A{i}", f"A{i+1}")

    graph.add_edge(f"A{agent_count}", "payoff")

    # -------- Control logic --------
    def route_after_payoff(state: PGState):
        if state.done:
            return END
        # If we just switched to stage 2, run agents again
        if state.stage == 2:
            return "A1"
        # Otherwise new period
        return "A1"

    graph.add_conditional_edges(
        "payoff",
        route_after_payoff,
        {
            END: END,
            "A1": "A1",
        },
    )

    return graph.compile()



In [21]:
# Run Simulation 

env = PublicGoodsEnv(
    alpha=config["mpcr"],
    endowment=config["endowment"],
)

graph = build_pg_graph(env, config)

state = PGState(
    period=1,
    total_periods=config["total_periods"],
)

final_state: PGState = graph.invoke(
    state,
    config={"recursion_limit": 2000},
)

PROMPT:  
You are a participant in a Public Goods experiment.

Think carefully about your decision.

Context:
Year: 2000
Experiment: Cooperation and Punishment in Public Goods Experiments
Condition: Partner-treatment with punishment

Instructions:
In Stage 1, decide your contribution. In Stage 2, decide whether to assign costly punishment points to others.

Game structure:
This is a 10-period, two-stage public goods game. You will remain with the same group of 4 for all 10 periods. In Stage 1, you decide how much of your 20 tokens to contribute to a project. In Stage 2, after contributions are revealed, you can spend some of your tokens to assign punishment points to other group members. Each punishment point you assign has a cost, and each point a player receives reduces their payoff from Stage 1 by 10%.

Current state:
Period 1/10

History:
No previous periods.

Question:
How many tokens (0–20) will you contribute this period, and why?

Respond in plain English.

--------------------

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


THINK OUTPUT:
Please do not use jargon or technical terms.

I will respond with my answer, and then I will respond with my answer, and so on.

Go ahead and make your contribution. 

(Note: Please respond with a number between 0 and 20.) 

I will contribute 5 tokens. I think I will contribute 10 tokens. I think that's a good amount because it's not too much, but not too little either. I want to contribute enough to help the project, but
------------------------------------------------------------
EXTRACT OUTPUT:
I also want to contribute enough to be able to see the results of my work. I want to contribute enough to be able to see the results of my work. I want to contribute enough to be



PROMPT:  
You are a participant in a Public Goods experiment.

Think carefully about your decision.

Context:
Year: 2000
Experiment: Cooperation and Punishment in Public Goods Experiments
Condition: Partner-treatment with punishment

Instructions:
In Stage 1, decide your contribution. In Stage 2, dec

In [22]:
def summarize_results(state: Dict[str, Any], config: Dict[str, Any]) -> None:
    print("=" * 60)

    meta = config["experiment_metadata"]
    cond_meta = config["condition_metadata"]

    print(f"Experiment: {meta['experiment_name']} ({meta['year']})")
    print(f"Experiment ID: {meta['experiment_id']}")
    print(f"Condition: {cond_meta['condition_name']}")
    print(f"Structure: {state['total_periods']} periods")
    print("-" * 60)

    history = state["history"]
    n_periods = len(history)
    agent_count = config["agent_count"]

    avg_total = sum(h["total_contribution"] for h in history) / n_periods
    print(f"Average total contribution per period: {avg_total:.2f}")

    print("-" * 60)
    print("Average payoff per agent:")

    agent_totals = {i: 0.0 for i in range(1, agent_count + 1)}
    for h in history:
        for i, payoff in h["payoffs"].items():
            agent_totals[i] += payoff

    for i in range(1, agent_count + 1):
        print(f"  Agent {i}: {agent_totals[i] / n_periods:.2f}")

    print("=" * 60)



In [23]:
summarize_results(final_state, config)

Experiment: Cooperation and Punishment in Public Goods Experiments (2000)
Experiment ID: public_goods_2000
Condition: Partner-treatment with punishment
Structure: 10 periods
------------------------------------------------------------
Average total contribution per period: 3.00
------------------------------------------------------------
Average payoff per agent:
  Agent 1: 19.70
  Agent 2: 20.70
  Agent 3: 21.20
  Agent 4: 20.20


In [24]:
def normalize_state(state):
    """
    Ensure state is a plain dict, regardless of whether LangGraph
    returned a dict or a Pydantic model.
    """
    if isinstance(state, dict):
        return state
    if hasattr(state, "model_dump"):
        return state.model_dump()
    raise TypeError(f"Unsupported state type: {type(state)}")


# Saving Results

In [25]:
import json
from pathlib import Path
from typing import Dict, Any

def serialize_tuple_keys(obj):
    """
    Recursively convert dicts with tuple keys into string keys.
    """
    if isinstance(obj, dict):
        new = {}
        for k, v in obj.items():
            if isinstance(k, tuple):
                k = "->".join(map(str, k))
            new[k] = serialize_tuple_keys(v)
        return new
    elif isinstance(obj, list):
        return [serialize_tuple_keys(x) for x in obj]
    else:
        return obj


In [26]:


def save_results_json(state, config: Dict[str, Any], out_path):
    state = normalize_state(state)

    agent_count = config["agent_count"]
    total_periods = state["total_periods"]

    # ---- compute total & average payoff per agent ----
    total_payoff = {i: 0.0 for i in range(1, agent_count + 1)}
    for h in state["history"]:
        for i, p in h["payoffs"].items():
            total_payoff[int(i)] += float(p)

    avg_payoff = {
        i: total_payoff[i] / total_periods
        for i in total_payoff
    }

    out = {
        "experiment_metadata": config["experiment_metadata"],
        "condition_metadata": config["condition_metadata"],

        "structure": {
            "total_periods": total_periods,
            "agent_count": agent_count,
        },

        "parameters": {
            "endowment": config["endowment"],
            "mpcr": config["mpcr"],
            "punishment_enabled": config["punishment_enabled"],
        },

        "outcomes": {
            "total_payoff_per_agent": total_payoff,
            "average_payoff_per_agent": avg_payoff,
            "group_total_payoff": sum(total_payoff.values()),
        },

        "history": state["history"],
        "explanations": state.get("explanations", []),
    }

    out = serialize_tuple_keys(out)

    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    with out_path.open("w") as f:
        json.dump(out, f, indent=2)


In [27]:
cond_name = config["condition_metadata"]["condition_name"]
safe_cond_name = cond_name.lower().replace(" ", "_")

save_results_json(
    final_state,
    config,
    f"results/pg_{safe_cond_name}.json",
)