In [None]:
%load_ext autoreload  
%autoreload 2
# so you dont have to restart the Kernel whenever utils updated

#TODO: End solara usage - it offically cannot handle size of data set - hanging 20 mins in

#TODO: Fix move logic - move partnering logic out of base class!

#NOTE: only 20ish% of beavers reproduce at age 2(juv) - code in, also some may stay with rents for years


""" Agent """

from mesa.experimental.cell_space import CellAgent

class Beaver(CellAgent):
    """Base Beaver Class"""

    def __init__(self, model, sex=None, cell=None, age=0):
        """
		* Initialise and populate the model
		"""
        super().__init__(model) 
        self.sex = sex if sex else model.random.choice(['M', 'F'])
        self.cell = cell
        self.partner = None
        self.age = age
        self.reproduction_timer = 0
        self.remove = False # mark for removal

    def step(self):
        if ( self.partner is None
            or getattr(self.partner, "remove", False)  # check if partner is not marked for removal
            or self.partner.partner != self
        ):
            # if no partner, or partner is marked for removal, or partner is not paired with self
            self.partner = None # clear partner
            potential_mates = [
                a for a in self.cell.agents
                if ( isinstance(a, Beaver) 
                    and a.sex != self.sex and (a.partner is None or getattr(a.partner, "remove", False) 
                    or a.partner.partner !=a))]
            if potential_mates:
                mate = self.random.choice(potential_mates)
                self.partner = mate
                mate.partner = self


        # move together if paired, else move alone
        if self.partner and self.partner.partner == self:
            if self.unique_id < self.partner.unique_id:  # only one of the pair moves both
                self.move(together=True)
        else:
            self.move(together=False)

    def move(self, together=False):
        new_cell = self.cell.neighborhood.select_random_cell()
        self.move_to(new_cell)
        if together and self.partner:
            if not getattr(self.partner, "remove", False):  # check if partner is not marked for removal
                self.partner.move_to(new_cell)
       

    def reproduce(self):
        if self.partner and self.cell is not None:
            for _ in range(self.random.randint(1, 3)): # random number of kits between 1-3
                kit = Kit(self.model, cell=self.cell)
                self.cell.agents.append(kit)
                self.model.type[Beaver].append(kit)

    def age_up(self):
        # kit -> juvenile at age 2 (24 steps), juvenile -> adult at age 3 (36 steps)
        if isinstance(self, Kit) and self.age >= 24: 
            return Juvenile(self.model, sex=self.sex, cell=self.cell, age=self.age)
        elif isinstance(self, Juvenile) and self.age >= 36:
            return Adult(self.model, sex=self.sex, cell=self.cell, age=self.age)
        else:
            return self


class Kit(Beaver):
    # kits move with group, can't pair or reproduce, age up

    def move(self, together=False): # move with colony
        adults = [a for a in self.cell.agents if isinstance(a, Adult)] #find adulgt in same cell
        if adults:
            self.move_to(adults[0].cell) # move to lead adults new cell - if no adult dont move!
        
         #TODO: finish later, should only move with parents or die - think this will mess up when parent dead so add in that 


    def step(self): 
        self.move() # specific movement logic - move with colony
        self.age += 1  

        new_self = self.age_up() # age up if applicable
        if new_self is not self:
            self.remove = True
            self.cell.agents.append(new_self)
            self.model.type[Beaver].append(new_self)
            # return new_self.step()
            return


class Juvenile(Beaver):
    # juveniles disperse away from group, pair and reproduce, !build dams!, age up
    def step(self):
        self.move()
        self.age += 1  

        # reproduction logic 
        if self.partner and self.partner.partner == self and self.unique_id < self.partner.unique_id:
            self.reproduction_timer += 1
            if self.reproduction_timer >= 12:
                self.reproduce()
                self.reproduction_timer = 0
        else:
            self.reproduction_timer = 0

        new_self = self.age_up() # age up if applicable
        if new_self is not self:
            self.remove = True
            self.cell.agents.append(new_self)
            self.model.type[Beaver].append(new_self)
            # return new_self.step() - no need to call step again, mutating the agent list by iterating
            return



class Adult(Beaver):
    # adults have full range of beaver behaviour (pairing, moving, reproducing, !building dams!, they dont age up-they die)
    def step(self):
        self.age += 1
        super().step()  # call base beaver logic (pairing, movement)

        # reproduction logic 
        if self.partner and self.partner.partner == self and self.unique_id < self.partner.unique_id:
            self.reproduction_timer += 1
            if self.reproduction_timer >= 12:
                self.reproduce()
                self.reproduction_timer = 0
        else:
            self.reproduction_timer = 0

        #TODO: partners dont re-pair when partner dies - they also dont move! fix
        if self.age >= 84: 
            # break pair bond if partner is alive
            if self.partner and self.partner.partner == self:
                self.partner.partner = None
            self.partner = None # clear self.partner
            self.remove = True
            return





""" Model """

