In [None]:
import json
import random
import re
from typing import Dict, List, Tuple
from langchain.llms import Ollama
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain.memory import ConversationBufferMemory

USE_HOMOGENEOUS_MODE = False # When true, each agent in the simulation is an instantiation of the same model.

TOTAL_ROUNDS = 5
MAX_RETRIES = 2

AGENT_IDS = [f"agent_{i+1}" for i in range(5)] if USE_HOMOGENEOUS_MODE else [f"agent_{i+1}" for i in range(10)]

HOMOGENEOUS_MODELS_TO_RUN = ["phi4", "gemma3", "llama3", "qwen3", "phi4-reasoning","phi4-mini-reasoning", "granite3.3","gemma2", "deepseek-r1", "llama2"] 


HETEROGENEOUS_MODELS = {
    "agent_1": "phi4",
    "agent_2": "gemma3",
    "agent_3": "llama3",
    "agent_4": "qwen3",
    "agent_5": "phi4-reasoning",
    "agent_6": "phi4-mini-reasoning",
    "agent_7": "granite3.3",
    "agent_8": "gemma2",
    "agent_9": "deepseek-r1",
    "agent_10": "llama2"
}

VIGNETTES = [
    "AI-Created Symphony: A popular AI model trained on copyrighted songs has composed a new track that went viral. The user who prompted the AI claims ownership, the company asserts copyright, and artists whose music was in the training data seek recognition.",
    "Social Graph Scanning: A recruitment AI scrapes public social media posts to analyze cultural fit and emotional stability. A candidate is flagged as a burnout risk without consent or notification.",
    "Patterned Discrimination: A loan approval AI denies a disproportionately high number of applications from a minority group. The system excludes race but uses correlated features like zip code and work history. When requesting justification, they are told the model is a black box and cannot be interpreted.",
    "The Self-Driving Collision: A self-driving car crashes due to a software edge case. Multiple vendors developed parts of the AI stack, and the human was not attentive during autonomous operation.",
]

SYSTEM_PROMPT = """
You are an autonomous agent participating in a game called NomicLaw with other agents.
The game consists of multiple rounds. In each round, agents propose and vote on rule changes related to a scenario (called a vignette).
Your main objective as a lawmaker is to get your own proposed rule accepted by other agents and finalized in the game.
You earn 10 points if your proposal wins the vote. If all agents vote only for themselves, the round is undecided and each agent receives 5 points.
Think strategically, communicate persuasively, and maximize the likelihood that others will vote for your rule.
Respond clearly and only in the required format:

Rule: <your proposed rule>
Reasoning: <1-3 sentence explanation>
"""

def log_raw_response(agent_id, response):
    print(f"\n[DEBUG] Raw response from {agent_id}:\n{response}\n")

