In [15]:
import mesa
import numpy as np
from mesa.visualization import SolaraViz,make_plot_measure
import solara
from matplotlib.figure import Figure

In [29]:
class WealthAgent(mesa.Agent):
    
    def __init__(self,model, proportion,innovation):
        super().__init__(model)
        self.wealth=10
        self.W = proportion
        self.I = innovation
        self.ioriginal = innovation
        self.decay = 0
        
        
    def exchange(self): 

        return self.random.choice(self.model.agents)
    
    
    def step(self):
        
        """
                                PAYDAY
        """
        count = 0
        #increase wealth by proportion - payday
        self.wealth += (self.W*self.wealth)
        self.wealth -= self.wealth*0.1 #basic survival
                        
        exchange_agent = self.exchange()
        
        if self.wealth > exchange_agent.W*self.wealth: 
            if exchange_agent is not None and exchange_agent is not self:
                #print(self.wealth)
                exchange_agent.wealth += (exchange_agent.W*self.wealth)
                self.wealth -= (exchange_agent.W*self.wealth)                
        else: 
            count += 1
            if count < 5: 
                self.step()
            else: 
                print(f"poor agent {self.wealth}")

        '''
                            INNOVATION
      
         '''    
        if model.innovation==True: 
            if self.wealth > self.model.total*model.threshold and self.I > 1.0: 
                #increase payday by innovation
                self.W*=self.I
                #Value of innovation decreases over time
                self.I-=self.decay #starts at 0
                #increase decay for next step 
                self.decay+=0.01
            else: 
                self.decay = 0 
                self.I = self.ioriginal
        
       

In [30]:


@solara.component
def Histogram(model):
    # Note: you must initialize a figure using this method instead of
    # plt.figure(), for thread safety purpose
    fig = Figure()
    ax = fig.subplots()
    wealth_vals = [agent.wealth for agent in model.agents]
    # Note: you have to use Matplotlib's OOP API instead of plt.hist
    # because plt.hist is not thread-safe.
    ax.hist(wealth_vals, bins=10)
    return solara.FigureMatplotlib(fig)

def compute_gini(model):
    agent_wealths = [abs(float(agent.wealth)) for agent in model.agents]
    x = sorted(agent_wealths)
    N = model.population
    B = sum(xi * (N - i) for i, xi in enumerate(x)) / (N * sum(x))
    return 1 + (1 / N) - 2 * B

def total_wealth(model): 
    return sum([float(agent.wealth) for agent in model.agents])

In [31]:
class WealthModel(mesa.Model): 
    
    def __init__(self, population, threshold, tax=0.0, debt=False, innovation=False):
        
        super().__init__()
        self.population = population
        self.threshold = threshold
        self.tax = tax
        self.tax_dynamic = int(tax*population)
        self.tax_dynamic2 = int(np.log2(population) + 1)
        self.debt = debt
        self.innovation=innovation
        self.total = self.population*10
    
        #self.schedule = mesa.time.RandomActivation(self)
        self.datacollector = mesa.DataCollector(model_reporters = {"Gini": compute_gini, "Total": total_wealth },
                                               agent_reporters={"Wealth":"wealth", "Innovation":"I","Pay":"W" })
        
        
        # create an array of iniaital weatth value    
        payday_array = np.random.normal(loc=0.2,
                                          scale=0.03,
                                          size=self.population)
        innovation_array = np.random.normal(loc=1.05,
                                          scale=0.01,
                                          size=self.population)
        # round array to two decimals
        payday_array = np.around(payday_array, decimals=2)

        innovation_array = np.around(innovation_array, decimals=2)
        
        for idx in range(self.population):
            WealthAgent(self, float(payday_array[idx]), float(innovation_array[idx]))
    
    def step(self):
        self.datacollector.collect(self)
        Histogram(self)
        self.agents.shuffle_do("step")
        self.total = total_wealth(self)
        
        # Tax Model 1 - RobinHood
        if model.tax > 0.0:
            # Sort agents from richest to poorest
            sorted_agents = sorted(self.agents, key=lambda agent: agent.wealth, reverse=True)
            taxes = 0
            for agent_idx in range(self.tax_dynamic): 
                #get percent
                tax_amount = sorted_agents[agent_idx].wealth*self.tax
                #tax wealthy
                sorted_agents[agent_idx].wealth -= tax_amount
                #give poor
                sorted_agents[-agent_idx].wealth += tax_amount
       
        '''
        # Tax Model 2 - Flat Tax
        if model.tax > 0.0: 
            # sort agents from poorest to richest
            sorted_agents = sorted(self.agents, key=lambda agent: agent.wealth)
            taxes = 0
            for agent in sorted_agents: 
                tax_amount = agent.wealth*self.tax
                taxes+=tax_amount
                agent.wealth-=tax_amount
            # determine historgram bins based on wealth distro
            counts, _ = np.histogram([agent.wealth for agent in self.agents], bins=int(np.log2(self.population) + 1))
            redistro = taxes/counts[0]
            print(taxes, redistro, counts)
            for agent in sorted_agents[:counts[0]]:
                agent.wealth+=redistro 
            
        '''

In [33]:
model = WealthModel(200, .02, tax=0.0, innovation=True)

for step in range(100):
    model.step()
    print(step, model.total)
    
output = model.datacollector.get_agent_vars_dataframe()
output.to_csv("inequality_output.csv")

output2 = model.datacollector.get_model_vars_dataframe()
output2.to_csv("model_output.csv")

0 2183.032639219006
1 2386.435156154273
2 2614.3686463925287
3 2865.486931166055
4 3140.908125783302
5 3441.6920242366873
6 3775.4699383477537
7 4145.468064604806
8 4557.050322639306
9 5010.380405014903
10 5504.99579765829
11 6054.626952541859
12 6644.646722392667
13 7308.9002362454685
14 8037.0487396661065
15 8830.151024092564
16 9697.195725299953
17 10655.583894941028
18 11709.801224129951
19 12874.985814447215
20 14164.630582468906
21 15582.385245723257
22 17153.250325271354
23 18888.521720120767
24 20786.887486013045
25 22877.42409791148
26 25166.564713299853
27 27671.42650206756
28 30378.695662515434
29 33370.56903500783
30 36714.11527480993
31 40408.0965089146
32 44472.73165259853
33 48958.66053028745
34 53914.834167765526
35 59412.0829665517
36 65408.284555728605
37 72094.96363028463
38 79378.80196283014
39 87279.64782071317
40 95950.75406593429
41 105438.28107002676
42 115893.08358974004
43 127261.38294078603
44 139633.87805575517
45 153435.06680458615
46 168528.68394414955
47 

In [None]:
model = WealthModel(50, 20)

model_params = {
    "population": {
        "type": "SliderInt",
        "value": 50,
        "label": "Number of agents:",
        "min": 10,
        "max": 200,
        "step": 10,
    },
    "tax" : {
    "type": "SliderFloat",
    "value": 0.0, 
    "min": 0.0,
    "max":1.0,
    "step":0.05},

    "threshold": {
        "type": "SliderFloat",
        "value": 0.1,
        "label": "Innovation Threshold:",
        "min": 0,
        "max": 1,
        "step": 0.1},
    
    "innovation": {
        "type": "Button",
        "value": False,
        "label": "Innovation On",
        }
}

wealth_plot = make_plot_measure("Gini")
total_wealth = make_plot_measure("Total")

dash = SolaraViz(
    model, 
    components=[wealth_plot, total_wealth],
    model_params=model_params,
)

dash