# Nearest-neighbor graph mappings

In [1]:
import os

import xarray as xr
import numpy as np

In [2]:
import cartopy.crs as ccrs
import cartopy.feature as cf

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [3]:
from sklearn.neighbors import NearestNeighbors

In [4]:
import torch
from torch_geometric.data import HeteroData 



In [5]:
import h3

In [6]:
NUM_ERA_NEIGHBORS = 9
NUM_H3_NEIGHBORS = 7

In [7]:
era_res = 96

## ERA5 -> O32

In [8]:
era = xr.load_dataset(f"/ec/res4/scratch/pamc/WeatherBench/sfc_o{era_res}_1979.grib", engine="cfgrib")
era

In [9]:
era.latitude

In [10]:
area_weights = []
nind = 0
nlon = 20
tlat = era_res
mlon = 4*tlat+16
for i in range(tlat):
    area = np.cos(np.deg2rad(era.latitude[nind].data))*mlon/nlon
    area_weights.extend([area]*nlon)
    #print(era.latitude[nind].data,era.longitude[nind].data,area)
    nind+=nlon
    nlon+=4
area_weights.extend(area_weights[::-1])
print(len(area_weights),era.latitude.size)

assert(len(area_weights)==era.latitude.size)

40320 40320


In [11]:
len(area_weights)

40320

In [12]:
era.longitude.min(), era.longitude.max()

(<xarray.DataArray 'longitude' ()>
 array(0.)
 Coordinates:
     number   int64 0
     step     timedelta64[ns] 00:00:00
     surface  float64 0.0,
 <xarray.DataArray 'longitude' ()>
 array(359.1)
 Coordinates:
     number   int64 0
     step     timedelta64[ns] 00:00:00
     surface  float64 0.0)

In [13]:
era.latitude.min(), era.latitude.max()

(<xarray.DataArray 'latitude' ()>
 array(-89.28422753)
 Coordinates:
     number   int64 0
     step     timedelta64[ns] 00:00:00
     surface  float64 0.0,
 <xarray.DataArray 'latitude' ()>
 array(89.28422753)
 Coordinates:
     number   int64 0
     step     timedelta64[ns] 00:00:00
     surface  float64 0.0)

In [14]:
elat = np.array(era["latitude"])
elon = np.array(era["longitude"])
ecoords = np.stack([elat, elon], axis=-1).reshape((-1, 2))
ecoords_sp = np.deg2rad(ecoords)
print(f"dcoords_sp.shape = {ecoords_sp.shape}")

dcoords_sp.shape = (40320, 2)


In [15]:
eneigh = NearestNeighbors(n_neighbors=NUM_ERA_NEIGHBORS, metric="haversine", n_jobs=4)
eneigh.fit(ecoords_sp)

eadjmat = eneigh.kneighbors_graph(ecoords_sp, NUM_ERA_NEIGHBORS, mode="distance").tocoo()
print(f"eadjmat.shape = {eadjmat.shape}")

eadjmat.shape = (40320, 40320)


In [16]:
eadjmat

<40320x40320 sparse matrix of type '<class 'numpy.float64'>'
	with 362880 stored elements in COOrdinate format>

In [17]:
from sklearn.preprocessing import normalize
eadjmat_norm = normalize(eadjmat, norm="l1", axis=1)
eadjmat_norm.data = 1.0 - eadjmat_norm.data

In [18]:
era2era_key = ("era", "to", "era")

era2era_gdata = {
    # we should swap rows and cols here. It does not matter too much since the
    # adjacency matrix is symmetric but better be consistent
    "edge_index": torch.from_numpy(np.stack([eadjmat.col, eadjmat.row], axis=0).astype(np.int64)),
    "edge_attr": torch.from_numpy(np.expand_dims(eadjmat_norm.data, axis=-1).astype(np.float32)),
    "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
    "info": f"o{era_res}_to_o{era_res} graph",
    'area_weights':torch.from_numpy(np.array(area_weights)),
}

In [19]:
resolution = 2

