In [None]:
"""Test Batch Runner."""

import pathlib
import random
from typing import Any

import polars as pl

from neighborly.components.business import Business, JobRole, Occupation
from neighborly.config import LoggingConfig, SimulationConfig
from neighborly.data_analysis import BatchRunner, Metric
from neighborly.datetime import SimDate
from neighborly.ecs import GameObject
from neighborly.events.defaults import LeaveJob
from neighborly.helpers.relationship import get_relationship
from neighborly.helpers.stats import get_stat
from neighborly.life_event import (
    EventRole,
    GlobalEventHistory,
    LifeEvent,
    event_consideration,
)
from neighborly.loaders import (
    load_businesses,
    load_characters,
    load_districts,
    load_job_roles,
    load_residences,
    load_settlements,
    load_skills,
    register_life_event_type,
)
from neighborly.plugins import (
    default_character_names,
    default_events,
    default_settlement_names,
    default_traits,
)
from neighborly.simulation import Simulation

import sample_plugin


class EventCountsMetric(Metric):
    """Metric for extracting event counts"""

    def extract_data(self, sim: Simulation) -> pl.DataFrame:
        events = list(sim.world.resource_manager.get_resource(GlobalEventHistory))

        # Count the number of times each event was fired in the world
        event_counts: dict[str, float] = {}

        for event in events:
            event_type = event.__class__.__name__

            if event_type not in event_counts:
                event_counts[event_type] = 0

            event_counts[event_type] += 1

        return pl.DataFrame(
            data={
                "event_type": list(event_counts.keys()),
                "count": list(event_counts.values()),
            },
            schema={"event_type": str, "count": int},
        )

    def get_aggregate_data(self) -> pl.DataFrame:
        return (
            pl.concat(self._tables)
            .group_by("event_type")
            .agg(
                [
                    pl.mean("count").alias("count_avg"),
                    pl.std("count").alias("count_std"),
                ]
            )
        )


class JobPromotion(LifeEvent):
    """The character is promoted at their job from a lower role to a higher role."""

    base_probability = 0.6  # <-- The probability of the event without considerations

    def __init__(
        self,
        subject: GameObject,
        business: GameObject,
        old_role: GameObject,
        new_role: GameObject,
    ) -> None:
        super().__init__(
            world=subject.world,
            roles=(
                EventRole("subject", subject, True),
                EventRole("business", business),
                EventRole("old_role", old_role),
                EventRole("new_role", new_role),
            ),
        )

    @staticmethod
    @event_consideration
    def relationship_with_owner(event: LifeEvent) -> float:
        """Considers the subject's reputation with the business' owner."""
        subject = event.roles["subject"]
        business_owner = event.roles["business"].get_component(Business).owner

        if business_owner is not None:
            return get_stat(
                get_relationship(business_owner, subject),
                "reputation",
            ).normalized

        return -1

    @staticmethod
    @event_consideration
    def boldness_consideration(event: LifeEvent) -> float:
        """Considers the subject's boldness stat."""
        return get_stat(event.roles["subject"], "boldness").normalized

    @staticmethod
    @event_consideration
    def reliability_consideration(event: LifeEvent) -> float:
        """Considers the subjects reliability stat."""
        return get_stat(event.roles["subject"], "reliability").normalized

    def execute(self) -> None:
        character = self.roles["subject"]
        business = self.roles["business"]
        new_role = self.roles["new_role"]

        business_data = business.get_component(Business)

        # Remove the old occupation
        character.remove_component(Occupation)

        business_data.remove_employee(character)

        # Add the new occupation
        character.add_component(
            Occupation(
                business=business,
                start_date=self.world.resource_manager.get_resource(SimDate),
                job_role=new_role.get_component(JobRole),
            )
        )

        business_data.add_employee(character, new_role.get_component(JobRole))

    @classmethod
    def instantiate(cls, subject: GameObject, **kwargs: Any):
        rng = subject.world.resource_manager.get_resource(random.Random)

        if subject.has_component(Occupation) is False:
            return None

        occupation = subject.get_component(Occupation)
        current_job_level = occupation.job_role.job_level
        business_data = occupation.business.get_component(Business)
        open_positions = business_data.get_open_positions()

        higher_positions = [
            role
            for role in open_positions
            if (role.job_level > current_job_level and role.check_requirements(subject))
        ]

        if len(higher_positions) == 0:
            return None

        # Get the simulation's random number generator
        rng = subject.world.resource_manager.get_resource(random.Random)

        chosen_role = rng.choice(higher_positions)

        return JobPromotion(
            subject=subject,
            business=business_data.gameobject,
            old_role=occupation.job_role.gameobject,
            new_role=chosen_role.gameobject,
        )

    def __str__(self) -> str:
        subject = self.roles["subject"]
        business = self.roles["business"]
        old_role = self.roles["old_role"]
        new_role = self.roles["new_role"]

        return (
            f"{subject.name} was promoted from {old_role.name} to "
            f"{new_role.name} at {business.name}."
        )


