In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchcps.datasets.airquality.beijing import parse_beijing

aq_data, meo_data, aq_pos, meo_pos = parse_beijing()

In [None]:
from torchcps.datasets.airquality.beijing import BeijingDataModule

dm = BeijingDataModule()
dm.setup()

In [None]:
plt.figure(figsize=(10, 10))
plt.title("Missing values in PM2.5 column")
plt.matshow(aq_data.isnan().squeeze().numpy(), aspect=30, fignum=0)
plt.show()

In [None]:
import cartopy.crs as ccrs
from cartopy.io.img_tiles import GoogleTiles
from pathlib import Path

data_dir = Path("data/beijing")

tiler = GoogleTiles(style="street")
mercator = tiler.crs

df_aq_station = pd.read_csv(
    data_dir / "Beijing_AirQuality_Stations_EN.txt", delimiter="\t", encoding="utf-16"
)
df_meo_station = pd.read_csv(data_dir / "Beijing_MEO_Stations_cn.csv")

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=mercator)
min_lon, max_lon = df_meo_station["longitude"].min(), df_meo_station["longitude"].max()
min_lat, max_lat = df_meo_station["latitude"].min(), df_meo_station["latitude"].max()
extent = (min_lon - 0.1, max_lon + 0.1, min_lat - 0.1, max_lat + 0.1)
ax.set_extent(extent, crs=ccrs.PlateCarree())  # type: ignore
plt.scatter(
    df_aq_station["longitude"],
    df_aq_station["latitude"],
    label="AQ Stations",
    transform=ccrs.PlateCarree(),
)
plt.scatter(
    df_meo_station["longitude"],
    df_meo_station["latitude"],
    label="MEO Stations",
    transform=ccrs.PlateCarree(),
)
ax.add_image(tiler, 9)
plt.legend()
plt.show()

In [None]:
positions_aq = df_aq_station[["longitude", "latitude"]].values
positions_met = df_meo_station[["longitude", "latitude"]].values
positions = np.concatenate([positions_aq, positions_met], axis=0)

# sample 10x more points than the number of stations
n_original = positions.shape[0]
positions = np.tile(positions, (5, 1))
positions[n_original:] += np.random.normal(0, 0.2, size=positions[n_original:].shape)

plt.figure(figsize=(10, 10))
plt.scatter(positions[:, 0], positions[:, 1])
plt.scatter(positions[:n_original, 0], positions[:n_original, 1])
plt.show()

In [None]:
from torchcps.datasets.airquality.beijing import BeijingDatasetRKHS, BeijingDataset
from torchcps.kernel.nn import sample_kernel
import torch

from torchcps.kernel.rkhs import Mixture

import torch
from torchcps.datasets.airquality.beijing import project_aq, project_meo, parse_beijing
from torchcps.kernel.rkhs import GaussianKernel

kernel = GaussianKernel(0.1)
aq_data, meo_data, aq_pos, meo_pos = parse_beijing()

valid = ~torch.isnan(aq_data).all(dim=0).squeeze(1)
aq_data = aq_data[:, valid]
meo_data = meo_data[:, valid]

positions = torch.cat([aq_pos, meo_pos], dim=0)
aq_weights = project_aq(aq_data, aq_pos, positions, kernel)
meo_weights = project_meo(meo_data, meo_pos, positions, kernel, implementation="lstsq")

In [None]:
resolution = 1000
t = 0

# extent from positions tesnsor
XY = torch.stack(
    torch.meshgrid(
        torch.linspace(extent[0], extent[1], resolution),
        torch.linspace(extent[2], extent[3], resolution),
    ),
    -1,
).reshape(-1, 2)

data = torch.cat([aq_weights, meo_weights], dim=-1)
samples = sample_kernel(kernel, Mixture(positions, data[:, t]), XY).weights.reshape(
    resolution, resolution, -1
)

features = [
    "PM2.5",
    "temperature",
    "pressure",
    "humidity",
    "wind_x",
    "wind_y",
]

for i, feature in enumerate(features):
    plt.figure(figsize=(7, 3.5))
    plt.imshow(
        samples[..., i].T,
        extent=extent,
        origin="lower",
    )
    plt.colorbar()
    plt.title(feature)
plt.show()

In [None]:
# relationship between sigma and condition number for k(meo_pos, meo_pos)
from torch_cluster import grid_cluster
from torchcps.kernel.rkhs import GaussianKernel
from torch_geometric.nn.pool import avg_pool_x

dataset = BeijingDataset()
meo_pos = dataset.meo_pos.contiguous().double()
aq_pos = dataset.aq_pos.contiguous().double()
positions = torch.cat([meo_pos, aq_pos], 0).contiguous().double()

sigmas = list(map(float, np.log10(np.logspace(1e-3, 1.0, 100))))
condition_numbers = []
for sigma in sigmas:
    kernel = GaussianKernel(sigma)
    K_xx = kernel(aq_pos, aq_pos) @ torch.eye(aq_pos.shape[0]).double()
    condition_numbers.append(torch.linalg.cond(K_xx, p="fro").item())

plt.figure()
plt.plot(sigmas, condition_numbers)
plt.xscale("log")
plt.yscale("log")
plt.xlabel(r"$\sigma$")
plt.ylabel(r"$||K_{xx}||_F||K_{xx}^{-1}||_F$")
plt.show()

sigmas = list(map(float, np.log10(np.logspace(1e-3, 1.0, 100))))
condition_numbers = []
for sigma in sigmas:
    kernel = GaussianKernel(sigma)
    K_yy = kernel(positions, positions) @ torch.eye(positions.shape[0]).double()
    condition_numbers.append(torch.linalg.cond(K_yy, p="fro").item())

plt.figure()
plt.plot(sigmas, condition_numbers)
plt.xscale("log")
plt.yscale("log")
plt.xlabel(r"$\sigma$")
plt.ylabel(r"$||K_{yy}||_F||K_{yy}^{-1}||_F$")
plt.show()