In [None]:
import os

import h3
import xarray as xr
import numpy as np
from sklearn.neighbors import NearestNeighbors
import torch
from torch_geometric.data import HeteroData 
from sklearn.preprocessing import normalize

In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cf

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
NUM_ERA_NEIGHBORS = 9

# https://h3geo.org/docs/core-library/restable/
# h3_0 : 122 cells
# h3_1:  842 cells
# h3_2: 5882 cells

NUM_H33_NEIGHBORS = 7
NUM_H32_NEIGHBORS = 7
NUM_H31_NEIGHBORS = 7
NUM_H30_NEIGHBORS = 5

In [None]:
era_res = 96
era = xr.open_dataset(f"/ec/res4/hpcperm/syma/gnn/o96/sfc_o{era_res}_19880401.grib", engine="cfgrib")
era

In [None]:
from torch_geometric.utils import contains_isolated_nodes, from_scipy_sparse_matrix

def create_self_mapping(coords_sp, num_neighbors, src_label, dest_label, area_weights=None):
    neigh = NearestNeighbors(n_neighbors=num_neighbors, metric="haversine", n_jobs=4)
    neigh.fit(coords_sp)
    adjmat = neigh.kneighbors_graph(coords_sp, num_neighbors, mode="distance").tocoo()
    print(f"adjmat.shape = {adjmat.shape}")
    adjmat_norm = normalize(adjmat, norm="l1", axis=1)
    adjmat_norm.data = 1.0 - adjmat_norm.data

    has_isolated = contains_isolated_nodes(from_scipy_sparse_matrix(adjmat_norm)[0])
    assert not has_isolated, "OOPS! You're left with some dangling nodes ... revisit your mapping"

    map_key = (src_label, "to", dest_label)

    map_gdata = {
        "edge_index": torch.from_numpy(np.stack([adjmat.col, adjmat.row], axis=0).astype(np.int64)),
        "edge_attr": torch.from_numpy(np.expand_dims(adjmat_norm.data, axis=-1).astype(np.float32)),
        # source and dest coords are the same, keep duplicates (makes the data structure a bit easier to use)
        "scoords_rad": torch.from_numpy(coords_sp.astype(np.float32)),
        "dcoords_rad": torch.from_numpy(coords_sp.astype(np.float32)),
        "info": f"{src_label}_to_{dest_label} graph",
    }

    if area_weights is not None:
        map_gdata['area_weights'] = torch.from_numpy(np.array(area_weights))

    return map_key, map_gdata, neigh

In [None]:
def get_era_weights():
    area_weights = []
    nind = 0
    nlon = 20
    tlat = era_res
    mlon = 4*tlat+16
    for i in range(tlat):
        area = np.cos(np.deg2rad(era.latitude[nind].data))*mlon/nlon
        area_weights.extend([area]*nlon)
        #print(era.latitude[nind].data,era.longitude[nind].data,area)
        nind+=nlon
        nlon+=4
    area_weights.extend(area_weights[::-1])
    assert(len(area_weights)==era.latitude.size)
    return area_weights

In [None]:
plt.plot(get_era_weights())

In [None]:
elat = np.array(era["latitude"])
elon = np.array(era["longitude"])
ecoords = np.stack([elat, elon], axis=-1).reshape((-1, 2))
ecoords_sp = np.deg2rad(ecoords)
print(f"ecoords_sp.shape = {ecoords_sp.shape}")

area_weights = get_era_weights()

era2era_key, era2era_gdata, eneigh = create_self_mapping(ecoords_sp, NUM_ERA_NEIGHBORS, "era", "era", area_weights=area_weights)
print(era2era_key, list(era2era_gdata.keys()))

In [None]:
def get_h3_coords(resolution):
    h3_grid = [h3.geo_to_h3(lat, lon, resolution) for lat, lon in ecoords]
    h3_grid = sorted(set(h3_grid))
    hcoords = np.array([h3.h3_to_geo(val) for val in h3_grid])
    hcoords_sp = np.deg2rad(hcoords)
    return hcoords_sp

In [None]:
h33_coords = get_h3_coords(resolution=3)
h33_2_h33_key, h33_2_h33_gdata, h33neigh = create_self_mapping(h33_coords, NUM_H33_NEIGHBORS, "h33", "h33", area_weights=None)
print(h33_2_h33_key, list(h33_2_h33_gdata.keys()))

