In [3]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt


def get_neighbors(base_rows, base_cols, i, j, method, scale=1):
    if type(scale) != int: raise ValueError("scale must be int")
    
    neighbors = []

    if method == "grid" or method == "dense_grid":
        directions = [(-scale, 0), (scale, 0), (0, -scale), (0, scale)]
        if method == "dense_grid":
            directions += [(-scale, -scale), (-scale, scale), (scale, -scale), (scale, scale)]

        for di, dj in directions:
            neighbor_row = i + di
            neighbor_col = (j + dj) % base_cols  # Periodic in east-west
            if 0 <= neighbor_row < base_rows:
                neighbors.append((neighbor_row, neighbor_col))

    return neighbors


def construct_adjacency_list_core(grid_size, method="grid", scales=None, origins=None, verbose=False):
    """
    Helper function: generates the adjacency list for a 24x72 grid of nodes.

    Params:
        method (str): base grid method. Options:
            "grid": connect each node to its 4 up-down-left-right neighbors.
            "dense_grid": connect each node to its 8 neighbors (including diagonals).
        scales (list of int): list of scales for generating downscaled grids. Each scale indicates the reduction factor.
        origins (list of tuple): list of origin tuples (row, col) for each downscaled grid. Should match the length of scales.

    Returns:
        np.array: adjacency list of shape (2, num_edges) 
    """
    rows, cols = grid_size
    adjacency_list = []

    if not scales:
        if origins: 
            raise ValueError("Cannot set `origins` if your scale is 1")
        scales = [1]

    if not origins:
        origins = [(0,0) for i in range(len(scales))]

    if len(scales) != len(origins):
            raise ValueError("`scales` and `origins` must have the same length.")

    for scale, origin in zip(scales, origins):
        for i in range(origin[0], rows, scale):
            for j in range(origin[1], cols, scale):
                curr_index = i * cols + j 

                neighbors = get_neighbors(rows, cols, i, j, method, scale=scale)
                if verbose: print(f"node {i,j} has neighbors {neighbors}")
                for neighbor_i, neighbor_j in neighbors:
                    neighbor_index = neighbor_i * cols + neighbor_j 
                    if verbose: print(f"appending {(curr_index, neighbor_index)}")
                    adjacency_list.append((curr_index, neighbor_index))

    adj_t = np.array(adjacency_list).T
    return adj_t


def construct_adjacency_list(method):
    """ 
    Wrapper function for constructing adjacency list with some presets
    """
    grid_size = (24, 72)

    if method == "simple_grid": 
        return construct_adjacency_list_core(grid_size, method="grid", scales=[1], origins=[(0,0)])
    
    elif method == "simple_grid_dense":
        return construct_adjacency_list_core(grid_size, method="dense_grid", scales=[1], origins=[(0,0)])

    elif method == "multimesh1":
        return construct_adjacency_list_core(grid_size, method="dense_grid", 
                                            scales=[1, 3, 6], origins=[(0,0), (1,1), (4,1)])
    
    else:
        raise NotImplementedError(f"That adjacency method has not been implemented!\
             Current settings: simple_grid, simple_grid_dense, multimesh1 ")

In [None]:
data = xr

In [4]:

adjacency_list = construct_adjacency_list("simple_grid_dense")

# Generate coordinates for each node in the grid
rows,cols = 24,72
y, x = np.meshgrid(range(rows), range(cols), indexing='ij')
node_coords = np.vstack([x.ravel(), y.ravel()]).T  # Shape (num_nodes, 2)

# Identify valid nodes (not NaN)
valid_nodes = ~np.isnan(data.ravel())

# Plot edges for each scale
colors = plt.cm.Spectral(np.linspace(0,1,10))

scale_edges = adjacency_list.T

# Filter edges based on valid nodes
valid_edges = scale_edges[
    valid_nodes[scale_edges[:, 0]] & valid_nodes[scale_edges[:, 1]]
]

# Plot edges for the current scale
for edge in valid_edges:
    x_coords = node_coords[edge, 0]
    y_coords = node_coords[edge, 1]
    plt.plot(x_coords, y_coords, color="tab:blue", alpha=0.6, linewidth=scale/3)

# Plot nodes
plt.scatter(node_coords[valid_nodes, 0], node_coords[valid_nodes, 1], c='k', s=3, zorder=3)




array([[   0,    0,    0, ..., 1727, 1727, 1727],
       [  72,   71,    1, ..., 1656, 1654, 1584]])