In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import pandas as pd
import geopandas as gpd
from src.settings import *
from tqdm.auto import tqdm
import torch
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from pytorch_lightning.loggers import WandbLogger
import wandb
from src.models.autoencoder import LitAutoEncoder, LitVAE
import json5 as json
import pickle as pkl
from src.tools.configs import ExperimentConfig, DatasetGenerationConfig
from src.tools.feature_extraction import SpatialDataset
import dataclasses
import gzip
from src.tools.feature_extraction import apply_feature_selection, normalize_df
from sklearn.model_selection import train_test_split
from src.tools.feature_extraction import features_wide_to_long

tqdm.pandas()

In [3]:
ec = ExperimentConfig(
    dataset_filename="dataset_2022-01-30_18-42-10_poland.pkl.gz",
    # dataset_filename="dataset_2022-01-06_13-18-38_europe.pkl.gz",
    model_name="autoencoder",
    mode="edges",
    # test_cities=["Łódź"],
    test_cities=[],
    test_size = 0.2,
    random_seed=42,
    batch_size=200,
    num_workers=3,
    shuffle=True,
    hidden_dim=64,
    enc_out_dim=40,
    latent_dim=30,
    epochs=50,
    kl_coeff=0.1,
    lr=1e-3
)

In [4]:
ds_path = FEATURES_DIR / ec.dataset_filename
with gzip.open(ds_path, "rb") as f:
    ds: SpatialDataset = pkl.load(f)

ds.__annotations__

{'config': src.tools.configs.DatasetGenerationConfig,
 'cities': pandas.core.frame.DataFrame,
 'edges': geopandas.geodataframe.GeoDataFrame,
 'edges_feature_selected': geopandas.geodataframe.GeoDataFrame,
 'hexagons': geopandas.geodataframe.GeoDataFrame,
 'hex_agg': typing.Optional[pandas.core.frame.DataFrame],
 'hex_agg_normalized': typing.Optional[pandas.core.frame.DataFrame]}

In [5]:
ds_config = ds.config
cities = ds.cities
edges = ds.edges
edges_feature_selected = ds.edges_feature_selected
hexagons = ds.hexagons
hex_agg = ds.hex_agg
hex_agg_normalized = ds.hex_agg_normalized

In [6]:
random_seed = ec.random_seed
pl.seed_everything(random_seed, workers=True)

Global seed set to 42


42

In [7]:
if ec.mode == "edges":
    input_df = edges_feature_selected
elif ec.mode == "hexagons":
    input_df = hex_agg_normalized
else:
    raise ValueError(f"Unknown mode: {ec.mode}")

test_cities = ec.test_cities
X = torch.Tensor(input_df.values)
if test_cities:
    train_cities = list(set(cities["city"]) - set(test_cities))
    X_train = torch.Tensor(input_df.drop(index=test_cities, level=2).values)
    X_test = torch.Tensor(input_df.loc[:, :, test_cities].values)
else:
    feature_keys = list(ds_config.featureset_selection["features"].keys())
    input_df_long = features_wide_to_long(input_df, feature_keys)
    most_frequent_value = input_df_long["highway"].value_counts().index[0]
    X_train, X_test = train_test_split(X, test_size=ec.test_size, random_state=random_seed, shuffle=True, stratify=input_df_long["highway"].fillna(most_frequent_value))
    del input_df_long

batch_size = ec.batch_size
num_workers = ec.num_workers
shuffle = ec.shuffle

