# Nearest-neighbor graph mappings

In [3]:
import os

import xarray as xr
import numpy as np

In [4]:
import cartopy.crs as ccrs
import cartopy.feature as cf

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [5]:
from sklearn.neighbors import NearestNeighbors

In [6]:
import dask
from dask.distributed import Client, LocalCluster

In [7]:
import torch
from torch_geometric.data import HeteroData 

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import h3

In [9]:
cluster = LocalCluster(n_workers=4, threads_per_worker=1)
client = Client(cluster)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 44335 instead


In [10]:
NUM_ERA_NEIGHBORS = 9
NUM_H3_NEIGHBORS = 7

## ERA5 -> O32

In [11]:
with client:
    era = xr.open_dataset("/ec/res4/hpcperm/syma/WeatherBench/netcdf/pl_2004.nc", chunks={"time": 5})
era

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 2 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 2 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 2 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 2 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 2 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 2 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,2 Graph Layers,293 Chunks
Type,float32,numpy.ndarray


In [12]:
era = era.assign_coords({"longitude": (((era.longitude + 180) % 360) - 180)})
era = era.sortby("longitude").sortby("latitude")
era

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 4 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 4 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 4 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 4 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 4 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 4.62 GiB 16.16 MiB Shape (1464, 13, 181, 360) (5, 13, 181, 360) Count 4 Graph Layers 293 Chunks Type float32 numpy.ndarray",1464  1  360  181  13,

Unnamed: 0,Array,Chunk
Bytes,4.62 GiB,16.16 MiB
Shape,"(1464, 13, 181, 360)","(5, 13, 181, 360)"
Count,4 Graph Layers,293 Chunks
Type,float32,numpy.ndarray


In [13]:
era.longitude.min(), era.longitude.max()

(<xarray.DataArray 'longitude' ()>
 array(-180., dtype=float32),
 <xarray.DataArray 'longitude' ()>
 array(179., dtype=float32))

In [14]:
era.latitude.min(), era.latitude.max()

(<xarray.DataArray 'latitude' ()>
 array(-90., dtype=float32),
 <xarray.DataArray 'latitude' ()>
 array(90., dtype=float32))

In [15]:
elats, elons = np.meshgrid(era.latitude.values, era.longitude.values)
ecoords = np.array([elats, elons]).T.reshape((-1, 2))
ecoords.shape

(65160, 2)

In [16]:
elats = np.array(era["latitude"])
elons = np.array(era["longitude"])
ecoords_v2 = np.stack(np.meshgrid(elats, elons, indexing="ij"), axis=-1).reshape((-1,2))
ecoords_v2.shape

(65160, 2)

In [17]:
np.max(ecoords - ecoords_v2)

0.0

In [18]:
# convert to rad
ecoords_sp = np.deg2rad(ecoords)

eneigh = NearestNeighbors(
    n_neighbors=NUM_ERA_NEIGHBORS,
    metric="haversine",
    n_jobs=4
)
eneigh.fit(ecoords_sp)

In [19]:
eadjmat = eneigh.kneighbors_graph(ecoords_sp, NUM_ERA_NEIGHBORS, mode="distance").tocoo()

In [20]:
eadjmat

<65160x65160 sparse matrix of type '<class 'numpy.float64'>'
	with 586440 stored elements in COOrdinate format>

In [21]:
with client:
    ds_o32 = xr.load_dataset("/ec/res4/scratch/syma/data/o32_t2m.grib")
ds_o32

In [22]:
olat = np.array(ds_o32["latitude"])
olon = np.array(ds_o32["longitude"])
ocoords = np.stack([olat, olon], axis=-1).reshape((-1,2))
ocoords[:10, :]

array([[ 87.86379884,   0.        ],
       [ 87.86379884,  18.        ],
       [ 87.86379884,  36.        ],
       [ 87.86379884,  54.        ],
       [ 87.86379884,  72.        ],
       [ 87.86379884,  90.        ],
       [ 87.86379884, 108.        ],
       [ 87.86379884, 126.        ],
       [ 87.86379884, 144.        ],
       [ 87.86379884, 162.        ]])

In [23]:
ocoords.shape

(5248, 2)

In [24]:
ocoords_sp = np.deg2rad(ocoords)

In [25]:
oneigh = NearestNeighbors(
    n_neighbors=NUM_H3_NEIGHBORS,
    metric="haversine",
    n_jobs=4
)
oneigh.fit(ocoords_sp)

