In [1]:
import sys
sys.path.append('../../') 
import geopandas as gpd
from src.organized_datasets_creation.utils import resolve_nominatim_city_name
from src.graph_layering.graph_layer_creator import GraphLayerController
import pandas as pd
import os
from src.graph_layering.graph_layer_creator import SourceType
import warnings
from src.graph_layering.create_hetero_data import create_hetero_data
from sklearn.preprocessing import OneHotEncoder

from tqdm import tqdm

import wandb.util
import wandb
import os


import numpy as np
from src.graph.create_osmnx_graph import OSMnxGraph
import json
from shapely.geometry import Point
from joblib import dump


from datetime import datetime
from sklearn.metrics import f1_score, roc_auc_score
from wandb.util import generate_id
from sklearn.linear_model import LogisticRegression
from src.training.train import train
from sklearn.preprocessing import StandardScaler

%load_ext autoreload
%autoreload 2

In [2]:
WANDB_API_KEY = os.environ.get("WANDB_API_KEY", None)
assert (
    WANDB_API_KEY is not None
), "WANDB_API_KEY is not set, did you forget it in the config file?"

In [3]:
# general settings
ORGANIZED_HEXES_LOCATION = "/home/grymar/studia/gradient/data/organized-hexes"
ORGANIZED_GRAPHS_LOCATION = "/home/grymar/studia/gradient/data/organized_graphs"
OSMNX_ALL_ATTRIBUTES_LOCATION = (
    "/home/grymar/studia/gradient/data/osmnx_attributes.json"
)

# downstream task settings
AIRBNB_LOCATION = "/home/grymar/studia/gradient/data/downstream_tasks/airbnb"
TRAIN_SAVE_DIR = "/tmp"

SWEEP_RUNS_COUNT = 50
EPOCHS = 10

ATTRIBUTES_CONFIGURATIONS = [
    {
        "USE_ORTOPHOTO": False,
        "USE_HEXES_ATTRS": True,
        "USE_OSMNX_ATTRS": True,
    },
    {
        "USE_ORTOPHOTO": False,
        "USE_HEXES_ATTRS": True,
        "USE_OSMNX_ATTRS": True,
    },
    {
        "USE_ORTOPHOTO": False,
        "USE_HEXES_ATTRS": False,
        "USE_OSMNX_ATTRS": True,
    },
]

WANDB_SWEEP_PARAMS_GRAPH_DATA = {
    "method": "bayes",
    "metric": {"name": "mean_f1", "goal": "maximize"},
    "parameters": {
        "hidden_channels": {"values": [10, 20, 30, 40, 50]},
        "learning_rate": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2,
        },
        "num_conv_layers": {"values": [1, 2, 3, 4, 5]},
        "lin_layer_size": {"values": [8, 16, 32, 64, 128]},
        "num_lin_layers": {"values": [0, 1, 2, 3, 4]},
        "weight_decay": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2,
        },
    },
}

WANDB_SWEEP_PARAMS_TABULAR_DATA = {
    "method": "bayes",
    "metric": {"name": "mean_f1", "goal": "maximize"},
    "parameters": {
        "solver_penalty": {
            "values": [
                "lbfgs;l2",
                "liblinear;l1",
                "liblinear;l2",
                "newton-cg;l2",
                "newton-cholesky;l2",
                "sag;l2",
                "saga;elasticnet",
                "saga;l1",
                "saga;l2",
            ]
        },
        "C": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1,
        },
    },
}

# Load airbnb data

In [4]:
airbnb_ny = pd.read_csv(f"{AIRBNB_LOCATION}/new_york_airbnb_in_hexes_res_9.csv")
airbnb_ny['mie_nazwa'] = "new_york"
airbnb_ny.reset_index(drop=True, inplace=True)
airbnb_ny

Unnamed: 0,region_id,price,price_class,mie_nazwa
0,892a1000003ffff,,no airbnb,new_york
1,892a1000007ffff,,no airbnb,new_york
2,892a100000bffff,,no airbnb,new_york
3,892a100000fffff,,no airbnb,new_york
4,892a1000013ffff,,no airbnb,new_york
...,...,...,...,...
11844,892a10776cbffff,,no airbnb,new_york
11845,892a10776cfffff,,no airbnb,new_york
11846,892a10776d3ffff,146.0,high,new_york
11847,892a10776d7ffff,163.4,high,new_york


