In [22]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
import pandas as pd
import geopandas as gpd
from src.settings import *
from tqdm.auto import tqdm
import torch
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from pytorch_lightning.loggers import WandbLogger
import wandb
from src.models.autoencoder import LitAutoEncoder, LitVAE
import json5 as json
import pickle as pkl
from src.tools.configs import ExperimentConfig, DatasetGenerationConfig
from src.tools.feature_extraction import SpatialDataset
import dataclasses
import gzip
from src.tools.feature_extraction import apply_feature_selection, normalize_df

tqdm.pandas()

In [24]:
ec = ExperimentConfig(
    dataset_filename="dataset_2021-11-29_20-45-47.pkl.gz",
    model_name="autoencoder",
    mode="edges",
    test_cities=["Łódź"],
    random_seed=42,
    batch_size=64,
    num_workers=3,
    shuffle=True,
    hidden_dim=64,
    enc_out_dim=40,
    latent_dim=30,
    epochs=10,
    kl_coeff=0.1,
    lr=1e-3
)

In [25]:
ds_path = FEATURES_DIR / ec.dataset_filename
with open(ds_path, "rb") as f:
    ds: SpatialDataset = pkl.load(f)

ds.__annotations__

{'config': src.tools.configs.DatasetGenerationConfig,
 'cities': pandas.core.frame.DataFrame,
 'edges': geopandas.geodataframe.GeoDataFrame,
 'hexagons': geopandas.geodataframe.GeoDataFrame,
 'hex_agg': typing.Optional[pandas.core.frame.DataFrame],
 'hex_agg_normalized': typing.Optional[pandas.core.frame.DataFrame]}

In [26]:
ds_config = ds.config
cities = ds.cities
edges = ds.edges
hexagons = ds.hexagons
hex_agg = ds.hex_agg
hex_agg_normalized = ds.hex_agg_normalized

In [27]:
random_seed = ec.random_seed
pl.seed_everything(random_seed, workers=True)

Global seed set to 42


42

In [29]:
test_cities = ec.test_cities
train_cities = list(set(cities["city"]) - set(test_cities))

if ec.mode == "edges":
    edges_features_selected = apply_feature_selection(edges, ds_config.featureset_selection, scale_length=False)
    # edges_features_selected = apply_feature_selection(edges, ds_config.featureset_selection, scale_length=ds_config.scale_length)
    # edges_normalized = normalize_df(edges_features_selected, type=ds.config.normalize_type)
    input_df = edges_normalized
    pass
elif ec.mode == "hexagons":
    input_df = hex_agg_normalized
else:
    raise ValueError(f"Unknown mode: {ec.mode}")

X = torch.Tensor(input_df.values)
X_train = torch.Tensor(input_df.drop(index=test_cities, level=2).values)
X_test = torch.Tensor(input_df.loc[:, :, test_cities].values)

batch_size = ec.batch_size
num_workers = ec.num_workers
shuffle = ec.shuffle