oadjmat = oneigh.kneighbors_graph(ocoords_sp, NUM_H3_NEIGHBORS, mode="distance").tocoo()
oadjmat

<5248x5248 sparse matrix of type '<class 'numpy.float64'>'
	with 36736 stored elements in COOrdinate format>

In [26]:
eadjmat, oadjmat

(<65160x65160 sparse matrix of type '<class 'numpy.float64'>'
 	with 586440 stored elements in COOrdinate format>,
 <5248x5248 sparse matrix of type '<class 'numpy.float64'>'
 	with 36736 stored elements in COOrdinate format>)

In [27]:
from sklearn.preprocessing import normalize
eadjmat_norm = normalize(eadjmat, norm="l1", axis=1)
eadjmat_norm.data = 1.0 - eadjmat_norm.data
eadjmat_norm

<65160x65160 sparse matrix of type '<class 'numpy.float64'>'
	with 586440 stored elements in Compressed Sparse Row format>

In [28]:
eadjmat_norm[0, :].data

array([1.        , 0.94999391, 0.89999232, 0.84999766, 0.8000151 ,
       0.8000156 , 0.84999816, 0.89999282, 0.94999442])

In [29]:
oadjmat_norm = normalize(oadjmat, norm="l1", axis=1)
oadjmat_norm.data = 1.0 - oadjmat_norm.data
oadjmat_norm

<5248x5248 sparse matrix of type '<class 'numpy.float64'>'
	with 36736 stored elements in Compressed Sparse Row format>

In [30]:
oadjmat.col

array([   0,    1,   19, ..., 5229, 5244, 5230], dtype=int32)

In [31]:
oadjmat_norm[0, :].data

array([1.        , 0.91493191, 0.83195571, 0.75311239, 0.75311239,
       0.83195571, 0.91493191])

In [32]:
# includes "self"
NUM_H3_TO_ERA_NEIGHBORS = 3
NUM_ERA_TO_H3_NEIGHBORS = 12

In [33]:
# compute mappings
# H3 -> ERA aka the "decoder"
o32_to_era_adjmat = oneigh.kneighbors_graph(
    ecoords_sp,
    n_neighbors=NUM_H3_TO_ERA_NEIGHBORS,
    mode="distance",
).tocoo()

# ERA -> H3 aka the "encoder"
era_to_o32_adjmat = eneigh.kneighbors_graph(
    ocoords_sp,
    n_neighbors=NUM_ERA_TO_H3_NEIGHBORS,
    mode="distance",
).tocoo()

In [34]:
o32_to_era_adjmat.max(),o32_to_era_adjmat.min()

(0.059220323403680365, 0.0)

In [33]:
o32_to_era_adjmat

<65160x5248 sparse matrix of type '<class 'numpy.float64'>'
	with 195480 stored elements in COOrdinate format>

In [34]:
o32_to_era_adjmat_norm = normalize(o32_to_era_adjmat, norm="l1", axis=1)
o32_to_era_adjmat_norm.data = 1.0 - o32_to_era_adjmat_norm.data
o32_to_era_adjmat_norm

<65160x5248 sparse matrix of type '<class 'numpy.float64'>'
	with 195480 stored elements in Compressed Sparse Row format>

In [35]:
era_to_o32_adjmat

<5248x65160 sparse matrix of type '<class 'numpy.float64'>'
	with 62976 stored elements in COOrdinate format>

In [36]:
era_to_o32_adjmat.data[:10]

array([0.00237716, 0.0024591 , 0.0024591 , 0.00268995, 0.00268995,
       0.00303587, 0.00303587, 0.00346247, 0.00346247, 0.00394353])

In [37]:
era_to_o32_adjmat_norm = normalize(era_to_o32_adjmat, norm="l1", axis=1)
era_to_o32_adjmat_norm.data = 1.0 - era_to_o32_adjmat_norm.data
era_to_o32_adjmat_norm

<5248x65160 sparse matrix of type '<class 'numpy.float64'>'
	with 62976 stored elements in Compressed Sparse Row format>

In [38]:
tmp = era_to_o32_adjmat_norm[:10].toarray()
nz_idx = np.where(tmp != 0.)
tmp[tmp != 0].shape

(120,)

In [39]:
nz_idx

