# Normative Uncertainty in IAMs

#### Testing the hypervolumes for convergence

- First generate reference set from all seeds (or islands of MM Borg MOEA)
- Then generate hypervolumes for each seed (or island) against reference set
- Plot the hypervolumes for each seed (or island) against number of function evaluations

In [None]:
# This code creates a reference set from the different seeds

# NOTE: For MMBorg archives, run the script to convert it to the format recognized by older code with ema-workbench.
# Example 
# python borg_archive_processor.py     --archive /Volumes/justicedrive/NU_data_20_Oct/PRIORITARIAN_200000_ref5_42/mm_intermediate.zip     --base-name PRIORITARIAN_200000_ref5_42     --step 10000

from solvers.convergence.hypervolume import get_global_reference_set, calculate_hypervolume_from_archives
import multiprocessing
# Suppress warnings
import warnings

from justice.util.enumerations import WelfareFunction, SSP
from justice.util.visualizer import plot_hypervolume

warnings.filterwarnings("ignore")

base_path = "data/temporary/NU_DATA/mmBorg/" # Change this to your path

swf = WelfareFunction.PRIORITARIAN
nfe = 100_000
ssp = SSP.SSP4
ssp_ref = 5
path = f"{base_path}/{swf.value[1]}_{str(ssp).split('.')[1]}"


print(f"Loading data from {path}...")

list_of_objectives = [
    "welfare",
    "fraction_above_threshold",
]
data_path = path 

direction_of_optimization = ["min", "min"] #, "max", "max"

get_global_reference_set(
    list_of_objectives=list_of_objectives,
    data_path=data_path,
    #file_name=None,
    swf=[
        swf.value[1],
    ],
    nfe=str(nfe), # Ran for 50k number of function evaluations

    # Setting the same epsilon values as optimization process  (see analysis/analyzer.py)
    epsilons=[
        0.00001,
        0.001,
    ],


    direction_of_optimization=direction_of_optimization,
    output_data_path=path,
    saving=True,
)




Loading data from data/temporary/NU_DATA/mmBorg/PRIORITARIAN_SSP4/200k...
Loading list of files
Loading archives for:  PRIORITARIAN
Filename:  PRIORITARIAN_200000_ref5_42_1.tar.gz
Matching file: PRIORITARIAN_200000_ref5_42_1.tar.gz
Loading archives from: PRIORITARIAN_200000_ref5_42_1.tar.gz
Max key: 200000
Number of rows in archive: 2
Archives loaded for: PRIORITARIAN_200000_ref5_42_1.tar.gz
Filename:  PRIORITARIAN_200000_ref5_42_3.tar.gz
Matching file: PRIORITARIAN_200000_ref5_42_3.tar.gz
Loading archives from: PRIORITARIAN_200000_ref5_42_3.tar.gz
Max key: 200000
Number of rows in archive: 3
Archives loaded for: PRIORITARIAN_200000_ref5_42_3.tar.gz
Filename:  PRIORITARIAN_200000_ref5_42_0.tar.gz
Matching file: PRIORITARIAN_200000_ref5_42_0.tar.gz
Loading archives from: PRIORITARIAN_200000_ref5_42_0.tar.gz
Max key: 200000
Number of rows in archive: 2
Archives loaded for: PRIORITARIAN_200000_ref5_42_0.tar.gz
Filename:  PRIORITARIAN_200000_ref5_42_2.tar.gz
Matching file: PRIORITARIAN_200

{'PRIORITARIAN':     center 0  center 1  center 2  center 3  center 4  center 5  center 6  \
 12  0.136858 -0.950275  0.042237 -0.932331  0.047157 -0.129507  0.013975   
 13  0.136858 -0.951947  0.053598  0.100470  0.231806 -0.129114  0.053326   
 
     center 7   radii 0   radii 1  ...  weights 220  weights 221  weights 222  \
 12 -0.045703  0.133822  0.999859  ...     0.938101     0.999547     0.939041   
 13 -0.016753  0.133822  0.999863  ...     0.938100     0.999543     0.939041   
 
     weights 223  weights 224  weights 225  weights 226  weights 227  \
 12     0.567953     0.980678     0.979616     0.994322     0.936356   
 13     0.567955     0.980678     0.979616     0.994050     0.936992   
 
        welfare  fraction_above_threshold  
 12  498.445120                      0.58  
 13  498.139221                      0.60  
 
 [2 rows x 246 columns]}

Computing the Hypervolume for the reference set

In [2]:
## This block computes the Hypervolume for the reference set

filenames = [


    # Loading Archives for the different seeds  Borg
    f"{swf.value[1]}_{nfe}_ref{ssp_ref}_42_0.tar.gz", 
    f"{swf.value[1]}_{nfe}_ref{ssp_ref}_42_1.tar.gz",
    f"{swf.value[1]}_{nfe}_ref{ssp_ref}_42_2.tar.gz",
    f"{swf.value[1]}_{nfe}_ref{ssp_ref}_42_3.tar.gz",
    f"{swf.value[1]}_{nfe}_ref{ssp_ref}_42_4.tar.gz",


]

reference_set = f"{swf.value[1]}_reference_set.csv"
# reference_set =  "final_archive/100000.csv"

with multiprocessing.Pool() as pool:
    # Enumerate through the filenames
    for filename in filenames:
        scores = calculate_hypervolume_from_archives(
            list_of_objectives=list_of_objectives,
            direction_of_optimization=direction_of_optimization,
            input_data_path=data_path,
            file_name=filename,
            output_data_path=path,
            saving=True,
            global_reference_set=True,
            global_reference_set_path=path,
            global_reference_set_file=reference_set,
            pool=pool,
          )  # NOTE: Change this according to the PF refset
        



Loading archives for PRIORITARIAN_200000_ref5_42_0.tar.gz
Archives loaded
list_of_archives:  (45, 2)
reference_set (2, 2)
type of reference_set <class 'numpy.ndarray'>
nfes: 
 [100, 10000, 100000, 110000, 120000, 130000, 140000, 150000, 160000, 170000, 180000, 190000, 20000, 200000, 30000, 40000, 50000, 60000, 70000, 80000, 90000]
Computing hypervolume for  PRIORITARIAN_200000_ref5_42_0.tar.gz
Time taken for Hypervolume Calculation: 2.921 seconds
data/temporary/NU_DATA/mmBorg/PRIORITARIAN_SSP4/200k/PRIORITARIAN_200000_ref5_42_0_hv.csv
Loading archives for PRIORITARIAN_200000_ref5_42_1.tar.gz
Archives loaded
list_of_archives:  (56, 2)
reference_set (2, 2)
type of reference_set <class 'numpy.ndarray'>
nfes: 
 [100, 10000, 100000, 110000, 120000, 130000, 140000, 150000, 160000, 170000, 180000, 190000, 20000, 200000, 30000, 40000, 50000, 60000, 70000, 80000, 90000]
Computing hypervolume for  PRIORITARIAN_200000_ref5_42_1.tar.gz
Time taken for Hypervolume Calculation: 0.003 seconds
data/tem

