In [1]:
from zoomin.data.osmtags import stations_tags_dict
from zoomin.data.constants import countries_dict
from zoomin.data import osm_stations_processing
from typing import Any
from shapely import wkt
from zoomin.data.osm_stations_processing import count_point_on_polygon, count_point_on_polygon_eu
import os
import pandas as pd
import geopandas as gpd
import osmnx as ox
import plotly.express as px
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
cwd = os.getcwd()
DATA_PATH = os.path.join(cwd, '..', '..', '..', 'data', 'input')
RAW_DATA_PATH = os.path.join(DATA_PATH, 'raw')
PROCESSED_DATA_PATH = os.path.join(DATA_PATH, 'processed')

In [None]:
def setup_polygon_for_point(territorial_unit: Any, country_tag: str) -> gpd.GeoDataFrame:
    """Get polygons geodataframe for each country at a territoriial unit."""
    polygon_shp_path = os.path.join(
        PROCESSED_DATA_PATH, "shapefiles", f"{territorial_unit}.shp"
    )
    polygon_gdf = gpd.read_file(polygon_shp_path, converters={'region_code': str})
    polygon_gdf = polygon_gdf[polygon_gdf["prnt_code"].str.contains(f"{country_tag}")]
    polygon_gdf.drop(
        [
            col
            for col in polygon_gdf.columns
            if "geometry" not in col and "code" not in col
        ],
        axis=1,
        inplace=True,
    )
    polygon_gdf.drop(
        [col for col in polygon_gdf.columns if col.startswith("prnt")],
        axis=1,
        inplace=True,
    )
    polygon_gdf.rename(columns={"code": "region_code"}, inplace=True)
    polygon_gdf.reset_index(drop=True, inplace=True)
    print(f'The number of polygon at {territorial_unit} levelfor {country_tag} are: ', len(polygon_gdf))
    return polygon_gdf

In [None]:
def get_point(component_name: str, territorial_unit: Any, country_tag: str) -> pd.DataFrame:
    
    """Get point geodataframe for each component_name in each country at a territoriial unit."""
    point_df_path_source = os.path.join(
        PROCESSED_DATA_PATH,
        "osm_data",
        "countries",
        f"{country_tag}",
        f"{component_name}_OverlapDf_{territorial_unit}.csv",
    )
    if os.path.exists(point_df_path_source):
        point_df = pd.read_csv(point_df_path_source, converters={'region_code': str})
        point_df.drop(
        [
            col
            for col in point_df.columns
            if "region_code" not in col and "value" not in col
        ],
        axis=1,
        inplace=True,
    )
    print(
        f'The total number of "{component_name}" in "{country_tag}" are: ',
        len(point_df),
    )
    return point_df
    

In [None]:
def merge_polygon_point(component_name,  territorial_unit, country_tag):
    
    """Overlap gridded data with polygon data."""
    overlap_gdf_path_destination = os.path.join(
        PROCESSED_DATA_PATH,
        "osm_data",
        "countries",
        f"{country_tag}",
        f"{component_name}_MergedGdf_{territorial_unit}.csv",
    )
    if not os.path.exists(overlap_gdf_path_destination):
        polygon_gdf = setup_polygon_for_point(territorial_unit, country_tag)
        if polygon_gdf.crs != 4326:
                polygon_gdf = polygon_gdf.to_crs(epsg=4326)
        point_df = get_point(component_name, territorial_unit, country_tag)
        # joining the geodataframe with the cleaned up csv dataframe
        merged_gdf = polygon_gdf.set_index('region_code').join(point_df.set_index('region_code'))
        print(
            f'The total number of "{component_name}" merged points in "{country_tag}" are: ',
            len(merged_gdf),
        )
        # replace nan values woth zero
        merged_gdf["value"] = merged_gdf["value"].fillna(0)
        #.head() returns the top 5(by default ) lines of the dataframe
        # merged_gdf.to_csv(overlap_gdf_path_destination)
        merged_gdf.sample(5)
    else:
        merged_df = pd.read_csv(overlap_gdf_path_destination, converters={'region_code': str})
        merged_df["geometry"] = gpd.GeoSeries.from_wkt(merged_df["geometry"])
        merged_gdf = gpd.GeoDataFrame(merged_df, geometry="geometry")
    return merged_gdf
    

In [None]:
def heat_map_plotting(component_name, territorial_unit, country_tag):
    
    # https://www.geeksforgeeks.org/plotting-geospatial-data-using-geopandas/
    merged_gdf = merge_polygon_point(component_name,  territorial_unit, country_tag)
    fig, ax = plt.subplots(1, figsize =(16, 12),
                        facecolor ='lightblue')
    
    merged_gdf.plot(ax = ax, color ='black')
    merged_gdf.plot(ax = ax, column ='value', cmap ='Reds',
            edgecolors ='grey')
    
    # axis for the color bar
    div = make_axes_locatable(ax)
    cax = div.append_axes("right", size ="2 %", pad = 0.05)
    
    # color bar
    vmax = merged_gdf.value.max()
    mappable = plt.cm.ScalarMappable(cmap ='Reds',
                                    norm = plt.Normalize(vmin = 0, vmax = vmax))
    cbar = fig.colorbar(mappable, cax)

    plt.title(f"{component_name} - {country_tag} - {territorial_unit} level", loc='right', fontweight='bold')  
    ax.axis('off')
    return plt.show()

