In [1]:
import os
import sys

notebook_path = os.path.abspath("__file__")
notebook_directory = os.path.dirname(notebook_path)
parent_directory = os.path.dirname(notebook_directory)

parent_parent_directory = os.path.dirname(parent_directory)

sys.path.append(parent_parent_directory)

Upewnij się, że masz ściągniete pliki dla danego miasta z folderu https://drive.google.com/drive/folders/1G3COVpQoB5ppYJhIeCMU8_tHog9mAYmC?usp=sharing <br>
Dane umieść w folderze `data/maps_emebeddings`

# 1. Imports


In [2]:
import geopandas as gpd
from src.organized_datasets_creation.utils import resolve_nominatim_city_name
from src.organized_datasets_creation.utils import convert_nominatim_name_to_filename
from src.graph_layering.graph_layer_creator import GraphLayerController
import pandas as pd
from typing import cast
import os
from src.graph_layering.graph_layer_creator import SourceType
import warnings
from src.graph_layering.create_hetero_data import create_hetero_data

from tqdm import tqdm

import wandb.util
import wandb
import os


import numpy as np
from src.graph.create_osmnx_graph import OSMnxGraph
import json
from shapely.geometry import Point
from joblib import dump


from datetime import datetime
from sklearn.metrics import f1_score, roc_auc_score
from wandb.util import generate_id
from sklearn.linear_model import LogisticRegression
from src.training.train import train
from sklearn.preprocessing import StandardScaler

In [3]:
generate_id()

'6keau6hw'

# 2. Setting up env variables and configs


In [4]:
WANDB_API_KEY = '9786aaac1ef0eb4c9545e2b6e5d3c9a882c97e7c'
assert (
    WANDB_API_KEY is not None
), "WANDB_API_KEY is not set, did you forget it in the config file?"

In [5]:
# general settings
ORGANIZED_HEXES_LOCATION = "C:/Users/kiera/Desktop/gradient/data/organized-hexes"
ORGANIZED_GRAPHS_LOCATION = "C:/Users/kiera/Desktop/gradient/data/organized_graphs"
OSMNX_ALL_ATTRIBUTES_LOCATION = (
    "C:/Users/kiera/Desktop/gradient/data/osmnx_attributes.json"
)

# downstream task settings
ACCIDENTS_LOCATION = "C:/Users/kiera/Desktop/gradient/data/downstream_tasks/accidents_prediction/accidents.csv"
TRAIN_SAVE_DIR = "C:/Users/kiera/Desktop/gradient/gradient_logs/"

SWEEP_RUNS_COUNT = 10
EPOCHS = 3

ATTRIBUTES_CONFIGURATIONS = [
    # {
    #     "USE_ORTOPHOTO": False,
    #     "USE_HEXES_ATTRS": True,
    #     "USE_OSMNX_ATTRS": False,
    # },
    # {
    #     "USE_ORTOPHOTO": False,
    #     "USE_HEXES_ATTRS": True,
    #     "USE_OSMNX_ATTRS": True,
    # },
    # {
    #     "USE_ORTOPHOTO": False,
    #     "USE_HEXES_ATTRS": False,
    #     "USE_OSMNX_ATTRS": True,
    # },
    {
        "USE_ORTOPHOTO": True,
        "USE_HEXES_ATTRS": True,
        "USE_OSMNX_ATTRS": True,
    },
]

WANDB_SWEEP_PARAMS_GRAPH_DATA = {
    "method": "bayes",
    "metric": {"name": "mean_f1", "goal": "maximize"},
    "parameters": {
        "hidden_channels": {"values": [10, 20, 30, 40, 50]},
        "learning_rate": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2,
        },
        "num_conv_layers": {"values": [1, 2, 3, 4, 5]},
        "lin_layer_size": {"values": [8, 16, 32, 64, 128]},
        "num_lin_layers": {"values": [0, 1, 2, 3, 4]},
        "weight_decay": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2,
        },
    },
}

WANDB_SWEEP_PARAMS_TABULAR_DATA = {
    "method": "bayes",
    "metric": {"name": "mean_f1", "goal": "maximize"},
    "parameters": {
        "solver_penalty": {
            "values": [
                "lbfgs;l2",
                "liblinear;l1",
                "liblinear;l2",
                "newton-cg;l2",
                "newton-cholesky;l2",
                "sag;l2",
                "saga;elasticnet",
                "saga;l1",
                "saga;l2",
            ]
        },
        "C": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1,
        },
    },
}

# 3. Loading accidents

The process includes removing unused columns and creating GeoSeries from raw X Y points


In [6]:
accidents = pd.read_csv(ACCIDENTS_LOCATION)


def create_point(x):
    return Point(float(x[0]), float(x[1]))


geometry = accidents[["wsp_gps_x", "wsp_gps_y"]].apply(create_point, axis=1)

