In [None]:
from zoomin.data.osmtags import networks_tags_dict
from zoomin.data.constants import countries_dict
from zoomin.data import osm_stations_processing
from typing import Any
from shapely import wkt
from zoomin.data.osm_stations_processing import count_point_on_polygon, count_point_on_polygon_eu
import os
import pandas as pd
import geopandas as gpd
import osmnx as ox
import plotly.express as px
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
cwd = os.getcwd()
DATA_PATH = os.path.join(cwd, '..', '..', '..', 'data', 'input')
RAW_DATA_PATH = os.path.join(DATA_PATH, 'raw')
PROCESSED_DATA_PATH = os.path.join(DATA_PATH, 'processed')

In [None]:
def setup_polygon_for_point(territorial_unit: Any, country_tag: str) -> gpd.GeoDataFrame:
    """Get polygons geodataframe for each country at a territoriial unit."""
    polygon_shp_path = os.path.join(
        PROCESSED_DATA_PATH, "shapefiles", f"{territorial_unit}.shp"
    )
    polygon_gdf = gpd.read_file(polygon_shp_path, converters={'region_code': str})
    polygon_gdf = polygon_gdf[polygon_gdf["prnt_code"].str.contains(f"{country_tag}")]
    polygon_gdf.drop(
        [
            col
            for col in polygon_gdf.columns
            if "geometry" not in col and "code" not in col
        ],
        axis=1,
        inplace=True,
    )
    polygon_gdf.drop(
        [col for col in polygon_gdf.columns if col.startswith("prnt")],
        axis=1,
        inplace=True,
    )
    polygon_gdf.rename(columns={"code": "region_code"}, inplace=True)
    polygon_gdf.reset_index(drop=True, inplace=True)
    print(f"The number of polygon at LAU level for {country_tag} are: ", len(polygon_gdf))
    return polygon_gdf

In [None]:
def get_line(component_name: str, territorial_unit: Any, country_tag: str) -> pd.DataFrame:
    
    """Get point geodataframe for each component_name in each country at a territoriial unit."""
    line_df_path_source = os.path.join(
        PROCESSED_DATA_PATH,
        "osm_data",
        "countries",
        f"{country_tag}",
        f"{component_name}_overlap_df_{country_tag}_{territorial_unit}.csv", 
    )
    if os.path.exists(line_df_path_source):
        line_df = pd.read_csv(line_df_path_source, converters={'region_code': str})
        line_df.drop(
        [
            col
            for col in line_df.columns
            if "region_code" not in col and "value" not in col
        ],
        axis=1,
        inplace=True,
    )
    return line_df
    

In [None]:
def merge_polygon_line(component_name,  territorial_unit, country_tag):
    
    """Overlap gridded data with polygon data."""
    overlap_gdf_path_destination = os.path.join(
        PROCESSED_DATA_PATH,
        "osm_data",
        "countries",
        f"{country_tag}",
        f"{component_name}_MergedGdf_{territorial_unit}.csv",
    )
    if not os.path.exists(overlap_gdf_path_destination):
        polygon_gdf = setup_polygon_for_point(territorial_unit, country_tag)
        # if polygon_gdf.crs != 4326:
        #         polygon_gdf = polygon_gdf.to_crs(epsg=4326)
        line_df = get_line(component_name, territorial_unit, country_tag)
        # joining the geodataframe with the cleaned up csv dataframe
        merged_gdf = polygon_gdf.set_index('region_code').join(line_df.set_index('region_code'))
        print(
            f'The total number of "{component_name}" merged networks in "{country_tag}" are: ',
            len(merged_gdf),
        )
        # replace nan values woth zero
        # merged_gdf["value"] = merged_gdf["value"].fillna(0)
        #.head() returns the top 5(by default ) lines of the dataframe
        # merged_gdf.to_csv(overlap_gdf_path_destination)
        merged_gdf.sample(5)
    else:
        merged_df = pd.read_csv(overlap_gdf_path_destination, converters={'region_code': str})
        merged_df["geometry"] = gpd.GeoSeries.from_wkt(merged_df["geometry"])
        merged_gdf = gpd.GeoDataFrame(merged_df, geometry="geometry")
    return merged_gdf
    

In [None]:
def heat_map_plotting(component_name, territorial_unit, country_tag):
    
    # https://www.geeksforgeeks.org/plotting-geospatial-data-using-geopandas/
    merged_gdf = merge_polygon_line(component_name,  territorial_unit, country_tag)
    fig, ax = plt.subplots(1, figsize =(16, 12),
                        facecolor ='lightblue')
    
    merged_gdf.plot(ax = ax, color ='black')
    merged_gdf.plot(ax = ax, column ='value', cmap ='Purples',
            edgecolors ='grey')
    
    # axis for the color bar
    div = make_axes_locatable(ax)
    cax = div.append_axes("right", size ="2 %", pad = 0.05)
    
    # color bar
    vmax = merged_gdf.value.max()
    mappable = plt.cm.ScalarMappable(cmap ='Purples',
                                    norm = plt.Normalize(vmin = 0, vmax = vmax))
    cbar = fig.colorbar(mappable, cax)
    
    plt.title(f"{component_name} - {country_tag} - {territorial_unit} level", loc='center', fontweight='bold')  
    ax.axis('off')
    return plt.show()

In [None]:
territorial_unit = input(
        'Please enter a character from: LAU, NUTS3, NUTS2, NUTS1, NUTS0, Europe')