luse_multi_mesh = False

if luse_multi_mesh:

    from aifs.utils.graph_gen import multi_mesh
    import networkx as nx

    h3_resolutions = tuple([x for x in range(resolution+1)]) # resolution of h3 grids
    resolution = "_".join([str(x) for x in h3_resolutions])

    H3 = multi_mesh(h3_resolutions, self_loop=False, flat=True, neighbour_children=False, depth=None)

    print(H3.number_of_nodes(), H3.number_of_edges())

    print(list(H3.nodes())[0:5])
    print(list(H3.edges())[0:5])
    print(H3.nodes[list(H3.nodes())[0]])
    print(H3.edges[list(H3.edges())[0]])

    h3_grid = [node for node in H3.nodes]
    hcoords = np.array([h3.h3_to_geo(val) for val in h3_grid])
    hcoords_sp = np.deg2rad(hcoords)

    hneigh = NearestNeighbors(  # this is used later for the era -> h and h -> era mapper
        n_neighbors=NUM_H3_NEIGHBORS,
        metric="haversine",
        n_jobs=4
    )
    hneigh.fit(hcoords_sp)

    hadjmat = nx.to_scipy_sparse_array(H3, format='coo')

else:

    h3_grid = [h3.geo_to_h3(lat, lon, resolution) for lat, lon in ecoords]
    h3_grid = sorted(set(h3_grid))
    hcoords = np.array([h3.h3_to_geo(val) for val in h3_grid])
    hcoords_sp = np.deg2rad(hcoords)
    
    hneigh = NearestNeighbors(
        n_neighbors=NUM_H3_NEIGHBORS,
        metric="haversine",
        n_jobs=4
    )
    hneigh.fit(hcoords_sp)

    hadjmat = hneigh.kneighbors_graph(hcoords_sp, NUM_H3_NEIGHBORS, mode="distance").tocoo()

print("-------------------")
hcoords.shape, hadjmat

-------------------


((5882, 2),
 <5882x5882 sparse matrix of type '<class 'numpy.float64'>'
 	with 41174 stored elements in COOrdinate format>)

In [20]:
hadjmat_norm = normalize(hadjmat, norm="l1", axis=1)
hadjmat_norm.data = 1.0 - hadjmat_norm.data
hadjmat_norm

<5882x5882 sparse matrix of type '<class 'numpy.float64'>'
	with 41174 stored elements in Compressed Sparse Row format>

In [21]:
h2h_key = ("h", "to", "h")

h2h_gdata = {
    # we should swap rows and cols here. It does not matter too much since the
    # adjacency matrix is symmetric but better be consistent
    "edge_index": torch.from_numpy(np.stack([hadjmat.col, hadjmat.row], axis=0).astype(np.int64)),
    "edge_attr": torch.from_numpy(np.expand_dims(hadjmat_norm.data, axis=-1).astype(np.float32)),
    "hcoords_rad": torch.from_numpy(hcoords_sp.astype(np.float32)),
    "info": "h3_to_h3 graph",
}

In [22]:
# includes "self"
NUM_H3_TO_ERA_NEIGHBORS = 3
NUM_ERA_TO_H3_NEIGHBORS = 12

In [23]:
# compute mappings
# H3 -> ERA aka the "decoder"
h3_to_era_adjmat = hneigh.kneighbors_graph(
    ecoords_sp,
    n_neighbors=NUM_H3_TO_ERA_NEIGHBORS,
    mode="distance",
).tocoo()

# ERA -> H3 aka the "encoder"
era_to_h3_adjmat = eneigh.kneighbors_graph(
    hcoords_sp,
    n_neighbors=NUM_ERA_TO_H3_NEIGHBORS,
    mode="distance",
).tocoo()

In [24]:
h3_to_era_adjmat, era_to_h3_adjmat

(<40320x5882 sparse matrix of type '<class 'numpy.float64'>'
 	with 120960 stored elements in COOrdinate format>,
 <5882x40320 sparse matrix of type '<class 'numpy.float64'>'
 	with 70584 stored elements in COOrdinate format>)

