In [1]:
import healpy as hp
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from scripts.utils.cheby_shev import ChebConv

In [2]:
nside = 512
order = 4

In [3]:
def nside2index(nside, order):
    nsample = 12 * order**2
    indexes = np.arange(hp.nside2npix(nside) // nsample)
    return indexes

In [4]:
def patch_index_wneighbor(nside, order):
    indexes = list(nside2index(nside, order))
    neighbors = hp.pixelfunc.get_all_neighbours(nside, indexes, nest=True)
    neighbors = np.unique(neighbors.reshape(-1))
    nei_patch = np.setdiff1d(neighbors, indexes)
    return nei_patch

In [5]:
l_nei_patch = {}
for i in range(np.log2(nside).astype(int)):
    tmp_patch_size = hp.nside2npix(nside)//hp.nside2npix(2**i)
    l_nei_patch[tmp_patch_size] = len(patch_index_wneighbor(nside, 2**i))

In [9]:
print(hp.nside2npix(nside)//hp.nside2npix(2**8))

4


In [6]:
print(l_nei_patch)

{262144: 2051, 65536: 1028, 16384: 516, 4096: 260, 1024: 132, 256: 68, 64: 36, 16: 20, 4: 12}


In [None]:
class SphericalChebConv_pad(nn.Module):
    def __init__(self, in_channels, out_channels, lap, pad_size, kernel_size):
        super().__init__()
        self.register_buffer("laplacian", lap)
        self.pad_size = pad_size
        self.paddding = nn.ConstantPad1d(padding=(0,pad_size), value=0)
        self.chebconv = ChebConv(in_channels, out_channels, kernel_size)

    def state_dict(self, *args, **kwargs):
        state_dict = super().state_dict(*args, **kwargs)
        del_keys = []
        for key in state_dict:
            if key.endswith("laplacian"):
                del_keys.append(key)
        for key in del_keys:
            del state_dict[key]
        return state_dict

    def forward(self, x):
        x = self.paddding(x)
        x = self.chebconv(self.laplacian, x)
        x = x[:, :-self.pad_size, :]
        return x