X_train_dl = DataLoader(X_train, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
X_test_dl = DataLoader(X_test, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)

n_features = X_train.shape[1]
print("Number of features:", n_features)
print("Number of training samples:", len(X_train))
print("Number of test samples:", len(X_test))

Number of features: 88
Number of training samples: 153494
Number of test samples: 38374


In [8]:
wandb_logger = WandbLogger(log_model=True)
run = wandb.init(project="osm-road-infrastructure_autoencoder", entity="pwr-spatial-lab", dir=CHECKPOINTS_DIR, reinit=True)
run_name = run.name
run_dir = RUNS_DATA_DIR / run_name
run_dir.mkdir(parents=True, exist_ok=True)

hidden_dim = ec.hidden_dim
enc_out_dim = ec.enc_out_dim
latent_dim = ec.latent_dim
epochs = ec.epochs
kl_coeff = ec.kl_coeff
lr = ec.lr

config = wandb.config
config.experiment_config = dataclasses.asdict(ec)
config.dataset_generation_config = dataclasses.asdict(ds.config)

input_path = run_dir / "input.pkl.gz"
input_df.to_pickle(input_path)

if ec.model_name == "autoencoder":
    model = LitAutoEncoder(in_dim=n_features, hidden_dim=hidden_dim, latent_dim=latent_dim, lr=lr)
elif ec.model_name == "vae":
    model = LitVAE(in_dim=n_features, hidden_dim=hidden_dim, enc_out_dim=enc_out_dim, latent_dim=latent_dim, lr=lr, kl_coeff=kl_coeff)
else:
    raise ValueError(f"Unknown model name: {ec.model_name}")

trainer = pl.Trainer(gpus=1, max_epochs=epochs, logger=wandb_logger, default_root_dir=CHECKPOINTS_DIR, precision=16)
trainer.fit(model, train_dataloaders=X_train_dl, val_dataloaders=X_test_dl)
trainer.save_checkpoint(run_dir / "model.ckpt")

2022-01-30 18:56:08,524 | wandb.jupyter | ERROR | notebook_metadata:227 | Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcalychas[0m (use `wandb login --relogin` to force relogin)


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 7.6 K 
1 | decoder | Sequential | 7.7 K 
---------------------------------------
15.3 K    Trainable params
0         Non-trainable params
15.3 K    Total params
0.031     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

Global seed set to 42


Epoch 49: 100%|██████████| 960/960 [00:12<00:00, 75.35it/s, loss=8.79e-07, v_num=88k0, train_loss_step=4.98e-7, val_loss_step=1.25e-6, val_loss_epoch=8.7e-5, train_loss_epoch=9.18e-7]     


In [9]:
model.eval()
z_df = pd.DataFrame(model(X).detach().numpy()).add_prefix("z_")
z_df.index = input_df.index

embeddings_path = run_dir / "embeddings.pkl.gz"
z_df.to_pickle(embeddings_path)

dataset_artifact = wandb.Artifact(f"dataset-{run_name}", type="dataset")
dataset_artifact.add_file(input_path)
wandb.log_artifact(dataset_artifact)

result_artifact = wandb.Artifact(f"result-{run_name}", type="result")
result_artifact.add_file(embeddings_path)
wandb.log_artifact(result_artifact)

z_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,z_0,z_1,z_2,z_3,z_4,z_5,z_6,z_7,z_8,z_9,...,z_20,z_21,z_22,z_23,z_24,z_25,z_26,z_27,z_28,z_29
continent,country,city,h3_id,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
Europe,Poland,Białystok,891f5106993ffff,0.654016,0.128576,0.938951,0.999976,-0.890342,-0.525410,0.305847,0.059518,0.981707,0.427838,...,0.530449,0.917454,-0.800111,0.085526,0.791737,-0.072527,0.999995,-0.999992,0.724442,-0.999998
Europe,Poland,Białystok,891f5106993ffff,0.545196,0.756931,0.635376,0.999997,-0.966610,-0.136907,-0.902312,-0.842295,0.928197,-0.655070,...,0.866342,0.988764,-0.008782,-0.126354,0.740904,0.194954,0.999999,-0.999999,0.568232,-1.000000
Europe,Poland,Białystok,891f5106d67ffff,0.545196,0.756931,0.635376,0.999997,-0.966610,-0.136907,-0.902312,-0.842295,0.928197,-0.655070,...,0.866342,0.988764,-0.008782,-0.126354,0.740904,0.194954,0.999999,-0.999999,0.568232,-1.000000
Europe,Poland,Białystok,891f5106997ffff,0.453701,0.682837,0.753632,0.999994,-0.920177,0.199313,-0.861040,-0.576378,0.922710,-0.344298,...,0.916185,0.992058,-0.669658,-0.000063,0.684470,0.333532,0.999999,-0.999998,0.674286,-1.000000
Europe,Poland,Białystok,891f5106993ffff,0.453701,0.682837,0.753632,0.999994,-0.920177,0.199313,-0.861040,-0.576378,0.922710,-0.344298,...,0.916185,0.992058,-0.669658,-0.000063,0.684470,0.333532,0.999999,-0.999998,0.674286,-1.000000
Europe,Poland,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Europe,Poland,Łódź,891e2186aafffff,0.971310,0.733201,-0.972138,0.999400,0.498624,-0.537558,-0.916344,0.651311,0.253206,-0.983803,...,-0.726774,-0.874906,-0.923498,-0.817996,-0.773381,0.393953,0.999854,-0.999876,-0.628103,-0.999900
Europe,Poland,Łódź,891e2186a33ffff,0.971310,0.733201,-0.972138,0.999400,0.498624,-0.537558,-0.916344,0.651311,0.253206,-0.983803,...,-0.726774,-0.874906,-0.923498,-0.817996,-0.773381,0.393953,0.999854,-0.999876,-0.628103,-0.999900
Europe,Poland,Łódź,891e2186a33ffff,0.956944,0.212992,-0.935270,0.999551,-0.742712,-0.460176,-0.801898,0.360151,0.665570,-0.961431,...,-0.403926,-0.324404,-0.896827,-0.814093,-0.491083,0.436427,0.999933,-0.999911,0.178055,-0.999922
Europe,Poland,Łódź,891e2186a33ffff,0.956944,0.212992,-0.935270,0.999551,-0.742712,-0.460176,-0.801898,0.360151,0.665570,-0.961431,...,-0.403926,-0.324404,-0.896827,-0.814093,-0.491083,0.436427,0.999933,-0.999911,0.178055,-0.999922


In [10]:
with open(run_dir / "experiment_config.json", "w") as f:
    json.dump(dataclasses.asdict(ec), f, indent=2, quote_keys=True, trailing_commas=False)

with open(run_dir / "dataset_generation_config.json", "w") as f:
    json.dump(dataclasses.asdict(ds_config), f, indent=2, quote_keys=True, trailing_commas=False)

with gzip.open(run_dir / "dataset.pkl.gz", "wb") as f:
    pkl.dump(ds, f)


In [11]:
run.finish()

VBox(children=(Label(value=' 5.83MB of 5.83MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss_epoch,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▂▁▁▁▁▁▁▁▁▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▆▂▂▂▂▃▃▃▃█▃▃▃▃
val_loss_epoch,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss_step,█▃▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,49.0
train_loss_epoch,0.0
train_loss_step,0.0
trainer/global_step,38399.0
val_loss_epoch,9e-05
val_loss_step,0.0