gdf_accidents = gpd.GeoDataFrame(accidents, geometry=geometry, crs="EPSG:4326")
gdf_accidents.drop(columns=["wsp_gps_x", "wsp_gps_y", "uczestnicy"], inplace=True)

# 4. Displaying available cities


In [7]:
cities = list(map(lambda x: x + ", Poland", accidents["mie_nazwa"].unique()))
print("Cities:")
print(cities)

Cities:
['Wrocław, Poland', 'Szczecin, Poland', 'Poznań, Poland', 'Kraków, Poland', 'Warszawa, Poland']


# 5. Creating GeoDataFrames

The process of creation has following steps:

1. loading OSMNX nodes and edges
2. assigning accidents to OSMNX nodes
3. taking latest H9 resolution hexes
4. combining OSMNX nodes, OSMNX edges, hexes in a single dict and packing it inside gdfs_dict


In [8]:
def add_accidents_to_osmnx_nodes(
    accidents: gpd.GeoDataFrame,
    nodes: gpd.GeoDataFrame,
    edges: gpd.GeoDataFrame,
    city_name: str,
):
    with open(OSMNX_ALL_ATTRIBUTES_LOCATION) as f:
        all_attributes = json.load(f)

    osmnx_graph = OSMnxGraph(
        accidents.loc[
            accidents["mie_nazwa"] == resolve_nominatim_city_name(city_name), :
        ],
        nodes,
        edges,
        all_attributes,
    )
    osmnx_graph._aggregate(element_type="node", aggregation_method="count")
    return osmnx_graph.gdf_nodes


def create_gdfs(city_name: str, accidents_gdf: gpd.GeoDataFrame = gdf_accidents):
    city_folder_name = convert_nominatim_name_to_filename(
        resolve_nominatim_city_name(city_name)
    )
    osmnx_nodes = gpd.read_parquet(
        os.path.join(ORGANIZED_GRAPHS_LOCATION, city_folder_name, "nodes.parquet")
    )
    osmnx_nodes = osmnx_nodes.reset_index()
    osmnx_nodes.index.names = ["node_id"]
    osmnx_nodes["x"] = osmnx_nodes["geometry"].x
    osmnx_nodes["y"] = osmnx_nodes["geometry"].y

    osmnx_edges = gpd.read_parquet(
        os.path.join(ORGANIZED_GRAPHS_LOCATION, city_folder_name, "edges.parquet")
    )
    osmnx_edges = osmnx_edges.reset_index().rename(columns={"index": "edge_id"})
    osmnx_edges.index.names = ["edge_id"]

    assert osmnx_nodes.crs == osmnx_edges.crs
    assert osmnx_nodes.crs == accidents_gdf.crs

    osmnx_nodes = add_accidents_to_osmnx_nodes(
        accidents=accidents_gdf,
        nodes=osmnx_nodes,
        city_name=city_name,
        edges=osmnx_edges,
    )

    hexes_years_folder = os.path.join(ORGANIZED_HEXES_LOCATION, city_folder_name)

    subfolders = [
        int(f)
        for f in os.listdir(hexes_years_folder)
        if os.path.isdir(os.path.join(hexes_years_folder, f))
    ]
    highest_year = subfolders[np.argmax(subfolders)]

    hexes: gpd.GeoDataFrame = gpd.read_parquet(
        os.path.join(
            ORGANIZED_HEXES_LOCATION,
            f"{convert_nominatim_name_to_filename(resolve_nominatim_city_name(city_name))}/{highest_year}/h9/count-embedder/dataset.parquet",
        )
    )

    hexes = hexes.rename(columns={"region_id": "h3_id"}).rename_axis(
        "region_id", axis=0
    )

    return dict(osmnx_nodes=osmnx_nodes, osmnx_edges=osmnx_edges, hexes=hexes)


print("Creating gdfs...")
gdfs_dict = {city_name: create_gdfs(city_name) for city_name in tqdm(cities)}

Creating gdfs...


100%|██████████| 5/5 [00:05<00:00,  1.05s/it]


# 6. Creating GraphLayerController for each of the cities

The creation is based on previously made GeoDataFrames. The controller is used to transfer accidents Y values from OSMNX nodes to hexes. It is also used to create complete graph data in case of graph datasets.