Plotting the Hypervolumes for each seed (or island) against number of function evaluations

In [3]:
input_data_path_list = {
    
   swf.value[1]: [
        f"{swf.value[1]}_{nfe}_ref{ssp_ref}_42_0_hv.csv",
        f"{swf.value[1]}_{nfe}_ref{ssp_ref}_42_1_hv.csv",
        f"{swf.value[1]}_{nfe}_ref{ssp_ref}_42_2_hv.csv",
        f"{swf.value[1]}_{nfe}_ref{ssp_ref}_42_3_hv.csv",
        f"{swf.value[1]}_{nfe}_ref{ssp_ref}_42_4_hv.csv",
    ],
}


fig = plot_hypervolume(
    path_to_data=path,
    path_to_output=path,
    input_data=input_data_path_list,
    yaxis_upper_limit=1.0,
    width=1000,
    height=800,
    fontsize=20,
    saving=True,
)

fig.show()

## Launch the Mapping Script in Util
```
python justice/util/postprocessing_for_regret_calculations.py data/temporary/NU_DATA/mmBorg/ UTILITARIAN SSP2
```
- This reevaluates all the Pareto optimal policy candidates to compute the 90th percentile regret values for welfare (utilitarian/prioritarian) and temperature rise in degree celsius.
- NOTE: This script takes a long time (around 30 minutes to several hours depending on the number of policy candidates, scenarios and computational resources available).
- Call the script separately for each social welfare function and reference scenario (under which the polices are optimized) combination.

# Automated Regret Calculation

In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from justice.util.model_time import TimeHorizon
from justice.util.data_loader import DataLoader
import json

from justice.util.enumerations import WelfareFunction, SSP

from pathlib import Path
import numpy as np
import pandas as pd
from justice.util.output_data_processor import compute_p90_regret_dataframe, minimax_regret_policy
from justice.util.enumerations import WelfareFunction, SSP


scenario_list = ["SSP126","SSP245","SSP370","SSP460","SSP534"]

# print(f"Processing scenario: {ssp}")

# print(SSP.get_index("SSP3"))

ethical_framing_and_regret = {
    "SSP1": { "UTILITARIAN": {"Temperature_Regret", "Welfare_Regret"}, "PRIORITARIAN": {"Temperature_Regret", "Welfare_Regret"}},
    "SSP2": { "UTILITARIAN": {"Temperature_Regret", "Welfare_Regret"}, "PRIORITARIAN": {"Temperature_Regret", "Welfare_Regret"}}, 
    "SSP3": { "UTILITARIAN": {"Temperature_Regret", "Welfare_Regret"}, "PRIORITARIAN": {"Temperature_Regret", "Welfare_Regret"}},
    "SSP4": { "UTILITARIAN": {"Temperature_Regret", "Welfare_Regret"}, "PRIORITARIAN": {"Temperature_Regret", "Welfare_Regret"}},
    "SSP5": { "UTILITARIAN": {"Temperature_Regret", "Welfare_Regret"}, "PRIORITARIAN": {"Temperature_Regret", "Welfare_Regret"}},
}

# Create a dictionary to hold the policy indices with minimum regret for each scenario, ethical framing, and regret type
min_regret_policy_indices = {}

base_path = "data/temporary/NU_DATA/mmBorg/"
save_regret_dfs = True

for key, value in ethical_framing_and_regret.items():
    print(f"Scenario: {key}")
    # print(SSP.get_index(key))

    baseline_scenario = None
    if key == "SSP1":
        baseline_scenario = "SSP126"
    elif key == "SSP2":
        baseline_scenario = "SSP245"
    elif key == "SSP3":
        baseline_scenario = "SSP370"
    elif key == "SSP4":
        baseline_scenario = "SSP460"
    elif key == "SSP5":
        baseline_scenario = "SSP534"
        
    for ethical_framing, regret_types in value.items():
        print(f"  Ethical Framing: {ethical_framing}")
        
        # swf = WelfareFunction.get_index(ethical_framing)
        # print(f"  Welfare Function: {WelfareFunction.get_string(swf)}")
        # print(f"  SSP: {SSP.get_index(key)}")

        for regret_type in regret_types:
            # print(f"    Regret Type: {regret_type}, Policy Index: {policy_index}")
            if regret_type == "Temperature_Regret":
                variable_of_interest = "global_temperature"  # Ensure same welfare function
                direction_of_interest = "min"  # Use min for global temperature

                p90_delta_df = compute_p90_regret_dataframe(
                    base_path=base_path + f"{ethical_framing}_{key}/",
                    welfare_function_name=ethical_framing,
                    baseline_scenario=baseline_scenario,
                    scenario_list=scenario_list,
                    variable_of_interest=variable_of_interest,
                    direction_of_interest=direction_of_interest,
                    mapping_subdir="mapping",
                    hdf5_filename_template="mapping_{}.h5",
                    save_df=save_regret_dfs,  # Save CSV file
                    df_output_path=None  # Will save to default location '<base_path>/p90_regret_<welfare_function_name>.csv'
                )
                temp_idx = minimax_regret_policy(p90_delta_df)
                print(f"Processing {ethical_framing} with {regret_type} for {key}  and baseline scenario {baseline_scenario}")
                print("Policy index with minimum regret:", temp_idx)

                # Fill the dictionary
                if key not in min_regret_policy_indices:
                    min_regret_policy_indices[key] = {}
                if ethical_framing not in min_regret_policy_indices[key]:
                    min_regret_policy_indices[key][ethical_framing] = {}
                min_regret_policy_indices[key][ethical_framing][regret_type] = temp_idx

            elif regret_type == "Welfare_Regret":
                if ethical_framing == "UTILITARIAN":
                    variable_of_interest = "utilitarian_welfare"
                    direction_of_interest = "max"  # Use max for welfare variables

                    p90_delta_df = compute_p90_regret_dataframe(
                        base_path=base_path + f"{ethical_framing}_{key}/",
                        welfare_function_name=ethical_framing,
                        baseline_scenario=baseline_scenario,
                        scenario_list=scenario_list,
                        variable_of_interest=variable_of_interest,
                        direction_of_interest=direction_of_interest,
                        mapping_subdir="mapping",
                        hdf5_filename_template="mapping_{}.h5",
                        save_df=save_regret_dfs,  # Save CSV file
                        df_output_path=None  # Will save to default location '<base_path>/p90_regret_<welfare_function_name>.csv'
                    )
                    temp_idx = minimax_regret_policy(p90_delta_df)
                    print(f"Processing {ethical_framing} with {regret_type} with variable of interest {variable_of_interest} for {key} and baseline scenario {baseline_scenario}")
                    print("Policy index with minimum regret:", temp_idx)

                    # Fill the dictionary
                    if key not in min_regret_policy_indices:
                        min_regret_policy_indices[key] = {}
                    if ethical_framing not in min_regret_policy_indices[key]:
                        min_regret_policy_indices[key][ethical_framing] = {}
                    min_regret_policy_indices[key][ethical_framing][regret_type] = temp_idx

                elif ethical_framing == "PRIORITARIAN":
                    variable_of_interest = "prioritarian_welfare"
                    direction_of_interest = "max"

                    p90_delta_df = compute_p90_regret_dataframe(
                        base_path=base_path + f"{ethical_framing}_{key}/",
                        welfare_function_name=ethical_framing,
                        baseline_scenario=baseline_scenario,
                        scenario_list=scenario_list,
                        variable_of_interest=variable_of_interest,
                        direction_of_interest=direction_of_interest,
                        mapping_subdir="mapping",
                        hdf5_filename_template="mapping_{}.h5",
                        save_df=save_regret_dfs,  # Save CSV file
                        df_output_path=None  # Will save to default location '<base_path>/p90_regret_<welfare_function_name>.csv'
                    )
                    temp_idx = minimax_regret_policy(p90_delta_df)
                    print(f"Processing {ethical_framing} with {regret_type} with variable of interest {variable_of_interest} for {key} and baseline scenario {baseline_scenario}")
                    print("Policy index with minimum regret:", temp_idx)
                    # Fill the dictionary
                    if key not in min_regret_policy_indices:
                        min_regret_policy_indices[key] = {}
                    if ethical_framing not in min_regret_policy_indices[key]:
                        min_regret_policy_indices[key][ethical_framing] = {}
                    min_regret_policy_indices[key][ethical_framing][regret_type] = temp_idx