In [5]:
airbnb_ny['price_class'].unique()

array(['no airbnb', 'very high', 'medium', 'low', 'high'], dtype=object)

In [6]:
airbnb_st = pd.read_csv(f"{AIRBNB_LOCATION}/seattle_airbnb_in_hexes_res_9.csv")
airbnb_st['mie_nazwa'] ="seattle"
airbnb_st.reset_index(drop=True, inplace=True)
airbnb_st

Unnamed: 0,region_id,price,price_class,mie_nazwa
0,8928d540003ffff,,no airbnb,seattle
1,8928d540007ffff,,no airbnb,seattle
2,8928d54000bffff,,no airbnb,seattle
3,8928d54000fffff,,no airbnb,seattle
4,8928d540013ffff,,no airbnb,seattle
...,...,...,...,...
4198,8928d5cdbcbffff,,no airbnb,seattle
4199,8928d5cdbcfffff,,no airbnb,seattle
4200,8928d5cdbd3ffff,,no airbnb,seattle
4201,8928d5cdbd7ffff,,no airbnb,seattle


In [7]:
airbnb = pd.concat([airbnb_ny, airbnb_st])
airbnb = airbnb.drop("price", axis=1)
mapping = {'no airbnb': 0, 'very high': 1, 'medium': 2, 'low': 3, 'high': 4}
airbnb['price_class'] = airbnb['price_class'].map(mapping)
airbnb

Unnamed: 0,region_id,price_class,mie_nazwa
0,892a1000003ffff,0,new_york
1,892a1000007ffff,0,new_york
2,892a100000bffff,0,new_york
3,892a100000fffff,0,new_york
4,892a1000013ffff,0,new_york
...,...,...,...
4198,8928d5cdbcbffff,0,seattle
4199,8928d5cdbcfffff,0,seattle
4200,8928d5cdbd3ffff,0,seattle
4201,8928d5cdbd7ffff,0,seattle


In [8]:
def add_airbnb_to_osmnx_nodes(
    accidents: gpd.GeoDataFrame,
    nodes: gpd.GeoDataFrame,
    edges: gpd.GeoDataFrame,
    city_name: str,
):
    with open(OSMNX_ALL_ATTRIBUTES_LOCATION) as f:
        all_attributes = json.load(f)

    osmnx_graph = OSMnxGraph(
        accidents.loc[
            accidents["mie_nazwa"] == resolve_nominatim_city_name(city_name), :
        ],
        nodes,
        edges,
        all_attributes,
    )
    osmnx_graph._aggregate(element_type="node", aggregation_method="count")
    return osmnx_graph.gdf_nodes


def create_gdfs(city_name: str):
    osmnx_nodes = gpd.read_parquet(
        os.path.join(ORGANIZED_GRAPHS_LOCATION, city_name, "nodes.parquet")
    )
    osmnx_nodes = osmnx_nodes.reset_index()
    osmnx_nodes.index.names = ["node_id"]
    osmnx_nodes["x"] = osmnx_nodes["geometry"].x
    osmnx_nodes["y"] = osmnx_nodes["geometry"].y

    osmnx_edges = gpd.read_parquet(
        os.path.join(ORGANIZED_GRAPHS_LOCATION, city_name, "edges.parquet")
    )
    osmnx_edges = osmnx_edges.reset_index().rename(columns={"index": "edge_id"})
    osmnx_edges.index.names = ["edge_id"]
    if "level_4" in osmnx_edges.columns:
        osmnx_edges = osmnx_edges.drop("level_4", axis=1)
    osmnx_edges = osmnx_edges.fillna(0)
    display(osmnx_edges)
    

    assert osmnx_nodes.crs == osmnx_edges.crs

    # osmnx_nodes = add_airbnb_to_osmnx_nodes(
    #     accidents=accidents_gdf,
    #     nodes=osmnx_nodes,
    #     city_name=city_name,
    #     edges=osmnx_edges,
    # )

    hexes_years_folder = os.path.join(ORGANIZED_HEXES_LOCATION, city_name)

    subfolders = [
        int(f)
        for f in os.listdir(hexes_years_folder)
        if os.path.isdir(os.path.join(hexes_years_folder, f))
    ]
    highest_year = subfolders[np.argmax(subfolders)]

    hexes: gpd.GeoDataFrame = gpd.read_parquet(
        os.path.join(
            ORGANIZED_HEXES_LOCATION,
            f"{city_name}/{highest_year}/h9/count-embedder/dataset.parquet",
        )
    )

    airbnb_city = airbnb.loc[
            airbnb["mie_nazwa"] == city_name, :
        ]
    airbnb_city = airbnb_city.drop("mie_nazwa", axis=1)
    display(airbnb_city)

    hexes = hexes.merge(airbnb_city, on="region_id", how='left').fillna(0)

    hexes = hexes.rename(columns={"region_id": "h3_id"}).rename_axis(
        "region_id", axis=0
    )
    display(hexes)

    return dict(osmnx_nodes=osmnx_nodes, osmnx_edges=osmnx_edges, hexes=hexes)