In [None]:
territorial_unit = input(
        'Please enter a character from: LAU, NUTS3, NUTS2, NUTS1, NUTS0, Europe')

In [None]:
# for component_name in stations_tags_dict.keys():
#         print(component_name)
#         country_name_list = []
#         merged_list = []
#         for country_name, country_tag in countries_dict.items():
#             point_df_path_source = os.path.join(
#                 PROCESSED_DATA_PATH,
#                 "osm_data",
#                 "countries",
#                 f"{country_tag}",
#                 f"{component_name}_OverlapDf_{territorial_unit}.csv",
#             )
#             if os.path.exists(point_df_path_source):
#                 # destination_path = os.path.join(
#                 # PROCESSED_DATA_PATH,
#                 # "osm_data",
#                 # "countries",
#                 # "plotting",
#                 # f"{component_name}_eu_level.csv",
#                 # )
#                 # merged_gdf = merge_polygon_point(component_name,  territorial_unit, country_tag)
#                 # number_of_stations = merged_gdf["value"].sum()
#                 # merged_list.append(number_of_stations)
#                 # country_name_list.append(country_tag)
#                 plot = heat_map_plotting(component_name, territorial_unit, country_tag)
#         # eu_merged = pd.DataFrame(list(zip(country_name_list, merged_list)), columns=["country_tag", f"{component_name} value"])
#         # print(eu_merged.head())
#         # eu_merged.to_csv(destination_path)
# # merge_polygon_point(component_name,  territorial_unit, country_tag).head(10)

In [None]:
for component_name in stations_tags_dict.keys():
        print(component_name)
        country_name_list = []
        merged_list = []
        for country_name, country_tag in countries_dict.items():
            point_df_path_source = os.path.join(
                PROCESSED_DATA_PATH,
                "osm_data",
                "countries",
                f"{country_tag}",
                f"{component_name}_OverlapDf_{territorial_unit}.csv",
            )
            if os.path.exists(point_df_path_source):
                destination_path = os.path.join(
                PROCESSED_DATA_PATH,
                "osm_data",
                "countries",
                "plotting",
                f"{component_name}_eu_level.csv",
                )
                merged_gdf = merge_polygon_point(component_name,  territorial_unit, country_tag)
                number_of_stations = merged_gdf["value"].sum()
                merged_list.append(number_of_stations)
                country_name_list.append(country_tag)
                # plot = plotting(component_name, territorial_unit, country_tag)
        eu_merged = pd.DataFrame(list(zip(country_name_list, merged_list)), columns=["country_tag", f"{component_name} value"])
        print(eu_merged.head())
        eu_merged.to_csv(destination_path)
# merge_polygon_point(component_name,  territorial_unit, country_tag).head(10)

In [None]:
eu_merged.head(27)

In [None]:
source_path = os.path.join(
                PROCESSED_DATA_PATH,
                "osm_data",
                "countries",
                "plotting",
                "fuel_stations_eu_level.csv"
                )
fuel_stations_df = pd.read_csv(source_path, converters={"fuel_stations value": str})

In [None]:
source_path = os.path.join(
                PROCESSED_DATA_PATH,
                "osm_data",
                "countries",
                "plotting",
                "charging_stations_eu_level.csv"
                )
charging_stations_df = pd.read_csv(source_path, converters={"charging_stations value": str})

In [None]:
fuel_stations_df.head()

In [None]:
charging_stations_df.head()

In [None]:
fuel_stations_df['charging_stations value'] = charging_stations_df['charging_stations value']

In [None]:
# plotdata = pd.DataFrame({
#     "fuel_stations value":[fuel_stations_df['fuel_stations value']],
#     "charging_stations":[charging_stations_df['charging_stations value']],
#     }, 
#     index=[fuel_stations_df['country_tag']]
# )

In [None]:
fuel_stations_df.head()

In [None]:
type(fuel_stations_df['charging_stations value'][0])

In [None]:
type(fuel_stations_df['value'][0])

In [None]:
fuel_stations_df['charging_stations value'] = fuel_stations_df['charging_stations value'].astype('float64')

In [None]:
fuel_stations_df.rename(columns = {'value':'fuel_stations'}, inplace = True)
fuel_stations_df.rename(columns = {'charging_stations value':'charging_stations'}, inplace = True)
fuel_stations_df.drop(
        [col for col in fuel_stations_df.columns if "country_tag" not in col and "fuel_stations" not in col and "charging_stations" not in col],
        axis=1,
        inplace=True,
    )
