# Nearest-neighbor graph mappings

In [1]:
import os

import xarray as xr
import numpy as np

In [2]:
import cartopy.crs as ccrs
import cartopy.feature as cf

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [3]:
from sklearn.neighbors import NearestNeighbors

In [4]:
import torch
from torch_geometric.data import HeteroData 

  from .autonotebook import tqdm as notebook_tqdm
  from neptune.version import version as neptune_client_version
  from neptune import new as neptune


In [5]:
import h3

In [6]:
NUM_ERA_NEIGHBORS = 9
NUM_H3_NEIGHBORS = 7

## ERA5 -> O32

In [7]:
era = xr.load_dataset("/ec/res4/scratch/pamc/WeatherBench/o160_t2m.grib", engine="cfgrib")
era

In [8]:
area_weights = []
nind = 0
nlon = 20
tlat = 160
mlon = 4*tlat+16
for i in range(tlat):
    area = np.cos(np.deg2rad(era.latitude[nind].data))*mlon/nlon
    area_weights.extend([area]*nlon)
    #print(era.latitude[nind].data,era.longitude[nind].data,area)
    nind+=nlon
    nlon+=4
area_weights.extend(area_weights[::-1])
print(len(area_weights),era.latitude.size)

assert(len(area_weights)==era.latitude.size)

108160 108160


In [9]:
len(area_weights)

108160

In [10]:
era.longitude.min(), era.longitude.max()

(<xarray.DataArray 'longitude' ()>
 array(0.)
 Coordinates:
     number      int64 0
     time        datetime64[ns] 2021-03-01
     step        timedelta64[ns] 00:00:00
     surface     float64 0.0
     valid_time  datetime64[ns] 2021-03-01,
 <xarray.DataArray 'longitude' ()>
 array(359.45121951)
 Coordinates:
     number      int64 0
     time        datetime64[ns] 2021-03-01
     step        timedelta64[ns] 00:00:00
     surface     float64 0.0
     valid_time  datetime64[ns] 2021-03-01)

In [11]:
era.latitude.min(), era.latitude.max()

(<xarray.DataArray 'latitude' ()>
 array(-89.57008955)
 Coordinates:
     number      int64 0
     time        datetime64[ns] 2021-03-01
     step        timedelta64[ns] 00:00:00
     surface     float64 0.0
     valid_time  datetime64[ns] 2021-03-01,
 <xarray.DataArray 'latitude' ()>
 array(89.57008955)
 Coordinates:
     number      int64 0
     time        datetime64[ns] 2021-03-01
     step        timedelta64[ns] 00:00:00
     surface     float64 0.0
     valid_time  datetime64[ns] 2021-03-01)

In [12]:
elat = np.array(era["latitude"])
elon = np.array(era["longitude"])
ecoords = np.stack([elat, elon], axis=-1).reshape((-1, 2))
ecoords_sp = np.deg2rad(ecoords)
print(f"dcoords_sp.shape = {ecoords_sp.shape}")

dcoords_sp.shape = (108160, 2)


In [13]:
eneigh = NearestNeighbors(n_neighbors=NUM_ERA_NEIGHBORS, metric="haversine", n_jobs=4)
eneigh.fit(ecoords_sp)

eadjmat = eneigh.kneighbors_graph(ecoords_sp, NUM_ERA_NEIGHBORS, mode="distance").tocoo()
print(f"eadjmat.shape = {eadjmat.shape}")

eadjmat.shape = (108160, 108160)


In [14]:
eadjmat

<108160x108160 sparse matrix of type '<class 'numpy.float64'>'
	with 973440 stored elements in COOrdinate format>

In [15]:
from sklearn.preprocessing import normalize
eadjmat_norm = normalize(eadjmat, norm="l1", axis=1)
eadjmat_norm.data = 1.0 - eadjmat_norm.data

In [127]:
era2era_key = ("era", "to", "era")

era2era_gdata = {
    # we should swap rows and cols here. It does not matter too much since the
    # adjacency matrix is symmetric but better be consistent
    "edge_index": torch.from_numpy(np.stack([eadjmat.col, eadjmat.row], axis=0).astype(np.int64)),
    "edge_attr": torch.from_numpy(np.expand_dims(eadjmat_norm.data, axis=-1).astype(np.float32)),
    "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
    "info": "o160_to_o160 graph",
    'area_weights':torch.from_numpy(np.array(area_weights)),
}

In [151]:
resolution = 2
h3_grid = [h3.geo_to_h3(lat, lon, resolution) for lat, lon in ecoords]
h3_grid = sorted(set(h3_grid))
hcoords = np.array([h3.h3_to_geo(val) for val in h3_grid])
hcoords_sp = np.deg2rad(hcoords)
hcoords.shape

(5882, 2)

In [152]:
hneigh = NearestNeighbors(
    n_neighbors=NUM_H3_NEIGHBORS,
    metric="haversine",
    n_jobs=4
)
hneigh.fit(hcoords_sp)

