In [None]:
import os 

import pandas as pd
import geopandas as gpd
from energyemissionsregio.config import DATA_PATH, SHP_PATH
from energyemissionsregio.utils import solve_proxy_equation, get_proxy_var_list
from energyemissionsregio.disaggregation import perform_proxy_based_disaggregation
from energyemissionsregio.plotting_functions import plot_validation_data
from sklearn.metrics import mean_squared_error

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cwd = os.getcwd()

In [None]:
lau_shp = gpd.read_file(os.path.join(SHP_PATH, "LAU.shp"))
lau_shp = lau_shp[lau_shp["code"].str.startswith(("DE", "ES"))]

### Commerce FEC (Eurostat)

In [None]:
eurostat_data_nuts0 = pd.read_csv(os.path.join(DATA_PATH, f"final_energy_consumption_in_commerce.csv"))
eurostat_data_nuts0 = eurostat_data_nuts0[eurostat_data_nuts0["region_code"].str.startswith(("DE", "ES"))][
    ["region_code", "value"]].copy()

In [None]:
eurostat_data_nuts0

### Hotmaps data

In [None]:
validation_data = pd.read_csv(
        os.path.join(DATA_PATH, "heat_demand_non_residential.csv")
    )
validation_data = validation_data[validation_data["region_code"].str.startswith(("DE", "ES"))][
    ["region_code", "value"]].copy()

In [None]:
target_data = validation_data.copy()
target_data["region_code"] = target_data["region_code"].str[:2]

target_data = target_data.groupby("region_code").sum().reset_index()

In [None]:
target_data["value_confidence_level"] = 5

difference between the Eurostat data disaggregated and hotmaps data used for validation of disaggregation 

In [None]:
diff_df = pd.merge(eurostat_data_nuts0, target_data, on="region_code")

In [None]:
diff_df["diff"] = diff_df["value_x"] - diff_df["value_y"]

In [None]:
diff_df

### Disaggregation of Hotmaps data

In [None]:
proxy_equations = {"DE": "de_non_residential_building_living_area*cproj_annual_mean_temperature_heating_degree_days",
                   "ES": "es_number_of_commerical_and_service_companies*cproj_annual_mean_temperature_heating_degree_days"}

In [None]:
disagg_data_list = []

for country in ["DE", "ES"]:
    sub_target_data = target_data[target_data["region_code"] == country].copy()

    proxy_equation = proxy_equations[country]

    proxy_var_list = get_proxy_var_list(proxy_equation)

    proxy_data_dict = {}
    for proxy_var in proxy_var_list:
        if os.path.exists(os.path.join(cwd, "..", "..", "data", "disaggregated_data", f"{proxy_var}.csv")):
            proxy_data = pd.read_csv(os.path.join(cwd, "..", "..", "data", "disaggregated_data", f"{proxy_var}.csv"))
        else:
            proxy_data = pd.read_csv(os.path.join(DATA_PATH, f"{proxy_var}.csv"))
            proxy_data["value_confidence_level"] = 5

        proxy_data = proxy_data[proxy_data["region_code"].str.startswith(country)][["region_code", 
                                                                                        "value",
                                                                                        "value_confidence_level"]].copy()

        proxy_data["value"] = proxy_data["value"].fillna(0)
        proxy_data_dict.update({proxy_var: proxy_data})

    solved_proxy_data = solve_proxy_equation(proxy_equation, proxy_data_dict)

    disagg_data = perform_proxy_based_disaggregation(sub_target_data, solved_proxy_data, "NUTS0", 4)

    disagg_data_list.append(disagg_data)

In [None]:
disagg_data = pd.concat(disagg_data_list)

In [None]:
# calulate MAE and country total -------------
merged_df_mae = pd.merge(validation_data, disagg_data, on = "region_code", how="outer", suffixes=("_true", "_disagg"))

true_values_de = merged_df_mae[merged_df_mae["region_code"].str.startswith("DE")]["value_true"]
disagg_values_de = merged_df_mae[merged_df_mae["region_code"].str.startswith("DE")]["value_disagg"]

true_values_es = merged_df_mae[merged_df_mae["region_code"].str.startswith("ES")]["value_true"]
disagg_values_es = merged_df_mae[merged_df_mae["region_code"].str.startswith("ES")]["value_disagg"]

rmse_de = mean_squared_error(true_values_de, disagg_values_de, squared=False).round(2)
rmse_es = mean_squared_error(true_values_es, disagg_values_es, squared=False).round(2)


In [None]:
# de_total = "223.99e6"
# es_total = "40.83e6"
de_total = true_values_de.sum().round(2)
es_total = true_values_es.sum().round(2)

In [None]:
fig_path = os.path.join("..", "..", "figures", "disaggregation_validation", "validation_commerce_fec.png")


plot_validation_data(validation_data, disagg_data, 
                     lau_shp, de_total, es_total, 
                     rmse_de, rmse_es, "MWh", "Hotmaps", "log", fig_path)

In [None]:
merge_df = pd.merge(disagg_data, validation_data, on="region_code", suffixes=["_disagg", "_validation"])

In [None]:
merge_df["diff"] = abs(merge_df["value_disagg"] - merge_df["value_validation"])

In [None]:
merge_df[merge_df["region_code"].str.startswith("DE")].sort_values("diff")

In [None]:
merge_df[merge_df["region_code"].str.startswith("ES")].sort_values("diff")