In [None]:
h32_coords = get_h3_coords(resolution=2)
h32_2_h32_key, h32_2_h32_gdata, h32neigh = create_self_mapping(h32_coords, NUM_H32_NEIGHBORS, "h32", "h32", area_weights=None)
print(h32_2_h32_key, list(h32_2_h32_gdata.keys()))

In [None]:
h31_coords = get_h3_coords(resolution=1)
h31_2_h31_key, h31_2_h31_gdata, h31neigh = create_self_mapping(h31_coords, NUM_H31_NEIGHBORS, "h31", "h31", area_weights=None)
print(h31_2_h31_key, list(h31_2_h31_gdata.keys()))

In [None]:
h30_coords = get_h3_coords(resolution=0)
h30_2_h30_key, h30_2_h30_gdata, h30neigh = create_self_mapping(h30_coords, NUM_H30_NEIGHBORS, "h30", "h30", area_weights=None)
print(h30_2_h30_key, list(h30_2_h30_gdata.keys()))

In [None]:
def create_cross_mapping(src_nn, src_coords_sp, dst_coords_sp, src_label, dest_label, hdist_cutoff=200.):
    RADIUS_EARTH = 6371
    # print(f"using cut-off radius of {H_CUTOFF}")
    RADIUS_SRC_TO_DST = hdist_cutoff / RADIUS_EARTH

    src_to_dest_adjmat = src_nn.radius_neighbors_graph(
        dst_coords_sp,
        radius=RADIUS_SRC_TO_DST,
    ).tocoo()   

    src_to_dest_adjmat_norm = normalize(src_to_dest_adjmat, norm="l1", axis=1)
    src_to_dest_adjmat_norm.data = 1.0 - src_to_dest_adjmat_norm.data
    map_key = (src_label, "to", dest_label)

    has_isolated = contains_isolated_nodes(from_scipy_sparse_matrix(src_to_dest_adjmat_norm)[0])
    assert not has_isolated, "OOPS! You're left with some dangling nodes ... revisit your mapping"

    map_gdata = {
        "edge_index": torch.from_numpy(np.stack([src_to_dest_adjmat.col, src_to_dest_adjmat.row], axis=0).astype(np.int64)),
        "edge_attr": torch.from_numpy(np.expand_dims(src_to_dest_adjmat_norm.data, axis=-1).astype(np.float32)),
        "scoords_rad": torch.from_numpy(src_coords_sp.astype(np.float32)),
        "dcoords_rad": torch.from_numpy(dst_coords_sp.astype(np.float32)),
        "info": f"{src_label}_to_{dest_label} graph",
    }

    if area_weights is not None:
        map_gdata['area_weights'] = torch.from_numpy(np.array(area_weights))

    return map_key, map_gdata

In [None]:
# get coordinates (in radians)
h33_coords_sp = h33_2_h33_gdata["scoords_rad"].numpy()
h32_coords_sp = h32_2_h32_gdata["scoords_rad"].numpy()
h31_coords_sp = h31_2_h31_gdata["scoords_rad"].numpy()
h30_coords_sp = h30_2_h30_gdata["scoords_rad"].numpy()

In [None]:
# cross-mappings
# era-to-h33
era_2_h33_key, era_2_h33_gdata = create_cross_mapping(eneigh, ecoords_sp, h33_coords_sp, "era", "h33")
print(era_2_h33_key, list(era_2_h33_gdata.keys()))

In [None]:
# h33-to-h32
h33_2_h32_key, h33_2_h32_gdata = create_cross_mapping(h33neigh, h33_coords_sp, h32_coords_sp, "h33", "h32", hdist_cutoff=200.)
print(h33_2_h32_key, list(h33_2_h32_gdata.keys()))

# h32-to-h31
h32_2_h31_key, h32_2_h31_gdata = create_cross_mapping(h32neigh, h32_coords_sp, h31_coords_sp, "h32", "h31", hdist_cutoff=350.)
print(h32_2_h31_key, list(h32_2_h31_gdata.keys()))

# h31-to-h30
h31_2_h30_key, h31_2_h30_gdata = create_cross_mapping(h31neigh, h31_coords_sp, h30_coords_sp, "h31", "h30", hdist_cutoff=1000.)
print(h31_2_h30_key, list(h31_2_h30_gdata.keys()))

In [None]:
critic_graph_data = HeteroData(
    {
        era2era_key: era2era_gdata,
        h33_2_h33_key: h33_2_h33_gdata,
        h32_2_h32_key: h32_2_h32_gdata,
        h31_2_h31_key: h31_2_h31_gdata,
        h30_2_h30_key: h30_2_h30_gdata,
        # cross-mappings
        era_2_h33_key: era_2_h33_gdata,
        h33_2_h32_key: h33_2_h32_gdata,
        h32_2_h31_key: h32_2_h31_gdata,
        h31_2_h30_key: h31_2_h30_gdata,
    }

)