# Save this dictionary at the base path
with open(base_path + "min_regret_policy_indices.json", "w") as f:
    json.dump(min_regret_policy_indices, f, indent=4)



  from .autonotebook import tqdm as notebook_tqdm


Scenario: SSP1
  Ethical Framing: UTILITARIAN
Saved p90 delta data to data/temporary/NU_DATA/mmBorg/UTILITARIAN_SSP1/p90_regret_UTILITARIAN_global_temperature.csv
Processing UTILITARIAN with Temperature_Regret for SSP1  and baseline scenario SSP126
Policy index with minimum regret: 6
Saved p90 delta data to data/temporary/NU_DATA/mmBorg/UTILITARIAN_SSP1/p90_regret_UTILITARIAN_utilitarian_welfare.csv
Processing UTILITARIAN with Welfare_Regret with variable of interest utilitarian_welfare for SSP1 and baseline scenario SSP126
Policy index with minimum regret: 4
  Ethical Framing: PRIORITARIAN
Saved p90 delta data to data/temporary/NU_DATA/mmBorg/PRIORITARIAN_SSP1/p90_regret_PRIORITARIAN_global_temperature.csv
Processing PRIORITARIAN with Temperature_Regret for SSP1  and baseline scenario SSP126
Policy index with minimum regret: 0
Saved p90 delta data to data/temporary/NU_DATA/mmBorg/PRIORITARIAN_SSP1/p90_regret_PRIORITARIAN_prioritarian_welfare.csv
Processing PRIORITARIAN with Welfare_Re

## Run the reevaluation script
```
python justice/util/reevaluate_optimal_policy.py
```
- Reevaluates the policy candidates selected in the previous step across all scenarios.
- Extracts relevant variables - emissions, temperature, emission control rates and saves them in npy files for further analysis.
- NOTE: This script generates big files (several GBs). At least ensure 100 GB of free space in the drive. Select the appropriate output path in the script before running.

# Visualize the Pathways

In [None]:
from justice.util.visualizer import plot_comparison_with_boxplots, plot_choropleth_2D_data
from justice.util.enumerations import WelfareFunction, SSP
import json
import numpy as np
import plotly.express as px
import pandas as pd

variable_name = "emissions"

base_path = "data/temporary/NU_DATA/mmBorg/"
# Read the dictionary back
with open(base_path + "min_regret_policy_indices.json", "r") as f:
    loaded_min_regret_policy_indices = json.load(f)

# Print the final dictionary of minimum regret policy indices
print("\nMinimum Regret Policy Indices:")
for scenario, ethical_data in loaded_min_regret_policy_indices.items():
    print(f"Scenario: {scenario}")
    for ethical_framing, regret_data in ethical_data.items():
        print(f"  Ethical Framing: {ethical_framing}")
        for regret_type, policy_index in regret_data.items():
            print(f"    Regret Type: {regret_type}, Policy Index: {policy_index}")

            plot_comparison_with_boxplots(
                data_paths=[

                    base_path + f"{ethical_framing}_{scenario}/ref_{scenario}_{regret_type}_idx{policy_index}/{ethical_framing}_ref_{scenario}_{regret_type}_idx{policy_index}_{variable_name}_idx{policy_index}_SSP126_{variable_name}.npy",
                    base_path + f"{ethical_framing}_{scenario}/ref_{scenario}_{regret_type}_idx{policy_index}/{ethical_framing}_ref_{scenario}_{regret_type}_idx{policy_index}_{variable_name}_idx{policy_index}_SSP245_{variable_name}.npy",
                    base_path + f"{ethical_framing}_{scenario}/ref_{scenario}_{regret_type}_idx{policy_index}/{ethical_framing}_ref_{scenario}_{regret_type}_idx{policy_index}_{variable_name}_idx{policy_index}_SSP370_{variable_name}.npy",
                    base_path + f"{ethical_framing}_{scenario}/ref_{scenario}_{regret_type}_idx{policy_index}/{ethical_framing}_ref_{scenario}_{regret_type}_idx{policy_index}_{variable_name}_idx{policy_index}_SSP460_{variable_name}.npy",
                    base_path + f"{ethical_framing}_{scenario}/ref_{scenario}_{regret_type}_idx{policy_index}/{ethical_framing}_ref_{scenario}_{regret_type}_idx{policy_index}_{variable_name}_idx{policy_index}_SSP534_{variable_name}.npy",
                
                
                ],
                labels=[
                    
                    'SSP1',
                    'SSP2',
                    'SSP3',
                    'SSP4',
                    'SSP5',
                    ], 
                start_year=2015,
                end_year=2300,
                data_timestep=5,
                timestep=1,
                visualization_start_year=2015,
                visualization_end_year=2100,
                yaxis_range=[0, 80],
                plot_title=' ',
                xaxis_title='Year',
                yaxis_title='Global Emissions (GtCO2)',
                template='plotly_white',
                width=1000,
                height=700,
                output_path=base_path +"/"+ "plots",
                saving=True,
                show_red_dashed_line=False,
                show_interquartile_range=True,
                linecolors=[
            

                    "rgba(141,211,199, 1)",
                    "rgba(254,217,166, 1)", 
                    "rgba(190,186,218, 1)", 
                    "rgba(128,177,211, 1)", 
                    "rgba(251,128,114, 1)", 
                    ],
                colors = [ 

                    "rgba(141,211,199, 0.4)", 
                    "rgba(254,217,166, 0.4)",
                    "rgba(190,186,218, 0.4)", 
                    "rgba(128,177,211, 0.4)", 
                    "rgba(251,128,114, 0.4)", 

                    ],
                first_plot_proportion=[0, 0.75],
                second_plot_proportion=[0.85, 1],
                transpose_data=True,
                show_min_max = False,
                dtick=10,
                output_name_suffix=regret_type,
            )