fuel_stations_df = fuel_stations_df.set_index('country_tag')

In [None]:
fuel_stations_df.head()

In [None]:
plotdata = fuel_stations_df

In [None]:
plotdata = plotdata.sort_values('fuel_stations', ascending=False)
plotdata.plot.bar(rot=0, figsize =(8, 6))
plt.title("Fuel Stations & Charging Stations - OSM data")
plt.xlabel("countries")
plt.ylabel("# of stations")

In [None]:
source_path = os.path.join(
                PROCESSED_DATA_PATH,
                "osm_data",
                "countries",
                "plotting",
                "bus_stations_eu_level.csv"
                )
bus_stations_df = pd.read_csv(source_path, converters={"bus_stations value": str})

In [None]:
bus_stations_df['bus_stations value'] = bus_stations_df['bus_stations value'].astype('float64')

In [None]:
bus_stations_df.rename(columns = {'bus_stations value':'bus_stations'}, inplace = True)
bus_stations_df.drop(
        [col for col in bus_stations_df.columns if "country_tag" not in col and "bus_stations" not in col],
        axis=1,
        inplace=True,
    )
bus_stations_df = bus_stations_df.set_index('country_tag')

In [None]:
bus_stations_df.head(27)

In [None]:
plotdata = bus_stations_df.sort_values('bus_stations', ascending=False)
plotdata.plot.bar(rot=0, figsize =(8, 6))
plt.title("bus_stations - OSM data")
plt.xlabel("countries")
plt.ylabel("# of stations")

In [None]:
source_path = os.path.join(
                PROCESSED_DATA_PATH,
                "osm_data",
                "countries",
                "plotting",
                "airport_stations_eu_level.csv"
                )
airport_stations_df = pd.read_csv(source_path, converters={"airport_stations value": str})

In [None]:
airport_stations_df['airport_stations value'] = airport_stations_df['airport_stations value'].astype('float64')

In [None]:
airport_stations_df.rename(columns = {'airport_stations value':'airport_stations'}, inplace = True)
airport_stations_df.drop(
        [col for col in airport_stations_df.columns if "country_tag" not in col and "airport_stations" not in col],
        axis=1,
        inplace=True,
    )
airport_stations_df = airport_stations_df.set_index('country_tag')

In [None]:
airport_stations_df.head(27)

In [None]:
plotdata = airport_stations_df.sort_values('airport_stations', ascending=False)
plotdata.plot.bar(rot=0, figsize =(8, 6))
plt.title("airport_stations - OSM data")
plt.xlabel("countries")
plt.ylabel("# of stations")

In [None]:
# # https://www.geeksforgeeks.org/plotting-geospatial-data-using-geopandas/
 
# fig, ax = plt.subplots(1, figsize =(16, 12),
#                        facecolor ='lightblue')
  
# merged_gdf.plot(ax = ax, color ='black')
# merged_gdf.plot(ax = ax, column ='value', cmap ='Reds',
#            edgecolors ='grey')
  
# # axis for the color bar
# div = make_axes_locatable(ax)
# cax = div.append_axes("right", size ="3 %", pad = 0.05)
  
# # color bar
# vmax = merged_gdf.value.max()
# mappable = plt.cm.ScalarMappable(cmap ='Reds',
#                                  norm = plt.Normalize(vmin = 0, vmax = vmax))
# cbar = fig.colorbar(mappable, cax)
 
# plt.title(f"Fuel Stations - Germany - LAU level", loc='right', fontweight='bold')  
# ax.axis('off')
# plt.show()

In [None]:
# for component_name in stations_tags_dict.keys():
#         print(component_name)
#         for country_name, country_tag in countries_dict.items():
#                 point_df_path_source = os.path.join(PROCESSED_DATA_PATH, 'osm_data', 
#                                         'countries', f'{country_tag}',
#                                         f"{component_name}_{country_tag}_from_place.csv")
#                 if os.path.exists(point_df_path_source): 
#                         # point_gdf = setupPointGdf(component_name, territorial_unit)
#                         # print(f"The total number of stations in the EU27 are: ", len(point_gdf))
#                         # print(point_gdf.sample(5))
#                         # overlap_gdf = overlapPointandPolygon(component_name, territorial_unit)
#                         # print(f"The total number of {component_name} in the EU27 are: ", len(overlap_gdf))
#                         # print(overlap_gdf.sample(5))
#                         overlap_df = count_point_on_polygon(component_name, territorial_unit, country_name, country_tag)
#                         print(f"The number of \"{territorial_unit}\" regions mapped by \"{component_name}\" at \"{country_tag}\" are: ", len(overlap_df))
#                         try:
#                                 print(overlap_df.sample(5))
#                         except:
#                                 continue
#                 else:
#                         continue  