In [50]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [51]:
import pandas as pd
import geopandas as gpd
from src.settings import *
from tqdm.auto import tqdm
import torch
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from pytorch_lightning.loggers import WandbLogger
import wandb
from src.models.autoencoder import LitAutoEncoder, LitVAE
import json5 as json
import pickle as pkl
from src.tools.configs import ExperimentConfig, DatasetGenerationConfig
from src.tools.feature_extraction import SpatialDataset
import dataclasses
import gzip
from src.tools.feature_extraction import apply_feature_selection, normalize_df
from sklearn.model_selection import train_test_split
from src.tools.feature_extraction import features_wide_to_long

tqdm.pandas()

In [52]:
ec = ExperimentConfig(
    dataset_filename="dataset_2022-01-04_20-41-53_poland.pkl.gz",
    model_name="autoencoder",
    mode="edges",
    # test_cities=["Łódź"],
    test_cities=[],
    test_size = 0.2,
    random_seed=42,
    batch_size=200,
    num_workers=3,
    shuffle=True,
    hidden_dim=64,
    enc_out_dim=40,
    latent_dim=30,
    epochs=50,
    kl_coeff=0.1,
    lr=1e-3
)

# ec = ExperimentConfig(
#     dataset_filename="dataset_2021-11-29_20-45-47_poland.pkl.gz",
#     model_name="vae",
#     mode="hexagons",
#     test_cities=["Łódź"],
#     random_seed=42,
#     batch_size=64,
#     num_workers=3,
#     shuffle=True,
#     hidden_dim=64,
#     enc_out_dim=40,
#     latent_dim=30,
#     epochs=10,
#     kl_coeff=0.1,
#     lr=1e-3
# )

In [53]:
ds_path = FEATURES_DIR / ec.dataset_filename
with gzip.open(ds_path, "rb") as f:
    ds: SpatialDataset = pkl.load(f)

ds.__annotations__

{'config': src.tools.configs.DatasetGenerationConfig,
 'cities': pandas.core.frame.DataFrame,
 'edges': geopandas.geodataframe.GeoDataFrame,
 'edges_feature_selected': geopandas.geodataframe.GeoDataFrame,
 'hexagons': geopandas.geodataframe.GeoDataFrame,
 'hex_agg': typing.Optional[pandas.core.frame.DataFrame],
 'hex_agg_normalized': typing.Optional[pandas.core.frame.DataFrame]}

In [54]:
ds_config = ds.config
cities = ds.cities
edges = ds.edges
edges_feature_selected = ds.edges_feature_selected
hexagons = ds.hexagons
hex_agg = ds.hex_agg
hex_agg_normalized = ds.hex_agg_normalized

In [55]:
random_seed = ec.random_seed
pl.seed_everything(random_seed, workers=True)

Global seed set to 42


42

In [56]:
if ec.mode == "edges":
    input_df = edges_feature_selected
elif ec.mode == "hexagons":
    input_df = hex_agg_normalized
else:
    raise ValueError(f"Unknown mode: {ec.mode}")

test_cities = ec.test_cities
X = torch.Tensor(input_df.values)
if test_cities:
    train_cities = list(set(cities["city"]) - set(test_cities))
    X_train = torch.Tensor(input_df.drop(index=test_cities, level=2).values)
    X_test = torch.Tensor(input_df.loc[:, :, test_cities].values)
else:
    feature_keys = list(ds_config.featureset_selection["features"].keys())
    input_df_long = features_wide_to_long(input_df, feature_keys)
    most_frequent_value = input_df_long["highway"].value_counts().index[0]
    X_train, X_test = train_test_split(X, test_size=ec.test_size, random_state=random_seed, shuffle=True, stratify=input_df_long["highway"].fillna(most_frequent_value))
    del input_df_long

batch_size = ec.batch_size
num_workers = ec.num_workers
shuffle = ec.shuffle