Minimum Regret Policy Indices:
Scenario: SSP1
  Ethical Framing: UTILITARIAN
    Regret Type: Temperature_Regret, Policy Index: 6
Data is 3D
Shape of data:  (57, 286, 1001)
Shape of data after summing:  (286, 1001)
Data is 3D
Shape of data:  (57, 286, 1001)
Shape of data after summing:  (286, 1001)
Data is 3D
Shape of data:  (57, 286, 1001)
Shape of data after summing:  (286, 1001)
Data is 3D
Shape of data:  (57, 286, 1001)
Shape of data after summing:  (286, 1001)
Data is 3D
Shape of data:  (57, 286, 1001)
Shape of data after summing:  (286, 1001)
    Regret Type: Welfare_Regret, Policy Index: 4
Data is 3D
Shape of data:  (57, 286, 1001)
Shape of data after summing:  (286, 1001)
Data is 3D
Shape of data:  (57, 286, 1001)
Shape of data after summing:  (286, 1001)
Data is 3D
Shape of data:  (57, 286, 1001)
Shape of data after summing:  (286, 1001)
Data is 3D
Shape of data:  (57, 286, 1001)
Shape of data after summing:  (286, 1001)
Data is 3D
Shape of data:  (57, 286, 1001)
Shape of dat

# Visualize the Distribution of Emission Control Rates across SSPs

In [None]:
from justice.util.visualizer import plot_comparison_with_boxplots, plot_choropleth_2D_data
from justice.util.enumerations import WelfareFunction, SSP
import json
import numpy as np
import plotly.express as px
import pandas as pd

variable_name = "constrained_emission_control_rate"

base_path = "data/temporary/NU_DATA/mmBorg/"

# Read the dictionary back # This plots everything

# with open(base_path + "min_regret_policy_indices.json", "r") as f:
#     loaded_min_regret_policy_indices = json.load(f)

# For plotting some, hardcode the dictionary
loaded_min_regret_policy_indices = {

    "SSP2": {
        "UTILITARIAN": {
            # "Temperature_Regret": 25,
            "Welfare_Regret": 9
        },
        "PRIORITARIAN": {
            # "Temperature_Regret": 0,
            "Welfare_Regret": 4
        }
    },

}   

# Print the final dictionary of minimum regret policy indices
print("\nMinimum Regret Policy Indices:")
for scenario, ethical_data in loaded_min_regret_policy_indices.items():
    print(f"Scenario: {scenario}")
    for ethical_framing, regret_data in ethical_data.items():
        print(f"  Ethical Framing: {ethical_framing}")
        for regret_type, policy_index in regret_data.items():
            print(f"    Regret Type: {regret_type}, Policy Index: {policy_index}")


            fig, prior_data = plot_choropleth_2D_data(
                path_to_data=base_path + f"{ethical_framing}_{scenario}/ref_{scenario}_{regret_type}_idx{policy_index}/",
                path_to_output=base_path +"/"+ "plots", #"./data/temporary", #/rbf_dist_test
                projection= "natural earth1", 
                colourmap= px.colors.sequential.Reds,
                year_to_visualize=2050,
                input_data_path_list=[


                    f"{ethical_framing}_ref_{scenario}_{regret_type}_idx{policy_index}_{variable_name}_idx{policy_index}_SSP126_{variable_name}.npy",
                    f"{ethical_framing}_ref_{scenario}_{regret_type}_idx{policy_index}_{variable_name}_idx{policy_index}_SSP245_{variable_name}.npy",
                    f"{ethical_framing}_ref_{scenario}_{regret_type}_idx{policy_index}_{variable_name}_idx{policy_index}_SSP370_{variable_name}.npy",
                    f"{ethical_framing}_ref_{scenario}_{regret_type}_idx{policy_index}_{variable_name}_idx{policy_index}_SSP460_{variable_name}.npy",
                    f"{ethical_framing}_ref_{scenario}_{regret_type}_idx{policy_index}_{variable_name}_idx{policy_index}_SSP534_{variable_name}.npy",


                ],
                    
                data_label="Emission Control Rate",
                legend_label="", 
                data_normalization=True,
                saving=True,
                show_colorbar=False,
                normalized_colorbar=True,
                plot_saving_format="svg",

            )

            fig.show()



Minimum Regret Policy Indices:
Scenario: SSP2
  Ethical Framing: UTILITARIAN
    Regret Type: Welfare_Regret, Policy Index: 9
Taking average over the last dimension.
Taking average over the last dimension.
Taking average over the last dimension.
Taking average over the last dimension.
Taking average over the last dimension.
0
1
2
3
4


  Ethical Framing: PRIORITARIAN
    Regret Type: Welfare_Regret, Policy Index: 4
Taking average over the last dimension.
Taking average over the last dimension.
Taking average over the last dimension.
Taking average over the last dimension.
Taking average over the last dimension.
0
1
2
3
4


# Feature Importance Analysis

In [1]:
from justice.util.feature_importance import build_long_dataframe, run_all_ml_importance

years = (2030, 2040, 2050, 2060, 2070,  2080, 2090, 2100)
long_df = build_long_dataframe(
    base_path="data/temporary/NU_DATA/mmBorg/",
    region_mapping_path="data/input/12_regions.json",
    rice_region_dict_path="data/input/rice50_regions_dict.json",
    years_of_interest=years,
)

print("Long DF shape:", long_df.shape)


# # 2) Run CatBoost + SHAP for mean/median/P90, both global and per-region
# #    Plots are saved in ml_importance_plots/<scope>/<stat>/...
# results = run_all_ml_importance(
#     long_df=long_df,
#     years=(2030, 2050, 2070, 2100),
#     target_stats=("median", "p90"), #("mean", "median", "p90"),
#     output_dir="ml_importance_plots",
#     cv_folds=5,
#     random_state=42,
#     model_params=dict(
#         depth=6,
#         learning_rate=0.05,
#         n_estimators=800,  # upper bound; early stopping finds best < this in CV
#         l2_leaf_reg=3.0,
#         loss_function="RMSE",  # overridden per statistic internally
#         random_seed=42,
#         od_type="Iter",
#         od_wait=50,
#         use_best_model=True,
#         verbose=False,
#         allow_writing_files=False,
#     ),
#     normalized_plots=True,  # set False to see raw mean |SHAP|
#     model_type="final",  # "final" or "cv-mean"
# )


results = run_all_ml_importance(
    long_df=long_df,
    years=years,
    target_stats=("raw",),
    output_dir="ml_importance_plots",
    cv_folds=5,
    random_state=42,
    model_params=dict(
        depth=6,
        learning_rate=0.05,
        n_estimators=800,
        l2_leaf_reg=3.0,
        loss_function="RMSE",
        random_seed=42,
        od_type="Iter",
        od_wait=50,
        use_best_model=True,
        verbose=False,
        allow_writing_files=False,
    ),
    normalized_plots=True,
    model_type="final",
    scope="global", 
)