In [9]:
# with warnings.catch_warnings():
#     warnings.simplefilter("ignore")
#     for gdf_for_city in gdfs_dict.values():
#         gdf_for_city["controller"] = GraphLayerController(
#             gdf_for_city["hexes"],
#             gdf_for_city["osmnx_nodes"],
#             gdf_for_city["osmnx_edges"],
#         )

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for gdf_for_city_key in gdfs_dict.keys():
        gdf_for_city = gdfs_dict[gdf_for_city_key]
        city = gdf_for_city_key.split(',')[0]
        orthophotomap_emb_res_8 = pd.read_parquet(f'C:/Users/kiera/Desktop/gradient/data/maps_embeddings/regions_8_emb_{city}.parquet')

        gdf_8_to_9 = pd.read_parquet(f'C:/Users/kiera/Desktop/gradient/data/maps_embeddings/correlations_regions_9_8_{city}.parquet')[['region_id', 'region_id_res_8']]
        gdf_8_to_9.rename(columns={'region_id': 'h3_id'}, inplace=True) 
        gdf_8_to_9.rename(columns={'region_id_res_8': 'region_id'}, inplace=True)

        orthophotomap_emb_res_9  = pd.merge(gdf_8_to_9, orthophotomap_emb_res_8, on='region_id', how='left')  # You can change 'inner' to 'outer', 'left', or 'right' depending on your needs

        orthophotomap_emb_res_9.drop(
            columns='region_id',
            inplace=True,
        )
  
        gdf_for_city["hexes"] = pd.merge(gdf_for_city["hexes"], orthophotomap_emb_res_9, on='h3_id', how='left')  # You can change 'inner' to 'outer', 'left', or 'right' depending on your needs

        zeros_array = np.zeros(151296)
        nan_indices = gdf_for_city["hexes"].index[gdf_for_city["hexes"]['emb'].isna()]
        for idx in nan_indices:
            gdf_for_city["hexes"].at[idx, 'emb'] = zeros_array


        gdf_for_city["hexes"] = gdf_for_city["hexes"].reset_index().set_index('index').rename_axis('region_id')
        gdf_for_city["hexes"].rename(columns={'emb': 'ortophoto_embedding'}, inplace=True)

        gdf_for_city["controller"] = GraphLayerController(
            gdf_for_city["hexes"],
            gdf_for_city["osmnx_nodes"],
            gdf_for_city["osmnx_edges"],
        )

# 7. Patching hexes

The y value (1 = accident occured, 0 = no accident) is assigned to each of the hexes according to its underlying OSMNX nodes


In [10]:
def patch_hexes_with_y(
    osmnx_nodes: gpd.GeoDataFrame,
    hexes: gpd.GeoDataFrame,
    controller: GraphLayerController,
):
    virtual_edges = controller.get_virtual_edges_to_hexes(SourceType.OSMNX_NODES)
    hexes_with_y = cast(
        gpd.GeoDataFrame,
        hexes.merge(
            virtual_edges.merge(osmnx_nodes, left_on="source_id", right_index=True)[
                ["region_id", "accidents_count"]
            ]
            .groupby("region_id")
            .sum(),
            left_index=True,
            right_index=True,
            how="left",
        ).fillna(0),
    )
    hexes_with_y["accident_occured"] = (hexes_with_y["accidents_count"] > 0).astype(int)
    hexes_with_y.drop(columns="accidents_count", inplace=True)
    controller.hexes_gdf = hexes_with_y
    controller._hexes_centroids_gdf = controller._create_hexes_centroids_gdf()

In [11]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for gdfs in gdfs_dict.values():
        patch_hexes_with_y(gdfs["osmnx_nodes"], gdfs["hexes"], gdfs["controller"])

# 8. Creating graph data

Graph data is used when we include OSMNX attributes and in turn maintain the graph structure of the data

The data is created only once for now just to create (train, val, test) folds labels for crossvalidation on graph-based versions of the task


In [12]:
from typing import Iterable
import torch

from src.graph_layering.city_hetero_data import CityHeteroData
from src.graph_layering.graph_layer_creator import GraphLayerController, SourceType

import gc

