In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import pandas as pd
import geopandas as gpd
from src.settings import *
from tqdm.auto import tqdm
import torch
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from pytorch_lightning.loggers import WandbLogger
import wandb
from src.models.autoencoder import LitAutoEncoder, LitVAE
import json5 as json
import pickle as pkl
from src.tools.configs import ExperimentConfig, DatasetGenerationConfig
from src.tools.feature_extraction import SpatialDataset
import dataclasses
import gzip
from src.tools.feature_extraction import apply_feature_selection, normalize_df

tqdm.pandas()

In [10]:
# ec = ExperimentConfig(
#     dataset_filename="dataset_2021-11-29_20-45-47_poland.pkl.gz",
#     model_name="autoencoder",
#     mode="edges",
#     test_cities=["Łódź"],
#     random_seed=42,
#     batch_size=64,
#     num_workers=3,
#     shuffle=True,
#     hidden_dim=64,
#     enc_out_dim=40,
#     latent_dim=30,
#     epochs=10,
#     kl_coeff=0.1,
#     lr=1e-3
# )

ec = ExperimentConfig(
    dataset_filename="dataset_2021-11-29_20-45-47_poland.pkl.gz",
    model_name="vae",
    mode="hexagons",
    test_cities=["Łódź"],
    random_seed=42,
    batch_size=64,
    num_workers=3,
    shuffle=True,
    hidden_dim=64,
    enc_out_dim=40,
    latent_dim=30,
    epochs=10,
    kl_coeff=0.1,
    lr=1e-3
)

In [11]:
ds_path = FEATURES_DIR / ec.dataset_filename
with gzip.open(ds_path, "rb") as f:
    ds: SpatialDataset = pkl.load(f)

ds.__annotations__

{'config': src.tools.configs.DatasetGenerationConfig,
 'cities': pandas.core.frame.DataFrame,
 'edges': geopandas.geodataframe.GeoDataFrame,
 'hexagons': geopandas.geodataframe.GeoDataFrame,
 'hex_agg': typing.Optional[pandas.core.frame.DataFrame],
 'hex_agg_normalized': typing.Optional[pandas.core.frame.DataFrame]}

In [12]:
ds_config = ds.config
cities = ds.cities
edges = ds.edges
hexagons = ds.hexagons
hex_agg = ds.hex_agg
hex_agg_normalized = ds.hex_agg_normalized

In [8]:
random_seed = ec.random_seed
pl.seed_everything(random_seed, workers=True)

Global seed set to 42


42

In [13]:
test_cities = ec.test_cities
train_cities = list(set(cities["city"]) - set(test_cities))

if ec.mode == "edges":
    edges_features_selected = apply_feature_selection(edges, ds_config.featureset_selection, scale_length=False)
    # edges_features_selected = apply_feature_selection(edges, ds_config.featureset_selection, scale_length=ds_config.scale_length)
    # edges_normalized = normalize_df(edges_features_selected, type=ds.config.normalize_type)
    input_df = edges_normalized
    pass
elif ec.mode == "hexagons":
    input_df = hex_agg_normalized
else:
    raise ValueError(f"Unknown mode: {ec.mode}")

X = torch.Tensor(input_df.values)
X_train = torch.Tensor(input_df.drop(index=test_cities, level=2).values)
X_test = torch.Tensor(input_df.loc[:, :, test_cities].values)

batch_size = ec.batch_size
num_workers = ec.num_workers
shuffle = ec.shuffle

