In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import geopandas as gpd
from src.organized_datasets_creation.utils import resolve_nominatim_city_name
from src.graph_layering.create_dataframes import create_osmnx_dataframes
from src.organized_datasets_creation.utils import convert_nominatim_name_to_filename
from src.graph_layering.graph_layer_creator import GraphLayerController
import pandas as pd
from typing import cast
import os
from src.graph_layering.graph_layer_creator import SourceType
import warnings
from src.graph_layering.create_hetero_data import create_hetero_data

from tqdm import tqdm

from tqdm import tqdm
import wandb.util
import wandb

In [3]:
GRAPH_LOCATION = "/home/staszek/mgr/gradient/gradient/data/wro/wro-map.osm"
ACCIDENTS_LOCATION = "/home/staszek/mgr/gradient/gradient/data/wypadki-pl/accidents.csv"
ORGANIZED_DATASETS_LOCATION = (
    "/home/staszek/mgr/gradient/gradient/data/organized-datasets"
)

In [4]:
accidents = gpd.read_file(ACCIDENTS_LOCATION)

In [5]:
cities = [
    "Wrocław, Poland",
    "Warsaw, Poland",
    "Szczecin, Poland",
    "Poznań, Poland",
    "Kraków, Poland",
]

In [6]:
def create_gdfs(city_name: str, h3_resolution: int = 9, year: int = 2017):
    osmnx_nodes, osmnx_edges = create_osmnx_dataframes(ACCIDENTS_LOCATION, city_name)
    hexes: gpd.GeoDataFrame = gpd.read_parquet(
        os.path.join(
            ORGANIZED_DATASETS_LOCATION,
            f"{convert_nominatim_name_to_filename(resolve_nominatim_city_name(city_name))}/{year}/h{h3_resolution}/count-embedder/dataset.parquet",
        )
    )
    hexes = (
        hexes.rename(columns={"region_id": "h3_id"})
        .rename_axis("region_id", axis=0)
        .drop(columns="accidents_count")
    )  # we will be using different aggregation type than the one in the dataset

    return dict(osmnx_nodes=osmnx_nodes, osmnx_edges=osmnx_edges, hexes=hexes)


with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    gdfs_dict = {city_name: create_gdfs(city_name) for city_name in tqdm(cities)}

100%|██████████| 5/5 [03:16<00:00, 39.23s/it]


# Usuwanie kolumn, które nie są wspólne dla wszystkich miast


In [7]:
def get_presence_df(gdfs_dict, tested_df_name):
    presence_df = pd.DataFrame(
        list(
            map(
                lambda v: (v[0], v[1][tested_df_name].columns.to_list()),
                list(gdfs_dict.items()),
            )
        ),
        columns=["city_name", "col"],
    ).explode("col")
    presence_df = (
        pd.get_dummies(presence_df, columns=["col"], prefix="", prefix_sep="")
        .groupby("city_name")
        .sum()
    )
    return presence_df


def filter_presence_df(df):
    return df.loc[:, df.sum(axis=0) == len(cities)]


def get_common_columns(gdfs_dict, tested_df_name):
    df_columns_presence = get_presence_df(gdfs_dict, tested_df_name)
    df_common_columns = filter_presence_df(df_columns_presence)
    return df_common_columns

In [8]:
df_osmnx_node_common_columns = get_common_columns(gdfs_dict, "osmnx_nodes")
df_osmnx_node_common_columns

Unnamed: 0_level_0,accidents_count,crossing,geometry,mini_roundabout,motorway_junction,osmid,street_count,traffic_signals,turning_circle,x,y
city_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
"Kraków, Poland",1,1,1,1,1,1,1,1,1,1,1
"Poznań, Poland",1,1,1,1,1,1,1,1,1,1,1
"Szczecin, Poland",1,1,1,1,1,1,1,1,1,1,1
"Warsaw, Poland",1,1,1,1,1,1,1,1,1,1,1
"Wrocław, Poland",1,1,1,1,1,1,1,1,1,1,1


In [9]:
df_osmnx_edge_common_columns = get_common_columns(gdfs_dict, "osmnx_edges")
df_osmnx_edge_common_columns