(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9]),
 array([64255, 64256, 64257, 64258, 64259, 64260, 64261, 64262, 64263,
        64264, 64265, 64266, 64272, 64273, 64274, 64275, 64276, 64277,
        64278, 64279, 64280, 64281, 64282, 64283, 64290, 64291, 64292,
        64293, 64294, 64295, 64296, 64297, 64298, 64299, 64300, 64301,
        64308, 64309, 64310, 64311, 64312, 64313, 64314, 64315, 64316,
        64317, 64318, 64319, 64326, 64327, 64328, 64329, 64330, 64331,
        64332, 64333, 64334, 64335, 64336, 64337, 64344, 64345, 64346,
        64347, 64348, 64349, 64350, 64351, 64352, 64353, 64354, 64355,
        64363, 64364, 

In [40]:
tmp[9, [1800, 1801,1802, 1803, 1804, 1805, 1806, 2155, 2156, 2157, 2158, 2159]]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [41]:
tmp = oadjmat_norm[:10, :].toarray()
nz_idx = np.where(tmp != 0.)
nz_idx

(array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3,
        3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6,
        6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9,
        9, 9, 9, 9]),
 array([ 0,  1,  2,  3, 17, 18, 19,  0,  1,  2,  3,  4, 18, 19,  0,  1,  2,
         3,  4,  5, 19,  0,  1,  2,  3,  4,  5,  6,  1,  2,  3,  4,  5,  6,
         7,  2,  3,  4,  5,  6,  7,  8,  3,  4,  5,  6,  7,  8,  9,  4,  5,
         6,  7,  8,  9, 10,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10,
        11, 12]))

In [42]:
tmp[9, [7,  8,  9 , 10, 11, 12, 13]]

array([0.83195571, 0.91493191, 1.        , 0.91493191, 0.83195571,
       0.75311239, 0.        ])

In [95]:
graphs = HeteroData(
    {
        # we should swap rows and cols here. It does not matter too much since the
        # adjacency matrix is symmetric but better be consistent
        ("h", "to", "h"): {
            "edge_index": torch.from_numpy(np.stack([oadjmat.col, oadjmat.row], axis=0).astype(np.int64)),
            "edge_attr": torch.from_numpy(np.expand_dims(oadjmat.data, axis=-1).astype(np.float32)),
            "hcoords_rad": torch.from_numpy(ocoords_sp.astype(np.float32)),
            "hinfo": "O32 grid",
        },

        # and again here
        ("era", "to", "era"): {
            "edge_index": torch.from_numpy(np.stack([eadjmat.col, eadjmat.row], axis=0).astype(np.int64)),
            "edge_attr": torch.from_numpy(np.expand_dims(eadjmat.data, axis=-1).astype(np.float32)),
            "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
        },

        # be careful, we need to swap rows and cols here
        ("h", "to", "era"): {
            "edge_index": torch.from_numpy(np.stack([o32_to_era_adjmat.col, o32_to_era_adjmat.row], axis=0).astype(np.int64)),
            "edge_attr": torch.from_numpy(np.expand_dims(o32_to_era_adjmat.data, axis=-1).astype(np.float32)),
            "hcoords_rad": torch.from_numpy(ocoords_sp.astype(np.float32)),
            "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
            "hinfo": "O32 grid",
        },

        ("era", "to", "h"): {
            "edge_index": torch.from_numpy(np.stack([era_to_o32_adjmat.col, era_to_o32_adjmat.row], axis=0).astype(np.int64)),
            "edge_attr": torch.from_numpy(np.expand_dims(era_to_o32_adjmat.data, axis=-1).astype(np.float32)),
            "hcoords_rad": torch.from_numpy(ocoords_sp.astype(np.float32)),
            "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
            "hinfo": "O32 grid",
        },
    }
)

In [107]:
graphs_normed = HeteroData(
    {
        # we should swap rows and cols here. It does not matter too much since the
        # adjacency matrix is symmetric but better be consistent
        ("h", "to", "h"): {
            "edge_index": torch.from_numpy(np.stack([oadjmat.col, oadjmat.row], axis=0).astype(np.int64)),
            "edge_attr": torch.from_numpy(np.expand_dims(oadjmat_norm.data, axis=-1).astype(np.float32)),
            "hcoords_rad": torch.from_numpy(ocoords_sp.astype(np.float32)),
            "hinfo": "O32 grid",
        },

        # and again here
        ("era", "to", "era"): {
            "edge_index": torch.from_numpy(np.stack([eadjmat.col, eadjmat.row], axis=0).astype(np.int64)),
            "edge_attr": torch.from_numpy(np.expand_dims(eadjmat_norm.data, axis=-1).astype(np.float32)),
            "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
        },

        # be careful, we need to swap rows and cols here
        ("h", "to", "era"): {
            "edge_index": torch.from_numpy(np.stack([o32_to_era_adjmat.col, o32_to_era_adjmat.row], axis=0).astype(np.int64)),
            "edge_attr": torch.from_numpy(np.expand_dims(o32_to_era_adjmat_norm.data, axis=-1).astype(np.float32)),
            "hcoords_rad": torch.from_numpy(ocoords_sp.astype(np.float32)),
            "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
            "hinfo": "O32 grid",
        },

        ("era", "to", "h"): {
            "edge_index": torch.from_numpy(np.stack([era_to_o32_adjmat.col, era_to_o32_adjmat.row], axis=0).astype(np.int64)),
            "edge_attr": torch.from_numpy(np.expand_dims(era_to_o32_adjmat_norm.data, axis=-1).astype(np.float32)),
            "hcoords_rad": torch.from_numpy(ocoords_sp.astype(np.float32)),
            "ecoords_rad": torch.from_numpy(ecoords_sp.astype(np.float32)),
            "hinfo": "O32 grid",
        },
    }
)


