In [None]:
from mesa import Model
from mesa.space import MultiGrid
from mesa.time import RandomActivationByType
from mesa.datacollection import DataCollector
from numpy import random
from mesa import Agent
from numpy import random
from mesa.visualization import CanvasGrid, ModularServer
import numpy as np
from mesa.time import RandomActivation
from mesa.visualization.modules import CanvasGrid, ChartModule
from mesa.visualization.ModularVisualization import ModularServer
import matplotlib.pyplot as plt

In [None]:
import solara
from matplotlib.figure import Figure

In [None]:


def get_distance(pos1, pos2):
    return (pos1[0] - pos2[0]) ** 2 + (pos1[1] - pos2[1]) ** 2

class Cell(Agent):
    def __init__(self, unique_id, model, capacities):
        super().__init__(unique_id, model)
        self.capacities = capacities
        self.sugar = capacities[0]
        self.spice = capacities[1]

    def step(self):
        self.regenerate()

    def regenerate(self):
        self.sugar = min(self.sugar + 1, self.capacities[0])
        self.spice = min(self.spice + 1, self.capacities[1])

class Trader(Agent):
    def __init__(self, unique_id, model, sugar, sugar_metabolism, spice, spice_metabolism, vision):
        super().__init__(unique_id, model)
        self.sugar = sugar
        self.sugar_metabolism = sugar_metabolism
        self.spice = spice
        self.spice_metabolism = spice_metabolism
        self.vision = vision
        self.spice_weight = sugar_metabolism / (sugar_metabolism + spice_metabolism)
        self.sugar_weight = 1 - self.spice_weight

    def step(self):
        self.move()
        self.pick_up()
        self.trade()
        self.metabolize()


    def move(self):
        # Get neighborhood
        neighbours = [i
                      for i in self.model.grid.get_neighborhood(
                self.pos, moore=True, include_center=False, radius=self.vision)]

        # Get cell with most sugar
        max_total = -1
        shortest_distance = 100
        max_cell = []
        for neighbour in neighbours:
            this_cell = self.model.grid.get_cell_list_contents([neighbour])
            for agent in this_cell:
                if isinstance(agent, Cell):
                    # Compute weighted average of sugar and spice
                    weighted_sugar = self.sugar_weight * agent.sugar
                    weighted_spice = self.spice_weight * agent.spice
                    total = weighted_sugar + weighted_spice

                    # Update max_sugar and max_sugar_cells
                    if total > max_total:
                        # Get distance to cell
                        distance = get_distance(self.pos, neighbour)
                        shortest_distance = distance
                        max_total = total
                        max_cell = [neighbour]

                    # Append to max_sugar_cells if equal
                    elif total == max_total:
                        # Get distance to cell
                        distance = get_distance(self.pos, neighbour)
                        if distance < shortest_distance:
                            shortest_distance = distance
                            max_cell = [neighbour]
                        elif distance == shortest_distance:
                            max_cell.append(neighbour)

        # Move to cell with most sugar
        new_position = random.choice(range(len(max_cell)))
        new_position = max_cell[new_position]
        self.model.grid.move_agent(self, new_position)

    def pick_up(self):
        this_cell = self.model.grid.get_cell_list_contents([self.pos])
        # Grab all sugar and spice from cell
        for agent in this_cell:
            if isinstance(agent, Cell):
                self.sugar += agent.sugar
                agent.sugar = 0

                self.spice += agent.spice
                agent.spice = 0

    def metabolize(self):
        # Metabolize sugar
        self.sugar -= self.sugar_metabolism

        # Metabolize spice
        self.spice -= self.spice_metabolism

        # Die if sugar is less than 0
        if self.sugar < 0:
            self.model.grid.remove_agent(self)
            self.model.schedule.remove(self)

        # Die if spice is less than 0
        if self.spice < 0:
            self.model.grid._remove_agent(self.pos, self)
            self.model.schedule.remove(self)

    def trade(self):
        neighbors = self.model.grid.get_neighbors(self.pos, moore=False, include_center=False, radius=1)
        random.shuffle(neighbors)
        for neighbor in neighbors:
            if isinstance(neighbor, Trader):
                while True:
                    my_mrs = self.get_mrs_sugar_spice()
                    their_mrs = neighbor.get_mrs_sugar_spice()

                    if my_mrs == their_mrs:
                        break

                    if my_mrs > their_mrs:
                        trader_high_mrs = self
                        trader_low_mrs = neighbor
                    else:
                        trader_high_mrs = neighbor
                        trader_low_mrs = self

                    trade_price = np.sqrt(my_mrs * their_mrs)
                    if trade_price > 1:
                        trade_spice = trade_price
                        trade_sugar = 1
                    else:
                        trade_spice = 1
                        trade_sugar = 1 / trade_price

                    trade_sugar = min(trade_sugar, trader_low_mrs.sugar)
                    trade_spice = min(trade_spice, trader_high_mrs.spice)

                    if trade_sugar <= 0 or trade_spice <= 0:
                        break

                    if self.improve_welfare(trader_high_mrs, trader_low_mrs, trade_sugar, trade_spice):
                        trader_high_mrs.spice -= trade_spice
                        trader_high_mrs.sugar += trade_sugar
                        trader_low_mrs.spice += trade_spice
                        trader_low_mrs.sugar -= trade_sugar

                        self.model.datacollector.add_table_row("Trades", {
                            'Step': self.model.schedule.steps,
                            'TraderHighMRS_ID': trader_high_mrs.unique_id,
                            'TraderLowMRS_ID': trader_low_mrs.unique_id,
                            'TradeSugar': trade_sugar,
                            'TradeSpice': trade_spice,
                            'TradePrice': trade_price
                        })
                    else:
                        break 


    def get_mrs_sugar_spice(self):
        return (self.sugar_metabolism * self.spice) / (self.spice_metabolism * self.sugar + 1e-9)

    def improve_welfare(self, trader_high_mrs, trader_low_mrs, trade_sugar, trade_spice):
        high_mrs_after_trade = (trader_high_mrs.sugar_metabolism * (trader_high_mrs.spice - trade_spice)) / (trader_high_mrs.spice_metabolism * (trader_high_mrs.sugar + trade_sugar + 1e-9))
        low_mrs_after_trade = (trader_low_mrs.sugar_metabolism * (trader_low_mrs.spice + trade_spice)) / (trader_low_mrs.spice_metabolism * (trader_low_mrs.sugar - trade_sugar + 1e-9))
        improves_welfare = high_mrs_after_trade < trader_high_mrs.get_mrs_sugar_spice() and low_mrs_after_trade > trader_low_mrs.get_mrs_sugar_spice()
        mrs_no_crossing = high_mrs_after_trade > low_mrs_after_trade
        return improves_welfare and mrs_no_crossing