In [25]:
h3_to_era_adjmat_norm = normalize(h3_to_era_adjmat, norm="l1", axis=1)
h3_to_era_adjmat_norm.data = 1.0 - h3_to_era_adjmat_norm.data

era_to_h3_adjmat_norm = normalize(era_to_h3_adjmat, norm="l1", axis=1)
era_to_h3_adjmat_norm.data = 1.0 - era_to_h3_adjmat_norm.data

In [26]:
h3_to_era_adjmat_norm, era_to_h3_adjmat_norm

(<40320x5882 sparse matrix of type '<class 'numpy.float64'>'
 	with 120960 stored elements in Compressed Sparse Row format>,
 <5882x40320 sparse matrix of type '<class 'numpy.float64'>'
 	with 70584 stored elements in Compressed Sparse Row format>)

In [27]:
h2e_key = ("h", "to", "era")

h2e_gdata = {
    # we should swap rows and cols here. It does not matter too much since the
    # adjacency matrix is symmetric but better be consistent
    "edge_index": torch.from_numpy(np.stack([h3_to_era_adjmat.col, h3_to_era_adjmat.row], axis=0).astype(np.int64)),
    "edge_attr": torch.from_numpy(np.expand_dims(h3_to_era_adjmat_norm.data, axis=-1).astype(np.float32)),
    "hcoords_rad": torch.from_numpy(hcoords_sp.astype(np.float32)),
    "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
    "info": "h3_to_era graph",
}

In [28]:
e2h_key = ("era", "to", "h")

e2h_gdata = {
    # we should swap rows and cols here. It does not matter too much since the
    # adjacency matrix is symmetric but better be consistent
    "edge_index": torch.from_numpy(np.stack([era_to_h3_adjmat.col, era_to_h3_adjmat.row], axis=0).astype(np.int64)),
    "edge_attr": torch.from_numpy(np.expand_dims(era_to_h3_adjmat_norm.data, axis=-1).astype(np.float32)),
    "hcoords_rad": torch.from_numpy(hcoords_sp.astype(np.float32)),
    "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
    "info": "era_to_h3 graph",
}

In [29]:
graphs_normed = HeteroData(
    {
       era2era_key : era2era_gdata,
       h2h_key: h2h_gdata,
       e2h_key : e2h_gdata,
       h2e_key : h2e_gdata,
    }
)

# Add directionality attribute

In [30]:
from aifs.utils.graph_gen import directional_edge_features, directional_edge_features_rotated

luse_rotated_edge_features = True

if luse_rotated_edge_features:
    edge_directions_func = directional_edge_features_rotated # relative to target node rotated to north pole
else:
    edge_directions_func = directional_edge_features # loc target node - loc source node


In [31]:
hhedge_dirs = []
for n in range(graphs_normed[("h", "to", "h")]['edge_index'].shape[1]):
    i,j = graphs_normed[("h", "to", "h")]['edge_index'][:,n]
    ic = graphs_normed[("h", "to", "h")]['hcoords_rad'][i,:]
    jc = graphs_normed[("h", "to", "h")]['hcoords_rad'][j,:]
    hhedge_dirs.append(edge_directions_func(ic, jc))
hhedge_dirs = torch.from_numpy(np.stack(hhedge_dirs).astype(np.float32))
hhedge_attr = torch.concat([graphs_normed[("h", "to", "h")]['edge_attr'],hhedge_dirs],axis=-1)

In [32]:
ehedge_dirs = []
for n in range(graphs_normed[("era", "to", "h")]['edge_index'].shape[1]):
    i,j = graphs_normed[("era", "to", "h")]['edge_index'][:,n]
    ic = graphs_normed[("era", "to", "h")]['ecoords_rad'][i,:]
    jc = graphs_normed[("era", "to", "h")]['hcoords_rad'][j,:]
    ehedge_dirs.append(edge_directions_func(ic, jc))
