In [None]:
import numpy as np
import torch
from pyproj import Geod
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

### Functions

In [None]:
class GeodesicPath:
    def __init__(self):
        self.geod = Geod(ellps='WGS84')

    def geodesic_path_adaptive(
        self,
        lon1,
        lat1,
        lon2,
        lat2,
        points_per_degree=2,
    ):
        dist_deg = max(abs(lon2 - lon1), abs(lat2 - lat1))
        n_points = max(2, int(dist_deg * points_per_degree))
        points = self.geod.npts(lon1, lat1, lon2, lat2, n_points)
        if not points:
            return [lon1, lon2], [lat1, lat2]
        lons = [lon1] + [p[0] for p in points] + [lon2]
        lats = [lat1] + [p[1] for p in points] + [lat2]
        return lons, lats

In [None]:
def get_nodes(graph, key):
    nodes = graph[key]['x'].numpy()
    nodes = np.stack((
        nodes[:, 0] * 180 / np.pi,
        np.mod(nodes[:, 1] * 180 / np.pi, 360),
    ), axis=1)
    return nodes

In [None]:
def get_edges(graph, key_source, key_target):
    return graph[key_source, 'to', key_target]['edge_index'].numpy()

In [None]:
def compute_target_connectivity(num_nodes, edges):
    connectivity = np.zeros(num_nodes, dtype=int)
    for source, target in edges.T:
        connectivity[target] += 1
    return connectivity

In [None]:
def get_example_targets(num_nodes, edges, seed=31415):
    rng = np.random.default_rng(seed=seed)
    connectivity = compute_target_connectivity(num_nodes, edges)
    unique_connectivity = np.unique(connectivity)
    examples = []
    for connect in unique_connectivity:
        examples.append(rng.permutation([i for i in range(num_nodes) if connectivity[i] == connect])[0])
    return examples