Long DF shape: (10410400, 9)
Saving plots to: /Users/palokbiswas/Desktop/pollockdevis_git/JUSTICE/ml_importance_plots/global/raw
Saving feature importance data to: /Users/palokbiswas/Desktop/pollockdevis_git/JUSTICE/ml_importance_plots/global/raw/global_2030_shap_full.csv
Saving feature importance data to: /Users/palokbiswas/Desktop/pollockdevis_git/JUSTICE/ml_importance_plots/global/raw/global_2040_shap_full.csv
Saving feature importance data to: /Users/palokbiswas/Desktop/pollockdevis_git/JUSTICE/ml_importance_plots/global/raw/global_2050_shap_full.csv
Saving feature importance data to: /Users/palokbiswas/Desktop/pollockdevis_git/JUSTICE/ml_importance_plots/global/raw/global_2060_shap_full.csv
Saving feature importance data to: /Users/palokbiswas/Desktop/pollockdevis_git/JUSTICE/ml_importance_plots/global/raw/global_2070_shap_full.csv
Saving feature importance data to: /Users/palokbiswas/Desktop/pollockdevis_git/JUSTICE/ml_importance_plots/global/raw/global_2080_shap_full.csv
Saving 

In [8]:
import os
import re
from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import re
from pathlib import Path
from typing import Dict, Iterable, List

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Default feature order and colors if no grouping is provided
FEATURE_ORDER = ["Scenario", "Regret", "Welfare", "Optimization", "Sample"]
FEATURE_COLORS = {
    "Scenario": "#8da0cb",
    "Regret": "#b2e2e2",
    "Welfare": "#66c2a4",
    "Optimization": "#238b45",
    "Sample": "#fc8d62",
}

# Default colors for grouped bars; feel free to tweak
GROUP_COLORS = {
    "Deep Uncertainty": "#8da0cb",
    "Normative Uncertainty": "#238b45",
    "Stochastic Uncertainty": "#fc8d62",
}


def plot_grouped_stacked_feature_importance_from_csvs(
    base_dir,
    scope="global",
    stat="mean",
    model_type="final",
    years=(2030, 2050, 2070, 2100),
    region=None,
    output_file=None,
    normalized=True,
    figsize=(9, 4),
    bar_width=1.0,
    legend_fontsize=9,
    feature_colors=None,
    feature_order=None,
    group_map: Dict[str, List[str]] = None,
    group_colors: Dict[str, str] = None,
):
    """
    Builds a stacked bar chart with one bar per year from saved SHAP CSVs.
    If `group_map` is provided, features are aggregated by group and the legend
    shows entries like "Deep Uncertainty (Scenario)".
    """
    year_order = list(years)
    feature_order = feature_order if feature_order is not None else FEATURE_ORDER
    feature_colors = feature_colors if feature_colors is not None else FEATURE_COLORS

    kind = "shap_full" if model_type == "final" else "shap_cv"
    root = Path(base_dir) / scope.lower() / stat.lower()
    if not root.exists():
        raise FileNotFoundError(f"Directory not found: {root}")

    def read_importance_csv(path: Path):
        if not path.exists():
            return None
        df = pd.read_csv(path)
        if not {"Feature", "Importance"}.issubset(df.columns):
            raise ValueError(f"{path} must contain 'Feature' and 'Importance' columns")
        s = df.set_index("Feature")["Importance"]
        return s.reindex(feature_order, fill_value=0.0)

    sns.set_theme(style="white")

    if group_map:
        group_order = list(group_map.keys())
        label_map = {
            group: f"{group} ({', '.join(group_map[group])})"
            for group in group_order
        }
        group_colors = group_colors if group_colors is not None else GROUP_COLORS

        def aggregate_by_group(df_row):
            row = {"Year": df_row["Year"]}
            for group, feats in group_map.items():
                missing = [f for f in feats if f not in df_row.index]
                if missing:
                    raise KeyError(f"Missing features {missing} required for group '{group}'")
                row[group] = df_row[list(feats)].sum()
            return row

        def transform_df(df_plot):
            records = [aggregate_by_group(row) for _, row in df_plot.iterrows()]
            return pd.DataFrame(records), group_order, group_colors, label_map, group_map
    else:
        def transform_df(df_plot):
            return df_plot, feature_order, feature_colors, None, None

    def plot_df(raw_df, title=None, outfile=None):
        df_plot = raw_df.copy()
        df_plot["Year"] = pd.Categorical(df_plot["Year"], categories=year_order, ordered=True)
        df_plot = df_plot.sort_values("Year")

        df_stack, stack_order, colors_map, legend_labels, legend_features = transform_df(df_plot)

        fig, ax = plt.subplots(figsize=figsize)
        x_pos = np.arange(len(df_stack))
        bottoms = np.zeros(len(df_stack))

        for key in stack_order:
            values = df_stack[key].to_numpy()
            ax.bar(
                x_pos,
                values,
                width=bar_width,
                bottom=bottoms,
                color=colors_map.get(key, "#999999"),
                label=legend_labels[key] if legend_labels else key,
                align="center",
            )
            bottoms += values

        ax.set_xticks(x_pos)
        ax.set_xticklabels([str(y) for y in df_stack["Year"]])
        ax.set_xlim(-0.5, len(df_stack) - 0.5)
        ax.margins(x=0)
        sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
        ax.set_xlabel("")
        ax.set_ylabel("Importance" + (" (normalized)" if normalized else ""))
        if title:
            ax.set_title(title)

        handles, labels = ax.get_legend_handles_labels()
        unique = dict(zip(labels, handles))
        ordered_handles = [unique[lbl] for lbl in labels if lbl in unique]
        ax.legend(
            ordered_handles,
            [lbl for lbl in labels if lbl in unique],
            frameon=False,
            fontsize=legend_fontsize,
            ncol=1,
            loc="upper left",
            bbox_to_anchor=(1.02, 1.0),
            borderaxespad=0.0,
        )

        if outfile:
            Path(outfile).parent.mkdir(parents=True, exist_ok=True)
            fig.savefig(outfile, dpi=300, bbox_inches="tight")
            plt.close(fig)
        else:
            plt.show()

        return fig

    if scope.lower() == "global":
        rows = []
        for yr in year_order:
            fpath = root / f"global_{yr}_{kind}.csv"
            s = read_importance_csv(fpath)
            if s is None:
                continue
            rows.append({"Year": yr, **{feat: float(s[feat]) for feat in feature_order}})

        if not rows:
            raise FileNotFoundError(f"No CSVs found for scope=global, stat={stat}, kind={kind} in {root}")

        df_plot = pd.DataFrame(rows)
        fig = plot_df(df_plot, outfile=output_file)
        return {"data": df_plot, "figure": fig}

    pattern = re.compile(rf"^(?P<region>.+)_(?P<year>\d{{4}})_{kind}\.csv$")
    files = [p for p in root.glob("*.csv") if p.is_file()]
    region_set = set()
    for p in files:
        m = pattern.match(p.name)
        if not m:
            continue
        yy = int(m.group("year"))
        if yy in years:
            region_set.add(m.group("region"))

    region_list = [region] if region else sorted(region_set)
    if not region_list:
        raise FileNotFoundError(f"No regional CSVs found for stat={stat}, kind={kind} in {root}")

    figs = {}
    all_rows = []

    for rgn in region_list:
        rows = []
        for yr in year_order:
            fpath = root / f"{rgn}_{yr}_{kind}.csv"
            s = read_importance_csv(fpath)
            if s is None:
                continue
            rows.append({
                "Region": rgn,
                "Year": yr,
                **{feat: float(s[feat]) for feat in feature_order},
            })
        if not rows:
            continue

        df_plot = pd.DataFrame(rows)
        title = rgn.replace("_", " ")
        out = None
        if output_file:
            outpath = Path(output_file)
            out = str(Path(outpath.parent) / f"{outpath.stem}_{rgn}{outpath.suffix}")

        fig = plot_df(df_plot, title=title, outfile=out)
        figs[rgn] = fig
        all_rows.append(df_plot)

    df_all = pd.concat(all_rows, ignore_index=True) if all_rows else pd.DataFrame()
    return {"data": df_all, "figure": figs}


