In [1]:
%run packages_import.ipynb    #import all necessary packages - numpy, pandas etc
%run parameters_class.ipynb  #import the parameters class and the base paramters object
%run simulation_class.ipynb  #import the class that runs the simulation
%run base_parameters.ipynb import base_parameters #import all the base paamters that you define
%run base_parameters.ipynb import Parameters_for_Analysis #import parameters for analyis function where you keep the base parameters fixed and modify the rest

In [2]:
def plot_2d_sensitivity_analysis_initial_pop_vs_evolvability(N2_N1_ratios, V_A2_V_A1_ratios, final_outcomes_matrix):
    color_map = {
        "sp 2 (high-genetic-var) wins, sp 1 (low-genetic-var) extinct": 0,
        "sp 1 (low-genetic-var) wins, sp 2 (high-genetic-var) extinct": 1,
        "both extinct": 2,
        "both coexist": 3
    }

    reverse_color_map = {v: k for k, v in color_map.items()}

    # Convert the final outcomes to numerical values
    numerical_matrix = np.vectorize(color_map.get)(final_outcomes_matrix)

    # Create the heatmap
    heatmap = go.Heatmap(
        z=numerical_matrix.T,  # Transpose the matrix to match the correct axis
        x=N2_N1_ratios,
        y=V_A2_V_A1_ratios,
        colorscale=[[0, 'red'], [1/3, 'blue'], [2/3, 'black'], [1, 'green']],
        showscale = False,
        colorbar=dict(
            tickvals=[0, 1, 2, 3],
            ticktext=list(color_map.keys()),
        )
    )

    fig = go.Figure(data=[heatmap])

    legend_labels = [
        ("Slow evolver extinct", 'red'),
        ("Fast evolver extinct", 'blue'),
        ('Both coexist', 'green'),
        ('Both extinct', 'black')
    ]

    for label, color in legend_labels:
        fig.add_trace(go.Scatter(
            x=[None], y=[None], mode='markers',
            marker=dict(color=color),
            showlegend=True,
            name=label
        ))



    fig.update_layout(
        xaxis=dict(
            title='Initial Population Ratio (log-scale)(Sp.2/Sp.1)',
            type='log',
            tickvals=[0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000],
            ticktext=["0.001", "0.01", "0.1", "1", "10", "100", "1000"]
        ),
        yaxis=dict(
            title='Evolvability Ratio [sp.2/sp.1] (log-scale)',
            type='log',
            tickvals=[1, 10, 100, 1000],
            ticktext=['1', '10', '100', '1000']
        ),
        template='plotly',
        width=1000,
        height=800,

 
    )
    

    pio.show(fig)

def generate_parameters_with_varied_parameters(N2_N1_ratios, V_A2_V_A1_ratios):
    parameters_list = []
    for N2_N1 in N2_N1_ratios:
        for V_A2_V_A1 in V_A2_V_A1_ratios:
            parameters_dict = Parameters_for_Analysis(base_parameters=base_parameters, V_A1=1, V_A2=1).__dict__.copy()
            parameters_dict['initial_population_species_2'] = parameters_dict['initial_population_species_1'] * N2_N1
            parameters_dict['V_A2'] = parameters_dict['V_A1'] * V_A2_V_A1
            parameters = Parameters_for_Analysis(base_parameters=base_parameters, **parameters_dict)
            parameters_list.append(parameters)
    return parameters_list

def run_sensitivity_analysis(parameters_list):
    final_outcomes_list = []
    for parameters in parameters_list:
        final_outcome_instance = run_simulation_and_get_outcome_instance(parameters)
        final_outcomes_list.append(final_outcome_instance.outcome)
    return final_outcomes_list

def main():
    N2_N1_powers = np.arange(-3, 3, 0.25)
    V_A2_V_A1_powers = np.arange(0, 2, 0.25)
    
    N2_N1_ratios = np.power(10.0, N2_N1_powers)
    V_A2_V_A1_ratios = np.power(10.0, V_A2_V_A1_powers)
    
    parameters_list = generate_parameters_with_varied_parameters(N2_N1_ratios, V_A2_V_A1_ratios)
    final_outcomes_list = run_sensitivity_analysis(parameters_list)
    
    final_outcomes_matrix = np.array(final_outcomes_list).reshape(len(N2_N1_ratios), len(V_A2_V_A1_ratios))
    
    plot_2d_sensitivity_analysis_initial_pop_vs_evolvability(N2_N1_ratios, V_A2_V_A1_ratios, final_outcomes_matrix)

if __name__ == "__main__":
    main()