print("Creating gdfs...")
gdfs_dict = {city_name: create_gdfs(city_name) for city_name in tqdm(['new_york', 'seattle'])}

Creating gdfs...


  0%|          | 0/2 [00:00<?, ?it/s]

Unnamed: 0_level_0,u,v,key,level_3,geometry,oneway,lanes,maxspeed,reversed,length,...,junction_circular,junction_jughandle,junction_roundabout,bridge_low_water_crossing,bridge_viaduct,bridge_movable,bridge_yes,tunnel_building_passage,tunnel_no,tunnel_yes
edge_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,29445653,4433425415,0,0,"LINESTRING (-122.31969 47.64248, -122.31943 47...",1,2,97,0,20.39,...,0,0,0,0,0,0,0,0,0,0
1,29445655,4433425410,0,0,"LINESTRING (-122.32113 47.64217, -122.32098 47...",1,2,97,0,12.41,...,0,0,0,0,0,0,0,0,0,0
2,29445656,4433425404,0,0,"LINESTRING (-122.32184 47.64178, -122.32171 47...",1,2,97,0,14.23,...,0,0,0,0,0,0,0,0,0,0
3,29445657,1864943659,0,0,"LINESTRING (-122.32204 47.64160, -122.32193 47...",1,2,97,0,14.65,...,0,0,0,0,0,0,0,0,0,0
4,29445659,4433425403,0,0,"LINESTRING (-122.32219 47.64140, -122.32212 47...",1,2,97,0,12.50,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
216717,11932454288,53131439,0,0,"LINESTRING (-122.29712 47.57187, -122.29701 47...",0,2,36,0,9.70,...,0,0,0,0,0,0,0,0,0,0
216718,11932454288,53098734,0,0,"LINESTRING (-122.29712 47.57187, -122.29778 47...",0,2,36,1,53.94,...,0,0,0,0,0,0,0,0,0,0
216719,11933647008,9426316116,0,0,"LINESTRING (-122.33324 47.70133, -122.33330 47...",0,2,40,0,5.31,...,0,0,0,0,0,0,0,0,0,0
216720,11933647008,410778999,0,0,"LINESTRING (-122.33324 47.70133, -122.33318 47...",0,2,40,1,4.09,...,0,0,0,0,0,0,0,0,0,0


Unnamed: 0,region_id,price_class
0,892a1000003ffff,0
1,892a1000007ffff,0
2,892a100000bffff,0
3,892a100000fffff,0
4,892a1000013ffff,0
...,...,...
11844,892a10776cbffff,0
11845,892a10776cfffff,0
11846,892a10776d3ffff,4
11847,892a10776d7ffff,4


