In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import geopandas as gpd
from src.settings import *
from tqdm.auto import tqdm
import torch
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from src.models.autoencoder import LitAutoEncoder
import json5 as json
import pickle as pkl
from src.tools.configs import ExperimentConfig, DatasetGenerationConfig
from src.tools.feature_extraction import SpatialDataset
import dataclasses
import gzip
from src.tools.feature_extraction import apply_feature_selection, normalize_df
from sklearn.model_selection import train_test_split
from src.tools.feature_extraction import features_wide_to_long

tqdm.pandas()

In [None]:
ec = ExperimentConfig(
    dataset_filename="dataset_2022-11-01_11-39-37.pkl.gz",
    model_name="autoencoder",
    mode="edges",
    # test_cities=["Łódź"],
    test_cities=[],
    test_size = 0.2,
    random_seed=42,
    batch_size=200,
    num_workers=3,
    shuffle=True,
    hidden_dim=64,
    enc_out_dim=40,
    latent_dim=30,
    epochs=10,
    kl_coeff=0.1,
    lr=1e-3
)

In [None]:
ds_path = FEATURES_DIR / ec.dataset_filename
with gzip.open(ds_path, "rb") as f:
    ds: SpatialDataset = pkl.load(f)

ds.__annotations__

In [None]:
ds_config = ds.config
cities = ds.cities
edges = ds.edges
edges_feature_selected = ds.edges_feature_selected
hexagons = ds.hexagons
hex_agg = ds.hex_agg
hex_agg_normalized = ds.hex_agg_normalized

In [None]:
random_seed = ec.random_seed
pl.seed_everything(random_seed, workers=True)

In [None]:
if ec.mode == "edges":
    input_df = edges_feature_selected
elif ec.mode == "hexagons":
    input_df = hex_agg_normalized
else:
    raise ValueError(f"Unknown mode: {ec.mode}")

test_cities = ec.test_cities
X = torch.Tensor(input_df.values)
if test_cities:
    train_cities = list(set(cities["city"]) - set(test_cities))
    X_train = torch.Tensor(input_df.drop(index=test_cities, level=2).values)
    X_test = torch.Tensor(input_df.loc[:, :, test_cities].values)
else:
    feature_keys = list(ds_config.featureset_selection["features"].keys())
    input_df_long = features_wide_to_long(input_df, feature_keys)
    most_frequent_value = input_df_long["highway"].value_counts().index[0]
    X_train, X_test = train_test_split(X, test_size=ec.test_size, random_state=random_seed, shuffle=True, stratify=input_df_long["highway"].fillna(most_frequent_value))
    del input_df_long

batch_size = ec.batch_size
num_workers = ec.num_workers
shuffle = ec.shuffle

X_train_dl = DataLoader(X_train, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
X_test_dl = DataLoader(X_test, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)

n_features = X_train.shape[1]
print("Number of features:", n_features)
print("Number of training samples:", len(X_train))
print("Number of test samples:", len(X_test))

In [None]:
run_name = "run_01"
run_dir = RUNS_DATA_DIR / run_name
run_dir.mkdir(parents=True, exist_ok=True)

hidden_dim = ec.hidden_dim
enc_out_dim = ec.enc_out_dim
latent_dim = ec.latent_dim
epochs = ec.epochs
kl_coeff = ec.kl_coeff
lr = ec.lr

input_path = run_dir / "input.pkl.gz"
input_df.to_pickle(input_path)

model = LitAutoEncoder(in_dim=n_features, hidden_dim=hidden_dim, latent_dim=latent_dim, lr=lr)

trainer = pl.Trainer(gpus=0, max_epochs=epochs, default_root_dir=CHECKPOINTS_DIR, precision=16)
trainer.fit(model, train_dataloaders=X_train_dl, val_dataloaders=X_test_dl)
trainer.save_checkpoint(run_dir / "model.ckpt")

In [None]:
model.eval()
z_df = pd.DataFrame(model(X).detach().numpy()).add_prefix("z_")
z_df.index = input_df.index

embeddings_path = run_dir / "embeddings.pkl.gz"
z_df.to_pickle(embeddings_path)

z_df

In [None]:
with open(run_dir / "experiment_config.json", "w") as f:
    json.dump(dataclasses.asdict(ec), f, indent=2, quote_keys=True, trailing_commas=False)

with open(run_dir / "dataset_generation_config.json", "w") as f:
    json.dump(dataclasses.asdict(ds_config), f, indent=2, quote_keys=True, trailing_commas=False)

with gzip.open(run_dir / "dataset.pkl.gz", "wb") as f:
    pkl.dump(ds, f)