X_train_dl = DataLoader(X_train, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
X_test_dl = DataLoader(X_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

n_features = X_train.shape[1]

Dask Apply: 100%|██████████| 12/12 [00:52<00:00,  4.41s/it]


In [30]:
wandb_logger = WandbLogger(log_model=True)
run = wandb.init(project="osm-road-infrastructure_autoencoder", entity="pwr-spatial-lab", dir=CHECKPOINTS_DIR, reinit=True)
run_name = run.name
run_dir = RUNS_DATA_DIR / run_name
run_dir.mkdir(parents=True, exist_ok=True)

hidden_dim = ec.hidden_dim
enc_out_dim = ec.enc_out_dim
latent_dim = ec.latent_dim
epochs = ec.epochs
kl_coeff = ec.kl_coeff
lr = ec.lr

config = wandb.config
config.experiment_config = dataclasses.asdict(ec)
config.dataset_generation_config = dataclasses.asdict(ds.config)

input_path = run_dir / "input.pkl.gz"
input_df.to_pickle(input_path)

if ec.model_name == "autoencoder":
    model = LitAutoEncoder(in_dim=n_features, hidden_dim=hidden_dim, latent_dim=latent_dim, lr=lr)
elif ec.model_name == "vae":
    model = LitVAE(in_dim=n_features, hidden_dim=hidden_dim, enc_out_dim=enc_out_dim, latent_dim=latent_dim, lr=lr, kl_coeff=kl_coeff)
else:
    raise ValueError(f"Unknown model name: {ec.model_name}")

trainer = pl.Trainer(gpus=1, max_epochs=epochs, logger=wandb_logger, default_root_dir=CHECKPOINTS_DIR)
trainer.fit(model, train_dataloaders=X_train_dl, val_dataloaders=X_test_dl)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 7.1 K 
1 | decoder | Sequential | 7.2 K 
---------------------------------------
14.3 K    Trainable params
0         Non-trainable params
14.3 K    Total params
0.057     Total estimated model params size (MB)


                                                                      

Global seed set to 42


Epoch 9: 100%|██████████| 2999/2999 [00:32<00:00, 91.20it/s, loss=1.43e-05, v_num=stf9, train_loss_step=4.39e-6, val_loss_step=2.62e-6, val_loss_epoch=1.8e-5, train_loss_epoch=1.01e-5]  


In [31]:
model.eval()
z_df = pd.DataFrame(model(X).detach().numpy()).add_prefix("z_")
z_df.index = input_df.index

embeddings_path = run_dir / "embeddings.pkl.gz"
z_df.to_pickle(embeddings_path)

dataset_artifact = wandb.Artifact(f"dataset-{run_name}", type="dataset")
dataset_artifact.add_file(input_path)
wandb.log_artifact(dataset_artifact)

result_artifact = wandb.Artifact(f"result-{run_name}", type="result")
result_artifact.add_file(embeddings_path)
wandb.log_artifact(result_artifact)

z_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,z_0,z_1,z_2,z_3,z_4,z_5,z_6,z_7,z_8,z_9,...,z_20,z_21,z_22,z_23,z_24,z_25,z_26,z_27,z_28,z_29
continent,country,city,h3_id,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
Europe,Poland,Białystok,891f5106993ffff,0.006024,0.006690,-0.062477,0.015018,-0.000063,-0.001436,0.008834,0.003837,-0.014188,0.001814,...,0.006180,0.008048,-0.027267,-0.001397,-0.035145,-0.006807,0.003759,0.015372,-0.077718,-0.026832
Europe,Poland,Białystok,891f5106993ffff,-0.011414,0.017323,-0.048001,0.046310,0.006863,0.037827,-0.053198,0.038419,-0.068242,-0.024245,...,-0.064640,-0.052550,-0.021730,0.047828,0.023719,-0.022455,-0.000553,0.007469,-0.095184,-0.012846
Europe,Poland,Białystok,891f5106d67ffff,-0.011414,0.017323,-0.048001,0.046310,0.006863,0.037827,-0.053198,0.038419,-0.068242,-0.024245,...,-0.064640,-0.052550,-0.021730,0.047828,0.023719,-0.022455,-0.000553,0.007469,-0.095184,-0.012846
Europe,Poland,Białystok,891f5106997ffff,0.008282,0.007449,-0.061780,0.017839,-0.003820,-0.001150,0.011173,0.003203,-0.014339,-0.002862,...,0.005814,0.006147,-0.026052,-0.002661,-0.030695,-0.008824,0.009122,0.018489,-0.076171,-0.021428
Europe,Poland,Białystok,891f5106993ffff,0.008282,0.007449,-0.061780,0.017839,-0.003820,-0.001150,0.011173,0.003203,-0.014339,-0.002862,...,0.005814,0.006147,-0.026052,-0.002661,-0.030695,-0.008824,0.009122,0.018489,-0.076171,-0.021428
Europe,Poland,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Europe,Poland,Łódź,891e2186aafffff,0.018351,0.010144,-0.045057,0.017574,-0.005276,-0.015026,0.014580,0.004769,-0.002345,0.004657,...,0.013841,0.027719,-0.016889,0.001789,-0.009808,0.001951,0.014284,0.028429,-0.056058,-0.018602
Europe,Poland,Łódź,891e2186a33ffff,0.018351,0.010144,-0.045057,0.017574,-0.005276,-0.015026,0.014580,0.004769,-0.002345,0.004657,...,0.013841,0.027719,-0.016889,0.001789,-0.009808,0.001951,0.014284,0.028429,-0.056058,-0.018602
Europe,Poland,Łódź,891e2186a33ffff,0.017371,0.009814,-0.047165,0.017281,-0.005313,-0.013874,0.014697,0.003948,-0.003006,0.003920,...,0.013512,0.026492,-0.017789,0.000948,-0.012643,0.000818,0.014002,0.027389,-0.057927,-0.019107
Europe,Poland,Łódź,891e2186a33ffff,0.012607,0.008212,-0.057416,0.015855,-0.005495,-0.008275,0.015269,-0.000045,-0.006222,0.000337,...,0.011909,0.020526,-0.022165,-0.003140,-0.026431,-0.004687,0.012628,0.022333,-0.067016,-0.021562


In [32]:
with open(run_dir / "experiment_config.json", "w") as f:
    json.dump(dataclasses.asdict(ec), f, indent=2, quote_keys=True, trailing_commas=False)

with open(run_dir / "dataset_generation_config.json", "w") as f:
    json.dump(dataclasses.asdict(ds_config), f, indent=2, quote_keys=True, trailing_commas=False)

with gzip.open(run_dir / "dataset.pkl.gz", "wb") as f:
    pkl.dump(ds, f)

trainer.save_checkpoint(run_dir / "model.ckpt")

In [33]:
run.finish()

VBox(children=(Label(value=' 17.40MB of 17.40MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
train_loss_epoch,█▂▂▁▁▁▁▁▁▁
train_loss_step,█▂▂▂▂▂▁▆▁▄▁▁▂▁▂▄▂▁▂▂▁▁▁▁▁▁▁▂▁▁▁▂▁▁▁▁▂▂▁▁
trainer/global_step,▁▁▁▁▂▁▁▁▃▁▁▁▄▁▁▁▄▁▁▁▅▁▁▁▆▁▁▁▇▁▂▂▇▂▂▂█▂▂▂
val_loss_epoch,█▄▃▂▂▂▁▁▁▂
val_loss_step,▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁█▁

0,1
epoch,9.0
train_loss_epoch,1e-05
train_loss_step,2e-05
trainer/global_step,27459.0
val_loss_epoch,2e-05
val_loss_step,0.0
