In [None]:
import os 

import pandas as pd
import geopandas as gpd
from energyemissionsregio.config import DATA_PATH, SHP_PATH, units
from energyemissionsregio.disaggregation import perform_proxy_based_disaggregation
from energyemissionsregio.plotting_functions import plot_proxy_data, plot_target_data

In [None]:
cwd = os.getcwd()

In [None]:
lau_shp = gpd.read_file(os.path.join(SHP_PATH, "LAU.shp"))
nuts2_shp = gpd.read_file(os.path.join(SHP_PATH, "NUTS2.shp"))

In [None]:
lau_shp = lau_shp[lau_shp["code"].str.startswith(("DE", "ES"))]
nuts2_shp = nuts2_shp[nuts2_shp["code"].str.startswith(("DE", "ES"))]

In [None]:
var_proxy_details = [
    {"target_var": "number_of_motorcycles",
      "proxy": "road_network",
      "proxy_confidence_level": 2}, 

    {"target_var": "air_transport_of_freight",
     "proxy": "airports_cover",
    "proxy_confidence_level": 4},

 {"target_var": "air_transport_of_passengers",
  "proxy": "airports_cover",
  "proxy_confidence_level": 4},]

In [None]:
for proxy_detail_dict in var_proxy_details: 
    target_var = proxy_detail_dict["target_var"]
    proxy_var = proxy_detail_dict["proxy"]
    proxy_confidence_level = proxy_detail_dict["proxy_confidence_level"]

    print(target_var)

    # Fig paths 
    fig_path = os.path.join("..", "..", "figures", "disaggregation", "NUTS2", target_var)
    os.makedirs(fig_path, exist_ok=True)

    # read in target data 
    if os.path.exists(os.path.join(cwd, "..", "..", "data", "imputed_data", f"{target_var}.csv")):
        target_data = pd.read_csv(os.path.join(cwd, "..", "..", "data", "imputed_data", f"{target_var}.csv"))
    else:
        target_data = pd.read_csv(os.path.join(DATA_PATH, f"{target_var}.csv"))

    target_data["value_confidence_level"] = 5 # VERY HIGH

    # read in proxy data 
    if os.path.exists(os.path.join(cwd, "..", "..", "data", "imputed_data", f"{proxy_var}.csv")):
        proxy_data = pd.read_csv(os.path.join(cwd, "..", "..", 
                                              "data", 
                                              "imputed_data", 
                                              f"{proxy_var}.csv")) # has already assigned value_confidence_level (from data imputation stage)
    else:
        proxy_data = pd.read_csv(os.path.join(DATA_PATH, f"{proxy_var}.csv"))
        proxy_data["value_confidence_level"] = 5 # VERY HIGH because no missing values

    target_data = target_data[target_data["region_code"].str.startswith(("DE", "ES"))][["region_code", 
                                                                                        "value", 
                                                                                        "value_confidence_level"]].copy()
    
    proxy_data = proxy_data[proxy_data["region_code"].str.startswith(("DE", "ES"))][["region_code", 
                                                                                     "value", 
                                                                                     "value_confidence_level"]].copy()

    proxy_data["value"] = proxy_data["value"].fillna(0)

    ## plot
    proxy_var_unit = units[proxy_var]
    save_path = os.path.join(fig_path, f"{proxy_var}.png")
    plot_proxy_data(proxy_data, lau_shp, proxy_var_unit, save_path)

    target_var_unit = units[target_var]
    round_to_int = True if target_var_unit == "number" else False
    
    disagg_data = perform_proxy_based_disaggregation(target_data, proxy_data, "NUTS2", proxy_confidence_level, round_to_int)

    disagg_data.to_csv(os.path.join(cwd, "..", "..", "data", "disaggregated_data", f"{target_var}.csv"), index=False)

    ## plot
    save_path = os.path.join(fig_path, f"{target_var}.png")
    plot_target_data(target_data, disagg_data, nuts2_shp, lau_shp, target_var_unit, save_path)
