# 1. Imports


In [1]:
import sys
sys.path.append('../../') 

import geopandas as gpd
from src.organized_datasets_creation.utils import resolve_nominatim_city_name
from src.organized_datasets_creation.utils import convert_nominatim_name_to_filename
from src.graph_layering.graph_layer_creator import GraphLayerController
import pandas as pd
from typing import cast
import os
from src.graph_layering.graph_layer_creator import SourceType
import warnings
from src.graph_layering.create_hetero_data import create_hetero_data

from tqdm import tqdm

import wandb.util
import wandb
import os

import numpy as np
from src.graph.create_osmnx_graph import OSMnxGraph
import json
from shapely.geometry import Point
from joblib import dump


from datetime import datetime
from sklearn.metrics import f1_score, roc_auc_score
from wandb.util import generate_id
from sklearn.linear_model import LogisticRegression
from src.training.train import train
from sklearn.preprocessing import StandardScaler

# 2. Setting up env variables and configs


In [2]:
WANDB_API_KEY = os.environ.get("WANDB_API_KEY", None)
assert (
    WANDB_API_KEY is not None
), "WANDB_API_KEY is not set, did you forget it in the config file?"

In [3]:
pd.read_csv(f"../../data/downstream_tasks/zabka_shops/zabka_shops.csv")

