In [30]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
import h3
import pandas as pd
import geopandas as gpd
import numpy as np
import folium
from src.settings import *
from shapely.geometry import Polygon
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from src.tools.osmnx_utils import get_place_dir_name
from src.tools.h3_utils import get_resolution_buffered_suffix
from pathlib import Path
import plotly.express as px
from src.tools.clustering import cluster_hdbscan
from src.models.tfidf import tfidf
from src.tools.dim_reduction import reduce_umap
import matplotlib.pyplot as plt
import contextily as ctx
from keplergl import KeplerGl
from src.tools.aggregation import aggregate_hex
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from pytorch_lightning.loggers import WandbLogger
import wandb
from src.models.autoencoder import LitAutoEncoder
import random

tqdm.pandas()

In [32]:
SEED = 42

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

# def seed_worker(worker_id):
#     worker_seed = torch.initial_seed() % 2**32
#     np.random.seed(worker_seed)
#     random.seed(worker_seed)

In [33]:
cities = pd.read_csv(RAW_DATA_DIR.joinpath("cities.csv"))
cities = cities[cities["country"] == "Poland"]
cities: pd.DataFrame = cities[cities["city"] != "Świdnica"].reset_index(drop=True)
cities

Unnamed: 0,city,country
0,Wrocław,Poland
1,Warszawa,Poland
2,Kraków,Poland
3,Poznań,Poland
4,Gdańsk,Poland
5,Szczecin,Poland
6,Katowice,Poland
7,Częstochowa,Poland
8,Białystok,Poland


In [34]:
resolution = 9
buffered = True
network_type = "drive"

pbar = tqdm(cities.itertuples(), total=cities.shape[0])
edges_hex_cities = []
hexagons_cities = []
for row in pbar:
    place_name = f"{row.city},{row.country}"
    place_dir_name = get_place_dir_name(place_name)
    pbar.set_description(place_name)

    hexagon_city = gpd.read_file(GENERATED_DATA_DIR.joinpath(place_dir_name, f"hex_{get_resolution_buffered_suffix(resolution, buffered)}.geojson"), driver="GeoJSON")
    hexagon_city["city"] = row.city
    hexagons_cities.append(hexagon_city)

    edges_hex_city = gpd.read_file(GENERATED_DATA_DIR.joinpath(place_dir_name, f"edges_{network_type}_{get_resolution_buffered_suffix(resolution, buffered)}.geojson"), driver="GeoJSON")
    edges_hex_city["city"] = row.city
    edges_hex_cities.append(edges_hex_city)

hexagons = pd.concat(hexagons_cities, ignore_index=True).set_index("h3_id")
edges_hex = pd.concat(edges_hex_cities, ignore_index=True)
hex_agg = aggregate_hex(edges_hex.drop(columns="city"))
hex_agg_tfidf = tfidf(hex_agg)

del edges_hex_cities
del hexagons_cities

Białystok,Poland: 100%|██████████| 9/9 [01:34<00:00, 10.45s/it]


In [35]:
X = torch.Tensor(hex_agg.values)
X_tfidf = torch.Tensor(hex_agg_tfidf.values)

batch_size = 200
num_workers = 5

X_dl = DataLoader(X, batch_size=batch_size, num_workers=num_workers)
X_tfidf_dl = DataLoader(X_tfidf, batch_size=batch_size, num_workers=num_workers)

n = X.shape[1]

In [36]:
wandb_logger = WandbLogger()
run = wandb.init(project="osm-autoencoder", reinit=True)

model = LitAutoEncoder(in_dim=n, hidden_dim=64, code_dim=10)
trainer = pl.Trainer(gpus=1, max_epochs=40, logger=wandb_logger)
trainer.fit(model, X_tfidf_dl)

y = model(X)
y_df = pd.DataFrame(y.detach().numpy()).add_prefix("x_")
y_df.index = hex_agg.index

run.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 9.1 K 
1 | decoder | Sequential | 9.2 K 
---------------------------------------
18.3 K    Trainable params
0         Non-trainable params
18.3 K    Total params
0.073     Total estimated model params size (MB)


Epoch 39: 100%|██████████| 81/81 [00:04<00:00, 17.25it/s, loss=0.000606, v_num=krbe]


In [37]:
# from sklearn.preprocessing import normalize
# y_df = pd.DataFrame(normalize(y_df, norm="l2"), columns=y_df.columns, index=y_df.index)  # cosine metric
# y_df["cluster"] = cluster_hdbscan(y_df, min_cluster_size=50, metric="euclidean")[0]
# hexagons_clustered = hexagons.join(y_df).dropna()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss_step,0.00012
epoch,39.0
trainer/global_step,3239.0
_runtime,169.0
_timestamp,1620951782.0
_step,103.0
train_loss_epoch,0.00062