ehedge_dirs = torch.from_numpy(np.stack(ehedge_dirs).astype(np.float32))
ehedge_attr = torch.concat([graphs_normed[("era", "to", "h")]['edge_attr'],ehedge_dirs],axis=-1)

In [33]:
heedge_dirs = []
for n in range(graphs_normed[("h", "to", "era")]['edge_index'].shape[1]):
    i,j = graphs_normed[("h", "to", "era")]['edge_index'][:,n]
    ic = graphs_normed[("h", "to", "era")]['hcoords_rad'][i,:]
    jc = graphs_normed[("h", "to", "era")]['ecoords_rad'][j,:]
    heedge_dirs.append(edge_directions_func(ic, jc))
heedge_dirs = torch.from_numpy(np.stack(heedge_dirs).astype(np.float32))
heedge_attr = torch.concat([graphs_normed[("h", "to", "era")]['edge_attr'],heedge_dirs],axis=-1)

In [34]:
eeedge_dirs = []
for n in range(graphs_normed[("era", "to", "era")]['edge_index'].shape[1]):
    i,j = graphs_normed[("era", "to", "era")]['edge_index'][:,n]
    ic = graphs_normed[("era", "to", "era")]['ecoords_rad'][i,:]
    jc = graphs_normed[("era", "to", "era")]['ecoords_rad'][j,:]
    eeedge_dirs.append(edge_directions_func(ic, jc))
eeedge_dirs = torch.from_numpy(np.stack(eeedge_dirs).astype(np.float32))
eeedge_attr = torch.concat([graphs_normed[("era", "to", "era")]['edge_attr'],eeedge_dirs],axis=-1)

In [35]:
graphs_normed[("h", "to", "era")]['edge_attr'] = heedge_attr
graphs_normed[("h", "to", "h")]['edge_attr'] = hhedge_attr
graphs_normed[("era", "to", "h")]['edge_attr'] = ehedge_attr
graphs_normed[("era", "to", "era")]['edge_attr'] = eeedge_attr

In [36]:
graphs_normed[("era", "to", "h")]['edge_attr'].shape

torch.Size([70584, 3])

In [37]:
output_dir = "/ec/res4/hpcperm/pamc/gnn/"
torch.save(graphs_normed, os.path.join(output_dir, f"graph_mappings_normed_edge_attrs_o{era_res}_h3_{resolution}.pt"))

In [38]:
!ls -lt $HPCPERM/gnn

total 642408
-rw-r----- 1 nesl rd 18389496 Apr 19 21:35 graph_mappings_normed_edge_attrs_new2_o96_h3_0_1_2.pt
-rw-r----- 1 nesl rd 17947640 Apr 19 17:29 graph_mappings_normed_edge_attrs_new_o96_h3_2.pt.bak_worked
-rw-r----- 1 nesl rd 18389496 Apr 19 17:23 graph_mappings_normed_edge_attrs_new1_o96_h3_0_1_2.pt
-rw-r----- 1 nesl rd 35012792 Apr 19 15:29 graph_mappings_normed_edge_attrs_new1_o96_h3_3.pt
-rw-r----- 1 nesl rd 20548728 Apr 19 09:02 graph_mappings_normed_edge_attrs_new1_o96_h3_0.pt
-rw-r----- 1 nesl rd 10809220 Apr 18 14:22 graph_mappings_normed_edge_attrs_new1_o96_h3_2.pt
-rw-r----- 1 nesl rd 18389496 Apr 16 10:51 graph_mappings_normed_edge_attrs_new_o96_h3_0_1_2.pt
-rw-r----- 1 nesl rd 17947640 Apr 16 10:49 graph_mappings_normed_edge_attrs_new_o96_h3_2.pt
-rw-r----- 1 nesl rd 35012792 Apr 15 15:00 graph_mappings_normed_edge_attrs_o96_h3_3.pt
-rw-r----- 1 nesl rd 39514555 Apr  4 11:48 graph_mappings_multi_3_noslfl_no_neigh_child.pt
-rw-r----- 1 nesl rd 39514555 Apr  4 11:48 g