class FiredFromJob(LifeEvent):
    """The character is fired from their job."""

    base_probability = 0.4

    def __init__(
        self, subject: GameObject, business: GameObject, job_role: GameObject
    ) -> None:
        super().__init__(
            world=subject.world,
            roles=(
                EventRole("subject", subject, True),
                EventRole("business", business),
                EventRole("job_role", job_role),
            ),
        )

    @staticmethod
    @event_consideration
    def relationship_with_owner(event: LifeEvent) -> float:
        """Considers the subject's reputation with the business' owner."""
        subject = event.roles["subject"]
        business_owner = event.roles["business"].get_component(Business).owner

        if business_owner is not None:
            return (
                1
                - get_stat(
                    get_relationship(business_owner, subject),
                    "reputation",
                ).normalized
            )

        return -1

    @staticmethod
    @event_consideration
    def reliability_consideration(event: LifeEvent) -> float:
        """Considers the subjects reliability stat."""
        return 1 - get_stat(event.roles["subject"], "reliability").normalized

    def execute(self) -> None:
        subject = self.roles["subject"]
        business = self.roles["business"]
        job_role = self.roles["job_role"]

        # Events can dispatch other events
        LeaveJob(
            subject=subject, business=business, job_role=job_role, reason="fired"
        ).dispatch()

        business_data = business.get_component(Business)

        owner = business_data.owner
        if owner is not None:
            get_stat(get_relationship(subject, owner), "reputation").base_value -= 20
            get_stat(get_relationship(owner, subject), "reputation").base_value -= 10

    @classmethod
    def instantiate(cls, subject: GameObject, **kwargs: Any):
        if subject.has_component(Occupation) is False:
            return None

        occupation = subject.get_component(Occupation)

        return FiredFromJob(
            subject=subject,
            business=occupation.business,
            job_role=occupation.job_role.gameobject,
        )

    def __str__(self) -> str:
        subject = self.roles["subject"]
        business = self.roles["business"]
        job_role = self.roles["job_role"]

        return (
            f"{subject.name} was fired from their role as a "
            f"{job_role.name} at {business.name}."
        )


def sim_factory() -> Simulation:
    """Create new simulation instances."""
    sim = Simulation(
        SimulationConfig(
            settlement="basic_settlement",
            logging=LoggingConfig(logging_enabled=False),
        )
    )

    sample_plugin.load_plugin(sim)
    default_events.load_plugin(sim)
    default_traits.load_plugin(sim)
    default_character_names.load_plugin(sim)
    default_settlement_names.load_plugin(sim)

    # Add the events to the simulation
    register_life_event_type(sim, JobPromotion)
    register_life_event_type(sim, FiredFromJob)

    return sim


runner = BatchRunner(sim_factory, 20, 100)

count_metric = EventCountsMetric()

runner.add_metric(count_metric)

runner.run()

for row in count_metric.get_aggregate_data().rows(named=True):
    print(row)