0,1
train_loss_step,█▅▃▃▂▄▂▂▃▂▂▂▂▃▂▂▂▁▂▂▁▂▁▂▁▂▂▁▂▁▁▁▁▂▂▂▂▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss_epoch,█▅▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁


In [56]:
from sklearn.cluster import AgglomerativeClustering

agglomerative_clustering = AgglomerativeClustering(n_clusters=10, affinity="euclidean", linkage="ward")
y_df["cluster"] = pd.Series(agglomerative_clustering.fit_predict(y_df), index=y_df.index).astype("category")
hexagons_clustered = hexagons.join(y_df).dropna()

In [67]:
from src.tools.vis_utils import plot_hexagons_map

city = "Warszawa"
plot_hexagons_map(hexagons_clustered[hexagons_clustered["city"] == city], edges_hex[edges_hex["city"] == city], "cluster")

In [63]:
from src.tools.vis_utils import plot_clusters
embedding = reduce_umap(hex_agg_tfidf, n_components=2, n_neighbors=30, metric="euclidean")[0]
embedding["cluster"] = y_df["cluster"]
plot_clusters(embedding)

In [64]:
hex_tfidf_by_cluster = hex_agg_tfidf.join(y_df).groupby(by="cluster")

In [65]:
hex_tfidf_by_cluster_mean = hex_tfidf_by_cluster.mean().reset_index()
hex_tfidf_by_cluster_mean

Unnamed: 0,cluster,oneway_0,oneway_1,lanes_1,lanes_2,lanes_3,lanes_4,lanes_5,lanes_6,lanes_7,...,x_0,x_1,x_2,x_3,x_4,x_5,x_6,x_7,x_8,x_9
0,0,0.272252,0.128759,0.070207,0.111725,0.035178,0.011408,0.000824,0.000407,0.0,...,-9.103894,3.378928,-4.070654,5.292844,-7.342453,7.712244,2.225409,-0.914578,8.609584,3.511558
1,1,0.366863,0.06823,0.028883,0.140476,0.056367,0.025474,0.005771,0.00095,9.6e-05,...,-0.576686,0.510557,-0.34573,0.101104,-0.274628,-0.125201,-0.380844,0.842255,1.02822,0.198607
2,2,0.542167,0.09733,0.028928,0.109458,0.031016,0.016212,0.003006,7.2e-05,0.0,...,-1.736102,1.43512,-1.050126,1.54759,-1.819968,-0.699775,-2.292029,3.639486,3.331794,-0.368382
3,3,0.18953,0.507552,0.195421,0.193711,0.092382,0.041319,0.005561,0.000374,0.00025,...,-8.154156,4.7692,2.351127,-0.961397,-0.899282,-2.105249,-4.753105,7.507835,9.885488,8.264849
4,4,0.501103,0.197511,0.041348,0.146812,0.035942,0.017652,0.002848,0.000159,0.0,...,-2.228415,2.599047,-1.342562,1.857093,-3.322126,-1.01484,-5.959754,5.892847,7.763484,-0.545697
5,5,0.123926,0.436099,0.128307,0.204734,0.187267,0.093173,0.022354,0.003251,0.0,...,-8.432601,2.393214,-1.045919,-3.749647,3.118029,-2.801258,-3.274976,0.487632,9.356117,0.832059
6,6,0.353705,0.193332,0.071441,0.114724,0.031424,0.015302,0.001609,0.000292,0.0,...,-4.000593,3.84277,-5.798953,7.477767,-2.659114,-2.285135,-6.741744,6.937705,5.380107,4.75309
7,7,0.457075,0.145025,0.033602,0.100968,0.04166,0.018242,0.003097,0.000184,0.0,...,-5.049233,4.604876,-3.815565,1.90988,-4.547224,5.272195,-2.556175,4.47561,5.418378,0.989437
8,8,0.125385,0.382942,0.134668,0.255558,0.173836,0.070162,0.011573,0.001356,0.0,...,-2.248924,4.474616,-0.142077,-2.475616,0.187758,-1.533591,-2.417195,2.299626,6.504553,4.740971
9,9,0.099712,0.505824,0.138061,0.233637,0.171013,0.080925,0.013664,0.000609,0.0,...,-5.033039,10.330987,-0.419263,-9.185753,-0.161226,-0.309544,-8.409241,1.699709,12.471782,11.212369


In [66]:
from src.tools.feature_extraction import FEATURESET
for feature_name in FEATURESET.keys():
    fig = px.bar(hex_tfidf_by_cluster_mean, x="cluster", y=list(filter(lambda x: feature_name in x, hex_tfidf_by_cluster_mean.columns)), width=1300, title=feature_name)
    fig.update_layout(
        xaxis = dict(
            tickmode = 'linear',
        )
    )
    fig.show()