def create_hetero_data(
    controller: GraphLayerController,
    hexes_attrs_columns_names: Iterable[str],
    osmnx_node_attrs_columns_names: Iterable[str],
    osmnx_edge_attrs_columns_names: Iterable[str],
    virtual_edge_attrs_columns_names: Iterable[str],
    hexes_y_columns_names: Iterable[str],
    squeeze_y: bool = True,
    use_ortophoto: bool = False,
) -> CityHeteroData:
    data = CityHeteroData()
    edges_between_hexes = controller.get_edges_between_hexes()
    edges_between_source_and_hexes = controller.get_virtual_edges_to_hexes(
        SourceType.OSMNX_NODES
    )

    for col in hexes_y_columns_names:
        assert (
            col not in controller.hexes_centroids_gdf[hexes_attrs_columns_names].columns
        )

    data.hex.x = torch.tensor(
        controller.hexes_centroids_gdf[hexes_attrs_columns_names].to_numpy(),
        dtype=torch.float32,
    )

    if use_ortophoto:
        ortophoto_embedding = np.stack(controller.hexes_centroids_gdf['ortophoto_embedding'].values)
        ortophoto_embedding = torch.tensor(ortophoto_embedding, dtype=torch.float32)
        data.hex.x = torch.cat((data.hex.x, ortophoto_embedding), dim=1)
        del ortophoto_embedding
        gc.collect()

    data.hex.y = torch.tensor(
        controller.hexes_centroids_gdf[hexes_y_columns_names].to_numpy(),
        dtype=torch.float32,
    ).to(torch.int64)

    data.osmnx_node.x = torch.tensor(
        controller.osmnx_nodes_gdf[osmnx_node_attrs_columns_names].to_numpy(),
        dtype=torch.float32,
    )

    data.hex_connected_to_hex.edge_index = torch.tensor(
        edges_between_hexes.merge(
            controller.hexes_gdf.reset_index(),
            left_on="u",
            right_on="h3_id",
        )
        .rename(columns={"region_id": "u_region_id"})
        .merge(
            controller.hexes_gdf.reset_index(),
            left_on="v",
            right_on="h3_id",
        )
        .rename(columns={"region_id": "v_region_id"})[["u_region_id", "v_region_id"]]
        .to_numpy()
        .T
    )

    node_to_node_connections = (
        controller.osmnx_edges_gdf.merge(
            controller.osmnx_nodes_gdf.reset_index(), left_on="u", right_on="osmid"
        )
        .rename(columns={"node_id": "u_node_id"})
        .merge(controller.osmnx_nodes_gdf.reset_index(), left_on="v", right_on="osmid")
        .rename(columns={"node_id": "v_node_id"})
    )

    data.osmnx_node_connected_to_osmnx_node.edge_index = torch.tensor(
        node_to_node_connections[["u_node_id", "v_node_id"]].to_numpy().T
    )

    data.osmnx_node_connected_to_osmnx_node.edge_attr = torch.tensor(
        node_to_node_connections[osmnx_edge_attrs_columns_names].to_numpy(),
        dtype=torch.float32,
    )

    del node_to_node_connections
    gc.collect()

    data.osmnx_node_connected_to_hex.edge_index = torch.tensor(
        edges_between_source_and_hexes[["source_id", "region_id"]].to_numpy().T
    )

    data.osmnx_node_connected_to_hex.edge_attr = torch.tensor(
        edges_between_source_and_hexes[virtual_edge_attrs_columns_names].to_numpy(),
        dtype=torch.float32,
    )

    if squeeze_y:
        data.hex.y = data.hex.y.squeeze()

    return data


In [13]:
def create_graph_data(
    osmnx_nodes,
    osmnx_edges,
    hexes,
    controller: GraphLayerController,
    use_hexes_attr: bool,
    use_ortophoto: bool,
):

    edges_attr_columns = osmnx_edges.columns[
        ~osmnx_edges.columns.isin(["u", "v", "key", "geometry"])
    ]
    nodes_attr_columns = osmnx_nodes.columns[
        ~osmnx_nodes.columns.isin(["geometry", "x", "y", "osmid", "accidents_count"])
    ]
    # hexes_attr_columns = (
    #     hexes.columns[~hexes.columns.isin(["geometry", "h3_id"])]
    #     if use_hexes_attr
    #     else []
    # )

    if use_hexes_attr:
        if use_ortophoto:
            hexes_attr_columns = hexes.columns[~hexes.columns.isin(["geometry", "h3_id", "ortophoto_embedding"])]
        else:
            hexes_attr_columns = hexes.columns[~hexes.columns.isin(["geometry", "h3_id"])]
    else:
        hexes_attr_columns = []

    data = create_hetero_data(
        controller,
        hexes_attrs_columns_names=hexes_attr_columns,
        osmnx_edge_attrs_columns_names=edges_attr_columns,
        osmnx_node_attrs_columns_names=nodes_attr_columns,
        virtual_edge_attrs_columns_names=[],
        hexes_y_columns_names=["accident_occured"],
        use_ortophoto=use_ortophoto,
    )
    return data


graph_data_dict = {
    city_name: create_graph_data(**gdfs, use_ortophoto=True, use_hexes_attr=True)
    for city_name, gdfs in gdfs_dict.items()
}