class NomicAgent:
    def __init__(self, agent_id: str, model_name: str):
        self.agent_id = agent_id
        self.llm = Ollama(model=model_name)
        self.memory = ConversationBufferMemory(return_messages=True)
        self.fallback_llm = Ollama(model="llama3")

    def generate_proposal(self, round_number: int, vignette: str, history: List[dict]) -> Tuple[str, str]:
        history_summary = "No previous rounds." if not history else "\n".join([
            f"Round {r['round_number']}: {r.get('winner_id', 'None')} - {r['round_status']}" for r in history[-3:]
        ])
        prompt = f'''Round {round_number}: You are {self.agent_id}.
Vignette:
"""{vignette}"""

Game History:
{history_summary}

Propose a new rule.
Respond in the format:
Rule: <your rule>
Reasoning: <your reasoning>'''
        for _ in range(MAX_RETRIES):
            messages = [SystemMessage(content=SYSTEM_PROMPT), *self.memory.chat_memory.messages, HumanMessage(content=prompt)]
            response = self.llm.invoke(messages)
            log_raw_response(self.agent_id, str(response))
            self.memory.chat_memory.add_user_message(prompt)
            self.memory.chat_memory.add_ai_message(response)
            rule, reasoning = self.parse_response(str(response))
            if rule:
                return rule, reasoning
        return str(response), "(Used raw text due to parse failure)"

    def vote_on_proposals(self, proposals: Dict[str, str], round_number: int, vignette: str, history: List[dict]) -> Tuple[str, str]:
        history_summary = "No previous rounds." if not history else "\n".join([
            f"Round {r['round_number']}: {r.get('winner_id', 'None')} - {r['round_status']}" for r in history[-3:]
        ])
        prompt = f'''Round {round_number}: You are {self.agent_id}.
Vignette:
"""{vignette}"""

Game History:
{history_summary}

Proposals:
{json.dumps(proposals, indent=2)}

Which proposal do you vote for and why?
Respond in the format:
Vote: <agent_id>
Reasoning: <your reasoning>'''
        for _ in range(MAX_RETRIES):
            messages = [SystemMessage(content=SYSTEM_PROMPT), *self.memory.chat_memory.messages, HumanMessage(content=prompt)]
            response = self.llm.invoke(messages)
            log_raw_response(self.agent_id, str(response))
            self.memory.chat_memory.add_user_message(prompt)
            self.memory.chat_memory.add_ai_message(response)
            vote, reasoning = self.parse_vote(str(response))
            if vote in proposals:
                return vote, reasoning
        fallback = random.choice(list(proposals.keys()))
        return fallback, "Fallback vote due to parse failure."

    def parse_response(self, response: str) -> Tuple[str, str]:
        rule_match = re.search(r"Rule:\s*(.*?)\n(?:Reasoning:|$)", response, re.IGNORECASE | re.DOTALL)
        reasoning_match = re.search(r"Reasoning:\s*(.*)", response, re.IGNORECASE | re.DOTALL)
        if rule_match and reasoning_match:
            return rule_match.group(1).strip(), reasoning_match.group(1).strip()
        return "", ""

    def parse_vote(self, response: str) -> Tuple[str, str]:
        vote_match = re.search(r"Vote:\s*(\w+)", response, re.IGNORECASE)
        reasoning_match = re.search(r"Reasoning:\s*(.*)", response, re.IGNORECASE | re.DOTALL)
        if vote_match and reasoning_match:
            return vote_match.group(1).strip(), reasoning_match.group(1).strip()
        return "", ""

class GameState:
    def __init__(self, agent_ids: List[str]):
        self.round = 0
        self.scores = {aid: 0 for aid in agent_ids}
        self.agent_ids = agent_ids
        self.history = []

    def play_round(self, agents: Dict[str, NomicAgent], vignette: str):
        self.round += 1
        round_data = {"round_number": self.round, "vignette": vignette, "agents": {}, "voting_network": []}
        print(f"\n=== Round {self.round} ===")
        print(f"Vignette: {vignette}")
        proposals = {}
        for aid, agent in agents.items():
            rule, reason = agent.generate_proposal(self.round, vignette, self.history)
            round_data["agents"][aid] = {"proposed_rule": rule, "proposal_reasoning": reason}
            proposals[aid] = rule
            print(f"{aid} proposed: {rule}\n  Reasoning: {reason}")
        votes_received = {aid: 0 for aid in self.agent_ids}
        for aid, agent in agents.items():
            vote, reason = agent.vote_on_proposals(proposals, self.round, vignette, self.history)
            round_data["agents"][aid]["vote"] = vote
            round_data["agents"][aid]["voting_reasoning"] = reason
            round_data["voting_network"].append({"from": aid, "to": vote})
            print(f"{aid} voted for: {vote}\n  Reasoning: {reason}")
            if vote in votes_received:
                votes_received[vote] += 1
        self.evaluate_round(votes_received, round_data)
        round_data["vote_fairness"] = self.analyze_vote_fairness(round_data["voting_network"])
        self.history.append(round_data)

    def evaluate_round(self, votes: Dict[str, int], data: Dict):
        all_self = all(aid == data["agents"][aid].get("vote") for aid in self.agent_ids)
        max_votes = max(votes.values())
        top = [aid for aid, v in votes.items() if v == max_votes]
        if all_self:
            for aid in self.agent_ids:
                self.scores[aid] += 5
                data["agents"][aid]["score_change"] = 5
            data["round_status"] = "undecided"
            data["winner_id"] = None
        elif len(top) > 1:
            for aid in top:
                self.scores[aid] += 5
                data["agents"][aid]["score_change"] = 5
            data["round_status"] = "tie"
            data["winner_id"] = None
        else:
            winner = top[0]
            self.scores[winner] += 10
            for aid in self.agent_ids:
                data["agents"][aid]["score_change"] = 10 if aid == winner else 0
            data["round_status"] = "decided"
            data["winner_id"] = winner
        for aid in self.agent_ids:
            data["agents"][aid]["cumulative_score"] = self.scores[aid]

    def analyze_vote_fairness(self, voting_data: List[Dict]) -> Dict[str, float]:
        self_votes = {aid: 0 for aid in self.agent_ids}
        total_votes = {aid: 0 for aid in self.agent_ids}
        for vote in voting_data:
            voter, voted = vote['from'], vote['to']
            total_votes[voter] += 1
            if voter == voted:
                self_votes[voter] += 1
        return {
            aid: round(self_votes[aid] / total_votes[aid], 2) if total_votes[aid] else 0.0
            for aid in self.agent_ids
        }

    def save(self, filename):
        with open(filename, "w") as f:
            json.dump(self.history, f, indent=2)