Unnamed: 0_level_0,access_0,access_destination,access_no,access_permissive,access_yes,bridge_0,bridge_viaduct,bridge_yes,geometry,highway_living_street,...,length,maxspeed,oneway,reversed,tunnel_0,tunnel_building_passage,tunnel_yes,u,v,width
city_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"Kraków, Poland",1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
"Poznań, Poland",1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
"Szczecin, Poland",1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
"Warsaw, Poland",1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
"Wrocław, Poland",1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1


In [10]:
df_hexes_common_columns = get_common_columns(gdfs_dict, "hexes")
df_hexes_common_columns

Unnamed: 0_level_0,aeroway_aerodrome,aeroway_helipad,aeroway_runway,amenity_animal_shelter,amenity_arts_centre,amenity_atm,amenity_bank,amenity_bar,amenity_bbq,amenity_bench,...,water_pond,water_reservoir,water_river,water_wastewater,waterway_canal,waterway_ditch,waterway_drain,waterway_river,waterway_stream,waterway_weir
city_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"Kraków, Poland",1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
"Poznań, Poland",1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
"Szczecin, Poland",1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
"Warsaw, Poland",1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
"Wrocław, Poland",1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1


In [11]:
for gdf_for_city in gdfs_dict.values():
    osmnx_nodes = gdf_for_city["osmnx_nodes"]
    osmnx_edges = gdf_for_city["osmnx_edges"]
    hexes = gdf_for_city["hexes"]

    osmnx_nodes.drop(
        columns=osmnx_nodes.columns.difference(df_osmnx_node_common_columns.columns),
        inplace=True,
    )
    osmnx_edges.drop(
        columns=osmnx_edges.columns.difference(df_osmnx_edge_common_columns.columns),
        inplace=True,
    )
    hexes.drop(
        columns=hexes.columns.difference(df_hexes_common_columns.columns), inplace=True
    )

    gdf_for_city["osmnx_nodes"] = osmnx_nodes.reindex(
        columns=df_osmnx_node_common_columns.columns
    )
    gdf_for_city["osmnx_edges"] = osmnx_edges.reindex(
        columns=df_osmnx_edge_common_columns.columns
    )
    gdf_for_city["hexes"] = hexes.reindex(columns=df_hexes_common_columns.columns)

    gdf_for_city["controller"] = GraphLayerController(
        gdf_for_city["hexes"], gdf_for_city["osmnx_nodes"], gdf_for_city["osmnx_edges"]
    )


  self.hexes_gdf.centroid, columns=["centroid_geometry"]

  self.hexes_gdf.centroid, columns=["centroid_geometry"]

  self.hexes_gdf.centroid, columns=["centroid_geometry"]

  self.hexes_gdf.centroid, columns=["centroid_geometry"]

  self.hexes_gdf.centroid, columns=["centroid_geometry"]


In [12]:
def patch_hexes_with_y(
    osmnx_nodes: gpd.GeoDataFrame,
    hexes: gpd.GeoDataFrame,
    controller: GraphLayerController,
):
    virtual_edges = controller.get_virtual_edges_to_hexes(SourceType.OSMNX_NODES)
    hexes_with_y = cast(
        gpd.GeoDataFrame,
        hexes.merge(
            virtual_edges.merge(osmnx_nodes, left_on="source_id", right_index=True)[
                ["region_id", "accidents_count"]
            ]
            .groupby("region_id")
            .sum(),
            left_index=True,
            right_index=True,
            how="left",
        ).fillna(0),
    )
    hexes_with_y["accident_occured"] = (hexes_with_y["accidents_count"] > 0).astype(int)
    hexes_with_y.drop(columns="accidents_count", inplace=True)
    controller.hexes_gdf = hexes_with_y
    controller._hexes_centroids_gdf = controller._create_hexes_centroids_gdf()

In [13]:
for gdfs in gdfs_dict.values():
    patch_hexes_with_y(gdfs["osmnx_nodes"], gdfs["hexes"], gdfs["controller"])


  self.hexes_gdf.centroid, columns=["centroid_geometry"]

  self.hexes_gdf.centroid, columns=["centroid_geometry"]

  self.hexes_gdf.centroid, columns=["centroid_geometry"]

  self.hexes_gdf.centroid, columns=["centroid_geometry"]

  self.hexes_gdf.centroid, columns=["centroid_geometry"]