# Add directionality attribute

In [116]:
#
hhedge_dirs = []
for n in range(graphs_normed[("h", "to", "h")]['edge_index'].shape[1]):
    i,j = graphs_normed[("h", "to", "h")]['edge_index'][:,n]
    ic = graphs_normed[("h", "to", "h")]['hcoords_rad'][i,:]
    jc = graphs_normed[("h", "to", "h")]['hcoords_rad'][j,:]
    hhedge_dirs.append(jc - ic)
hhedge_dirs = torch.from_numpy(np.stack(hhedge_dirs).astype(np.float32))
hhedge_attr = torch.concat([graphs_normed[("h", "to", "h")]['edge_attr'],hhedge_dirs],axis=-1)

In [117]:
#
ehedge_dirs = []
for n in range(graphs_normed[("era", "to", "h")]['edge_index'].shape[1]):
    i,j = graphs_normed[("era", "to", "h")]['edge_index'][:,n]
    ic = graphs_normed[("era", "to", "h")]['ecoords_rad'][i,:]
    jc = graphs_normed[("era", "to", "h")]['hcoords_rad'][j,:]
    ehedge_dirs.append(jc - ic)
ehedge_dirs = torch.from_numpy(np.stack(ehedge_dirs).astype(np.float32))
ehedge_attr = torch.concat([graphs_normed[("era", "to", "h")]['edge_attr'],ehedge_dirs],axis=-1)

In [118]:
#
heedge_dirs = []
for n in range(graphs_normed[("h", "to", "era")]['edge_index'].shape[1]):
    i,j = graphs_normed[("h", "to", "era")]['edge_index'][:,n]
    ic = graphs_normed[("h", "to", "era")]['hcoords_rad'][i,:]
    jc = graphs_normed[("h", "to", "era")]['ecoords_rad'][j,:]
    heedge_dirs.append(jc - ic)
heedge_dirs = torch.from_numpy(np.stack(heedge_dirs).astype(np.float32))
heedge_attr = torch.concat([graphs_normed[("h", "to", "era")]['edge_attr'],heedge_dirs],axis=-1)

In [119]:
graphs_normed[("h", "to", "era")]['edge_attr'] = heedge_attr
graphs_normed[("h", "to", "h")]['edge_attr'] = hhedge_attr
graphs_normed[("era", "to", "h")]['edge_attr'] = ehedge_attr

In [120]:
graphs_normed[("era", "to", "h")]['edge_attr'].shape

torch.Size([62976, 3])

In [121]:
output_dir = "/ec/res4/hpcperm/pamc/gnn/"
torch.save(graphs, os.path.join(output_dir, "graph_mappings_o32.pt"))

output_dir = "/ec/res4/hpcperm/pamc/gnn/"
torch.save(graphs_normed, os.path.join(output_dir, "graph_mappings_normed_edge_attrs_o32.pt"))

In [47]:
!ls -l $HPCPERM/gnn

total 32768
-rw-r--r-- 1 pamc rd 19326587 Feb 12 15:35 graph_mappings_normed_edge_attrs_o32.pt
-rw-r--r-- 1 pamc rd 19326587 Feb 12 15:35 graph_mappings_o32.pt


In [122]:
!ls -l $HPCPERM/gnn

total 32768
-rw-r--r-- 1 pamc rd 21688123 Feb 12 20:37 graph_mappings_normed_edge_attrs_o32.pt
-rw-r--r-- 1 pamc rd 19326587 Feb 12 20:37 graph_mappings_o32.pt