In [None]:
def plot_data_hidden_graph(
    nodes_data,
    nodes_hidden,
    edges_data_hidden,
    edges_hidden_hidden,
    the_target,
    figsize=(12, 9),
    figname=None,
):
    geod = GeodesicPath()
    lat, lon = nodes_hidden[the_target]
    projection = ccrs.Orthographic(lon, lat)
    
    fig = plt.figure(figsize=figsize)
    ax = plt.axes(projection=projection)
    ax.set_global()
    ax.coastlines(linewidth=0.6)
    ax.add_feature(cfeature.LAND, edgecolor='black', facecolor='lightgray')
    ax.gridlines(linewidth=0.3, linestyle='--', alpha=0.5)

    for source, target in edges_hidden_hidden.T:
        lons, lats = geod.geodesic_path_adaptive(
            nodes_hidden[source, 1],
            nodes_hidden[source, 0],
            nodes_hidden[target, 1],
            nodes_hidden[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color='grey',
            transform=ccrs.Geodetic(),
            alpha=0.5,
            zorder=1,
        )

    sources = []
    for source, target in edges_data_hidden.T:
        if target != the_target:
            continue
        sources.append(source)
        lons, lats = geod.geodesic_path_adaptive(
            nodes_data[source, 1],
            nodes_data[source, 0],
            nodes_hidden[target, 1],
            nodes_hidden[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color='b',
            transform=ccrs.Geodetic(),
            alpha=0.7,
            zorder=2,
        )

    ax.scatter(
        nodes_data[:, 1],
        nodes_data[:, 0],
        marker='o',
        color='r',
        s=10,
        edgecolor='black',
        linewidth=0.3,
        transform=ccrs.PlateCarree(),
        alpha=0.9,
        zorder=3,
    )
    ax.scatter(
        nodes_data[sources, 1],
        nodes_data[sources, 0],
        marker='o',
        color='g',
        s=30,
        edgecolor='black',
        linewidth=0.3,
        transform=ccrs.PlateCarree(),
        alpha=0.9,
        zorder=4,
    )
    if figname is None:
        plt.show()
    else:
        plt.savefig(figname)
        plt.close()

In [None]:
def plot_hidden_graph(
    nodes_hidden,
    edges_multi_scale,
    edges_single_scale,
    the_target,
    figsize=(12, 9),
    figname=None,
):
    geod = GeodesicPath()
    lat, lon = nodes_hidden[the_target]
    projection = ccrs.Orthographic(lon, lat)
    
    fig = plt.figure(figsize=figsize)
    ax = plt.axes(projection=projection)
    ax.set_global()
    ax.coastlines(linewidth=0.6)
    ax.add_feature(cfeature.LAND, edgecolor='black', facecolor='lightgray')
    ax.gridlines(linewidth=0.3, linestyle='--', alpha=0.5)

    for source, target in edges_single_scale.T:
        lons, lats = geod.geodesic_path_adaptive(
            nodes_hidden[source, 1],
            nodes_hidden[source, 0],
            nodes_hidden[target, 1],
            nodes_hidden[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color='grey',
            transform=ccrs.Geodetic(),
            alpha=0.5,
            zorder=1,
        )

    sources = []
    for source, target in edges_multi_scale.T:
        if target != the_target:
            continue
        sources.append(source)
        lons, lats = geod.geodesic_path_adaptive(
            nodes_hidden[source, 1],
            nodes_hidden[source, 0],
            nodes_hidden[target, 1],
            nodes_hidden[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color='b',
            transform=ccrs.Geodetic(),
            alpha=0.7,
            zorder=2,
        )

    ax.scatter(
        nodes_hidden[sources, 1],
        nodes_hidden[sources, 0],
        marker='o',
        color='g',
        s=30,
        edgecolor='black',
        linewidth=0.3,
        transform=ccrs.PlateCarree(),
        alpha=0.9,
        zorder=4,
    )
    if figname is None:
        plt.show()
    else:
        plt.savefig(figname)
        plt.close()

In [None]:
def plot_fine_to_coarse_hidden_graph(
    nodes_fine,
    nodes_coarse,
    edges_fine,
    edges_coarse,
    edges,
    the_target,
    figsize=(12, 9),
    figname=None,
):
    geod = GeodesicPath()
    lat, lon = nodes_coarse[the_target]
    projection = ccrs.Orthographic(lon, lat)
    
    fig = plt.figure(figsize=figsize)
    ax = plt.axes(projection=projection)
    ax.set_global()
    ax.coastlines(linewidth=0.6)
    ax.add_feature(cfeature.LAND, edgecolor='black', facecolor='lightgray')
    ax.gridlines(linewidth=0.3, linestyle='--', alpha=0.5)

    for source, target in edges_fine.T:
        lons, lats = geod.geodesic_path_adaptive(
            nodes_fine[source, 1],
            nodes_fine[source, 0],
            nodes_fine[target, 1],
            nodes_fine[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color=(0.7, 0.7, 0.7),
            transform=ccrs.Geodetic(),
            zorder=1,
        )

    for source, target in edges_coarse.T:
        lons, lats = geod.geodesic_path_adaptive(
            nodes_coarse[source, 1],
            nodes_coarse[source, 0],
            nodes_coarse[target, 1],
            nodes_coarse[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color=(0.4, 0.4, 0.4),
            transform=ccrs.Geodetic(),
            zorder=1.5,
        )

    sources = []
    for source, target in edges.T:
        if target != the_target:
            continue
        sources.append(source)
        lons, lats = geod.geodesic_path_adaptive(
            nodes_fine[source, 1],
            nodes_fine[source, 0],
            nodes_coarse[target, 1],
            nodes_coarse[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color='b',
            transform=ccrs.Geodetic(),
            alpha=0.7,
            zorder=2,
        )

    ax.scatter(
        nodes_fine[sources, 1],
        nodes_fine[sources, 0],
        marker='o',
        color='g',
        s=30,
        edgecolor='black',
        linewidth=0.3,
        transform=ccrs.PlateCarree(),
        alpha=0.9,
        zorder=4,
    )
    if figname is None:
        plt.show()
    else:
        plt.savefig(figname)
        plt.close()

In [None]:
def plot_coarse_to_fine_hidden_graph(
    nodes_fine,
    nodes_coarse,
    edges_fine,
    edges_coarse,
    edges,
    the_target,
    figsize=(12, 9),
    figname=None,
):
    geod = GeodesicPath()
    lat, lon = nodes_fine[the_target]
    projection = ccrs.Orthographic(lon, lat)
    
    fig = plt.figure(figsize=figsize)
    ax = plt.axes(projection=projection)
    ax.set_global()
    ax.coastlines(linewidth=0.6)
    ax.add_feature(cfeature.LAND, edgecolor='black', facecolor='lightgray')
    ax.gridlines(linewidth=0.3, linestyle='--', alpha=0.5)

    for source, target in edges_fine.T:
        lons, lats = geod.geodesic_path_adaptive(
            nodes_fine[source, 1],
            nodes_fine[source, 0],
            nodes_fine[target, 1],
            nodes_fine[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color=(0.7, 0.7, 0.7),
            transform=ccrs.Geodetic(),
            zorder=1,
        )

    for source, target in edges_coarse.T:
        lons, lats = geod.geodesic_path_adaptive(
            nodes_coarse[source, 1],
            nodes_coarse[source, 0],
            nodes_coarse[target, 1],
            nodes_coarse[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color=(0.4, 0.4, 0.4),
            transform=ccrs.Geodetic(),
            zorder=1.5,
        )

    sources = []
    for source, target in edges.T:
        if target != the_target:
            continue
        sources.append(source)
        lons, lats = geod.geodesic_path_adaptive(
            nodes_coarse[source, 1],
            nodes_coarse[source, 0],
            nodes_fine[target, 1],
            nodes_fine[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color='b',
            transform=ccrs.Geodetic(),
            alpha=0.7,
            zorder=2,
        )

    ax.scatter(
        nodes_coarse[sources, 1],
        nodes_coarse[sources, 0],
        marker='o',
        color='g',
        s=30,
        edgecolor='black',
        linewidth=0.3,
        transform=ccrs.PlateCarree(),
        alpha=0.9,
        zorder=4,
    )
    if figname is None:
        plt.show()
    else:
        plt.savefig(figname)
        plt.close()

In [None]:
def plot_hidden_data_graph(
    nodes_data,
    nodes_hidden,
    edges_hidden_data,
    edges_hidden_hidden,
    the_target,
    figsize=(12, 9),
    figname=None,
):
    geod = GeodesicPath()
    lat, lon = nodes_data[the_target]
    projection = ccrs.Orthographic(lon, lat)
    
    fig = plt.figure(figsize=figsize)
    ax = plt.axes(projection=projection)
    ax.set_global()
    ax.coastlines(linewidth=0.6)
    ax.add_feature(cfeature.LAND, edgecolor='black', facecolor='lightgray')
    ax.gridlines(linewidth=0.3, linestyle='--', alpha=0.5)

    for source, target in edges_hidden_hidden.T:
        lons, lats = geod.geodesic_path_adaptive(
            nodes_hidden[source, 1],
            nodes_hidden[source, 0],
            nodes_hidden[target, 1],
            nodes_hidden[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color='grey',
            transform=ccrs.Geodetic(),
            alpha=0.5,
            zorder=1,
        )

    sources = []
    for source, target in edges_hidden_data.T:
        if target != the_target:
            continue
        sources.append(source)
        lons, lats = geod.geodesic_path_adaptive(
            nodes_hidden[source, 1],
            nodes_hidden[source, 0],
            nodes_data[target, 1],
            nodes_data[target, 0],
        )
        ax.plot(
            lons,
            lats,
            color='b',
            transform=ccrs.Geodetic(),
            alpha=0.7,
            zorder=2,
        )

    ax.scatter(
        nodes_data[:, 1],
        nodes_data[:, 0],
        marker='o',
        color='r',
        s=15,
        edgecolor='black',
        linewidth=0.3,
        transform=ccrs.PlateCarree(),
        alpha=0.9,
        zorder=3,
    )
    ax.scatter(
        nodes_hidden[sources, 1],
        nodes_hidden[sources, 0],
        marker='o',
        color='g',
        s=30,
        edgecolor='black',
        linewidth=0.3,
        transform=ccrs.PlateCarree(),
        alpha=0.9,
        zorder=4,
    )
    if figname is None:
        plt.show()
    else:
        plt.savefig(figname)
        plt.close()

In [None]:
def plot_targets(
    name,
    targets,
    plot_function,
    *args,
):
    for target in targets:
        figname = f'figs/{name}_target_{target}.pdf'
        print(figname)
        plot_function(
            *args,
            the_target=target,
            figname=figname,
        )

### Retrieve graph from file

In [None]:
graph_raw = torch.load('wdir/final_02.pt', weights_only=False)

In [None]:
graph = torch.load('wdir/final_01.pt', weights_only=False)

In [None]:
nodes_data = get_nodes(graph, 'data')
nodes_hidden_3 = get_nodes(graph, 'hidden_3')
nodes_hidden_2 = get_nodes(graph, 'hidden_2')
nodes_hidden_1 = get_nodes(graph, 'hidden_1')
edges_hidden_3_data = get_edges(graph, 'hidden_3', 'data')
edges_data_hidden_3 = get_edges(graph, 'data', 'hidden_3')
edges_hidden_3 = get_edges(graph, 'hidden_3', 'hidden_3')
edges_hidden_2 = get_edges(graph, 'hidden_2', 'hidden_2')
edges_hidden_1 = get_edges(graph, 'hidden_1', 'hidden_1')
edges_hidden_3_single_scale = get_edges(graph_raw, 'hidden_3', 'hidden_3')
edges_hidden_2_single_scale = get_edges(graph_raw, 'hidden_2', 'hidden_2')
edges_hidden_1_single_scale = get_edges(graph_raw, 'hidden_1', 'hidden_1')
edges_hidden_3_2 = get_edges(graph, 'hidden_3', 'hidden_2')
edges_hidden_2_1 = get_edges(graph, 'hidden_2', 'hidden_1')
edges_hidden_2_3 = get_edges(graph, 'hidden_2', 'hidden_3')
edges_hidden_1_2 = get_edges(graph, 'hidden_1', 'hidden_2')

### Data -> hidden 3

In [None]:
examples = get_example_targets(len(nodes_hidden_3), edges_data_hidden_3)

In [None]:
plot_targets(
    '01_data_to_hidden_3',
    examples,
    plot_data_hidden_graph,
    nodes_data,
    nodes_hidden_3,
    edges_data_hidden_3,
    edges_hidden_3_single_scale,
)

### hidden 3 -> hidden 3

In [None]:
examples = get_example_targets(len(nodes_hidden_3), edges_hidden_3)

In [None]:
plot_targets(
    '02_hidden_3_to_hidden_3',
    examples,
    plot_hidden_graph,
    nodes_hidden_3,
    edges_hidden_3,
    edges_hidden_3_single_scale,
)

### hidden 3 -> hidden 2

In [None]:
examples = get_example_targets(len(nodes_hidden_2), edges_hidden_3_2)

In [None]:
plot_targets(
    '03_hidden_3_to_hidden_2',
    examples,
    plot_fine_to_coarse_hidden_graph,
    nodes_hidden_3,
    nodes_hidden_2,
    edges_hidden_3_single_scale,
    edges_hidden_2_single_scale,
    edges_hidden_3_2,
)

### hidden 2 -> hidden 2

In [None]:
examples = get_example_targets(len(nodes_hidden_2), edges_hidden_2)

In [None]:
plot_targets(
    '04_hidden_2_to_hidden_2',
    examples,
    plot_hidden_graph,
    nodes_hidden_2,
    edges_hidden_2,
    edges_hidden_2_single_scale,
)

### hidden 2 -> hidden 1

In [None]:
examples = get_example_targets(len(nodes_hidden_1), edges_hidden_2_1)

In [None]:
plot_targets(
    '05_hidden_2_to_hidden_1',
    examples,
    plot_fine_to_coarse_hidden_graph,
    nodes_hidden_2,
    nodes_hidden_1,
    edges_hidden_2_single_scale,
    edges_hidden_1_single_scale,
    edges_hidden_2_1,
)

### hidden 1 -> hidden 1

In [None]:
examples = get_example_targets(len(nodes_hidden_1), edges_hidden_1)

In [None]:
plot_targets(
    '06_hidden_1_to_hidden_1',
    examples,
    plot_hidden_graph,
    nodes_hidden_1,
    edges_hidden_1,
    edges_hidden_1_single_scale,
)

### hidden 1 -> hidden 2

In [None]:
examples = get_example_targets(len(nodes_hidden_2), edges_hidden_1_2)

In [None]:
plot_targets(
    '07_hidden_1_to_hidden_2',
    examples,
    plot_coarse_to_fine_hidden_graph,
    nodes_hidden_2,
    nodes_hidden_1,
    edges_hidden_2_single_scale,
    edges_hidden_1_single_scale,
    edges_hidden_1_2,
)

### hidden 2 -> hidden 3

In [None]:
examples = get_example_targets(len(nodes_hidden_3), edges_hidden_2_3)

In [None]:
plot_targets(
    '08_hidden_2_to_hidden_3',
    examples,
    plot_coarse_to_fine_hidden_graph,
    nodes_hidden_3,
    nodes_hidden_2,
    edges_hidden_3_single_scale,
    edges_hidden_2_single_scale,
    edges_hidden_2_3,
)

### hidden 2 -> hidden 3

In [None]:
examples = get_example_targets(len(nodes_data), edges_hidden_3_data)

In [None]:
plot_targets(
    '09_hidden_3_to_data',
    examples,
    plot_hidden_data_graph,
    nodes_data,
    nodes_hidden_3,
    edges_hidden_3_data,
    edges_hidden_3_single_scale,
)