In [5]:
from SALib.sample import saltelli
from SALib.analyze import sobol
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from Environment import SIERDModel

# Define the problem
problem = {
    'num_vars': 4,
    'names': ['transmission_rate', 'latency_period', 'infection_duration', 'recovery_rate'],
    'bounds': [[0.1, 1.0],    # transmission_rate range
               [10, 70],     # latency_period range
               [10, 70],     # infection_duration range
               [0.1, 1.0]]   # recovery_rate range
}

# Generate parameter samples
param_values = saltelli.sample(problem, 100)

# Define the simulation function
def run_sobol_simulation(param_values, width, height, density, policy, num_districts, initial_infected, steps):
    results = []

    for params in param_values:
        transmission_rate, latency_period, infection_duration, recovery_rate = params
        model = SIERDModel(width, height, density, transmission_rate, latency_period, infection_duration, recovery_rate, policy, num_districts, initial_infected)
        
        for _ in range(steps):
            model.step()

        result = model.datacollector.get_model_vars_dataframe()
        results.append(result)
    
    return results

# Simulation parameters
width = 10
height = 10
density = 0.8
policy = "No Interventions"
num_districts = 5
initial_infected = 50
steps = 100

# Run the simulations
results = run_sobol_simulation(param_values, width, height, density, policy, num_districts, initial_infected, steps)

# Extract output data for Sobol analysis
Y = [result['Infected'].mean() for result in results]

# Perform Sobol sensitivity analysis
Si = sobol.analyze(problem, np.array(Y), print_to_console=True)

# Plot first-order sensitivity indices
plt.figure()
plt.bar(problem['names'], Si['S1'])
plt.xlabel('Parameter')
plt.ylabel('First-order sensitivity index')
plt.title('First-order Sensitivity')
plt.show()

# Plot total-order sensitivity indices
plt.figure()
plt.bar(problem['names'], Si['ST'])
plt.xlabel('Parameter')
plt.ylabel('Total-order sensitivity index')
plt.title('Total-order Sensitivity')
plt.show()

ModuleNotFoundError: No module named 'Environment'