def render_all_grouped_stacked_charts(
    base_dir,
    scope="global",
    stat="mean",
    model_type="final",
    years=(2030, 2050, 2070, 2100),
    output_dir=None,
    normalized=True,
    figsize=(9, 6),
    legend_fontsize=9,
    bar_width=1.0,
    group_map: Dict[str, List[str]] = None,
):
    if output_dir is not None:
        Path(output_dir).mkdir(parents=True, exist_ok=True)

    kwargs = dict(
        base_dir=base_dir,
        stat=stat,
        model_type=model_type,
        years=years,
        output_file=None if output_dir is None else "",
        normalized=normalized,
        figsize=figsize,
        legend_fontsize=legend_fontsize,
        bar_width=bar_width,
        group_map=group_map,
    )
    if scope.lower() == "global":
        outfile = None if output_dir is None else str(Path(output_dir) / f"global_{stat}_{model_type}_stacked.png")
        kwargs["scope"] = "global"
        kwargs["output_file"] = outfile
        return plot_grouped_stacked_feature_importance_from_csvs(**kwargs)
    else:
        outfile = None if output_dir is None else str(Path(output_dir) / f"regional_{stat}_{model_type}_stacked.png")
        kwargs["scope"] = "regional"
        kwargs["output_file"] = outfile
        return plot_grouped_stacked_feature_importance_from_csvs(**kwargs)
# ------------------------------------------------------------------
# Example usage:
# base_dir = "ml_importance_plots"
# render_all_grouped_stacked_charts(
#     base_dir,
#     scope="global",
#     stat="raw",
#     model_type="final",
#     years= (2030, 2040, 2050, 2060, 2070,  2080, 2090, 2100), #(2030, 2050, 2070, 2100),
#     output_dir="figs",
# )
# render_all_grouped_stacked_charts(base_dir, scope="regional", stat="raw", model_type="cv-mean", years=(2030,2050,2070,2100), output_dir="figs")



GROUP_MAP = {
    "Deep Uncertainty": ["Scenario"],
    "Normative Uncertainty": ["Optimization", "Welfare", "Regret"],
    "Stochastic Uncertainty": ["Sample"],
}

render_all_grouped_stacked_charts(
    base_dir="ml_importance_plots",
    scope="regional",
    stat="raw",
    model_type="final",
    years=(2030, 2050, 2070, 2100),#(2030, 2040, 2050, 2060, 2070,  2080, 2090, 2100), #(2030, 2050, 2070, 2100),
    group_map=GROUP_MAP,
    output_dir="figs"
)

