In [1]:
from SALib.sample import saltelli
from SALib.analyze import sobol
from SALib.test_functions import Ishigami
import numpy as np
import ineqpy
import matplotlib.pyplot as plt
from os import path
import main as sugar

runstorun = [5, 10] # lower and upper, both inclusive
runsfilename = f"sensitivity_analysis_runs_{runstorun[0]}_{runstorun[1]}.npy"
samplesfilename = "sensitivity_analysis_samples.npy"

In [2]:
def runmodel(params):
    # set parameters
    N = int(params[2])
    size = 99
    vision = int(params[0])
    tax_brackets = [0,0]
    tax_percentages = [0,0]
    inheritance_tax_brackets = [0, 10, 30, 50, 100]
    inheritance_tax_percentages = [0, 0.3, 0.3, 0.35, 0.6]
    starting_wealth = int(params[1])
    steps = 300
    
    parameters = N, size, vision, tax_brackets, tax_percentages, inheritance_tax_brackets, inheritance_tax_percentages, starting_wealth, steps

    # run model
    df = sugar.main(parameters)
    df = df.reset_index()
    
    # extract gini from last timestep
    maxstep = max(df["Step"].tolist())
    wealth = np.array(df.loc[(df.Step == maxstep)]["Wealth"].tolist() )
    gini = ineqpy.gini(np.array(wealth))
    
    return gini

def runmodel(params):
    return np.random.rand()

print(runmodel([2.0, 2.0, 5.0]))

0.532061666754642


In [6]:
# Define the model inputs
problem = {
    'num_vars': 6,
    'names': ['NAgents', 'Vision', 'Total_init_sugar', 'Amsterdam_map', "Death", "Instant_regrowth"],
    'bounds': [[200, 1000], # NAgents
               [1, 10], #Vision
               [1, 10], #Total init sugar
               [0, 1], #Amsterdam map
               [0, 1], #Death
               [0, 1]] #Instant_regrowth
}

# load samples if exist, otherwise create
samples = None
if path.exists(samplesfilename):
    samples = np.load(samplesfilename)
    print("Samples found and loaded")
else:
    # Generate samples
    samples = saltelli.sample(problem, 10, calc_second_order=False) # 1000 samples
    np.save(samplesfilename, samples)
    print("Samples created and saved")

print(len(samples)," samples, first 10:\n",samples[:10])

Samples found and loaded
80  samples, first 10:
 [[3.75781250e+02 1.87011719e+00 5.66699219e+00 6.76757812e-01
  2.80273438e-01 9.07226562e-01]
 [2.36718750e+02 1.87011719e+00 5.66699219e+00 6.76757812e-01
  2.80273438e-01 9.07226562e-01]
 [3.75781250e+02 9.09472656e+00 5.66699219e+00 6.76757812e-01
  2.80273438e-01 9.07226562e-01]
 [3.75781250e+02 1.87011719e+00 5.50878906e+00 6.76757812e-01
  2.80273438e-01 9.07226562e-01]
 [3.75781250e+02 1.87011719e+00 5.66699219e+00 6.93359375e-02
  2.80273438e-01 9.07226562e-01]
 [3.75781250e+02 1.87011719e+00 5.66699219e+00 6.76757812e-01
  8.49609375e-02 9.07226562e-01]
 [3.75781250e+02 1.87011719e+00 5.66699219e+00 6.76757812e-01
  2.80273438e-01 2.54882812e-01]
 [2.36718750e+02 9.09472656e+00 5.50878906e+00 6.93359375e-02
  8.49609375e-02 2.54882812e-01]
 [7.75781250e+02 6.37011719e+00 1.16699219e+00 1.76757812e-01
  7.80273438e-01 4.07226562e-01]
 [6.36718750e+02 6.37011719e+00 1.16699219e+00 1.76757812e-01
  7.80273438e-01 4.07226562e-01]]


In [None]:
# load runs
runs = None
if path.exists(runsfilename):
    runs = np.load(runsfilename)
    print("Runs found and loaded")
else:
    runs = ["NaN"] * len(samples)
    np.save(runsfilename, runs)
    print("Runs file created and saved")

runs_todo = [i for i, x in enumerate(runs) if x == "NaN"]
print("current runs result: ", runs)
print("Runs todo: ", runs_todo)

# Run simulations

In [None]:
print("Current run: ")
for run in runs_todo:
    if run >= runstorun[0] and run <= runstorun[1]:
        print(run, end = " ")
        runs[run] = runmodel(samples[run])
        np.save(runsfilename, runs)

# Analysis

In [None]:
ginis = np.array(runs)
print("\n\n\n")

# Perform analysis
Si = sobol.analyze(problem, ginis, print_to_console=True, calc_second_order=False)

# # Print the first-order sensitivity indices
# print(Si['S1'])

S1 = Si['S1']
S1_conf = Si['S1_conf']

print(S1, "\n", S1_conf)

# S1 = [1, 2, 3]
# S1_conf = [0.5, 0.4, 0.3]

plt.errorbar(np.arange(len(S1)), S1, yerr=S1_conf, fmt='o')
# plt.ylim((0, 10))
# plt.xticks(np.arange(len(S1)), ("var1", "var2", "var3"))
plt.show()