X_train_dl = DataLoader(X_train, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
X_test_dl = DataLoader(X_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

n_features = X_train.shape[1]

In [14]:
wandb_logger = WandbLogger(log_model=True)
run = wandb.init(project="osm-road-infrastructure_autoencoder", entity="pwr-spatial-lab", dir=CHECKPOINTS_DIR, reinit=True)
run_name = run.name
run_dir = RUNS_DATA_DIR / run_name
run_dir.mkdir(parents=True, exist_ok=True)

hidden_dim = ec.hidden_dim
enc_out_dim = ec.enc_out_dim
latent_dim = ec.latent_dim
epochs = ec.epochs
kl_coeff = ec.kl_coeff
lr = ec.lr

config = wandb.config
config.experiment_config = dataclasses.asdict(ec)
config.dataset_generation_config = dataclasses.asdict(ds.config)

input_path = run_dir / "input.pkl.gz"
input_df.to_pickle(input_path)

if ec.model_name == "autoencoder":
    model = LitAutoEncoder(in_dim=n_features, hidden_dim=hidden_dim, latent_dim=latent_dim, lr=lr)
elif ec.model_name == "vae":
    model = LitVAE(in_dim=n_features, hidden_dim=hidden_dim, enc_out_dim=enc_out_dim, latent_dim=latent_dim, lr=lr, kl_coeff=kl_coeff)
else:
    raise ValueError(f"Unknown model name: {ec.model_name}")

trainer = pl.Trainer(gpus=1, max_epochs=epochs, logger=wandb_logger, default_root_dir=CHECKPOINTS_DIR)
trainer.fit(model, train_dataloaders=X_train_dl, val_dataloaders=X_test_dl)

2021-12-12 23:11:51,197 | wandb.jupyter | ERROR | notebook_metadata:227 | Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcalychas[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 7.8 K 
1 | decoder | Sequential | 7.2 K 
2 | fc_mu   | Linear     | 1.2 K 
3 | fc_var  | Linear     | 1.2 K 
---------------------------------------
17.4 K    Trainable params
0         Non-trainable params
17.4 K    Total params
0.070     Total estimated model params size (MB)


                                                                      

Global seed set to 42


Epoch 9: 100%|██████████| 360/360 [00:12<00:00, 29.37it/s, loss=0.002, v_num=3pkn, train_loss_step=0.00358, val_loss_step=0.000674, val_loss_epoch=0.00118, train_loss_epoch=0.00196]   


In [15]:
model.eval()
z_df = pd.DataFrame(model(X).detach().numpy()).add_prefix("z_")
z_df.index = input_df.index

embeddings_path = run_dir / "embeddings.pkl.gz"
z_df.to_pickle(embeddings_path)

dataset_artifact = wandb.Artifact(f"dataset-{run_name}", type="dataset")
dataset_artifact.add_file(input_path)
wandb.log_artifact(dataset_artifact)

result_artifact = wandb.Artifact(f"result-{run_name}", type="result")
result_artifact.add_file(embeddings_path)
wandb.log_artifact(result_artifact)

z_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,z_0,z_1,z_2,z_3,z_4,z_5,z_6,z_7,z_8,z_9,...,z_20,z_21,z_22,z_23,z_24,z_25,z_26,z_27,z_28,z_29
continent,country,city,h3_id,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
Europe,Poland,Białystok,891f5104d23ffff,0.366589,0.174845,1.385094,-0.446281,1.445962,0.856142,2.217206,0.522639,1.175070,0.561035,...,0.011056,-0.338698,-1.340301,-0.585232,-0.563746,1.056218,-1.470554,1.433102,0.743348,-0.480834
Europe,Poland,Białystok,891f5104d2bffff,-1.049775,0.603278,0.404714,-1.354620,-0.497558,0.474582,-2.509469,0.487444,0.784371,0.028837,...,-0.890364,0.409805,-1.456819,-0.102411,-0.598554,0.477090,-0.170249,0.233252,4.035619,1.280596
Europe,Poland,Białystok,891f5104d37ffff,-0.013628,0.240144,0.132619,0.763336,0.518111,0.387994,-0.580085,-0.169593,1.930826,1.011580,...,1.480905,0.344711,-1.423813,-0.116625,-0.972009,0.958336,1.618989,1.450407,-0.426131,0.263591
Europe,Poland,Białystok,891f5104da7ffff,-1.439223,0.520920,0.348844,0.966816,-0.466147,1.605205,-0.426488,-1.337956,-0.193442,0.653352,...,1.289471,0.051710,-1.295905,0.050098,-0.585782,-0.389888,0.034830,0.120424,-0.806992,-0.207508
Europe,Poland,Białystok,891f5104db7ffff,-1.161760,-0.964319,-0.375266,0.801236,-0.519245,-1.500930,-1.928392,0.127772,-0.955782,-1.244577,...,0.526839,2.110601,-0.521011,-0.931633,1.060999,0.208291,-0.580089,0.325535,-0.817872,-1.019571
Europe,Poland,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Europe,Poland,Łódź,891f5249ba7ffff,0.346064,-0.010734,0.532501,0.309067,-0.562025,1.065421,0.332233,1.231295,0.683582,0.224964,...,-1.791058,0.734579,-0.330608,0.301517,2.391814,0.174945,1.743661,0.592250,0.446658,1.189627
Europe,Poland,Łódź,891f5249babffff,-1.498725,0.456613,0.359766,0.849864,1.178639,-0.583105,2.135825,-0.771965,-0.054783,-1.829083,...,-2.599682,-0.518701,-0.856305,0.862378,0.510678,0.757378,-1.114163,-0.867191,-0.521043,-0.316355
Europe,Poland,Łódź,891f5249bafffff,0.318634,-1.130468,0.832411,-0.634404,1.358718,-0.718366,0.075132,-2.339742,-0.704352,-1.429365,...,0.400934,1.762576,0.715867,0.632653,-1.425870,0.096217,-1.098310,1.224127,-1.642889,0.456717
Europe,Poland,Łódź,891f5249bb3ffff,1.770486,-0.075515,0.129778,-1.141050,2.288926,-0.370318,0.103118,-1.139910,-0.377423,-1.676692,...,0.473710,0.951404,0.695312,0.012463,-0.494042,0.446955,-0.594283,0.692558,0.298075,1.474496


In [16]:
with open(run_dir / "experiment_config.json", "w") as f:
    json.dump(dataclasses.asdict(ec), f, indent=2, quote_keys=True, trailing_commas=False)

with open(run_dir / "dataset_generation_config.json", "w") as f:
    json.dump(dataclasses.asdict(ds_config), f, indent=2, quote_keys=True, trailing_commas=False)

with gzip.open(run_dir / "dataset.pkl.gz", "wb") as f:
    pkl.dump(ds, f)

trainer.save_checkpoint(run_dir / "model.ckpt")

In [17]:
run.finish()

VBox(children=(Label(value=' 3.66MB of 3.66MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
train_kl_loss_epoch,█▁▁▁▁▁▁▁▁▁
train_kl_loss_step,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_epoch,█▁▁▁▁▁▁▁▁▁
train_loss_step,█▄▂▂▂▃▂▂▃▁▂▁▁▂▃▁▂▁▂▁▂▂▃▂▃▁▁▁▁▁▂▂▂▃▁▂▂▂▂▂
train_recon_loss_epoch,█▁▁▁▁▁▁▁▁▁
train_recon_loss_step,█▄▂▂▂▃▂▂▃▂▂▁▁▂▃▁▂▂▂▁▂▂▃▂▃▁▁▁▁▂▂▂▃▄▁▂▂▂▂▂
trainer/global_step,▁▁▁▁▂▁▁▁▃▁▁▁▄▁▁▁▄▁▁▁▆▁▂▂▆▂▂▂▇▂▂▂█▂▂▂▂▂▂▂
val_kl_loss_epoch,█▂▁▁▁▁▁▁▁▁
val_kl_loss_step,████▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,9.0
train_kl_loss_epoch,0.0
train_kl_loss_step,0.0
train_loss_epoch,0.00196
train_loss_step,0.00245
train_recon_loss_epoch,0.00196
train_recon_loss_step,0.00245
trainer/global_step,3229.0
val_kl_loss_epoch,0.0
val_kl_loss_step,0.0