{'data':                 Region  Year  Scenario    Regret   Welfare  Optimization  \
 0               Brazil  2030  0.795486  0.017992  0.091440      0.086193   
 1               Brazil  2050  0.415365  0.170699  0.242661      0.143869   
 2               Brazil  2070  0.565380  0.112725  0.156480      0.144586   
 3               Brazil  2100  0.654396  0.081853  0.104574      0.146754   
 4                China  2030  0.772423  0.010931  0.056289      0.155265   
 5                China  2050  0.455388  0.156483  0.248937      0.123719   
 6                China  2070  0.648436  0.108258  0.159317      0.066578   
 7                China  2100  0.744043  0.070398  0.083458      0.092328   
 8               Europe  2030  0.391870  0.103696  0.193401      0.222827   
 9               Europe  2050  0.300765  0.218910  0.233170      0.213810   
 10              Europe  2070  0.398782  0.223795  0.249301      0.095662   
 11              Europe  2100  0.632265  0.110725  0.139672      0.0

In [2]:
long_df.head()

Unnamed: 0,Optimization,Regret,Scenario,Welfare,Region,Year,Sample,AbatedEmission,Scope
0,SSP1,Temperature_Regret,SSP126,UTILITARIAN,Rest of the World,2030,0,0.447325,Regional
1,SSP1,Temperature_Regret,SSP126,UTILITARIAN,Rest of the World,2030,1,0.447762,Regional
2,SSP1,Temperature_Regret,SSP126,UTILITARIAN,Rest of the World,2030,2,0.449517,Regional
3,SSP1,Temperature_Regret,SSP126,UTILITARIAN,Rest of the World,2030,3,0.447558,Regional
4,SSP1,Temperature_Regret,SSP126,UTILITARIAN,Rest of the World,2030,4,0.448138,Regional


# Ternary and Choropleth Maps

In [5]:
import json
import re
from pathlib import Path
from typing import Iterable

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
import matplotlib.colors as mcolors
import seaborn as sns
import plotly.graph_objects as go

# =============================================================================
# 0. Constants & Mappings
# =============================================================================
FEATURE_ORDER = ["Scenario", "Regret", "Welfare", "Optimization", "Sample"]
UNCERTAINTY_GROUPS = {
    "Deep": ["Scenario"],
    "Normative": ["Regret", "Welfare", "Optimization"],
    "Stochastic": ["Sample"],
}
GROUP_ORDER = ["Normative", "Deep", "Stochastic"]
BASE_TERNARY_COLORS = np.array([
    [1.0, 0.0, 1.0],  # Normative → magenta
    [1.0, 1.0, 0.0],  # Deep      → yellow
    [0.0, 1.0, 1.0],  # Stochastic→ cyan
])
SQRT3 = np.sqrt(3)

# =============================================================================
# 1. Load SHAP CSVs (regional, stat="raw") and compute N/D/S shares
# =============================================================================
def load_regional_uncertainty_shares(
    base_dir: str,
    stat: str = "raw",
    model_type: str = "final",
    years: Iterable[int] = (2030, 2050, 2070, 2100),
) -> pd.DataFrame:
    base = Path(base_dir) / "regional" / stat.lower()
    if not base.exists():
        raise FileNotFoundError(f"Directory not found: {base}")

    kind = "shap_full" if model_type == "final" else "shap_cv"
    pattern = re.compile(rf"^(?P<region>.+)_(?P<year>\d{{4}})_{kind}\.csv$")

    records = []
    for csv_path in base.glob("*.csv"):
        m = pattern.match(csv_path.name)
        if not m:
            continue

        region_slug = m.group("region")
        year = int(m.group("year"))
        if year not in years:
            continue

        df = pd.read_csv(csv_path)
        if not {"Feature", "Importance"}.issubset(df.columns):
            print(f"[warn] '{csv_path}' missing required columns. Skipping.")
            continue

        s = df.set_index("Feature")["Importance"]
        s = s.reindex(FEATURE_ORDER, fill_value=0.0)
        total = s.sum()
        if total <= 0:
            continue

        shares = s / total
        records.append({
            "Region": region_slug,
            "Year": year,
            "Normative": float(shares[["Regret", "Welfare", "Optimization"]].sum()),
            "Deep": float(shares["Scenario"]),
            "Stochastic": float(shares["Sample"]),
            "Scenario": float(shares["Scenario"]),
            "Regret": float(shares["Regret"]),
            "Welfare": float(shares["Welfare"]),
            "Optimization": float(shares["Optimization"]),
            "Sample": float(shares["Sample"]),
        })

    shares_df = pd.DataFrame(records)
    if shares_df.empty:
        raise ValueError("No valid regional CSVs found. Check paths/stat/model_type.")
    return shares_df

# =============================================================================
# 2. Ternary background + color mixing
# =============================================================================
def mix_color(normative, deep, stochastic, base_colors=BASE_TERNARY_COLORS, as_hex=True):
    weights = np.array([normative, deep, stochastic], dtype=float)
    total = weights.sum()
    if total <= 0:
        raise ValueError("Normative + Deep + Stochastic must be positive.")
    weights /= total
    rgb = np.clip(weights @ base_colors, 0, 1)
    return mcolors.to_hex(rgb) if as_hex else rgb


def quantize_simplex(n, d, s, scale=8):
    """Snap (n, d, s) onto a discrete ternary lattice with step size 1/scale."""
    weights = np.array([n, d, s], dtype=float)
    total = weights.sum()
    if total <= 0:
        raise ValueError("Normative + Deep + Stochastic must be positive.")
    weights /= total

    scaled = weights * scale
    base = np.floor(scaled)
    remainder = scale - int(base.sum())

    if remainder > 0:
        frac = scaled - base
        for idx in np.argsort(-frac):
            if remainder == 0:
                break
            base[idx] += 1
            remainder -= 1
    elif remainder < 0:
        frac = scaled - base
        for idx in np.argsort(frac):
            if remainder == 0:
                break
            base[idx] -= 1
            remainder += 1

    snapped = base / scale
    snapped /= snapped.sum()
    return snapped


def barycentric_to_cartesian(normative, deep, stochastic):
    x = stochastic + 0.5 * normative
    y = (SQRT3 / 2.0) * normative
    return x, y


def build_triangular_mesh(scale):
    bary_coords = []
    cart_coords = []
    idx_lookup = {}
    idx = 0
    for i in range(scale + 1):
        for j in range(scale + 1 - i):
            k = scale - i - j
            n = i / scale
            d = j / scale
            s = k / scale
            bary_coords.append(np.array([n, d, s]))
            cart_coords.append(barycentric_to_cartesian(n, d, s))
            idx_lookup[(i, j)] = idx
            idx += 1

    triangles = []
    for i in range(scale):
        for j in range(scale - i):
            p0 = idx_lookup[(i, j)]
            p1 = idx_lookup[(i + 1, j)]
            p2 = idx_lookup[(i, j + 1)]
            triangles.append((p0, p1, p2))
            if i + j < scale - 1:
                p3 = idx_lookup[(i + 1, j + 1)]
                triangles.append((p1, p3, p2))
    return bary_coords, cart_coords, triangles


def draw_ternary_background(scale=8, base_colors=BASE_TERNARY_COLORS, ax=None, annotate=False):
    bary_coords, cart_coords, triangles = build_triangular_mesh(scale)

    if ax is None:
        fig, ax = plt.subplots(figsize=(6.5, 6))
    else:
        fig = ax.figure

    patches = []
    colors = []
    for tri in triangles:
        verts = [cart_coords[idx] for idx in tri]
        centroid = np.mean([bary_coords[idx] for idx in tri], axis=0)
        face_color = mix_color(*centroid, base_colors=base_colors, as_hex=False)
        patches.append(Polygon(verts))
        colors.append(face_color)

        if annotate:
            cx, cy = barycentric_to_cartesian(*centroid)
            ax.text(cx, cy,
                    f"{centroid[0]:.2f}\n{centroid[1]:.2f}\n{centroid[2]:.2f}",
                    ha="center", va="center", fontsize=6, color="black")

    pcoll = PatchCollection(patches, facecolors=colors, edgecolors="k", linewidths=0.3)
    ax.add_collection(pcoll)

    boundary = np.array([
        barycentric_to_cartesian(0, 0, 1),
        barycentric_to_cartesian(0, 1, 0),
        barycentric_to_cartesian(1, 0, 0),
        barycentric_to_cartesian(0, 0, 1)
    ])
    ax.plot(boundary[:, 0], boundary[:, 1], color="black", linewidth=1.25)

    for i in range(1, scale):
        t = i / scale
        p1 = barycentric_to_cartesian(t, 0, 1 - t)
        p2 = barycentric_to_cartesian(t, 1 - t, 0)
        ax.plot([p1[0], p2[0]], [p1[1], p2[1]], color="white", alpha=0.6, linewidth=0.8)
        p1 = barycentric_to_cartesian(0, t, 1 - t)
        p2 = barycentric_to_cartesian(1 - t, t, 0)
        ax.plot([p1[0], p2[0]], [p1[1], p2[1]], color="white", alpha=0.6, linewidth=0.8)
        p1 = barycentric_to_cartesian(0, 1 - t, t)
        p2 = barycentric_to_cartesian(1 - t, 0, t)
        ax.plot([p1[0], p2[0]], [p1[1], p2[1]], color="white", alpha=0.6, linewidth=0.8)

    ax.text(0.5, SQRT3 / 2 + 0.04, "Normative", ha="center", va="bottom", fontsize=12)
    ax.text(-0.04, -0.03, "Deep", ha="right", va="top", fontsize=12)
    ax.text(1.04, -0.03, "Stochastic", ha="left", va="top", fontsize=12)

    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, SQRT3 / 2 + 0.08)
    ax.set_aspect("equal")
    ax.axis("off")
    return fig, ax


