In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import geopandas as gpd
from src.settings import *
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from src.tools.osmnx_utils import get_place_dir_name
from src.tools.h3_utils import get_resolution_buffered_suffix, get_edges_with_features_filename
import plotly.express as px
from src.tools.clustering import cluster_hdbscan
from src.models.tfidf import tfidf
from src.tools.dim_reduction import reduce_umap
import matplotlib.pyplot as plt
from src.tools.aggregation import aggregate_hex
import torch
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from pytorch_lightning.loggers import WandbLogger
import wandb
from src.models.autoencoder import LitAutoEncoder, LitVAE
from src.tools.feature_extraction import apply_feature_selection, sparse_dtype, normalize_df
import json5 as json
from sklearn.cluster import AgglomerativeClustering
from src.tools.vis_utils import plot_clusters, plot_hexagons_map, visualize_dendrogram, FIGSIZE
import pickle as pkl
from src.tools.configs import ExperimentConfig, DatasetGenerationConfig
from src.tools.feature_extraction import SpatialDataset
import dataclasses
import gzip

tqdm.pandas()

In [None]:
ec = ExperimentConfig(
    dataset_filename="dataset_2021-11-29_20-45-47.pkl.gz",
    model_name="autoencoder",
    test_cities=["Łódź"],
    random_seed=42,
    batch_size=64,
    num_workers=3,
    shuffle=True,
    hidden_dim=64,
    enc_out_dim=40,
    latent_dim=30,
    epochs=10,
    kl_coeff=0.1,
    lr=1e-3
)

In [None]:
ds_path = FEATURES_DIR / ec.dataset_filename
with open(ds_path, "rb") as f:
    ds: SpatialDataset = pkl.load(f)

ds.__annotations__

In [None]:
ds_config = ds.config
cities = ds.cities
edges = ds.edges
hexagons = ds.hexagons
hex_agg = ds.hex_agg
hex_agg_normalized = ds.hex_agg_normalized

In [None]:
random_seed = ec.random_seed
pl.seed_everything(random_seed, workers=True)

In [None]:
test_cities = ec.test_cities
train_cities = list(set(cities["city"]) - set(test_cities))

X = torch.Tensor(hex_agg_normalized.values)
X_train = torch.Tensor(hex_agg_normalized.drop(index=test_cities, level=2).values)
X_test = torch.Tensor(hex_agg_normalized.loc[:, :, test_cities].values)

batch_size = ec.batch_size
num_workers = ec.num_workers
shuffle = ec.shuffle

X_train_dl = DataLoader(X_train, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
X_test_dl = DataLoader(X_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

n_features = X_train.shape[1]

In [None]:
wandb_logger = WandbLogger(log_model=True)
run = wandb.init(project="osm-road-infrastructure_autoencoder", entity="pwr-spatial-lab", dir=CHECKPOINTS_DIR, reinit=True)
run_name = run.name
run_dir = RUNS_DATA_DIR / run_name
run_dir.mkdir(parents=True, exist_ok=True)

hidden_dim = ec.hidden_dim
enc_out_dim = ec.enc_out_dim
latent_dim = ec.latent_dim
epochs = ec.epochs
kl_coeff = ec.kl_coeff
lr = ec.lr

config = wandb.config
config.experiment_config = dataclasses.asdict(ec)
config.dataset_generation_config = dataclasses.asdict(ds.config)

input_path = run_dir / "input.pkl.gz"
hex_agg_normalized.to_pickle(input_path)

if ec.model_name == "autoencoder":
    model = LitAutoEncoder(in_dim=n_features, hidden_dim=hidden_dim, latent_dim=latent_dim, lr=lr)
elif ec.model_name == "vae":
    model = LitVAE(in_dim=n_features, hidden_dim=hidden_dim, enc_out_dim=enc_out_dim, latent_dim=latent_dim, lr=lr, kl_coeff=kl_coeff)
else:
    raise ValueError(f"Unknown model name: {ec.model_name}")

trainer = pl.Trainer(gpus=1, max_epochs=epochs, logger=wandb_logger, default_root_dir=CHECKPOINTS_DIR)
trainer.fit(model, train_dataloaders=X_train_dl, val_dataloaders=X_test_dl)

In [None]:
model.eval()
z_df = pd.DataFrame(model(X).detach().numpy()).add_prefix("z_")
z_df.index = hex_agg_normalized.index

embeddings_path = run_dir / "embeddings.pkl.gz"
z_df.to_pickle(embeddings_path)

dataset_artifact = wandb.Artifact(f"dataset-{run_name}", type="dataset")
dataset_artifact.add_file(input_path)
wandb.log_artifact(dataset_artifact)

result_artifact = wandb.Artifact(f"result-{run_name}", type="result")
result_artifact.add_file(embeddings_path)
wandb.log_artifact(result_artifact)

z_df

In [None]:
with open(run_dir / "experiment_config.json", "w") as f:
    json.dump(dataclasses.asdict(ec), f, indent=2, quote_keys=True, trailing_commas=False)

with open(run_dir / "dataset_generation_config.json", "w") as f:
    json.dump(dataclasses.asdict(ds_config), f, indent=2, quote_keys=True, trailing_commas=False)

with gzip.open(run_dir / "dataset.pkl.gz", "wb") as f:
    pkl.dump(ds, f)

trainer.save_checkpoint(run_dir / "model.ckpt")

In [None]:
run.finish()