from mesa import Model
from mesa.datacollection import DataCollector
from mesa.experimental.cell_space import OrthogonalVonNeumannGrid
from mesa.experimental.devs import ABMSimulator
import numpy as np
from rasterio import open as rio_open

#from beaver_agent import Beaver  # if this is seperate files

class BeaverModel(Model):
    def __init__(self, initial_beavers=50, seed=None, simulator=None): # initialise
        super().__init__(seed=seed)

        with rio_open("Users/r34093ls/Documents/test_flood/clipped_dtm.tif") as dem:  # 50m resolution
            self.dem = dem.read(1) # read the data out of band 1 in the datase

        self.width, self.height = self.dem.shape

        # properly initialise the grid
        self.grid = OrthogonalVonNeumannGrid(
            [self.height, self.width],
            torus=True,
            capacity=float("inf"),
            random=self.random,
        )

        # initialise type as a set NOT list
        self.type = {Beaver: []}

        # create initial beavers and add them to the grid
        for _ in range(initial_beavers):
            cell = self.random.choice(self.grid.all_cells.cells)
            beaver = Adult(model=self, cell=cell) # add only adult beavers 
            cell.agents.append(beaver)
            self.type[Beaver].append(beaver)


        self.datacollector = DataCollector({
            "Beavers": lambda m: len(m.type[Beaver]),
            "Paired Beavers": lambda m: len(
                [a for a in m.type[Beaver] if a.partner and a.unique_id < a.partner.unique_id]
            ),
            "Males": lambda m: len([a for a in m.type[Beaver] if a.sex == "M"]),
            "Females": lambda m: len([a for a in m.type[Beaver] if a.sex == "F"]),
            "Kits": lambda m: len([a for a in m.type[Beaver] if isinstance(a, Kit)]),
            "Juveniles": lambda m: len([a for a in m.type[Beaver] if isinstance(a, Juvenile)]),
            "Adults": lambda m: len([a for a in m.type[Beaver] if isinstance(a, Adult)]),
        })
        self.datacollector.collect(self)

        if simulator is not None:
            self.simulator = simulator
            self.simulator.setup(self)
            
        self.running = True

    def step(self):
        # update the agents
        for agent in list(self.type[Beaver]):
            agent.step()

        for agent in list(self.type[Beaver]):
            if getattr(agent, "remove", False):
                if agent in agent.cell.agents:
                    agent.cell.agents.remove(agent)
                if agent in self.type[Beaver]:
                    self.type[Beaver].remove(agent)
        
        self.datacollector.collect(self) # collect data on each step



""" App """

from mesa.experimental.devs import ABMSimulator
from mesa.visualization import (
    Slider,
    SolaraViz,
    make_plot_component,
    make_space_component,
)
import matplotlib.pyplot as plt

#from beaver_model import BeaverModel  # your adapted model
#from beaver_agent import Beaver  # your Beaver agent class


def beaver_portrayal(agent):
    if not getattr(agent, "cell", None):
        return None  # skip agents with no cell

    portrayal = {
        "size": 25,
        "marker": "o",
        "zorder": 2,
    }

    ## TODO: make neater - or stop using solara

    if isinstance(agent, Beaver):
        if agent.partner is not None:
            portrayal["color"] = "purple"
        elif agent.sex == "M":
            portrayal["color"] = "blue"
        else:
            portrayal["color"] = "red"

    if isinstance(agent, Kit):
        portrayal["color"] = "green"
    elif isinstance(agent, Juvenile):
        portrayal["color"] = "orange"
    elif isinstance(agent, Adult):
        portrayal["color"] = "brown"
    else:
        portrayal["color"] = "gray"

    return portrayal

dem_img = None

def post_process_space(ax):
    global dem_img 
    if dem_img is None:
        dem_img = ax.imshow(model.dem, cmap='viridis', alpha = 0.5, origin = "upper")
    ax.set_aspect("equal")
    ax.set_xticks([])
    ax.set_yticks([])

def post_process_lines(ax):
    ax.legend(loc="center left", bbox_to_anchor=(1, 0.9))

model_params = {
    "seed": {"type": "InputText", "value": 42,"label": "Random Seed" },
    "initial_beavers": Slider("Initial Beaver Population", 50, 10, 200),
}

space_component = make_space_component(
    beaver_portrayal, draw_grid=False, post_process=post_process_space
)

lineplot_component = make_plot_component(
    {
        "Beavers": "tab:gray",
        "Males": "blue",
        "Females": "red",
        "Paired Beavers": "purple",
        "Kits": "green",
        "Juveniles": "orange",
        "Adults": "brown",
    },
    post_process=post_process_lines,
)

simulator = ABMSimulator()
model = BeaverModel(simulator=simulator)

page = SolaraViz(
    model,  
    components=[space_component, lineplot_component],
    model_params=model_params,
    name="Beaver Simulation",
    simulator=simulator,
)

page  # noqa

  from mesa.experimental.cell_space import CellAgent