In [None]:


class SugarScape(Model):
    def __init__(self, height=50, width=50, initial_population=100):
        super().__init__()
        self.height = height
        self.width = width
        self.current_step = 0
        self.initial_population = initial_population

        self.schedule = RandomActivation(self)
        self.grid = MultiGrid(self.height, self.width, False)

        self.datacollector = DataCollector(
            model_reporters={
                "Trade Price": compute_average_trade_price,
                "Gini": compute_gini,
            },
            tables={"Trades": ["Step", "TraderHighMRS_ID", "TraderLowMRS_ID", "TradeSugar", "TradeSpice", "TradePrice"]}
        )

         # Create cells
        id = 0
        for content, (x, y) in self.grid.coord_iter():
            # Instantiate cell
            capacities = random.randint(1, 10, 2)
            cell = Cell(id, self, capacities)

            # Place cell on grid
            self.grid.place_agent(cell, (x, y))
            self.schedule.add(cell)

            # Increment id
            id += 1

        # Create traders
        for i in range(self.initial_population):
            # Random position
            x = random.randint(0, self.width)
            y = random.randint(0, self.height)

            # Instantiate trader
            sugar, spice = random.randint(1, 10, 2)
            sugar_metabolism, spice_metabolism = random.randint(1, 4, 2)
            vision = random.randint(1, 4)
            trader = Trader(id, self, sugar, sugar_metabolism, spice, spice_metabolism, vision)

            # Place trader on grid
            self.grid.place_agent(trader, (x, y))
            self.schedule.add(trader)

            # Increment id
            id += 1

        self.running = True
        self.datacollector.collect(self)

    def step(self):
        self.schedule.step()
        self.datacollector.collect(self)
        self.running = self.schedule.get_agent_count() > 0
        self.current_step += 1

    def run_model(self, step_count=200):
        for i in range(step_count):
            self.step()


    def get_trade_log(self):
        return self.datacollector.get_table_dataframe("Trades")