hadjmat = hneigh.kneighbors_graph(hcoords_sp, NUM_H3_NEIGHBORS, mode="distance").tocoo()
hadjmat

<5882x5882 sparse matrix of type '<class 'numpy.float64'>'
	with 41174 stored elements in COOrdinate format>

In [153]:
hadjmat_norm = normalize(hadjmat, norm="l1", axis=1)
hadjmat_norm.data = 1.0 - hadjmat_norm.data
hadjmat_norm

<5882x5882 sparse matrix of type '<class 'numpy.float64'>'
	with 41174 stored elements in Compressed Sparse Row format>

In [154]:
h2h_key = ("h", "to", "h")

h2h_gdata = {
    # we should swap rows and cols here. It does not matter too much since the
    # adjacency matrix is symmetric but better be consistent
    "edge_index": torch.from_numpy(np.stack([hadjmat.col, hadjmat.row], axis=0).astype(np.int64)),
    "edge_attr": torch.from_numpy(np.expand_dims(hadjmat_norm.data, axis=-1).astype(np.float32)),
    "hcoords_rad": torch.from_numpy(hcoords_sp.astype(np.float32)),
    "info": "h3_to_h3 graph",
}

In [155]:
# includes "self"
NUM_H3_TO_ERA_NEIGHBORS = 3
NUM_ERA_TO_H3_NEIGHBORS = 12

In [156]:
# compute mappings
# H3 -> ERA aka the "decoder"
h3_to_era_adjmat = hneigh.kneighbors_graph(
    ecoords_sp,
    n_neighbors=NUM_H3_TO_ERA_NEIGHBORS,
    mode="distance",
).tocoo()

# ERA -> H3 aka the "encoder"
era_to_h3_adjmat = eneigh.kneighbors_graph(
    hcoords_sp,
    n_neighbors=NUM_ERA_TO_H3_NEIGHBORS,
    mode="distance",
).tocoo()

In [157]:
h3_to_era_adjmat, era_to_h3_adjmat

(<108160x5882 sparse matrix of type '<class 'numpy.float64'>'
 	with 324480 stored elements in COOrdinate format>,
 <5882x108160 sparse matrix of type '<class 'numpy.float64'>'
 	with 70584 stored elements in COOrdinate format>)

In [158]:
h3_to_era_adjmat_norm = normalize(h3_to_era_adjmat, norm="l1", axis=1)
h3_to_era_adjmat_norm.data = 1.0 - h3_to_era_adjmat_norm.data

era_to_h3_adjmat_norm = normalize(era_to_h3_adjmat, norm="l1", axis=1)
era_to_h3_adjmat_norm.data = 1.0 - era_to_h3_adjmat_norm.data

In [159]:
h3_to_era_adjmat_norm, era_to_h3_adjmat_norm

(<108160x5882 sparse matrix of type '<class 'numpy.float64'>'
 	with 324480 stored elements in Compressed Sparse Row format>,
 <5882x108160 sparse matrix of type '<class 'numpy.float64'>'
 	with 70584 stored elements in Compressed Sparse Row format>)

In [160]:
h2e_key = ("h", "to", "era")

h2e_gdata = {
    # we should swap rows and cols here. It does not matter too much since the
    # adjacency matrix is symmetric but better be consistent
    "edge_index": torch.from_numpy(np.stack([h3_to_era_adjmat.col, h3_to_era_adjmat.row], axis=0).astype(np.int64)),
    "edge_attr": torch.from_numpy(np.expand_dims(h3_to_era_adjmat_norm.data, axis=-1).astype(np.float32)),
    "hcoords_rad": torch.from_numpy(hcoords_sp.astype(np.float32)),
    "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
    "info": "h3_to_era graph",
}

In [161]:
e2h_key = ("era", "to", "h")

e2h_gdata = {
    # we should swap rows and cols here. It does not matter too much since the
    # adjacency matrix is symmetric but better be consistent
    "edge_index": torch.from_numpy(np.stack([era_to_h3_adjmat.col, era_to_h3_adjmat.row], axis=0).astype(np.int64)),
    "edge_attr": torch.from_numpy(np.expand_dims(era_to_h3_adjmat_norm.data, axis=-1).astype(np.float32)),
    "hcoords_rad": torch.from_numpy(hcoords_sp.astype(np.float32)),
    "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
    "info": "era_to_h3 graph",
}

In [162]:
graphs_normed = HeteroData(
    {
       era2era_key : era2era_gdata,
       h2h_key: h2h_gdata,
       e2h_key : e2h_gdata,
       h2e_key : h2e_gdata,
    }
)

# Add directionality attribute

In [163]:
hhedge_dirs = []
for n in range(graphs_normed[("h", "to", "h")]['edge_index'].shape[1]):
    i,j = graphs_normed[("h", "to", "h")]['edge_index'][:,n]
    ic = graphs_normed[("h", "to", "h")]['hcoords_rad'][i,:]
    jc = graphs_normed[("h", "to", "h")]['hcoords_rad'][j,:]
    hhedge_dirs.append(jc - ic)