Unnamed: 0_level_0,h3_id,geometry,amenity_gym,building_garages,landuse_gravel,sport_ultimate,office_bakery,natural_shrubbery,landuse_gress,building_guard_cabin,...,historic_heritage,building_government,aeroway_navigationaid,historic_park,historic_train_station,shop_hobby,building_floating_home,amenity_vacuum_cleaner,building_castle,price_class
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,892a1000003ffff,"POLYGON ((-73.78199 40.86018, -73.78410 40.859...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,892a1000007ffff,"POLYGON ((-73.77975 40.85759, -73.78187 40.856...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,892a100000bffff,"POLYGON ((-73.78628 40.86007, -73.78839 40.859...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,892a100000fffff,"POLYGON ((-73.78404 40.85748, -73.78615 40.856...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,892a1000013ffff,"POLYGON ((-73.77994 40.86288, -73.78205 40.861...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10752,892a10776cbffff,"POLYGON ((-74.02133 40.67672, -74.02343 40.675...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
10753,892a10776cfffff,"POLYGON ((-74.01909 40.67414, -74.02119 40.673...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
10754,892a10776d3ffff,"POLYGON ((-74.01503 40.67954, -74.01713 40.678...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,4
10755,892a10776d7ffff,"POLYGON ((-74.01279 40.67696, -74.01489 40.676...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,4


 50%|█████     | 1/2 [00:01<00:01,  1.41s/it]

Unnamed: 0_level_0,u,v,key,level_3,geometry,oneway,lanes,maxspeed,reversed,length,...,junction_circular,junction_jughandle,junction_roundabout,bridge_low_water_crossing,bridge_viaduct,bridge_movable,bridge_yes,tunnel_building_passage,tunnel_no,tunnel_yes
edge_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,29445653,4433425415,0,0,"LINESTRING (-122.31969 47.64248, -122.31943 47...",1,2,97,0,20.39,...,0,0,0,0,0,0,0,0,0,0
1,29445655,4433425410,0,0,"LINESTRING (-122.32113 47.64217, -122.32098 47...",1,2,97,0,12.41,...,0,0,0,0,0,0,0,0,0,0
2,29445656,4433425404,0,0,"LINESTRING (-122.32184 47.64178, -122.32171 47...",1,2,97,0,14.23,...,0,0,0,0,0,0,0,0,0,0
3,29445657,1864943659,0,0,"LINESTRING (-122.32204 47.64160, -122.32193 47...",1,2,97,0,14.65,...,0,0,0,0,0,0,0,0,0,0
4,29445659,4433425403,0,0,"LINESTRING (-122.32219 47.64140, -122.32212 47...",1,2,97,0,12.50,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
216717,11932454288,53131439,0,0,"LINESTRING (-122.29712 47.57187, -122.29701 47...",0,2,36,0,9.70,...,0,0,0,0,0,0,0,0,0,0
216718,11932454288,53098734,0,0,"LINESTRING (-122.29712 47.57187, -122.29778 47...",0,2,36,1,53.94,...,0,0,0,0,0,0,0,0,0,0
216719,11933647008,9426316116,0,0,"LINESTRING (-122.33324 47.70133, -122.33330 47...",0,2,40,0,5.31,...,0,0,0,0,0,0,0,0,0,0
216720,11933647008,410778999,0,0,"LINESTRING (-122.33324 47.70133, -122.33318 47...",0,2,40,1,4.09,...,0,0,0,0,0,0,0,0,0,0


Unnamed: 0,region_id,price_class
0,8928d540003ffff,0
1,8928d540007ffff,0
2,8928d54000bffff,0
3,8928d54000fffff,0
4,8928d540013ffff,0
...,...,...
4198,8928d5cdbcbffff,0
4199,8928d5cdbcfffff,0
4200,8928d5cdbd3ffff,0
4201,8928d5cdbd7ffff,0


Unnamed: 0_level_0,h3_id,geometry,amenity_gym,building_garages,landuse_gravel,sport_ultimate,office_bakery,natural_shrubbery,landuse_gress,building_guard_cabin,...,historic_heritage,building_government,aeroway_navigationaid,historic_park,historic_train_station,shop_hobby,building_floating_home,amenity_vacuum_cleaner,building_castle,price_class
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,8928d540003ffff,"POLYGON ((-122.25110 47.65043, -122.24991 47.6...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,8928d540007ffff,"POLYGON ((-122.25509 47.65151, -122.25390 47.6...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,8928d54000bffff,"POLYGON ((-122.24732 47.65219, -122.24612 47.6...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,8928d54000fffff,"POLYGON ((-122.25131 47.65328, -122.25011 47.6...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,8928d540013ffff,"POLYGON ((-122.25090 47.64757, -122.24970 47.6...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4198,8928d5cdbcbffff,"POLYGON ((-122.39215 47.49291, -122.39096 47.4...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4199,8928d5cdbcfffff,"POLYGON ((-122.39613 47.49399, -122.39494 47.4...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4200,8928d5cdbd3ffff,"POLYGON ((-122.39572 47.48827, -122.39453 47.4...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4201,8928d5cdbd7ffff,"POLYGON ((-122.39970 47.48936, -122.39851 47.4...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


100%|██████████| 2/2 [00:02<00:00,  1.05s/it]


In [9]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for gdf_for_city in gdfs_dict.values():
        gdf_for_city["controller"] = GraphLayerController(
            gdf_for_city["hexes"],
            gdf_for_city["osmnx_nodes"],
            gdf_for_city["osmnx_edges"],
        )

In [10]:
def patch_hexes_with_y(
    hexes: gpd.GeoDataFrame,
    controller: GraphLayerController,
):
    virtual_edges = controller.get_virtual_edges_to_hexes(SourceType.OSMNX_NODES)
    controller.hexes_gdf = hexes
    controller._hexes_centroids_gdf = controller._create_hexes_centroids_gdf()

In [11]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for gdfs in gdfs_dict.values():
        patch_hexes_with_y(gdfs["hexes"], gdfs["controller"])

In [12]:
def create_graph_data(
    osmnx_nodes,
    osmnx_edges,
    hexes,
    controller: GraphLayerController,
    use_hexes_attr: bool,
    use_ortophoto: bool,
):

    edges_attr_columns = osmnx_edges.columns[
        ~osmnx_edges.columns.isin(["u", "v", "key", "geometry"])
    ]
    nodes_attr_columns = osmnx_nodes.columns[
        ~osmnx_nodes.columns.isin(["geometry", "x", "y", "osmid"])
    ]
    hexes_attr_columns = (
        hexes.columns[~hexes.columns.isin(["geometry", "h3_id", "price_class"])]
        if use_hexes_attr
        else []
    )

    data = create_hetero_data(
        controller,
        hexes_attrs_columns_names=hexes_attr_columns,
        osmnx_edge_attrs_columns_names=edges_attr_columns,
        osmnx_node_attrs_columns_names=nodes_attr_columns,
        virtual_edge_attrs_columns_names=[],
        hexes_y_columns_names=["price_class"],
    )
    return data


graph_data_dict = {
    city_name: create_graph_data(**gdfs, use_ortophoto=False, use_hexes_attr=True)
    for city_name, gdfs in gdfs_dict.items()
}

0        892a1000003ffff
1        892a1000003ffff
2        892a1000003ffff
3        892a1000003ffff
4        892a1000003ffff
              ...       
62978    892a10776dbffff
62979    892a10776dbffff
62980    892a10776dbffff
62981    892a10776dbffff
62982    892a10776dbffff
Name: u, Length: 62983, dtype: object
<class 'str'>
XD
0        892a1000007ffff
1        892a100000fffff
2        892a1000017ffff
3        892a100001bffff
4        892a1000013ffff
              ...       
62978    892a10776cbffff
62979    892a1072bafffff
62980    892a1072ba7ffff
62981    892a10776d3ffff
62982    892a1072b37ffff
Name: v, Length: 62983, dtype: object
<class 'str'>
0
0        8928d540003ffff
1        8928d540003ffff
2        8928d540003ffff
3        8928d540003ffff
4        8928d540003ffff
              ...       
24563    8928d5cdbdbffff
24564    8928d5cdbdbffff
24565    8928d5cdbdbffff
24566    8928d5cdbdbffff
24567    8928d5cdbdbffff
Name: u, Length: 24568, dtype: object
<class 'str'>
XD
0        89

In [13]:
def create_tabular_data(
    hexes: pd.DataFrame,
    controller: GraphLayerController,
    use_hexes_attr: bool,
    use_ortophoto: bool,
):
    assert use_ortophoto or use_hexes_attr, "Provide at least one data source"

    hexes_attr_columns = (
        hexes.columns[~hexes.columns.isin(["geometry", "h3_id"])]
        if use_hexes_attr
        else []
    )

    hexes_y_columns_names = ["price_class"]

    X = hexes[hexes_attr_columns]
    y = controller.hexes_centroids_gdf[hexes_y_columns_names]

    return {"X": X, "y": y}


tabular_data_dict = {
    city_name: create_tabular_data(
        gdfs["hexes"], gdfs["controller"], use_ortophoto=False, use_hexes_attr=True
    )
    for city_name, gdfs in gdfs_dict.items()
}

In [14]:
def shift_elements_right(lst):
    shifted_lst = [lst[-1]] + lst[:-1]
    return shifted_lst


cities_names_list = list(graph_data_dict.keys())
cities_names_list.sort(key=lambda x: str(x))

# val + test
folds_tuples = list(zip(shift_elements_right(cities_names_list), cities_names_list))
display(folds_tuples)

[('seattle', 'new_york'), ('new_york', 'seattle')]

In [15]:
def run_k_fold_graph_data(closure_config, sweep_id):
    # pass external config (i.e. what attributes are used in the data), closure to avoid passing it to the function directly
    def wrapped():
        run = wandb.init()
        epochs = EPOCHS

        config = wandb.config

        for k, v in closure_config.items():
            run.log({k: 1 if v else 0})

        run.log({"data_structure": "graph"})

        # create hparams
        if hasattr(config, "lin_layer_size") and hasattr(config, "num_lin_layers"):
            lin_layer_sizes = [config.lin_layer_size] * config.num_lin_layers
        else:
            lin_layer_sizes = config.lin_layer_sizes
        hparams = {
            "hidden_channels": config.hidden_channels,
            "lr": config.learning_rate,
            "num_conv_layers": config.num_conv_layers,
            "lin_layer_sizes": lin_layer_sizes,
            "weight_decay": config.weight_decay,
        }

        aucs = []
        accuracies = []
        f1s = []

        fold_group_id = generate_id()

        # log data as artifact if no data was logged in the sweep before
        # dataset is uploaded only on the first run in sweep, because it does not change across runs in sweep
        # in wandb, dataset will be visible on the first run in the sweep
        artifact_path = os.path.join(TRAIN_SAVE_DIR, f"graph_data_{sweep_id}.pkl")
        if not os.path.exists(artifact_path):
            dump(
                graph_data_dict,
                artifact_path,
                protocol=5,
            )
            artifact = wandb.Artifact(
                name="graph_data", type="dataset", metadata=closure_config
            )
            artifact.add_file(local_path=artifact_path)
            run.log_artifact(artifact)

        # run k-fold
        for index, (val_city_name, test_city_name) in enumerate(folds_tuples):
            # prepare data
            val_data = [graph_data_dict[val_city_name].to("cpu").clone()]
            train_data = [
                v.to("cpu").clone()
                for k, v in graph_data_dict.items()
                if k != val_city_name and k != test_city_name
            ]
            test_data = graph_data_dict[test_city_name].to("cpu").clone()

            # run training with checkpointing on lowest val_loss, return test metrics for the best model and its path
            # builtin preprocessing - scaling to N(0, 1)
            auc, accuracy, f1, model_path = train(
                train_data=train_data,
                val_data=val_data,
                test_data=test_data,
                epochs=epochs,
                hparams=hparams,
                train_save_dir=TRAIN_SAVE_DIR,
            )

            # logging - single fold
            run.log_model(
                path=model_path,
                name=f"model_{fold_group_id}_fold_{index}",
            )
            run.log({f"auc_fold_{index}": auc})
            run.log({f"accuracy_fold_{index}": accuracy})
            run.log({f"f1_fold_{index}": f1})

            aucs.append(auc)
            accuracies.append(accuracy)
            f1s.append(f1)

        # logging - summary statistics
        mean_auc = sum(aucs) / len(aucs)
        mean_accuracy = sum(accuracies) / len(accuracies)
        mean_f1 = sum(f1s) / len(f1s)
        run.log({"mean_auc": mean_auc})
        run.log({"mean_accuracy": mean_accuracy})
        run.log({"mean_f1": mean_f1})

    return wrapped


def run_k_fold_tabular_data(closure_config, sweep_id):
    # analogously to the graph data, but for tabular data
    def wrapped():
        run = wandb.init()

        config = wandb.config

        for k, v in closure_config.items():
            run.log({k: 1 if v else 0})

        run.log({"data_structure": "tabular"})

        hparams = {}
        hparams["C"] = config["C"]
        solver, penalty = config["solver_penalty"].split(";")
        hparams["solver"] = solver
        if penalty == "None":
            penalty = None
        hparams["penalty"] = penalty

        aucs = []
        accuracies = []
        f1s = []

        fold_group_id = generate_id()

        # log data as artifact
        artifact_path = os.path.join(TRAIN_SAVE_DIR, f"tabular_data_{sweep_id}.pkl")

        if not os.path.exists(artifact_path):
            dump(
                tabular_data_dict,
                artifact_path,
                protocol=5,
            )
            artifact = wandb.Artifact(
                name="tabular_data", type="dataset", metadata=closure_config
            )
            artifact.add_file(local_path=artifact_path)
            run.log_artifact(artifact)

        timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

        for index, test_city_name in enumerate(cities_names_list):
            scaler = StandardScaler()
            X = pd.concat(
                [
                    m["X"]
                    for key, m in tabular_data_dict.items()
                    if key != test_city_name
                ]
            ).to_numpy()
            y = (
                pd.concat(
                    [
                        m["y"]
                        for key, m in tabular_data_dict.items()
                        if key != test_city_name
                    ]
                )
                .to_numpy()
                .ravel()
            )
            

            X = scaler.fit_transform(X)
            

            logistic_regression = LogisticRegression(
                C=hparams["C"],
                solver=hparams["solver"],
                penalty=hparams["penalty"],
                dual=False,
                tol=1e-4,
                fit_intercept=True,
                intercept_scaling=1,
                class_weight="balanced",
                random_state=1124,
                max_iter=1000,
                multi_class="auto",
                warm_start=False,
                n_jobs=-1,
                l1_ratio=0.5,
            )
            logistic_regression.fit(X, y)

            test_X = tabular_data_dict[test_city_name]["X"].to_numpy()
            test_X = scaler.transform(test_X)
            test_y = tabular_data_dict[test_city_name]["y"].to_numpy().ravel()

            encoder = OneHotEncoder(sparse=False)
            test_y_ohe = encoder.fit_transform(test_y.reshape(-1, 1))

            y_pred = logistic_regression.predict(test_X)
            y_proba = logistic_regression.predict_proba(test_X)


            auc = roc_auc_score(test_y_ohe, y_proba, average="weighted", multi_class='ovr')
            accuracy = (y_pred == test_y).mean()
            f1 = f1_score(
                test_y,
                y_pred,
                average="weighted",
            )

            model_dir = os.path.join(TRAIN_SAVE_DIR, timestamp)

            os.makedirs(model_dir, exist_ok=True)

            model_path = os.path.join(
                model_dir, f"model_{fold_group_id}_fold_{index}.pkl"
            )

            with open(model_path, "wb") as f:
                dump(logistic_regression, f, protocol=5)

            run.log_model(
                path=model_path,
                name=f"model_{fold_group_id}_fold_{index}",
            )
            run.log({f"auc_fold_{index}": auc})
            run.log({f"accuracy_fold_{index}": accuracy})
            run.log({f"f1_fold_{index}": f1})

            aucs.append(auc)
            accuracies.append(accuracy)
            f1s.append(f1)

        mean_auc = sum(aucs) / len(aucs)
        mean_accuracy = sum(accuracies) / len(accuracies)
        mean_f1 = sum(f1s) / len(f1s)
        run.log({"mean_auc": mean_auc})
        run.log({"mean_accuracy": mean_accuracy})
        run.log({"mean_f1": mean_f1})

    return wrapped


def run_sweep_graph_data(config):
    try:
        wandb.login(key=WANDB_API_KEY)
        sweep_id = wandb.sweep(
            WANDB_SWEEP_PARAMS_GRAPH_DATA, project="airbnb-downstream-task"
        )

        wandb.agent(
            sweep_id,
            function=run_k_fold_graph_data(config, sweep_id),
            count=SWEEP_RUNS_COUNT,
        )
    except Exception as e:
        print(e)
        wandb.finish()
        wandb.sweep
        raise e


def run_sweep_tabular_data(config):
    try:
        wandb.login(key=WANDB_API_KEY)

        sweep_id = wandb.sweep(
            WANDB_SWEEP_PARAMS_TABULAR_DATA, project="airbnb-downstream-task"
        )

        wandb.agent(
            sweep_id,
            function=run_k_fold_tabular_data(config, sweep_id),
            count=SWEEP_RUNS_COUNT,
        )
    except Exception as e:
        print(e)
        wandb.finish()
        raise e

In [16]:
def derive_data_structure(attr_config):
    if attr_config["USE_OSMNX_ATTRS"]:
        return "graph"
    return "tabular"


configs_size = len(ATTRIBUTES_CONFIGURATIONS)

for index, attr_config in enumerate(ATTRIBUTES_CONFIGURATIONS):
    print("Sweep for config {}/{} in progress...".format(index + 1, configs_size))

    assert "USE_ORTOPHOTO" in attr_config, "Provide USE_ORTOPHOTO key"
    assert "USE_HEXES_ATTRS" in attr_config, "Provide USE_HEXES_ATTRS key"
    assert "USE_OSMNX_ATTRS" in attr_config, "Provide USE_OSMNX_ATTRS key"

    data_structure = derive_data_structure(attr_config)

    if data_structure == "graph":
        graph_data_dict = {
            city_name: create_graph_data(
                **gdfs,
                use_ortophoto=attr_config["USE_ORTOPHOTO"],
                use_hexes_attr=attr_config["USE_HEXES_ATTRS"],
            )
            for city_name, gdfs in gdfs_dict.items()
        }
        run_sweep_graph_data(attr_config)
    elif data_structure == "tabular":
        tabular_data_dict = {
            city_name: create_tabular_data(
                gdfs["hexes"],
                gdfs["controller"],
                use_ortophoto=attr_config["USE_ORTOPHOTO"],
                use_hexes_attr=attr_config["USE_HEXES_ATTRS"],
            )
            for city_name, gdfs in gdfs_dict.items()
        }
        run_sweep_tabular_data(attr_config)
    else:
        raise ValueError("Unknown data structure")

Sweep for config 1/3 in progress...


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgrymar[0m ([33mgraph-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/grymar/.netrc


Create sweep with ID: 3z59tj4p
Sweep URL: https://wandb.ai/graph-ai/airbnb-downstream-task/sweeps/3z59tj4p


[34m[1mwandb[0m: Agent Starting Run: d1li4uh0 with config:
[34m[1mwandb[0m: 	C: 1.4801162163295116e-05
[34m[1mwandb[0m: 	solver_penalty: sag;l2
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




VBox(children=(Label(value='246.020 MB of 246.020 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
USE_HEXES_ATTRS,▁
USE_ORTOPHOTO,▁
USE_OSMNX_ATTRS,▁
accuracy_fold_0,▁
accuracy_fold_1,▁
auc_fold_0,▁
auc_fold_1,▁
f1_fold_0,▁
f1_fold_1,▁
mean_accuracy,▁

0,1
USE_HEXES_ATTRS,1
USE_ORTOPHOTO,0
USE_OSMNX_ATTRS,0
accuracy_fold_0,0.66338
accuracy_fold_1,0.46729
auc_fold_0,0.85686
auc_fold_1,0.73666
data_structure,tabular
f1_fold_0,0.6442
f1_fold_1,0.51168


[34m[1mwandb[0m: Agent Starting Run: tv7k2ekh with config:
[34m[1mwandb[0m: 	C: 0.0977255799528218
[34m[1mwandb[0m: 	solver_penalty: saga;l2
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