In [14]:
def create_torch_geometric_hetero_data(
    osmnx_nodes, osmnx_edges, hexes, controller: GraphLayerController
):
    edges_attr_columns = osmnx_edges.columns[
        ~osmnx_edges.columns.isin(["u", "v", "key", "geometry"])
    ]
    nodes_attr_columns = osmnx_nodes.columns[
        ~osmnx_nodes.columns.isin(["geometry", "x", "y", "osmid", "accidents_count"])
    ]
    hexes_attr_columns = hexes.columns[~hexes.columns.isin(["geometry", "h3_id"])]

    data = create_hetero_data(
        controller,
        hexes_attrs_columns_names=hexes_attr_columns,
        osmnx_edge_attrs_columns_names=edges_attr_columns,
        osmnx_node_attrs_columns_names=nodes_attr_columns,
        virtual_edge_attrs_columns_names=[],
        hexes_y_columns_names=["accident_occured"],
    )
    return data


data_dict = {
    city_name: create_torch_geometric_hetero_data(**gdfs)
    for city_name, gdfs in gdfs_dict.items()
}

In [15]:
def shift_elements_right(lst):
    shifted_lst = [lst[-1]] + lst[:-1]
    return shifted_lst


cities_names_list = list(data_dict.keys())

# val + test
folds_tuples = list(zip(shift_elements_right(cities_names_list), cities_names_list))

In [16]:
folds_tuples

[('Kraków, Poland', 'Wrocław, Poland'),
 ('Wrocław, Poland', 'Warsaw, Poland'),
 ('Warsaw, Poland', 'Szczecin, Poland'),
 ('Szczecin, Poland', 'Poznań, Poland'),
 ('Poznań, Poland', 'Kraków, Poland')]

In [None]:
from wandb.util import generate_id

from src.training.train import train

sweep_configuration_bayesian = {
    "method": "bayes",
    "metric": {"name": "mean_f1", "goal": "maximize"},
    "parameters": {
        "hidden_channels": {"values": [10, 20, 30, 40, 50]},
        "learning_rate": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2,
        },
        "num_conv_layers": {"values": [1, 2, 3, 4, 5]},
        "lin_layer_size": {"values": [8, 16, 32, 64, 128]},
        "num_lin_layers": {"values": [0, 1, 2, 3, 4]},
        "weight_decay": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2,
        },
    },
}


def run_k_fold():
    epochs = 300
    run = wandb.init()

    config = wandb.config

    if hasattr(config, "lin_layer_size") and hasattr(config, "num_lin_layers"):
        lin_layer_sizes = [config.lin_layer_size] * config.num_lin_layers
    else:
        lin_layer_sizes = config.lin_layer_sizes

    hparams = {
        "hidden_channels": config.hidden_channels,
        "lr": config.learning_rate,
        "num_conv_layers": config.num_conv_layers,
        "lin_layer_sizes": lin_layer_sizes,
        "weight_decay": config.weight_decay,
    }

    aucs = []
    accuracies = []
    f1s = []

    fold_group_id = generate_id()
    for index, (val_city_name, test_city_name) in enumerate(folds_tuples):
        val_data = [data_dict[val_city_name].to("cpu").clone()]
        train_data = [
            v.to("cpu").clone()
            for k, v in data_dict.items()
            if k != val_city_name and k != test_city_name
        ]
        test_data = data_dict[test_city_name].to("cpu").clone()

        auc, accuracy, f1, model_path = train(
            train_data=train_data,
            val_data=val_data,
            test_data=test_data,
            epochs=epochs,
            hparams=hparams,
        )
        run.log_model(
            path=model_path,
            name=f"model_{fold_group_id}_fold_{index}",
        )
        run.log({f"auc_fold_{index}": auc})
        run.log({f"accuracy_fold_{index}": accuracy})
        run.log({f"f1_fold_{index}": f1})

        aucs.append(auc)
        accuracies.append(accuracy)
        f1s.append(f1)

    mean_auc = sum(aucs) / len(aucs)
    mean_accuracy = sum(accuracies) / len(accuracies)
    mean_f1 = sum(f1s) / len(f1s)
    run.log({"mean_auc": mean_auc})
    run.log({"mean_accuracy": mean_accuracy})
    run.log({"mean_f1": mean_f1})


def main():
    try:
        sweep_id = wandb.sweep(
            sweep_configuration_bayesian, project="accidents-downstream-task"
        )
        wandb.agent(sweep_id, function=run_k_fold, count=50)
    except Exception as e:
        print(e)
        wandb.finish()
        raise e


main()