hhedge_dirs = torch.from_numpy(np.stack(hhedge_dirs).astype(np.float32))
hhedge_attr = torch.concat([graphs_normed[("h", "to", "h")]['edge_attr'],hhedge_dirs],axis=-1)

In [164]:
ehedge_dirs = []
for n in range(graphs_normed[("era", "to", "h")]['edge_index'].shape[1]):
    i,j = graphs_normed[("era", "to", "h")]['edge_index'][:,n]
    ic = graphs_normed[("era", "to", "h")]['ecoords_rad'][i,:]
    jc = graphs_normed[("era", "to", "h")]['hcoords_rad'][j,:]
    ehedge_dirs.append(jc - ic)
ehedge_dirs = torch.from_numpy(np.stack(ehedge_dirs).astype(np.float32))
ehedge_attr = torch.concat([graphs_normed[("era", "to", "h")]['edge_attr'],ehedge_dirs],axis=-1)

In [165]:
heedge_dirs = []
for n in range(graphs_normed[("h", "to", "era")]['edge_index'].shape[1]):
    i,j = graphs_normed[("h", "to", "era")]['edge_index'][:,n]
    ic = graphs_normed[("h", "to", "era")]['hcoords_rad'][i,:]
    jc = graphs_normed[("h", "to", "era")]['ecoords_rad'][j,:]
    heedge_dirs.append(jc - ic)
heedge_dirs = torch.from_numpy(np.stack(heedge_dirs).astype(np.float32))
heedge_attr = torch.concat([graphs_normed[("h", "to", "era")]['edge_attr'],heedge_dirs],axis=-1)

In [166]:
eeedge_dirs = []
for n in range(graphs_normed[("era", "to", "era")]['edge_index'].shape[1]):
    i,j = graphs_normed[("era", "to", "era")]['edge_index'][:,n]
    ic = graphs_normed[("era", "to", "era")]['ecoords_rad'][i,:]
    jc = graphs_normed[("era", "to", "era")]['ecoords_rad'][j,:]
    eeedge_dirs.append(jc - ic)
eeedge_dirs = torch.from_numpy(np.stack(eeedge_dirs).astype(np.float32))
eeedge_attr = torch.concat([graphs_normed[("era", "to", "era")]['edge_attr'],eeedge_dirs],axis=-1)

In [167]:
graphs_normed[("h", "to", "era")]['edge_attr'] = heedge_attr
graphs_normed[("h", "to", "h")]['edge_attr'] = hhedge_attr
graphs_normed[("era", "to", "h")]['edge_attr'] = ehedge_attr
graphs_normed[("era", "to", "era")]['edge_attr'] = eeedge_attr

In [168]:
graphs_normed[("era", "to", "h")]['edge_attr'].shape

torch.Size([70584, 3])

In [169]:
output_dir = "/ec/res4/hpcperm/pamc/gnn/"
torch.save(graphs_normed, os.path.join(output_dir, f"graph_mappings_normed_edge_attrs_o160_h3_{resolution}.pt"))

In [170]:
!ls -lt $HPCPERM/gnn

total 509520
-rw-r--r-- 1 pamc rd 43077880 Mar 26 10:10 graph_mappings_normed_edge_attrs_o160_h3_2.pt
-rw-r--r-- 1 pamc rd 62693560 Mar 26 09:52 graph_mappings_normed_edge_attrs_o160_h3_3.pt
-rw-r--r-- 1 pamc rd 34424635 Mar 24 22:35 graph_mappings_o160_h3_2_normed_edge_attrs.pt
-rw-r--r-- 1 pamc rd 34424635 Mar 24 22:35 graph_mappings_o160_h3_2.pt
-rw-r--r-- 1 pamc rd 54040379 Mar 24 22:26 graph_mappings_o160_h3_3_normed_edge_attrs.pt
-rw-r--r-- 1 pamc rd 54040379 Mar 24 22:26 graph_mappings_o160_h3_3.pt
-rw-r--r-- 1 pamc rd  6649659 Mar 20 21:08 graph_mappings_2d_3_usecutoff.pt
-rw-r--r-- 1 pamc rd  6649659 Mar 20 21:08 graph_mappings_2d_normed_edge_attrs_3_usecutoff.pt
-rw-r--r-- 1 pamc rd 37838843 Mar 16 11:22 graph_mappings_normed_edge_attrs_3_usecutoff.pt
-rw-r--r-- 1 pamc rd 15975099 Mar 16 11:16 graph_mappings_2_usecutoff.pt
-rw-r--r-- 1 pamc rd 15975099 Mar 16 11:16 graph_mappings_normed_edge_attrs_2_usecutoff.pt
-rw-r--r-- 1 pamc rd 15975099 Mar 16 11:10 graph_mappings_normed