{'891e204e4b7ffff', '891e2040b6bffff', '891e204e4a3ffff', '891e204e5d3ffff', '891e204e4afffff', '891e204e5dbffff'} <class 'set'>
{'891e2042393ffff', '891e204206fffff', '891e2042383ffff', '891e20423d7ffff', '891e2042067ffff', '891e204238bffff'} <class 'set'>
{'891e2041823ffff', '891e20419cbffff', '891e20419dbffff', '891e2041833ffff', '891e20418afffff', '891e2041827ffff'} <class 'set'>
{'891e204722fffff', '891e204723bffff', '891e2047233ffff', '891e2047227ffff', '891e2047237ffff', '891e204722bffff'} <class 'set'>
{'891e204660bffff', '891e2046643ffff', '891e2046647ffff', '891e20466cfffff', '891e204661bffff', '891e2046653ffff'} <class 'set'>
{'891e20430bbffff', '891e2043033ffff', '891e2043017ffff', '891e20430a3ffff', '891e20430afffff', '891e2043007ffff'} <class 'set'>
{'891e2041dd3ffff', '891e204036bffff', '891e2041dd7ffff', '891e2041d9bffff', '891e2040367ffff', '891e2040363ffff'} <class 'set'>
{'891e204296bffff', '891e2042967ffff', '891e20476d7ffff', '891e204769bffff', '891e20476d3ffff', '

In [14]:
graph_data_dict

{'Wrocław, Poland': CityHeteroData(
   hex={
     x=[3168, 153447],
     y=[3168],
   },
   osmnx_node={ x=[71641, 22] },
   (hex, connected_to, hex)={ edge_index=[2, 9206] },
   (osmnx_node, connected_to, osmnx_node)={
     edge_index=[2, 124386],
     edge_attr=[124386, 44],
   },
   (osmnx_node, connected_to, hex)={
     edge_index=[2, 71641],
     edge_attr=[71641, 0],
   }
 ),
 'Szczecin, Poland': CityHeteroData(
   hex={
     x=[3534, 153447],
     y=[3534],
   },
   osmnx_node={ x=[64894, 22] },
   (hex, connected_to, hex)={ edge_index=[2, 10163] },
   (osmnx_node, connected_to, osmnx_node)={
     edge_index=[2, 106588],
     edge_attr=[106588, 44],
   },
   (osmnx_node, connected_to, hex)={
     edge_index=[2, 64894],
     edge_attr=[64894, 0],
   }
 ),
 'Poznań, Poland': CityHeteroData(
   hex={
     x=[2945, 153447],
     y=[2945],
   },
   osmnx_node={ x=[60082, 22] },
   (hex, connected_to, hex)={ edge_index=[2, 8486] },
   (osmnx_node, connected_to, osmnx_node)={
     edge

# 9. Creating tabular data

Tabular data is used when we omit OSMNX attributes and in turn lose the graph structure of the data

No folds creation on tabular-based versions of the task - using simple leave-one-out


In [15]:
def create_tabular_data(
    hexes: pd.DataFrame,
    controller: GraphLayerController,
    use_hexes_attr: bool,
    use_ortophoto: bool,
):
    assert use_ortophoto or use_hexes_attr, "Provide at least one data source"

    hexes_attr_columns = (
        hexes.columns[~hexes.columns.isin(["geometry", "h3_id", "ortophoto_embedding"])]
        if use_hexes_attr
        else []
    )

    hexes_y_columns_names = ["accident_occured"]

    X = hexes[hexes_attr_columns]
    y = controller.hexes_centroids_gdf[hexes_y_columns_names]

    return {"X": X, "y": y}


# tabular_data_dict = {
#     city_name: create_tabular_data(
#         gdfs["hexes"], gdfs["controller"], use_ortophoto=True, use_hexes_attr=True
#     )
#     for city_name, gdfs in gdfs_dict.items()
# }

# 10. Creating folds labels


In [16]:
def shift_elements_right(lst):
    shifted_lst = [lst[-1]] + lst[:-1]
    return shifted_lst


cities_names_list = list(graph_data_dict.keys())
cities_names_list.sort(key=lambda x: str(x))

# val + test
folds_tuples = list(zip(shift_elements_right(cities_names_list), cities_names_list))
display(folds_tuples)

[('Wrocław, Poland', 'Kraków, Poland'),
 ('Kraków, Poland', 'Poznań, Poland'),
 ('Poznań, Poland', 'Szczecin, Poland'),
 ('Szczecin, Poland', 'Warszawa, Poland'),
 ('Warszawa, Poland', 'Wrocław, Poland')]

# 11. Functions setup


In [17]:
def run_k_fold_graph_data(closure_config, sweep_id, hparams):
    # pass external config (i.e. what attributes are used in the data), closure to avoid passing it to the function directly
    def wrapped(hparams):
        results=[]
        # run = wandb.init()
        epochs = EPOCHS

        # config = wandb.config

        # for k, v in closure_config.items():
        #     run.log({k: 1 if v else 0})

        # run.log({"data_structure": "graph"})

        # create hparams
        # if hasattr(config, "lin_layer_size") and hasattr(config, "num_lin_layers"):
        #     lin_layer_sizes = [config.lin_layer_size] * config.num_lin_layers
        # else:
        #     lin_layer_sizes = config.lin_layer_sizes

        lin_layer_sizes = [hparams['lin_layer_size']] * hparams['num_lin_layers']

        # hparams = {
        #     "hidden_channels": config.hidden_channels,
        #     "lr": config.learning_rate,
        #     "num_conv_layers": config.num_conv_layers,
        #     "lin_layer_sizes": lin_layer_sizes,
        #     "weight_decay": config.weight_decay,
        # }
        hparams = {
            "hidden_channels": hparams['hidden_channels'],
            "lr": hparams['learning_rate'],
            "num_conv_layers": hparams['num_conv_layers'],
            "lin_layer_sizes": lin_layer_sizes,
            "weight_decay": hparams['weight_decay'],
        }
        aucs = []
        accuracies = []
        f1s = []

        # fold_group_id = generate_id()

        # log data as artifact if no data was logged in the sweep before
        # dataset is uploaded only on the first run in sweep, because it does not change across runs in sweep
        # in wandb, dataset will be visible on the first run in the sweep
        # artifact_path = os.path.join(TRAIN_SAVE_DIR, f"graph_data_{sweep_id}.pkl")
        # if not os.path.exists(artifact_path):
        #     dump(
        #         graph_data_dict,
        #         artifact_path,
        #         protocol=5,
        #     )
        #     artifact = wandb.Artifact(
        #         name="graph_data", type="dataset", metadata=closure_config
        #     )
        #     artifact.add_file(local_path=artifact_path)
            # run.log_artifact(artifact)

        # run k-fold
        print(folds_tuples)
        for index, (val_city_name, test_city_name) in enumerate(folds_tuples):
            # prepare data
            val_data = [graph_data_dict[val_city_name].to("cpu").clone()]
            train_data = [
                v.to("cpu").clone()
                for k, v in graph_data_dict.items()
                if k != val_city_name and k != test_city_name
            ]
            test_data = graph_data_dict[test_city_name].to("cpu").clone()

            # run training with checkpointing on lowest val_loss, return test metrics for the best model and its path
            # builtin preprocessing - scaling to N(0, 1)
            auc, accuracy, f1, model_path = train(
                train_data=train_data,
                val_data=val_data,
                test_data=test_data,
                epochs=epochs,
                hparams=hparams,
                train_save_dir=TRAIN_SAVE_DIR,
            )

            # logging - single fold
            # run.log_model(
            #     path=model_path,
            #     name=f"model_{fold_group_id}_fold_{index}",
            # )
            # run.log({f"auc_fold_{index}": auc})
            # run.log({f"accuracy_fold_{index}": accuracy})
            # run.log({f"f1_fold_{index}": f1})

            aucs.append(auc)
            accuracies.append(accuracy)
            f1s.append(f1)
            result = {
            "auc": auc,
            "accuracie": accuracy,
            "f1": f1,
            }
            
            with open(f'emb_results_{val_city_name}_{test_city_name}_{sweep_id}.json', 'w') as file:
                json.dump(result, file)
            print(result)
            del model_path, test_data, train_data, val_data
            gc.collect()

            del auc, accuracy, f1
            gc.collect()

        # logging - summary statistics
        mean_auc = sum(aucs) / len(aucs)
        mean_accuracy = sum(accuracies) / len(accuracies)
        mean_f1 = sum(f1s) / len(f1s)
        # run.log({"mean_auc": mean_auc})
        # run.log({"mean_accuracy": mean_accuracy})
        # run.log({"mean_f1": mean_f1})
        results.append({(val_city_name, test_city_name, sweep_id): {
            "mean_auc": mean_auc,
            "mean_accuracie": mean_accuracy,
            "mean_f1": mean_f1,
            }})
        with open(f'emb_results_{sweep_id}.json', 'w') as file:
            json.dump(results, file)
        del mean_auc, mean_accuracy, mean_f1, model_path
        gc.collect()
        
    wrapped(hparams)


# def run_k_fold_tabular_data(closure_config, sweep_id):
#     # analogously to the graph data, but for tabular data
#     def wrapped():
#         run = wandb.init()

#         config = wandb.config

#         for k, v in closure_config.items():
#             run.log({k: 1 if v else 0})

#         run.log({"data_structure": "tabular"})

#         hparams = {}
#         hparams["C"] = config["C"]
#         solver, penalty = config["solver_penalty"].split(";")
#         hparams["solver"] = solver
#         if penalty == "None":
#             penalty = None
#         hparams["penalty"] = penalty

#         aucs = []
#         accuracies = []
#         f1s = []

#         fold_group_id = generate_id()

#         # log data as artifact
#         artifact_path = os.path.join(TRAIN_SAVE_DIR, f"tabular_data_{sweep_id}.pkl")

#         if not os.path.exists(artifact_path):
#             dump(
#                 tabular_data_dict,
#                 artifact_path,
#                 protocol=5,
#             )
#             artifact = wandb.Artifact(
#                 name="tabular_data", type="dataset", metadata=closure_config
#             )
#             artifact.add_file(local_path=artifact_path)
#             run.log_artifact(artifact)

#         timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

#         for index, test_city_name in enumerate(cities_names_list):
#             scaler = StandardScaler()
#             X = pd.concat(
#                 [
#                     m["X"]
#                     for key, m in tabular_data_dict.items()
#                     if key != test_city_name
#                 ]
#             ).to_numpy()
#             y = (
#                 pd.concat(
#                     [
#                         m["y"]
#                         for key, m in tabular_data_dict.items()
#                         if key != test_city_name
#                     ]
#                 )
#                 .to_numpy()
#                 .ravel()
#             )

#             X = scaler.fit_transform(X)

#             logistic_regression = LogisticRegression(
#                 C=hparams["C"],
#                 solver=hparams["solver"],
#                 penalty=hparams["penalty"],
#                 dual=False,
#                 tol=1e-4,
#                 fit_intercept=True,
#                 intercept_scaling=1,
#                 class_weight="balanced",
#                 random_state=1124,
#                 max_iter=1000,
#                 multi_class="auto",
#                 warm_start=False,
#                 n_jobs=-1,
#                 l1_ratio=0.5,
#             )
#             logistic_regression.fit(X, y)

#             test_X = tabular_data_dict[test_city_name]["X"].to_numpy()
#             test_X = scaler.transform(test_X)
#             test_y = tabular_data_dict[test_city_name]["y"].to_numpy().ravel()
#             y_pred = logistic_regression.predict(test_X)
#             y_proba = logistic_regression.predict_proba(test_X)[:, 1]

#             auc = roc_auc_score(test_y, y_proba, average="micro")
#             accuracy = (y_pred == test_y).mean()
#             f1 = f1_score(
#                 test_y,
#                 y_pred,
#                 pos_label=1,
#                 average="binary",
#             )

#             model_dir = os.path.join(TRAIN_SAVE_DIR, timestamp)

#             os.makedirs(model_dir, exist_ok=True)

#             model_path = os.path.join(
#                 model_dir, f"model_{fold_group_id}_fold_{index}.pkl"
#             )

#             with open(model_path, "wb") as f:
#                 dump(logistic_regression, f, protocol=5)

#             run.log_model(
#                 path=model_path,
#                 name=f"model_{fold_group_id}_fold_{index}",
#             )
#             run.log({f"auc_fold_{index}": auc})
#             run.log({f"accuracy_fold_{index}": accuracy})
#             run.log({f"f1_fold_{index}": f1})

#             aucs.append(auc)
#             accuracies.append(accuracy)
#             f1s.append(f1)
            

#         mean_auc = sum(aucs) / len(aucs)
#         mean_accuracy = sum(accuracies) / len(accuracies)
#         mean_f1 = sum(f1s) / len(f1s)
#         run.log({"mean_auc": mean_auc})
#         run.log({"mean_accuracy": mean_accuracy})
#         run.log({"mean_f1": mean_f1})

#     return wrapped


def run_sweep_graph_data(config, configs_list):
    for i, con in enumerate(configs_list):
        run_k_fold_graph_data(config, i, con)
        

# def run_sweep_tabular_data(config):
#     try:
#         wandb.login(key=WANDB_API_KEY)

#         sweep_id = wandb.sweep(
#             WANDB_SWEEP_PARAMS_TABULAR_DATA, project="maps_test"
#         )

#         wandb.agent(
#             sweep_id,
#             function=run_k_fold_tabular_data(config, sweep_id),
#             count=SWEEP_RUNS_COUNT,
#         )
#     except Exception as e:
#         print(e)
#         wandb.finish()
#         raise e

# 12. Run functions

For each config:

1. Determine if config requires tabular or graph data
2. Create data excluding attributes not included in the config
3. Run the sweep


In [18]:
configs_list = [
    {
        "hidden_channels": 10,
        "learning_rate": 1e-5,
        "num_conv_layers": 1,
        "lin_layer_size": 8,
        "num_lin_layers": 0,
        "weight_decay": 1e-5,
    },
    {
        "hidden_channels": 30,
        "learning_rate": 1e-5,
        "num_conv_layers": 3,
        "lin_layer_size": 32,
        "num_lin_layers": 2,
        "weight_decay": 1e-5,
    },
    {
        "hidden_channels": 50,
        "learning_rate": 1e-5,
        "num_conv_layers": 5,
        "lin_layer_size": 128,
        "num_lin_layers": 4,
        "weight_decay": 1e-5,
    }
]

for config in configs_list:
    print(config)

{'hidden_channels': 10, 'learning_rate': 1e-05, 'num_conv_layers': 1, 'lin_layer_size': 8, 'num_lin_layers': 0, 'weight_decay': 1e-05}
{'hidden_channels': 30, 'learning_rate': 1e-05, 'num_conv_layers': 3, 'lin_layer_size': 32, 'num_lin_layers': 2, 'weight_decay': 1e-05}
{'hidden_channels': 50, 'learning_rate': 1e-05, 'num_conv_layers': 5, 'lin_layer_size': 128, 'num_lin_layers': 4, 'weight_decay': 1e-05}


In [19]:
import os

os.environ['WANDB_NOTEBOOK_NAME'] = 'accidents-prediction.ipynb'

In [20]:
ATTRIBUTES_CONFIGURATIONS

[{'USE_ORTOPHOTO': True, 'USE_HEXES_ATTRS': True, 'USE_OSMNX_ATTRS': True}]

In [21]:
def derive_data_structure(attr_config):
    if attr_config["USE_OSMNX_ATTRS"]:
        return "graph"
    return "tabular"



configs_size = len(ATTRIBUTES_CONFIGURATIONS)

for index, attr_config in enumerate(ATTRIBUTES_CONFIGURATIONS):
    print("Sweep for config {}/{} in progress...".format(index + 1, configs_size))

    assert "USE_ORTOPHOTO" in attr_config, "Provide USE_ORTOPHOTO key"
    assert "USE_HEXES_ATTRS" in attr_config, "Provide USE_HEXES_ATTRS key"
    assert "USE_OSMNX_ATTRS" in attr_config, "Provide USE_OSMNX_ATTRS key"

    # data_structure = derive_data_structure(attr_config)

    # if data_structure == "graph":
    # graph_data_dict = {
    #     city_name: create_graph_data(
    #         **gdfs,
    #         use_ortophoto=attr_config["USE_ORTOPHOTO"],
    #         use_hexes_attr=attr_config["USE_HEXES_ATTRS"],
    #     )
    #     for city_name, gdfs in gdfs_dict.items()
    # }
    run_sweep_graph_data(attr_config, configs_list)
    # elif data_structure == "tabular":
    #     tabular_data_dict = {
    #         city_name: create_tabular_data(
    #             gdfs["hexes"],
    #             gdfs["controller"],
    #             use_ortophoto=attr_config["USE_ORTOPHOTO"],
    #             use_hexes_attr=attr_config["USE_HEXES_ATTRS"],
    #         )
    #         for city_name, gdfs in gdfs_dict.items()
    #     }
    #     run_sweep_tabular_data(attr_config)
    # else:
    #     raise ValueError("Unknown data structure")

Sweep for config 1/1 in progress...
[('Wrocław, Poland', 'Kraków, Poland'), ('Kraków, Poland', 'Poznań, Poland'), ('Poznań, Poland', 'Szczecin, Poland'), ('Szczecin, Poland', 'Warszawa, Poland'), ('Warszawa, Poland', 'Wrocław, Poland')]


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: C:/Users/kiera/Desktop/gradient/gradient_logs/2024_06_02_17_26_49\lightning_logs
c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\utilities\model_summary\model_summary.py:454: A layer with UninitializedParameter was found. Thus, the total number of parameters detected may be inaccurate.

  | Name  | Type      | Params
------------------------------------
0 | model | HeteroGNN | 4.6 M 
------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.419    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


                                                                           

c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Epoch 2: 100%|██████████| 3/3 [00:14<00:00,  0.20it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 3/3 [00:14<00:00,  0.20it/s, v_num=0]
{'auc': 0.70692738650728, 'accuracie': 0.6956394498097747, 'f1': 0.5144724556489263}


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: C:/Users/kiera/Desktop/gradient/gradient_logs/2024_06_02_17_33_56\lightning_logs

  | Name  | Type      | Params
------------------------------------
0 | model | HeteroGNN | 4.6 M 
------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.419    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


                                                                           

c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Epoch 2: 100%|██████████| 3/3 [00:22<00:00,  0.14it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 3/3 [00:22<00:00,  0.13it/s, v_num=0]
{'auc': 0.7811089943370191, 'accuracie': 0.6030560271646859, 'f1': 0.5820521987844118}


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: C:/Users/kiera/Desktop/gradient/gradient_logs/2024_06_02_17_40_35\lightning_logs

  | Name  | Type      | Params
------------------------------------
0 | model | HeteroGNN | 4.6 M 
------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.419    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


                                                                           

c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Epoch 2: 100%|██████████| 3/3 [00:12<00:00,  0.24it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 3/3 [00:12<00:00,  0.24it/s, v_num=0]
{'auc': 0.6172546717852725, 'accuracie': 0.6061120543293718, 'f1': 0.35435992578849723}


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: C:/Users/kiera/Desktop/gradient/gradient_logs/2024_06_02_17_46_22\lightning_logs

  | Name  | Type      | Params
------------------------------------
0 | model | HeteroGNN | 4.6 M 
------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.419    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


                                                                           

c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Epoch 2: 100%|██████████| 3/3 [00:10<00:00,  0.29it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 3/3 [00:10<00:00,  0.29it/s, v_num=0]
{'auc': 0.6262244244233385, 'accuracie': 0.3630493273542601, 'f1': 0.48558597711140083}


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: C:/Users/kiera/Desktop/gradient/gradient_logs/2024_06_02_17_51_40\lightning_logs

  | Name  | Type      | Params
------------------------------------
0 | model | HeteroGNN | 4.6 M 
------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.419    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


                                                                           

c:\Users\kiera\Desktop\gradient\.venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Epoch 2: 100%|██████████| 3/3 [00:23<00:00,  0.13it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 3/3 [00:23<00:00,  0.13it/s, v_num=0]
{'auc': 0.5993568221551177, 'accuracie': 0.3317550505050505, 'f1': 0.4471141290154087}


TypeError: keys must be str, int, float, bool or None, not tuple