In [None]:
from aifs.utils.graph_gen import directional_edge_features, directional_edge_features_rotated

luse_rotated_edge_features = True

if luse_rotated_edge_features:
    edge_directions_func = directional_edge_features_rotated # relative to target node rotated to north pole
else:
    edge_directions_func = directional_edge_features # loc target node - loc source node

In [None]:
for h_ in ["h33", "h32", "h31", "h30"]:

    hhedge_dirs = []
    for n in range(critic_graph_data[(h_, "to", h_)]['edge_index'].shape[1]):
        i,j = critic_graph_data[(h_, "to", h_)]['edge_index'][:,n]
        ic = critic_graph_data[(h_, "to", h_)]['scoords_rad'][i,:]
        jc = critic_graph_data[(h_, "to", h_)]['dcoords_rad'][j,:]
        hhedge_dirs.append(edge_directions_func(ic, jc))
        
    hhedge_dirs = torch.from_numpy(np.stack(hhedge_dirs).astype(np.float32))
    hhedge_attr = torch.concat([critic_graph_data[(h_, "to", h_)]['edge_attr'], hhedge_dirs], axis=-1)
    critic_graph_data[(h_, "to", h_)]['edge_attr'] = hhedge_attr

In [None]:
meshes = ["era", "h33", "h32", "h31", "h30"]

for s_, d_ in zip(meshes[:-1], meshes[1:]):

    hhedge_dirs = []
    for n in range(critic_graph_data[(s_, "to", d_)]['edge_index'].shape[1]):
        i,j = critic_graph_data[(s_, "to", d_)]['edge_index'][:,n]
        ic = critic_graph_data[(s_, "to", d_)]['scoords_rad'][i,:]
        jc = critic_graph_data[(s_, "to", d_)]['dcoords_rad'][j,:]
        hhedge_dirs.append(edge_directions_func(ic, jc))
        
    hhedge_dirs = torch.from_numpy(np.stack(hhedge_dirs).astype(np.float32))
    hhedge_attr = torch.concat([critic_graph_data[(s_, "to", d_)]['edge_attr'], hhedge_dirs], axis=-1)
    critic_graph_data[(s_, "to", d_)]['edge_attr'] = hhedge_attr

In [None]:
critic_graph_data

In [None]:
from aifs.utils.graph_gen import plot_bipartite_from_graphdata

In [None]:
plot_bipartite_from_graphdata(
    "H31 to H30 Graph", 
    "blue",
    critic_graph_data[("h31", "to", "h30")],
    ('scoords_rad', 'dcoords_rad'),
    critic_graph_data[("h31", "to", "h30")]['scoords_rad'],
    critic_graph_data[("h31", "to", "h30")]['dcoords_rad']
)

In [None]:
plot_bipartite_from_graphdata(
    "H32 to H31 Graph", 
    "blue",
    critic_graph_data[("h32", "to", "h31")],
    ('scoords_rad', 'dcoords_rad'),
    critic_graph_data[("h32", "to", "h31")]['scoords_rad'],
    critic_graph_data[("h32", "to", "h31")]['dcoords_rad']
)

In [None]:
plot_bipartite_from_graphdata(
    "H33 to H32 Graph", 
    "blue",
    critic_graph_data[("h33", "to", "h32")],
    ('scoords_rad', 'dcoords_rad'),
    critic_graph_data[("h33", "to", "h32")]['scoords_rad'],
    critic_graph_data[("h33", "to", "h32")]['dcoords_rad']
)

In [None]:
plot_bipartite_from_graphdata(
    "ERA to H33 Graph", 
    "blue",
    critic_graph_data[("era", "to", "h33")],
    ('scoords_rad', 'dcoords_rad'),
    critic_graph_data[("era", "to", "h33")]['scoords_rad'],
    critic_graph_data[("era", "to", "h33")]['dcoords_rad']
)

In [None]:
output_dir = "/ec/res4/hpcperm/syma/gnn/"

fname = f"gan_critic_graph_mappings_normed_edge_attrs_o{era_res}_h_0_1_2_3.pt"
torch.save(critic_graph_data, os.path.join(output_dir, fname))

In [None]:
!ls -ltr /ec/res4/hpcperm/syma/gnn/