In [None]:
# for component_name in stations_tags_dict.keys():
#         print(component_name)
#         country_name_list = []
#         merged_list = []
#         for country_name, country_tag in countries_dict.items():
#             point_df_path_source = os.path.join(
#                 PROCESSED_DATA_PATH,
#                 "osm_data",
#                 "countries",
#                 f"{country_tag}",
#                 f"{component_name}_OverlapDf_{territorial_unit}.csv",
#             )
#             if os.path.exists(point_df_path_source):
#                 # destination_path = os.path.join(
#                 # PROCESSED_DATA_PATH,
#                 # "osm_data",
#                 # "countries",
#                 # "plotting",
#                 # f"{component_name}_eu_level.csv",
#                 # )
#                 # merged_gdf = merge_polygon_point(component_name,  territorial_unit, country_tag)
#                 # number_of_stations = merged_gdf["value"].sum()
#                 # merged_list.append(number_of_stations)
#                 # country_name_list.append(country_tag)
#                 plot = heat_map_plotting(component_name, territorial_unit, country_tag)
#         # eu_merged = pd.DataFrame(list(zip(country_name_list, merged_list)), columns=["country_tag", f"{component_name} value"])
#         # print(eu_merged.head())
#         # eu_merged.to_csv(destination_path)
# # merge_polygon_point(component_name,  territorial_unit, country_tag).head(10)

In [None]:
for component_name in networks_tags_dict.keys():
        print(component_name)
        country_name_list = []
        merged_list = []
        for country_name, country_tag in countries_dict.items():
                line_df_path_source = os.path.join(
                        PROCESSED_DATA_PATH,
                        "osm_data",
                        "countries",
                        f"{country_tag}",
                        f"{component_name}_overlap_df_{country_tag}_{territorial_unit}.csv", 
                )
                if os.path.exists(line_df_path_source):
                        # get_line_df = get_line(component_name, territorial_unit, country_tag)
                        destination_path = os.path.join(
                        PROCESSED_DATA_PATH,
                        "osm_data",
                        "countries",
                        "plotting",
                        f"{component_name}_eu_level.csv",
                        )
                        merged_gdf = merge_polygon_line(component_name,  territorial_unit, country_tag)
                        merged_gdf.head()
                        number_of_stations = merged_gdf["value"].sum()
                        merged_list.append(number_of_stations)
                        country_name_list.append(country_tag)
                        # plot = heat_map_plotting(component_name, territorial_unit, country_tag)
        eu_merged = pd.DataFrame(list(zip(country_name_list, merged_list)), columns=["country_tag", f"{component_name} value"])
        print(eu_merged.head())
        eu_merged.to_csv(destination_path)
# merge_polygon_point(component_name,  territorial_unit, country_tag).head(10)

In [None]:
source_path = os.path.join(
                PROCESSED_DATA_PATH,
                "osm_data",
                "countries",
                "plotting",
                "road_major_network_eu_level.csv"
                )
road_major_network_df = pd.read_csv(source_path, converters={"road_major_network value": str})

In [None]:
road_major_network_df['road_major_network value'] = road_major_network_df['road_major_network value'].astype('float64')

In [None]:
road_major_network_df.rename(columns = {'road_major_network value':'road_major_network'}, inplace = True)
road_major_network_df.drop(
        [col for col in road_major_network_df.columns if "country_tag" not in col and "road_major_network" not in col],
        axis=1,
        inplace=True,
    )
road_major_network_df = road_major_network_df.set_index('country_tag')

In [None]:
road_major_network_df.head(27)

In [None]:
road_major_network_df = road_major_network_df.sort_values('road_major_network', ascending=False)
road_major_network_df.plot.bar(rot=0, figsize =(8, 6))
plt.title("Major Raods Network - OSM data")
plt.xlabel("countries")
plt.ylabel("Meters of Road")

In [None]:
source_path = os.path.join(
                PROCESSED_DATA_PATH,
                "osm_data",
                "countries",
                "plotting",
                "bicycle_network_eu_level.csv"
                )
bicycle_network_df = pd.read_csv(source_path, converters={"bicycle_network value": str})

In [None]:
bicycle_network_df['bicycle_network value'] = bicycle_network_df['bicycle_network value'].astype('float64')

In [None]:
bicycle_network_df.rename(columns = {'bicycle_network value':'bicycle_network'}, inplace = True)
bicycle_network_df.drop(
        [col for col in bicycle_network_df.columns if "country_tag" not in col and "bicycle_network" not in col],
        axis=1,
        inplace=True,
    )
bicycle_network_df = bicycle_network_df.set_index('country_tag')

In [None]:
bicycle_network_df.head(27)

In [None]:
bicycle_network_df = bicycle_network_df.sort_values('bicycle_network', ascending=False)
bicycle_network_df.plot.bar(rot=0, figsize =(8, 6))
plt.title("Bicycle Network - OSM data")
plt.xlabel("countries")
plt.ylabel("Meters of bicycle road")

In [None]:
source_path = os.path.join(
                PROCESSED_DATA_PATH,
                "osm_data",
                "countries",
                "plotting",
                "railways_network_eu_level.csv"
                )
railways_network_df = pd.read_csv(source_path, converters={"railways_network value": str})

In [None]:
railways_network_df['railways_network value'] = railways_network_df['railways_network value'].astype('float64')

In [None]:
railways_network_df.rename(columns = {'railways_network value':'railways_network'}, inplace = True)
railways_network_df.drop(
        [col for col in railways_network_df.columns if "country_tag" not in col and "railways_network" not in col],
        axis=1,
        inplace=True,
    )
railways_network_df = railways_network_df.set_index('country_tag')

In [None]:
railways_network_df.head(27)

In [None]:
railways_network_df = railways_network_df.sort_values('railways_network', ascending=False)
railways_network_df.plot.bar(rot=0, figsize =(8, 6))
plt.title("Railways Network - OSM data")
plt.xlabel("countries")
plt.ylabel("Meters of railways")