Unnamed: 0,id,slug,openTime,closeTime,city,address,postcode,voivodeship,county,community,region,salesRegion,openTimeSeconds,closeTimeSeconds,lat,lng,services
0,ID06093,"ID06093,gdansk-jabloniowa-29a",06:00,23:00,Gdańsk,Jabłoniowa 29A,80-175,Pomorskie,Gdańsk,GDAŃSK (GMINA MIEJSKA),DS3,PS3.6.3,21600.0,82800.0,54.330567,18.557187,"BIH,DEN,GSM,KPO,LOT,ODP,PAC,RAC,REJ,TER,ZBC"
1,ID03871,"ID03871,gorzow-wielkopolski-obroncow-pokoju-38...",06:00,23:00,Gorzów Wielkopolski,Obrońców Pokoju 38 nr 38 I,66-400,Lubuskie,Gorzów Wielkopolski,GORZÓW WIELKOPOLSKI (GMINA MIE,DS2,PS2.5.3,21600.0,82800.0,52.764806,15.264941,"BIH,DEN,GSM,KPO,LOT,ODP,PAC,RAC,REJ,TER,ZBC"
2,ID06169,"ID06169,ruda-slaska-ul-niedurnego-45-lok-1",06:00,23:00,Ruda Śląska,ul. Niedurnego 45 lok. 1,41-709,Śląskie,Ruda Śląska,RUDA ŚLĄSKA (GMINA MIEJSKA),DS8,PS8.2.2,21600.0,82800.0,50.284140,18.876000,"BIH,DEN,GSM,KPO,LOT,ODP,PAC,RAC,REJ,TER,ZBC"
3,ID06264,"ID06264,warszawa-ul-zeganska-18",06:00,23:00,Warszawa,ul. Żegańska 18,04-713,Mazowieckie,Warszawa,"WARSZAWA (GMINA MIEJSKA, MIAST",DS5,PS5.4.1,21600.0,82800.0,52.205620,21.172909,"BIH,DEN,GSM,KPO,LOT,ODP,PAC,RAC,REJ,TER,ZBC"
4,ID05100,"ID05100,wejherowo-dworcowa-2",06:00,23:00,Wejherowo,Dworcowa 2,84-200,Pomorskie,wejherowski,WEJHEROWO (GMINA MIEJSKA),DS3,PS3.2.3,21600.0,82800.0,54.603632,18.228519,"BIH,DEN,GSM,KPO,LOT,ODP,PAC,RAC,REJ,TER,ZBC"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9754,ID0A221,"ID0A221,kedzierzyn-kozle-ul-gliwicka-1a",06:00,23:00,Kędzierzyn-Koźle,UL. GLIWICKA 1A,47-224,Opolskie,kędzierzyńsko-kozielski,KĘDZIERZYN-KOŹLE (GMINA MIEJSK,DS7,PS7.5.1,21600.0,82800.0,50.336754,18.183524,"DEN,GSM,KPO,ODP,PAC,RAC,REJ,TER,ZBC"
9755,ID0B739,"ID0B739,lewin-brzeski-ul-kosciuszki-37d",06:00,23:00,Lewin Brzeski,UL. KOŚCIUSZKI 37D,49-340,Opolskie,brzeski (woj. opolskie),LEWIN BRZESKI (GMINA MIEJSKO-W,DS7,PS7.5.5,21600.0,82800.0,50.751824,17.615581,"DEN,GSM,KPO,ODP,PAC,RAC,REJ,TER,ZBC"
9756,ID0C796,"ID0C796,uniejow-ul-targowa-18",00:00,23:59,Uniejów,UL. Targowa 18,99-210,Łódzkie,poddębicki,UNIEJÓW (GMINA MIEJSKO-WIEJSKA,DS6,PS6.6.5,0.0,86340.0,51.976362,18.791962,"DEN,GSM,KPO,ODP,PAC,RAC,REJ,TER,ZBC"
9757,ID0D004,"ID0D004,warszawa-ul-zamkowa-8-lok-u1",06:00,23:00,Warszawa,UL. Zamkowa 8 lok. U1 .,03-890,Mazowieckie,Warszawa,"WARSZAWA (GMINA MIEJSKA, MIAST",DS5,PS5.4.1,21600.0,82800.0,52.274640,21.081131,"DEN,GSM,KPO,ODP,PAC,RAC,REJ,TER,ZBC"


In [4]:
WANDB_API_KEY= ""

# general settings
ORGANIZED_HEXES_LOCATION = "/home/grymar/studia/gradient/data/organized-hexes"
ORGANIZED_GRAPHS_LOCATION = "/home/grymar/studia/gradient/data/organized_graphs"
OSMNX_ALL_ATTRIBUTES_LOCATION = (
    "../../data/osmnx_attributes.json"
)

HEX_FI_LOCATION = (
    "/home/grymar/studia/gradient/data/downstream_tasks/feature_importance"
)

# downstream task settings
ALL_ZABKAS_LOCATION = "../../data/downstream_tasks/zabka_shops/zabka_shops.csv"
# ZABKA_DIR = "/home/grymar/tmp/"

TRAIN_SAVE_DIR = "/home/grymar/tmp/"

SWEEP_RUNS_COUNT = 50
EPOCHS = 300

ATTRIBUTES_CONFIGURATIONS = [
    # {
    #     "USE_ORTOPHOTO": False,
    #     "USE_HEXES_ATTRS": True,
    #     "USE_OSMNX_ATTRS": True,
    # },
    # {
    #     "USE_ORTOPHOTO": False,
    #     "USE_HEXES_ATTRS": {"NUM_FEATURES": 20, "IN_PERCENT": False},
    #     "USE_OSMNX_ATTRS": True,
    # },
    # {
    #     "USE_ORTOPHOTO": False,
    #     "USE_HEXES_ATTRS": {"NUM_FEATURES": 20, "IN_PERCENT": True},
    #     "USE_OSMNX_ATTRS": True,
    # },
    # {
    #     "USE_ORTOPHOTO": False,
    #     "USE_HEXES_ATTRS": {"NUM_FEATURES": 50, "IN_PERCENT": False},
    #     "USE_OSMNX_ATTRS": True,
    # },
    # {
    #     "USE_ORTOPHOTO": False,
    #     "USE_HEXES_ATTRS": {"NUM_FEATURES": 50, "IN_PERCENT": True},
    #     "USE_OSMNX_ATTRS": True,
    # },
     {
        "USE_ORTOPHOTO": False,
        "USE_HEXES_ATTRS": {"NUM_FEATURES": 20, "IN_PERCENT": False},
        "USE_OSMNX_ATTRS": False,
    },
    {
        "USE_ORTOPHOTO": False,
        "USE_HEXES_ATTRS": {"NUM_FEATURES": 20, "IN_PERCENT": True},
        "USE_OSMNX_ATTRS": False,
    },
    {
        "USE_ORTOPHOTO": False,
        "USE_HEXES_ATTRS": {"NUM_FEATURES": 50, "IN_PERCENT": False},
        "USE_OSMNX_ATTRS": False,
    },
    {
        "USE_ORTOPHOTO": False,
        "USE_HEXES_ATTRS": {"NUM_FEATURES": 50, "IN_PERCENT": True},
        "USE_OSMNX_ATTRS": False,
    },
    # {
    #     "USE_ORTOPHOTO": False,
    #     "USE_HEXES_ATTRS": False,
    #     "USE_OSMNX_ATTRS": True,
    # },
    # {
    #     "USE_ORTOPHOTO": False,
    #     "USE_HEXES_ATTRS": True,
    #     "USE_OSMNX_ATTRS": False,
    # },
]
WANDB_SWEEP_PARAMS_GRAPH_DATA = {
    "method": "bayes",
    "metric": {"name": "mean_f1", "goal": "maximize"},
    "parameters": {
        "hidden_channels": {"values": [10, 20, 30, 40, 50]},
        "learning_rate": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2,
        },
        "num_conv_layers": {"values": [1, 2, 3, 4, 5]},
        "lin_layer_size": {"values": [8, 16, 32, 64, 128]},
        "num_lin_layers": {"values": [0, 1, 2, 3, 4]},
        "weight_decay": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2,
        },
    },
}

WANDB_SWEEP_PARAMS_TABULAR_DATA = {
    "method": "bayes",
    "metric": {"name": "mean_f1", "goal": "maximize"},
    "parameters": {
        "solver_penalty": {
            "values": [
                "lbfgs;l2",
                "liblinear;l1",
                "liblinear;l2",
                "newton-cg;l2",
                "newton-cholesky;l2",
                "sag;l2",
                "saga;elasticnet",
                "saga;l1",
                "saga;l2",
            ]
        },
        "C": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1,
        },
    },
}

In [5]:
def verify_attributes_configurations(configurations):
    for item in configurations:
        assert "USE_ORTOPHOTO" in item and isinstance(
            item["USE_ORTOPHOTO"], bool
        ), f"Invalid configuration: {item}, missing or invalid USE_ORTOPHOTO"
        assert "USE_OSMNX_ATTRS" in item and isinstance(
            item["USE_OSMNX_ATTRS"], bool
        ), f"Invalid configuration: {item}, missing or invalid USE_OSMNX_ATTRS"

        assert "USE_HEXES_ATTRS" in item, "Missing USE_HEXES_ATTRS"
        if not isinstance(item["USE_HEXES_ATTRS"], bool):
            assert isinstance(
                item["USE_HEXES_ATTRS"], dict
            ), "USE_HEXES_ATTRS should be a dict"
            assert "NUM_FEATURES" in item["USE_HEXES_ATTRS"], "Missing NUM_FEATURES"
            assert "IN_PERCENT" in item["USE_HEXES_ATTRS"], "Missing IN_PERCENT"
            assert isinstance(
                item["USE_HEXES_ATTRS"]["NUM_FEATURES"], int
            ), "NUM_FEATURES should be an int"
            assert isinstance(
                item["USE_HEXES_ATTRS"]["IN_PERCENT"], bool
            ), "IN_PERCENT should be a bool"


verify_attributes_configurations(ATTRIBUTES_CONFIGURATIONS)

# 3. Loading zabka

The process includes removing unused columns and creating GeoSeries from raw X Y points


In [6]:
all_zabkas = pd.read_csv(ALL_ZABKAS_LOCATION)


def create_point(x):
    return Point(float(x[0]), float(x[1]))


geometry = all_zabkas[["lng", "lat"]].apply(create_point, axis=1)

gdf_zabka = gpd.GeoDataFrame(all_zabkas, geometry=geometry, crs="EPSG:4326")
gdf_zabka.drop(columns=["lng", "lat"], inplace=True)
gdf_zabka['zabka_count'] = 1

# 4. Displaying available cities


In [7]:
cities = ["Kraków", "Poznań", "Szczecin", "Wrocław", "Warszawa"]

In [8]:
krk = gpd.read_parquet("/home/grymar/studia/gradient/data/organized-hexes/krakow/2022/h9/count-embedder/dataset.parquet")
krk

Unnamed: 0,region_id,geometry,amenity_gym,building_garages,landuse_gravel,sport_ultimate,office_bakery,natural_shrubbery,landuse_gress,building_guard_cabin,...,shop_eggs,historic_heritage,building_government,aeroway_navigationaid,historic_park,historic_train_station,shop_hobby,building_floating_home,amenity_vacuum_cleaner,building_castle
0,891e2e698c3ffff,"POLYGON ((20.00258 49.99518, 20.00254 49.99346...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,891e2e79abbffff,"POLYGON ((19.88302 50.10507, 19.88299 50.10336...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,891e2e693d7ffff,"POLYGON ((19.93560 49.99747, 19.93557 49.99576...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,891e2e6b09bffff,"POLYGON ((19.90475 50.07181, 19.90472 50.07010...",0,3,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,891e2e79b13ffff,"POLYGON ((19.91032 50.10264, 19.91028 50.10093...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3412,891e2e68137ffff,"POLYGON ((20.04835 50.04672, 20.04831 50.04501...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3413,891e2e61b77ffff,"POLYGON ((20.16126 50.09584, 20.16122 50.09413...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3414,891e05a64b3ffff,"POLYGON ((19.84877 49.99962, 19.84874 49.99791...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3415,891e2e7994fffff,"POLYGON ((19.92289 50.11039, 19.92286 50.10868...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


# 5. Creating GeoDataFrames

The process of creation has following steps:

1. loading OSMNX nodes and edges
2. assigning accidents to OSMNX nodes
3. taking latest H9 resolution hexes
4. combining OSMNX nodes, OSMNX edges, hexes in a single dict and packing it inside gdfs_dict


In [9]:
def add_accidents_to_osmnx_nodes(
    zabkas: gpd.GeoDataFrame,
    nodes: gpd.GeoDataFrame,
    edges: gpd.GeoDataFrame,
    city_name: str,
):
    with open(OSMNX_ALL_ATTRIBUTES_LOCATION) as f:
        all_attributes = json.load(f)

    osmnx_graph = OSMnxGraph(
        zabkas.loc[
            zabkas["county"] == resolve_nominatim_city_name(city_name), :
        ],
        nodes,
        edges,
        all_attributes,
        y_column_name="zabka_count"
    )
    osmnx_graph._aggregate(element_type="node", aggregation_method="count")
    return osmnx_graph.gdf_nodes


def create_gdfs(city_name: str, accidents_gdf: gpd.GeoDataFrame = gdf_zabka):
    city_folder_name = convert_nominatim_name_to_filename(
        resolve_nominatim_city_name(city_name)
    )
    osmnx_nodes = gpd.read_parquet(
        os.path.join(ORGANIZED_GRAPHS_LOCATION, city_folder_name, "nodes.parquet")
    )
    osmnx_nodes = osmnx_nodes.reset_index()
    osmnx_nodes.index.names = ["node_id"]
    osmnx_nodes["x"] = osmnx_nodes["geometry"].x
    osmnx_nodes["y"] = osmnx_nodes["geometry"].y

    osmnx_edges = gpd.read_parquet(
        os.path.join(ORGANIZED_GRAPHS_LOCATION, city_folder_name, "edges.parquet")
    )
    osmnx_edges = osmnx_edges.reset_index().rename(columns={"index": "edge_id"})
    osmnx_edges.index.names = ["edge_id"]

    assert osmnx_nodes.crs == osmnx_edges.crs
    assert osmnx_nodes.crs == accidents_gdf.crs

    osmnx_nodes = add_accidents_to_osmnx_nodes(
        zabkas=accidents_gdf,
        nodes=osmnx_nodes,
        city_name=city_name,
        edges=osmnx_edges,
    )

    hexes_years_folder = os.path.join(ORGANIZED_HEXES_LOCATION, city_folder_name)

    subfolders = [
        int(f)
        for f in os.listdir(hexes_years_folder)
        if os.path.isdir(os.path.join(hexes_years_folder, f))
    ]
    highest_year = subfolders[np.argmax(subfolders)]

    hexes: gpd.GeoDataFrame = gpd.read_parquet(
        os.path.join(
            ORGANIZED_HEXES_LOCATION,
            f"{convert_nominatim_name_to_filename(resolve_nominatim_city_name(city_name))}/{highest_year}/h9/count-embedder/dataset.parquet",
        )
    )

    hexes = hexes.rename(columns={"region_id": "h3_id"}).rename_axis(
        "region_id", axis=0
    )
    display(hexes)

    return dict(osmnx_nodes=osmnx_nodes, osmnx_edges=osmnx_edges, hexes=hexes)


print("Creating gdfs...")
gdfs_dict = {city_name: create_gdfs(city_name+ ", Poland") for city_name in tqdm(cities)}

Creating gdfs...


  0%|          | 0/5 [00:00<?, ?it/s]

Unnamed: 0_level_0,h3_id,geometry,amenity_gym,building_garages,landuse_gravel,sport_ultimate,office_bakery,natural_shrubbery,landuse_gress,building_guard_cabin,...,shop_eggs,historic_heritage,building_government,aeroway_navigationaid,historic_park,historic_train_station,shop_hobby,building_floating_home,amenity_vacuum_cleaner,building_castle
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,891e2e698c3ffff,"POLYGON ((20.00258 49.99518, 20.00254 49.99346...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,891e2e79abbffff,"POLYGON ((19.88302 50.10507, 19.88299 50.10336...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,891e2e693d7ffff,"POLYGON ((19.93560 49.99747, 19.93557 49.99576...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,891e2e6b09bffff,"POLYGON ((19.90475 50.07181, 19.90472 50.07010...",0,3,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,891e2e79b13ffff,"POLYGON ((19.91032 50.10264, 19.91028 50.10093...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3412,891e2e68137ffff,"POLYGON ((20.04835 50.04672, 20.04831 50.04501...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3413,891e2e61b77ffff,"POLYGON ((20.16126 50.09584, 20.16122 50.09413...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3414,891e05a64b3ffff,"POLYGON ((19.84877 49.99962, 19.84874 49.99791...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3415,891e2e7994fffff,"POLYGON ((19.92289 50.11039, 19.92286 50.10868...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


 20%|██        | 1/5 [00:00<00:01,  2.15it/s]

Unnamed: 0_level_0,h3_id,geometry,amenity_gym,building_garages,landuse_gravel,sport_ultimate,office_bakery,natural_shrubbery,landuse_gress,building_guard_cabin,...,shop_eggs,historic_heritage,building_government,aeroway_navigationaid,historic_park,historic_train_station,shop_hobby,building_floating_home,amenity_vacuum_cleaner,building_castle
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,891e24aa333ffff,"POLYGON ((16.92005 52.37637, 16.92009 52.37472...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,891e24a338bffff,"POLYGON ((16.95623 52.46905, 16.95627 52.46741...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,891e24aa667ffff,"POLYGON ((16.86364 52.38452, 16.86368 52.38287...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,891e24aa063ffff,"POLYGON ((16.89428 52.38548, 16.89432 52.38383...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,891e24b838fffff,"POLYGON ((16.80072 52.44687, 16.80076 52.44523...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2940,891e24a803bffff,"POLYGON ((16.91838 52.33916, 16.91842 52.33751...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2941,891e24aa4d7ffff,"POLYGON ((16.86291 52.41420, 16.86295 52.41255...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2942,891e24a85bbffff,"POLYGON ((16.93311 52.36439, 16.93315 52.36274...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2943,891e24ab58bffff,"POLYGON ((16.83312 52.37860, 16.83317 52.37695...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


 40%|████      | 2/5 [00:00<00:01,  2.29it/s]

Unnamed: 0_level_0,h3_id,geometry,amenity_gym,building_garages,landuse_gravel,sport_ultimate,office_bakery,natural_shrubbery,landuse_gress,building_guard_cabin,...,shop_eggs,historic_heritage,building_government,aeroway_navigationaid,historic_park,historic_train_station,shop_hobby,building_floating_home,amenity_vacuum_cleaner,building_castle
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,891f0e7b2d3ffff,"POLYGON ((14.47465 53.44809, 14.47475 53.44647...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,891f0e78e8fffff,"POLYGON ((14.65338 53.44017, 14.65348 53.43855...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,891f0e79357ffff,"POLYGON ((14.54000 53.39308, 14.54010 53.39147...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,891f0e7937bffff,"POLYGON ((14.54545 53.38849, 14.54554 53.38688...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,891f0e4e937ffff,"POLYGON ((14.46769 53.47689, 14.46779 53.47527...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3529,891f1db6a6bffff,"POLYGON ((14.53327 53.33432, 14.53336 53.33270...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3530,891f0e796cfffff,"POLYGON ((14.48220 53.40957, 14.48230 53.40796...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3531,891f0e7a9d3ffff,"POLYGON ((14.68903 53.49299, 14.68913 53.49137...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3532,891f1db69b7ffff,"POLYGON ((14.59808 53.37170, 14.59817 53.37009...",0,2,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


 60%|██████    | 3/5 [00:01<00:00,  2.27it/s]

Unnamed: 0_level_0,h3_id,geometry,amenity_gym,building_garages,landuse_gravel,sport_ultimate,office_bakery,natural_shrubbery,landuse_gress,building_guard_cabin,...,shop_eggs,historic_heritage,building_government,aeroway_navigationaid,historic_park,historic_train_station,shop_hobby,building_floating_home,amenity_vacuum_cleaner,building_castle
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,891e204e4a7ffff,"POLYGON ((17.04297 51.07181, 17.04301 51.07013...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,891e204239bffff,"POLYGON ((16.93065 51.15931, 16.93069 51.15763...",0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,891e2041837ffff,"POLYGON ((16.97292 51.06466, 16.97295 51.06298...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,891e2047223ffff,"POLYGON ((17.09217 51.11872, 17.09221 51.11703...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,891e2046657ffff,"POLYGON ((17.12372 51.17263, 17.12376 51.17095...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3163,891e2041137ffff,"POLYGON ((16.91990 51.08073, 16.91994 51.07905...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3164,891e2042273ffff,"POLYGON ((16.90335 51.14586, 16.90339 51.14418...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3165,891e20430d7ffff,"POLYGON ((16.84084 51.13635, 16.84088 51.13466...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3166,891e2050a13ffff,"POLYGON ((16.88187 51.20572, 16.88191 51.20404...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


 80%|████████  | 4/5 [00:01<00:00,  2.22it/s]

Unnamed: 0_level_0,h3_id,geometry,amenity_gym,building_garages,landuse_gravel,sport_ultimate,office_bakery,natural_shrubbery,landuse_gress,building_guard_cabin,...,shop_eggs,historic_heritage,building_government,aeroway_navigationaid,historic_park,historic_train_station,shop_hobby,building_floating_home,amenity_vacuum_cleaner,building_castle
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,891f53ca5cfffff,"POLYGON ((21.04469 52.34512, 21.04463 52.34346...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,891f522611bffff,"POLYGON ((20.93693 52.20629, 20.93687 52.20463...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,891f5352e6fffff,"POLYGON ((21.21538 52.17771, 21.21532 52.17605...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,891f535046fffff,"POLYGON ((21.16087 52.16057, 21.16080 52.15891...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,891f53ca6a3ffff,"POLYGON ((21.01083 52.33283, 21.01077 52.33118...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5570,891f53c160bffff,"POLYGON ((21.07095 52.35991, 21.07089 52.35825...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5571,891f522418fffff,"POLYGON ((20.95831 52.15892, 20.95826 52.15726...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5572,891f5352b43ffff,"POLYGON ((21.25334 52.16504, 21.25327 52.16338...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5573,891f53504afffff,"POLYGON ((21.16904 52.17298, 21.16898 52.17132...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


100%|██████████| 5/5 [00:02<00:00,  2.06it/s]


# 6. Creating GraphLayerController for each of the cities

The creation is based on previously made GeoDataFrames. The controller is used to transfer accidents Y values from OSMNX nodes to hexes. It is also used to create complete graph data in case of graph datasets.


In [10]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for gdf_for_city in gdfs_dict.values():
        gdf_for_city["controller"] = GraphLayerController(
            gdf_for_city["hexes"],
            gdf_for_city["osmnx_nodes"],
            gdf_for_city["osmnx_edges"],
        )

# 7. Patching hexes

The y value (1 = accident occured, 0 = no accident) is assigned to each of the hexes according to its underlying OSMNX nodes


In [11]:
def patch_hexes_with_y(
    osmnx_nodes: gpd.GeoDataFrame,
    hexes: gpd.GeoDataFrame,
    controller: GraphLayerController,
):
    virtual_edges = controller.get_virtual_edges_to_hexes(SourceType.OSMNX_NODES)
    hexes_with_y = cast(
        gpd.GeoDataFrame,
        hexes.merge(
            virtual_edges.merge(osmnx_nodes, left_on="source_id", right_index=True)[
                ["region_id", "zabka_count"]
            ]
            .groupby("region_id")
            .sum(),
            left_index=True,
            right_index=True,
            how="left",
        ).fillna(0),
    )
    hexes_with_y["zabka"] = (hexes_with_y["zabka_count"] > 0).astype(int)
    hexes_with_y.drop(columns="zabka_count", inplace=True)
    controller.hexes_gdf = hexes_with_y
    controller._hexes_centroids_gdf = controller._create_hexes_centroids_gdf()

In [12]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for gdfs in gdfs_dict.values():
        patch_hexes_with_y(gdfs["osmnx_nodes"], gdfs["hexes"], gdfs["controller"])

# 8. Creating graph data

Graph data is used when we include OSMNX attributes and in turn maintain the graph structure of the data

The data is created only once for now just to create (train, val, test) folds labels for crossvalidation on graph-based versions of the task


In [13]:
from typing import List, Literal, Union
def create_graph_data(
    osmnx_nodes,
    osmnx_edges,
    hexes,
    controller: GraphLayerController,
    use_hexes_attr: bool,
    use_ortophoto: bool,
    columns_to_take: Union[List[str], Literal["all"]] = []
):
    edges_attr_columns = osmnx_edges.columns[
        ~osmnx_edges.columns.isin(["u", "v", "key", "geometry"])
    ]
    nodes_attr_columns = osmnx_nodes.columns[
        ~osmnx_nodes.columns.isin(["geometry", "x", "y", "osmid", "zabka_count"])
    ]
    if use_hexes_attr:
        hexes_attr_columns = (
            hexes.columns[~hexes.columns.isin(["geometry", "h3_id"])]
            if columns_to_take == "all"
            else columns_to_take
        )
    else:
        hexes_attr_columns = []
    # hexes_attr_columns = (
    #     hexes.columns[~hexes.columns.isin(["geometry", "h3_id"])]
    #     if use_hexes_attr
    #     else []
    # )

    data = create_hetero_data(
        controller,
        hexes_attrs_columns_names=hexes_attr_columns,
        osmnx_edge_attrs_columns_names=edges_attr_columns,
        osmnx_node_attrs_columns_names=nodes_attr_columns,
        virtual_edge_attrs_columns_names=[],
        hexes_y_columns_names=["zabka"],
    )
    return data


graph_data_dict = {
    city_name: create_graph_data(**gdfs, use_ortophoto=True, use_hexes_attr=True)
    for city_name, gdfs in gdfs_dict.items()
}
graph_data_dict

True
True
True
True
True


{'Kraków': CityHeteroData(
   hex={
     x=[3417, 0],
     y=[3417],
   },
   osmnx_node={ x=[66493, 22] },
   (hex, connected_to, hex)={ edge_index=[2, 9911] },
   (osmnx_node, connected_to, osmnx_node)={
     edge_index=[2, 117967],
     edge_attr=[117967, 44],
   },
   (osmnx_node, connected_to, hex)={
     edge_index=[2, 66493],
     edge_attr=[66493, 0],
   }
 ),
 'Poznań': CityHeteroData(
   hex={
     x=[2945, 0],
     y=[2945],
   },
   osmnx_node={ x=[60082, 22] },
   (hex, connected_to, hex)={ edge_index=[2, 8486] },
   (osmnx_node, connected_to, osmnx_node)={
     edge_index=[2, 98654],
     edge_attr=[98654, 44],
   },
   (osmnx_node, connected_to, hex)={
     edge_index=[2, 60082],
     edge_attr=[60082, 0],
   }
 ),
 'Szczecin': CityHeteroData(
   hex={
     x=[3534, 0],
     y=[3534],
   },
   osmnx_node={ x=[64894, 22] },
   (hex, connected_to, hex)={ edge_index=[2, 10163] },
   (osmnx_node, connected_to, osmnx_node)={
     edge_index=[2, 106588],
     edge_attr=[106588

# 9. Creating tabular data

Tabular data is used when we omit OSMNX attributes and in turn lose the graph structure of the data

No folds creation on tabular-based versions of the task - using simple leave-one-out


In [14]:
def create_tabular_data(
    hexes: pd.DataFrame,
    controller: GraphLayerController,
    use_hexes_attr: bool,
    use_ortophoto: bool,
    columns_to_take: Union[List[str], Literal["all"]] = [],
):
    assert use_ortophoto or use_hexes_attr, "Provide at least one data source"
    print(columns_to_take)

    if use_hexes_attr:
        hexes_attr_columns = (
            hexes.columns[~hexes.columns.isin(["geometry", "h3_id"])]
            if columns_to_take == "all"
            else columns_to_take
        )
    else:
        hexes_attr_columns = []

    hexes_y_columns_names = ["zabka"]

    X = hexes[hexes_attr_columns]
    y = controller.hexes_centroids_gdf[hexes_y_columns_names]

    return {"X": X, "y": y}


tabular_data_dict = {
    city_name: create_tabular_data(
        gdfs["hexes"], gdfs["controller"], use_ortophoto=True, use_hexes_attr=True
    )
    for city_name, gdfs in gdfs_dict.items()
}

[]
[]
[]
[]
[]


# 10. Creating folds labels


In [15]:
def shift_elements_right(lst):
    shifted_lst = [lst[-1]] + lst[:-1]
    return shifted_lst


cities_names_list = list(graph_data_dict.keys())
cities_names_list.sort(key=lambda x: str(x))

# val + test
folds_tuples = list(zip(shift_elements_right(cities_names_list), cities_names_list))
display(folds_tuples)

[('Wrocław', 'Kraków'),
 ('Kraków', 'Poznań'),
 ('Poznań', 'Szczecin'),
 ('Szczecin', 'Warszawa'),
 ('Warszawa', 'Wrocław')]

# 11. Functions setup


In [16]:
import torch 
def run_k_fold_graph_data(closure_config, sweep_id, hex_attr_config=None):
    # pass external config (i.e. what attributes are used in the data), closure to avoid passing it to the function directly
    def wrapped():
        run = wandb.init()
        epochs = EPOCHS

        config = wandb.config

        for k, v in closure_config.items():
            run.log({k: str(v) if v else 0})

        run.log({"data_structure": "graph"})
        # if hex_attr_config is not None:
        #     run.log(
        #         {
        #             "hex_features": f"top_{hex_attr_config['NUM_FEATURES']}_percent_{hex_attr_config['IN_PERCENT']}"
        #         }
        #     )
        # else:
        #     run.log({"hex_features": "all"})

        # create hparams
        if hasattr(config, "lin_layer_size") and hasattr(config, "num_lin_layers"):
            lin_layer_sizes = [config.lin_layer_size] * config.num_lin_layers
        else:
            lin_layer_sizes = config.lin_layer_sizes
        hparams = {
            "hidden_channels": config.hidden_channels,
            "lr": config.learning_rate,
            "num_conv_layers": config.num_conv_layers,
            "lin_layer_sizes": lin_layer_sizes,
            "weight_decay": config.weight_decay,
        }

        aucs = []
        accuracies = []
        f1s = []

        fold_group_id = generate_id()

        # log data as artifact if no data was logged in the sweep before
        # dataset is uploaded only on the first run in sweep, because it does not change across runs in sweep
        # in wandb, dataset will be visible on the first run in the sweep
        artifact_path = os.path.join(TRAIN_SAVE_DIR, f"graph_data_{sweep_id}.pkl")
        if not os.path.exists(artifact_path):
            dump(
                graph_data_dict,
                artifact_path,
                protocol=5,
            )
            artifact = wandb.Artifact(
                name="graph_data", type="dataset", metadata=closure_config
            )
            artifact.add_file(local_path=artifact_path)
            run.log_artifact(artifact)

        # run k-fold
        for index, (val_city_name, test_city_name) in enumerate(folds_tuples):
            # prepare data
            val_data = [graph_data_dict[val_city_name].to("cpu").clone()]
            train_data = [
                v.to("cpu").clone()
                for k, v in graph_data_dict.items()
                if k != val_city_name and k != test_city_name
            ]
            test_data = graph_data_dict[test_city_name].to("cpu").clone()

            # run training with checkpointing on lowest val_loss, return test metrics for the best model and its path
            # builtin preprocessing - scaling to N(0, 1)
            auc, accuracy, f1, model_path = train(
                train_data=train_data,
                val_data=val_data,
                test_data=test_data,
                epochs=epochs,
                hparams=hparams,
                train_save_dir=TRAIN_SAVE_DIR,
                num_classes=torch.unique(train_data[0]["hex"].y).shape[0]
            )

            # logging - single fold
            run.log_model(
                path=model_path,
                name=f"model_{fold_group_id}_fold_{index}",
            )
            run.log({f"auc_fold_{index}": auc})
            run.log({f"accuracy_fold_{index}": accuracy})
            run.log({f"f1_fold_{index}": f1})

            aucs.append(auc)
            accuracies.append(accuracy)
            f1s.append(f1)

        # logging - summary statistics
        mean_auc = sum(aucs) / len(aucs)
        mean_accuracy = sum(accuracies) / len(accuracies)
        mean_f1 = sum(f1s) / len(f1s)
        run.log({"mean_auc": mean_auc})
        run.log({"mean_accuracy": mean_accuracy})
        run.log({"mean_f1": mean_f1})

    return wrapped


def run_k_fold_tabular_data(closure_config, sweep_id, hex_attr_config=None):
    # analogously to the graph data, but for tabular data
    def wrapped():
        run = wandb.init()

        config = wandb.config
        for k, v in closure_config.items():
            run.log({k: str(v) if v else 0})

        # run.log({"data_structure": "tabular"})
        # if hex_attr_config is not None:
        #     run.log(
        #         {
        #             "hex_features": f"top_{hex_attr_config['NUM_FEATURES']}_percent_{hex_attr_config['IN_PERCENT']}"
        #         }
        #     )
        # else:
        #     run.log({"hex_features": "all"})

        hparams = {}
        hparams["C"] = config["C"]
        solver, penalty = config["solver_penalty"].split(";")
        hparams["solver"] = solver
        if penalty == "None":
            penalty = None
        hparams["penalty"] = penalty

        aucs = []
        accuracies = []
        f1s = []

        fold_group_id = generate_id()

        # log data as artifact
        artifact_path = os.path.join(TRAIN_SAVE_DIR, f"tabular_data_{sweep_id}.pkl")

        if not os.path.exists(artifact_path):
            dump(
                tabular_data_dict,
                artifact_path,
                protocol=5,
            )
            artifact = wandb.Artifact(
                name="tabular_data", type="dataset", metadata=closure_config
            )
            artifact.add_file(local_path=artifact_path)
            run.log_artifact(artifact)

        timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

        for index, test_city_name in enumerate(cities_names_list):
            scaler = StandardScaler()
            X = pd.concat(
                [
                    m["X"]
                    for key, m in tabular_data_dict.items()
                    if key != test_city_name
                ]
            ).to_numpy()
            y = (
                pd.concat(
                    [
                        m["y"]
                        for key, m in tabular_data_dict.items()
                        if key != test_city_name
                    ]
                )
                .to_numpy()
                .ravel()
            )

            X = scaler.fit_transform(X)

            logistic_regression = LogisticRegression(
                C=hparams["C"],
                solver=hparams["solver"],
                penalty=hparams["penalty"],
                dual=False,
                tol=1e-4,
                fit_intercept=True,
                intercept_scaling=1,
                class_weight="balanced",
                random_state=1124,
                max_iter=1000,
                multi_class="auto",
                warm_start=False,
                n_jobs=-1,
                l1_ratio=0.5,
            )
            logistic_regression.fit(X, y)

            test_X = tabular_data_dict[test_city_name]["X"].to_numpy()
            test_X = scaler.transform(test_X)
            test_y = tabular_data_dict[test_city_name]["y"].to_numpy().ravel()
            y_pred = logistic_regression.predict(test_X)
            y_proba = logistic_regression.predict_proba(test_X)[:, 1]

            auc = roc_auc_score(test_y, y_proba, average="micro")
            accuracy = (y_pred == test_y).mean()
            f1 = f1_score(
                test_y,
                y_pred,
                pos_label=1,
                average="binary",
            )

            model_dir = os.path.join(TRAIN_SAVE_DIR, timestamp)

            os.makedirs(model_dir, exist_ok=True)

            model_path = os.path.join(
                model_dir, f"model_{fold_group_id}_fold_{index}.pkl"
            )

            with open(model_path, "wb") as f:
                dump(logistic_regression, f, protocol=5)

            run.log_model(
                path=model_path,
                name=f"model_{fold_group_id}_fold_{index}",
            )
            run.log({f"auc_fold_{index}": auc})
            run.log({f"accuracy_fold_{index}": accuracy})
            run.log({f"f1_fold_{index}": f1})

            aucs.append(auc)
            accuracies.append(accuracy)
            f1s.append(f1)

        mean_auc = sum(aucs) / len(aucs)
        mean_accuracy = sum(accuracies) / len(accuracies)
        mean_f1 = sum(f1s) / len(f1s)
        run.log({"mean_auc": mean_auc})
        run.log({"mean_accuracy": mean_accuracy})
        run.log({"mean_f1": mean_f1})

    return wrapped


def run_sweep_graph_data(config, hex_attr_config=None):
    try:
        wandb.login(key=WANDB_API_KEY)
        sweep_id = wandb.sweep(
            WANDB_SWEEP_PARAMS_GRAPH_DATA, project="zabka-downstream-task"
        )

        wandb.agent(
            sweep_id,
            function=run_k_fold_graph_data(config, sweep_id, hex_attr_config),
            count=SWEEP_RUNS_COUNT,
        )
    except Exception as e:
        print(e)
        wandb.finish()
        wandb.sweep
        raise e


def run_sweep_tabular_data(config, hex_attr_config=None):
    try:
        wandb.login(key=WANDB_API_KEY)

        sweep_id = wandb.sweep(
            WANDB_SWEEP_PARAMS_TABULAR_DATA, project="zabka-downstream-task"
        )

        wandb.agent(
            sweep_id,
            function=run_k_fold_tabular_data(config, sweep_id, hex_attr_config),
            count=SWEEP_RUNS_COUNT,
        )
    except Exception as e:
        print(e)
        wandb.finish()
        raise e

# 12. Run functions

For each config:

1. Determine if config requires tabular or graph data
2. Create data excluding attributes not included in the config
3. Run the sweep


In [17]:
from typing import Any, Dict
def derive_data_structure(attr_config):
    if attr_config["USE_OSMNX_ATTRS"]:
        return "graph"
    return "tabular"


configs_size = len(ATTRIBUTES_CONFIGURATIONS)

for index, attr_config in enumerate(ATTRIBUTES_CONFIGURATIONS):
    print("Sweep for config {}/{} in progress...".format(index + 1, configs_size))

    # assert "USE_ORTOPHOTO" in attr_config, "Provide USE_ORTOPHOTO key"
    # assert "USE_HEXES_ATTRS" in attr_config, "Provide USE_HEXES_ATTRS key"
    # assert "USE_OSMNX_ATTRS" in attr_config, "Provide USE_OSMNX_ATTRS key"

    data_structure = derive_data_structure(attr_config)
    creator_params: Dict[str, Any] = dict(
        use_hexes_attr=bool(attr_config["USE_HEXES_ATTRS"]),
    )
    if isinstance(attr_config["USE_HEXES_ATTRS"], dict):
        hex_fi_config = attr_config["USE_HEXES_ATTRS"]

        hex_features = pd.read_json(
            f"{HEX_FI_LOCATION}/zabka_shops_top_{hex_fi_config['NUM_FEATURES']}_percent_{hex_fi_config['IN_PERCENT']}.json"
        )
        hex_features = hex_features["top_values"].tolist()
        creator_params["columns_to_take"] = hex_features
    elif attr_config["USE_HEXES_ATTRS"] == True:
        creator_params["columns_to_take"] = "all"



    if data_structure == "graph":
        graph_data_dict = {
            city_name: create_graph_data(
                hexes=gdfs["hexes"],
                controller=cast(GraphLayerController, gdfs["controller"]),
                osmnx_edges=gdfs["osmnx_edges"],
                osmnx_nodes=gdfs["osmnx_nodes"],
                use_ortophoto=attr_config["USE_ORTOPHOTO"],
                **creator_params,
            )
            for city_name, gdfs in gdfs_dict.items()
        }
        run_sweep_graph_data(attr_config)
    elif data_structure == "tabular":
        tabular_data_dict = {
            city_name: create_tabular_data(
                hexes=gdfs["hexes"],
                controller=cast(GraphLayerController, gdfs["controller"]),
                use_ortophoto=attr_config["USE_ORTOPHOTO"],
                **creator_params,
            )
            for city_name, gdfs in gdfs_dict.items()
        }
        run_sweep_tabular_data(attr_config)
    else:
        raise ValueError("Unknown data structure")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Sweep for config 1/4 in progress...
['shop_convenience', 'amenity_parking', 'building_apartments', 'amenity_parking_entrance', 'building_retail', 'shop_supermarket', 'amenity_parcel_locker', 'leisure_playground', 'amenity_pharmacy', 'landuse_grass', 'amenity_shelter', 'building_yes', 'amenity_dive_centre', 'sport_basketball', 'building_subway_entrance', 'historic_machinery', 'leisure_fitness_centre', 'shop_slot_machine', 'shop_spa', 'sport_ultimate']
['shop_convenience', 'amenity_parking', 'building_apartments', 'amenity_parking_entrance', 'building_retail', 'shop_supermarket', 'amenity_parcel_locker', 'leisure_playground', 'amenity_pharmacy', 'landuse_grass', 'amenity_shelter', 'building_yes', 'amenity_dive_centre', 'sport_basketball', 'building_subway_entrance', 'historic_machinery', 'leisure_fitness_centre', 'shop_slot_machine', 'shop_spa', 'sport_ultimate']
['shop_convenience', 'amenity_parking', 'building_apartments', 'amenity_parking_entrance', 'building_retail', 'shop_supermarke

[34m[1mwandb[0m: Currently logged in as: [33mgrymar[0m ([33mgradient_pwr[0m). Use [1m`wandb login --relogin`[0m to force relogin


Create sweep with ID: qd0aywme
Sweep URL: https://wandb.ai/gradient_pwr/zabka-downstream-task/sweeps/qd0aywme


[34m[1mwandb[0m: Agent Starting Run: ufmmwagq with config:
[34m[1mwandb[0m: 	C: 0.104672579590072
[34m[1mwandb[0m: 	solver_penalty: saga;l1
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




VBox(children=(Label(value='3.001 MB of 3.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
USE_ORTOPHOTO,▁
USE_OSMNX_ATTRS,▁
accuracy_fold_0,▁
accuracy_fold_1,▁
accuracy_fold_2,▁
accuracy_fold_3,▁
accuracy_fold_4,▁
auc_fold_0,▁
auc_fold_1,▁
auc_fold_2,▁

0,1
USE_HEXES_ATTRS,"{'NUM_FEATURES': 20,..."
USE_ORTOPHOTO,0
USE_OSMNX_ATTRS,0
accuracy_fold_0,0.91513
accuracy_fold_1,0.92054
accuracy_fold_2,0.9635
accuracy_fold_3,0.87641
accuracy_fold_4,0.89867
auc_fold_0,0.96562
auc_fold_1,0.9586


[34m[1mwandb[0m: Agent Starting Run: 6ohc6bpf with config:
[34m[1mwandb[0m: 	C: 0.02878012211226249
[34m[1mwandb[0m: 	solver_penalty: newton-cg;l2
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


Sweep for config 2/4 in progress...
['shop_convenience', 'amenity_parking', 'building_apartments', 'amenity_parking_entrance', 'building_retail', 'shop_supermarket', 'amenity_parcel_locker', 'leisure_playground', 'amenity_pharmacy', 'landuse_grass', 'amenity_shelter', 'building_yes', 'amenity_dive_centre', 'sport_basketball', 'building_subway_entrance', 'historic_machinery', 'leisure_fitness_centre', 'shop_slot_machine', 'shop_spa', 'sport_ultimate', 'building_guard_cabin', 'landuse_gress', 'natural_shrubbery', 'office_bakery', 'sport_yes', 'landuse_gravel', 'sport_karate', 'building_pavilion', 'leisure_martial_arts', 'sport_indoor_skiing;climbing', 'sport_bodybuilding', 'amenity_social_club', 'office_university', 'office_engineer', 'building_recreation', 'office_nursing_service', 'shop_comic_books', 'sport_tennis;volleyball', 'sport_climbing;squash;table_tennis;dancing', 'office_software', 'amenity_fire_station', 'shop_fish', 'shop_frozen_food', 'tourism_camp_site', 'amenity_baby_hatc



{'Kraków': {'X':            shop_convenience  amenity_parking  building_apartments  \
region_id                                                           
0                         0                0                    0   
1                         0                3                    0   
2                         0                0                    0   
3                         2                3                   13   
4                         0                0                    0   
...                     ...              ...                  ...   
3412                      0                0                    0   
3413                      0                0                    0   
3414                      0                2                    0   
3415                      0                1                    0   
3416                      0                0                    0   

           amenity_parking_entrance  building_retail  shop_supermarket  \
region_id  

[34m[1mwandb[0m: [32m[41mERROR[0m Problem finishing run
Exception in thread Thread-105 (_run_job):
Traceback (most recent call last):
  File "/home/grymar/studia/gradient/env/lib/python3.10/site-packages/wandb/agents/pyagent.py", line 307, in _run_job
    self._function()
  File "/tmp/ipykernel_761225/3190395129.py", line 197, in wrapped
  File "/home/grymar/studia/gradient/env/lib/python3.10/site-packages/sklearn/utils/_set_output.py", line 157, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
  File "/home/grymar/studia/gradient/env/lib/python3.10/site-packages/sklearn/preprocessing/_data.py", line 1006, in transform
    X = self._validate_data(
  File "/home/grymar/studia/gradient/env/lib/python3.10/site-packages/sklearn/base.py", line 626, in _validate_data
    self._check_n_features(X, reset=reset)
  File "/home/grymar/studia/gradient/env/lib/python3.10/site-packages/sklearn/base.py", line 415, in _check_n_features
    raise ValueError(
ValueError: X has 430 feature

Create sweep with ID: v462n6a4
Sweep URL: https://wandb.ai/gradient_pwr/zabka-downstream-task/sweeps/v462n6a4