def compute_trade_counts(model):
    trade_data = model.get_trade_log()
    return len(trade_data)

def compute_average_trade_price(model):
    trade_data = model.get_trade_log()
    if len(trade_data) == 0:
        return 0
    current_step_trades = trade_data[trade_data["Step"] == model.current_step]
    if len(current_step_trades) == 0:
        return 0
    average_price = current_step_trades["TradePrice"].mean()
    return average_price

def compute_gini(model):
    agent_wealths = [agent.sugar/agent.sugar_metabolism + agent.spice/agent.spice_metabolism for agent in model.schedule.agents if isinstance(agent, Trader)]
    sorted_wealths = sorted(agent_wealths)
    # plt.hist(sorted_wealths, bins=10)
    # plt.show()
    n = len(sorted_wealths)
    print(n)
    if n == 0:
        return 0
    cumulative_sum = sum((i + 1) * wealth for i, wealth in enumerate(sorted_wealths))
    total_wealth = sum(sorted_wealths)
    gini = (2 * cumulative_sum) / (n * total_wealth) - (n + 1) / n
    
    return gini

In [None]:
import signal

In [None]:
# def make_histogram(model):
#     # Note: you must initialize a figure using this method instead of
#     # plt.figure(), for thread safety purpose
#     fig = Figure()
#     ax = fig.subplots()
#     wealth_vals = compute_gini(model)
#     # Note: you have to use Matplotlib's OOP API instead of plt.hist
#     # because plt.hist is not thread-safe.
#     ax.hist(wealth_vals, bins=10)
#     solara.FigureMatplotlib(fig)


In [None]:


def agent_portrayal(agent):
    if agent is None:
        return

    portrayal = {"Filled": "true",
                 "r": 0.5,
                 "w": 1,
                 "h": 1}

    if type(agent) is Trader:
        portrayal["Color"] = "red"
        portrayal["Layer"] = 1
        portrayal["Shape"] = "circle"
    elif type(agent) is Cell:
        portrayal["Shape"] = "rect"
        portrayal["Color"] = "green" if agent.sugar > 0 and agent.spice > 0 else "black"
        portrayal["Layer"] = 0

    return portrayal
canvas_element = CanvasGrid(agent_portrayal, 50, 50, 500, 500)

trade_count_chart = ChartModule(
    [{"Label": "Number of Trades", "Color": "Blue"}],
    data_collector_name='datacollector'
)

average_trade_price_chart = ChartModule(
    [{"Label": "Trade Price", "Color": "Red"}],
    data_collector_name='datacollector'
)

gini_pop = ChartModule(
    [{"Label": "Gini", "Color": "Black"}],
    data_collector_name='datacollector'
)



server = ModularServer(
    SugarScape, 
    [canvas_element, trade_count_chart, average_trade_price_chart, gini_pop], 
    "Sugarscape Model",
    {"height": 50, "width": 50, "initial_population": 100}
    
)

server.port = 8694
server.launch()



In [None]:
SugarScape().step()