X_train_dl = DataLoader(X_train, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
X_test_dl = DataLoader(X_test, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)

n_features = X_train.shape[1]
print("Number of features:", n_features)
print("Number of training samples:", len(X_train))
print("Number of test samples:", len(X_test))

Number of features: 88
Number of training samples: 153494
Number of test samples: 38374


In [57]:
wandb_logger = WandbLogger(log_model=True)
run = wandb.init(project="osm-road-infrastructure_autoencoder", entity="pwr-spatial-lab", dir=CHECKPOINTS_DIR, reinit=True)
run_name = run.name
run_dir = RUNS_DATA_DIR / run_name
run_dir.mkdir(parents=True, exist_ok=True)

hidden_dim = ec.hidden_dim
enc_out_dim = ec.enc_out_dim
latent_dim = ec.latent_dim
epochs = ec.epochs
kl_coeff = ec.kl_coeff
lr = ec.lr

config = wandb.config
config.experiment_config = dataclasses.asdict(ec)
config.dataset_generation_config = dataclasses.asdict(ds.config)

input_path = run_dir / "input.pkl.gz"
input_df.to_pickle(input_path)

if ec.model_name == "autoencoder":
    model = LitAutoEncoder(in_dim=n_features, hidden_dim=hidden_dim, latent_dim=latent_dim, lr=lr)
elif ec.model_name == "vae":
    model = LitVAE(in_dim=n_features, hidden_dim=hidden_dim, enc_out_dim=enc_out_dim, latent_dim=latent_dim, lr=lr, kl_coeff=kl_coeff)
else:
    raise ValueError(f"Unknown model name: {ec.model_name}")

trainer = pl.Trainer(gpus=1, max_epochs=epochs, logger=wandb_logger, default_root_dir=CHECKPOINTS_DIR, precision=16)
trainer.fit(model, train_dataloaders=X_train_dl, val_dataloaders=X_test_dl)
trainer.save_checkpoint(run_dir / "model.ckpt")

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 7.6 K 
1 | decoder | Sequential | 7.7 K 
---------------------------------------
15.3 K    Trainable params
0         Non-trainable params
15.3 K    Total params
0.031     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

Global seed set to 42


Epoch 49: 100%|██████████| 960/960 [00:13<00:00, 72.75it/s, loss=0.000155, v_num=wsou, train_loss_step=0.000232, val_loss_step=0.00028, val_loss_epoch=0.000142, train_loss_epoch=0.000125] 


In [58]:
model.eval()
z_df = pd.DataFrame(model(X).detach().numpy()).add_prefix("z_")
z_df.index = input_df.index

embeddings_path = run_dir / "embeddings.pkl.gz"
z_df.to_pickle(embeddings_path)

dataset_artifact = wandb.Artifact(f"dataset-{run_name}", type="dataset")
dataset_artifact.add_file(input_path)
wandb.log_artifact(dataset_artifact)

result_artifact = wandb.Artifact(f"result-{run_name}", type="result")
result_artifact.add_file(embeddings_path)
wandb.log_artifact(result_artifact)

z_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,z_0,z_1,z_2,z_3,z_4,z_5,z_6,z_7,z_8,z_9,...,z_20,z_21,z_22,z_23,z_24,z_25,z_26,z_27,z_28,z_29
continent,country,city,h3_id,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
Europe,Poland,Białystok,891f5106993ffff,0.342066,0.428381,0.087947,-0.603112,-0.395545,0.603423,-1.186571,0.789628,-0.274625,0.696226,...,-0.766874,-0.163890,0.693043,0.638905,0.944655,0.676234,0.310399,1.349102,0.848518,-0.396339
Europe,Poland,Białystok,891f5106993ffff,0.044766,0.683336,0.046412,-0.702277,-0.355882,0.347337,-0.238845,0.308295,0.332662,0.284348,...,-0.546090,-0.340523,1.219123,0.253508,0.232804,-0.147253,1.347803,1.308225,1.327217,0.063081
Europe,Poland,Białystok,891f5106d67ffff,0.044766,0.683336,0.046412,-0.702277,-0.355882,0.347337,-0.238845,0.308295,0.332662,0.284348,...,-0.546090,-0.340523,1.219123,0.253508,0.232804,-0.147253,1.347803,1.308225,1.327217,0.063081
Europe,Poland,Białystok,891f5106997ffff,0.063998,0.585682,-0.102041,-0.796996,-0.292700,0.266813,-0.442069,0.258601,0.344373,0.295771,...,-0.217529,-0.409388,0.874287,0.531757,0.172802,0.165596,1.330217,1.403318,1.285205,0.115912
Europe,Poland,Białystok,891f5106993ffff,0.063998,0.585682,-0.102041,-0.796996,-0.292700,0.266813,-0.442069,0.258601,0.344373,0.295771,...,-0.217529,-0.409388,0.874287,0.531757,0.172802,0.165596,1.330217,1.403318,1.285205,0.115912
Europe,Poland,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Europe,Poland,Łódź,891e2186aafffff,0.290721,-0.194353,-0.152859,0.051254,-0.573223,0.109253,-1.133735,0.410644,-0.223926,0.081291,...,0.142493,0.476549,-0.493761,-0.225807,0.032908,-0.208952,0.369929,-0.011557,-0.153618,-0.007273
Europe,Poland,Łódź,891e2186a33ffff,0.290721,-0.194353,-0.152859,0.051254,-0.573223,0.109253,-1.133735,0.410644,-0.223926,0.081291,...,0.142493,0.476549,-0.493761,-0.225807,0.032908,-0.208952,0.369929,-0.011557,-0.153618,-0.007273
Europe,Poland,Łódź,891e2186a33ffff,0.348859,-0.001323,-0.094344,0.141564,-0.424439,-0.264807,-0.586222,-0.077301,0.041692,0.134877,...,0.015419,0.026064,-0.227429,-0.118927,0.079861,-0.051392,0.365915,0.219033,-0.188814,0.003534
Europe,Poland,Łódź,891e2186a33ffff,0.348859,-0.001323,-0.094344,0.141564,-0.424439,-0.264807,-0.586222,-0.077301,0.041692,0.134877,...,0.015419,0.026064,-0.227429,-0.118927,0.079861,-0.051392,0.365915,0.219033,-0.188814,0.003534


In [59]:
with open(run_dir / "experiment_config.json", "w") as f:
    json.dump(dataclasses.asdict(ec), f, indent=2, quote_keys=True, trailing_commas=False)

with open(run_dir / "dataset_generation_config.json", "w") as f:
    json.dump(dataclasses.asdict(ds_config), f, indent=2, quote_keys=True, trailing_commas=False)

with gzip.open(run_dir / "dataset.pkl.gz", "wb") as f:
    pkl.dump(ds, f)


In [60]:
run.finish()

VBox(children=(Label(value=' 5.94MB of 5.94MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss_epoch,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▄▂▂▂▂▁▂▂▂▂▁▂▁▂▂▂▂▁▂▁▁▁▁▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▂▁▁▁▁▁▁▁▁▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▆▂▂▂▂▃▃▃▃█▃▃▃▃
val_loss_epoch,█▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss_step,▆█▃▅▅▂▅▄▂▄▂▄▃▁▄▃▁▃▃▁▃▁▃▃▁▃▃▁▃▃▃▃▁▃▃▁▃▃▁▃

0,1
epoch,49.0
train_loss_epoch,0.00012
train_loss_step,0.00023
trainer/global_step,38399.0
val_loss_epoch,0.00014
val_loss_step,0.00028