def add_regions_to_ternary(
    ax,
    df_year: pd.DataFrame,
    scale: int = 8,
    quantize: bool = True,
    annotate: bool = False,
    marker_size: float = 18,
    jitter_strength: float = 0.02,
    random_state: int = 0,
):
    rng = np.random.default_rng(random_state) if jitter_strength > 0 else None

    for _, row in df_year.iterrows():
        n, d, s = row["Normative"], row["Deep"], row["Stochastic"]
        if quantize:
            n, d, s = quantize_simplex(n, d, s, scale=scale)

        weights = np.array([n, d, s], dtype=float)
        if jitter_strength > 0:
            noise = rng.normal(scale=jitter_strength, size=3)
            weights = np.clip(weights + noise, 1e-6, None)
            weights /= weights.sum()

        x, y = barycentric_to_cartesian(*weights)
        ax.scatter(
            x,
            y,
            s=marker_size,
            marker=".",
            color="black",
            edgecolor="none",
        )

        if annotate:
            label = row["Region"].replace("_", " ")
            ax.text(x, y, label, fontsize=4, ha="left", va="top") # ha and va mean horizontalalignment and verticalalignment


# =============================================================================
# 3. Choropleth helpers
# =============================================================================
def hex_to_rgb(hex_color: str) -> str:
    hex_color = hex_color.lstrip("#")
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)
    return f"rgb({r},{g},{b})"


def build_choropleth(
    df_year: pd.DataFrame,
    region_to_iso_path: str,
    scale: int = 8,
    quantize: bool = True,
    projection_type: str = "equal earth",
) -> go.Figure:
    with open(region_to_iso_path, "r", encoding="utf-8") as f:
        region_to_iso = json.load(f)

    rows = []
    missing_regions = []

    for _, row in df_year.iterrows():
        label = row["Region"].replace("_", " ")
        if label not in region_to_iso:
            missing_regions.append(label)
            continue

        n, d, s = row["Normative"], row["Deep"], row["Stochastic"]
        if quantize:
            n, d, s = quantize_simplex(n, d, s, scale=scale)
        color_hex = mix_color(n, d, s, as_hex=True)

        for iso3 in region_to_iso[label]:
            if iso3 == "ATA":
                continue  # skip Antarctica
            rows.append({
                "iso_a3": iso3,
                "macro_region": label,
                "Normative": n,
                "Deep": d,
                "Stochastic": s,
                "color_hex": color_hex,
            })

    if missing_regions:
        print("[warn] Regions missing in JSON:", ", ".join(missing_regions))

    map_df = pd.DataFrame(rows)
    if map_df.empty:
        raise ValueError("No regions found for this year that match the JSON mapping.")

    unique_regions = map_df["macro_region"].unique()
    region_to_idx = {reg: idx for idx, reg in enumerate(unique_regions)}
    map_df["region_idx"] = map_df["macro_region"].map(region_to_idx)

    n_regions = len(unique_regions)
    colorscale = []
    for region in unique_regions:
        idx = region_to_idx[region]
        start = idx / n_regions
        end = (idx + 1) / n_regions
        color_rgb = hex_to_rgb(map_df.loc[map_df["macro_region"] == region, "color_hex"].iloc[0])
        colorscale.append([start, color_rgb])
        colorscale.append([end, color_rgb])

    choropleth = go.Choropleth(
        locations=map_df["iso_a3"],
        z=map_df["region_idx"],
        text=map_df["macro_region"],
        hovertemplate="<b>%{text}</b><extra></extra>",
        colorscale=colorscale,
        showscale=False,
        marker=dict(line=dict(color="rgba(255,255,255,0.7)", width=0.4)),
    )

    fig = go.Figure(data=choropleth)
    fig.update_layout(
        showlegend=False,
        paper_bgcolor="#f8f8f8",
        plot_bgcolor="#f8f8f8",
        margin=dict(l=0, r=0, t=0, b=0),
        geo=dict(
            projection=dict(type=projection_type),
            showframe=False,
            showcoastlines=False,
            bgcolor="#f8f8f8",
            landcolor="#f8f8f8",
        ),
    )
    return fig


# =============================================================================
# 4. Master routine
# =============================================================================
def generate_uncertainty_visualizations(
    base_dir: str,
    region_mapping_path: str,
    stat: str = "raw",
    model_type: str = "final",
    years: Iterable[int] = (2030, 2050, 2070, 2100),
    ternary_scale: int = 8,
    quantize: bool = True,
    annotate_points: bool = False,
    marker_size: float = 18,
    jitter_strength: float = 0.02,
    random_state: int = 0,
    output_dir: str = "fig_ternary_choropleth",
):
    shares_df = load_regional_uncertainty_shares(
        base_dir=base_dir,
        stat=stat,
        model_type=model_type,
        years=years,
    )

    output_base = Path(output_dir)
    output_base.mkdir(parents=True, exist_ok=True)

    sns.set_theme(style="white")

    results = {}
    last_map = None
    for yr in years:
        df_year = shares_df[shares_df["Year"] == yr]
        if df_year.empty:
            print(f"[warn] No data for year {yr}. Skipping.")
            continue

        fig_tern, ax_tern = draw_ternary_background(scale=ternary_scale)
        add_regions_to_ternary(
            ax_tern,
            df_year,
            scale=ternary_scale,
            quantize=quantize,
            annotate=annotate_points,
            marker_size=marker_size,
            jitter_strength=jitter_strength,
            random_state=random_state,
        )
        ax_tern.set_title(f"Regional Uncertainty Composition — {yr}")
        ternary_path = output_base / f"ternary_{yr}.png"
        fig_tern.savefig(ternary_path, dpi=300, bbox_inches="tight")
        plt.close(fig_tern)

        fig_map = build_choropleth(
            df_year,
            region_to_iso_path=region_mapping_path,
            scale=ternary_scale,
            quantize=quantize,
        )
        svg_map_path = output_base / f"choropleth_{yr}.svg"
        fig_map.write_image(str(svg_map_path), format="svg", width=800, height=600)

        results[yr] = {"ternary": ternary_path, "choropleth": svg_map_path}
        last_map = fig_map

    return last_map, results


# =============================================================================
# 5. Example usage
# =============================================================================
if __name__ == "__main__":
    base_dir = "ml_importance_plots"
    region_mapping_path = "data/input/12_regions.json"
    fig_map, results = generate_uncertainty_visualizations(
        base_dir=base_dir,
        region_mapping_path=region_mapping_path,
        stat="raw",
        model_type="final",
        years=(2030, 2050, 2070, 2100),
        ternary_scale=8,
        quantize=True,
        annotate_points=True,
        marker_size=18,
        jitter_strength=0.02,
        random_state=0,
        output_dir="figs",
    )
    print("Saved figures:", results)
    fig_map.show()

Saved figures: {2030: {'ternary': PosixPath('figs/ternary_2030.png'), 'choropleth': PosixPath('figs/choropleth_2030.svg')}, 2050: {'ternary': PosixPath('figs/ternary_2050.png'), 'choropleth': PosixPath('figs/choropleth_2050.svg')}, 2070: {'ternary': PosixPath('figs/ternary_2070.png'), 'choropleth': PosixPath('figs/choropleth_2070.svg')}, 2100: {'ternary': PosixPath('figs/ternary_2100.png'), 'choropleth': PosixPath('figs/choropleth_2100.svg')}}