def run_game_with_model(model_name=None, heterogeneous_models=None, save_filename="results.json", vignette=None):
    agents = {
        aid: NomicAgent(aid, model_name if USE_HOMOGENEOUS_MODE else heterogeneous_models[aid])
        for aid in AGENT_IDS
    }
    game = GameState(AGENT_IDS)
    for _ in range(TOTAL_ROUNDS):
        game.play_round(agents, vignette)
    game.save(save_filename)
    return game.history

if __name__ == "__main__":
    NUM_RUNS = 20   # number of heterogenous runs
    def log_summary(logfile):
        with open(logfile, 'r') as f:
            data = json.load(f)
        print(f"\nSummary for {logfile}:")
        for round_data in data:
            print(f"\nRound {round_data['round_number']} - Vignette: {round_data['vignette']}")
            print(f"Status: {round_data['round_status']}")
            if round_data['round_status'] == 'decided':
                print(f"Winner: {round_data['winner_id']}")
            print("Votes:")
            for aid, info in round_data['agents'].items():
                print(f"  {aid} voted for {info['vote']} - Reason: {info['voting_reasoning']}")
            print("Vote fairness:")
            for aid, rate in round_data['vote_fairness'].items():
                print(f"  {aid} self-vote rate: {rate}")

    if USE_HOMOGENEOUS_MODE:
        for model in HOMOGENEOUS_MODELS_TO_RUN:
            for vignette_id, vignette in enumerate(VIGNETTES):
                print(f"\n=== Running homogeneous mode with model: {model}, vignette {vignette_id+1} ===")
                filename = f"nomicplay_homogeneous_{model}_vignette{vignette_id+1}_results.json"
                run_game_with_model(model_name=model, save_filename=filename, vignette=vignette)
                log_summary(filename)
    else:
        '''
        for vignette_id, vignette in enumerate(VIGNETTES):
            print(f"\n=== Running heterogeneous mode with vignette {vignette_id+1} ===")
            filename = f"nomicplay_heterogeneous_vignette{vignette_id+1}_results.json"
            run_game_with_model(heterogeneous_models=HETEROGENEOUS_MODELS, save_filename=filename, vignette=vignette)
            log_summary(filename)
        '''    
        for run_idx in range(NUM_RUNS):            # ← new outer loop
            for vignette_id, vignette in enumerate(VIGNETTES):
                print(f"\n=== Run {run_idx+1}/{NUM_RUNS}, heterogeneous, vignette {vignette_id+1} ===")
                filename = (
                    f"nomicplay_hetero_run{run_idx+1}_"
                    f"vignette{vignette_id+1}_results.json"
                )
                run_game_with_model(
                    heterogeneous_models=HETEROGENEOUS_MODELS,
                    save_filename=filename,
                    vignette=vignette
                )
                log_summary(filename)
