In [4]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from utilities import *
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import NNConv
from timeit import default_timer

torch.manual_seed(0)
np.random.seed(0)

COMMENTLINES = False

In [2]:
class MGKN(torch.nn.Module):
    def __init__(self, width, ker_width, depth, ker_in, in_width, s):
        super(MGKN, self).__init__()
        self.depth = depth
        self.width = width
        self.s = s
        self.level = int(np.log2(s)-1)

        # P
        self.fc1 = torch.nn.Linear(in_width, width)

        # K_ll
        self.conv_list = []
        for l in range(self.level + 1):
            ker_width_l = max( ker_width // (2**l), 16)
            kernel_l = DenseNet([ker_in, ker_width_l, ker_width_l, width ** 2], torch.nn.ReLU)
            self.conv_list.append(NNConv(width, width, kernel_l, aggr='mean'))
        self.conv_list = torch.nn.ModuleList(self.conv_list)

        # Q
        self.fc2 = torch.nn.Linear(width, ker_width)
        self.fc3 = torch.nn.Linear(ker_width, 1)


    # K_{l,l+1}
    def Upsample(self, x, channels, scale, s):
        x = x.transpose(0, 1).view(1,channels,s) # (K,width) to (1, width, s)
        x = F.upsample(x, scale_factor=scale, mode='nearest') # (1, width, s) to (1, width,  s*2)
        x = x.view(channels, -1).transpose(0, 1) # (1, width, s*2, s*2) to (K*4, width)
        return x

    # K_{l+1,l}
    def Downsample(self, x, channels, scale, s):
        x = x.transpose(0, 1).view(1,channels,s) # (K,width) to (1, width,  s)
        x = F.avg_pool1d(x, kernel_size=scale)
        x = x.view(channels, -1).transpose(0, 1) # (1, width, s/2, s/2) to (K/4, width)
        return x

    def forward(self, data):
        X_list,_, edge_index_list, edge_attr_list = data
        level = len(X_list)
        x = X_list[0]
        x = self.fc1(x)
        phi = [None] * level # list of x, len=level
        for k in range(self.depth):
            # downward
            for l in range(level):
                phi[l] = x
                if (l != level - 1):
                    # downsample
                    x = self.Downsample(x, channels=self.width, scale=2, s=self.s // (2 ** l) )

            # upward
            x = F.relu(x + self.conv_list[-1](phi[-1], edge_index_list[-1], edge_attr_list[-1]))
            for l in reversed(range(level)):
                if (l != 0):
                    # upsample
                    x = self.Upsample(x, channels=self.width, scale=2, s=self.s // (2 ** l))
                    # interactive neighbors
                    x = F.relu(x + self.conv_list[l](phi[l-1], edge_index_list[l], edge_attr_list[l]))
                else:
                    x = F.relu(x + self.conv_list[0](phi[0], edge_index_list[0], edge_attr_list[0]))

        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [18]:
def multi_pole_grid1d(theta_d, s,  N, is_periodic=False):
    """
    theta: function values (shape: (batch=N, resolution=s, channels=1))
    theta_d: dim_domain
    s: resolution
    N: number of the instances
    """
    print("=========================< Multipole grid 1D >=========================")
    grid_list:              list[torch.Tensor]      = []
# theta_list:             list[torch.Tensor]      = []
    edge_index_list:        list[torch.LongTensor]  = []
    # level = int(np.log2(s) - 1)
    level = 3
    
    print(f"multi_pole_grid1d: {level}")
    
    for l in range(1, level+1):
        r_l = 2 ** (l - 1);     n_l = s_l = s // r_l
        
        grid_l = torch.linspace(0, 1, steps = s_l, dtype = torch.float)
        grid_list.append(grid_l)

        
        if COMMENTLINES:
            # theta_l = theta[:,:,:theta_d].reshape(N, s, theta_d)
            # theta_l = theta_l[:, ::r_l,  :]             # Downsampling, i.e., hierarchical subgraphs
            # # theta_l = theta_l.reshape(N, n_l, theta_d)  # Useless
            # theta_l = torch.tensor(theta_l, dtype=torch.float)
            # theta_list.append(theta_l)
            COMMENTLINES

        # For the finest level, we construct the nearest neighbors (NN)
        # Internal points: two neighbors
        # Boundary points: one neighbor
        if l==1:
            edge_index_nn = []
            for x_i in range(s_l):
                for dx in (-1,1):   # Not `range(-1, 1 + 1)`
                    x_j = x_i + dx

                    if is_periodic:
                        x_j = x_j % s_l

                    # if (xj, yj) is a valid node
                    if (x_j in range(s_l)): # Required condition for non-periodic domain
                        edge_index_nn.append([x_i,x_j])
            edge_index_nn = torch.tensor(edge_index_nn, dtype=torch.long).transpose(0, 1)
            edge_index_list.append(edge_index_nn)
        
        # Then compute the interactive neighbors
        # Their parents are NN, but they are not NN
        edge_index_inter = []
        for x_i in range(s_l):
            for dx in range(-3, 3 + 1): # for dx in [-3, -2, -1, 0, 1, 2, 3]
                x_j = x_i + dx
                # if (xj, yj) is a valid node
                if is_periodic:
                    x_j = x_j % s_l

                if (x_j in range(s_l)): # Required condition for non-periodic domain
                    if abs(dx) > 1: # if xi and xj are not NearestNeighbor
                        if abs(x_i//2 - x_j//2)%(s_l//2) <= 1: # if their parents are NN
                            edge_index_inter.append([x_i,x_j])

        
        edge_index_inter = torch.tensor(edge_index_inter, dtype=torch.long).transpose(0, 1)
        edge_index_list.append(edge_index_inter)

        if l == 3:
            print(f"Level {l}")
            print(f"* (r_l, n_l)=({r_l}, {n_l})")
            print(f"* grid_list:        {[_grid.shape       for _grid       in grid_list]}")
            print(f"* edge_index_list:  {[_edge_index.shape for _edge_index in edge_index_list]}")
            if COMMENTLINES:
                # print(f"* theta_list:       {[_theta.shape      for _theta      in theta_list]}")
                COMMENTLINES
            print()

    print("=======================================================================")
    
    return {
        'grid_list'         : grid_list,
        'edge_index_list'   : edge_index_list
    }

In [19]:
d = multi_pole_grid1d(theta_d = 2, s = 64, N = 2)
grid_list = d['grid_list']
edge_index_list = [ (2 ** cnt) * _edge_index for cnt, _edge_index in enumerate(d['edge_index_list']) ]

multi_pole_grid1d: 3
Level 3
* (r_l, n_l)=(4, 16)
* grid_list:        [torch.Size([64]), torch.Size([32]), torch.Size([16])]
* edge_index_list:  [torch.Size([2, 126]), torch.Size([2, 186]), torch.Size([2, 90]), torch.Size([2, 42])]



In [20]:
edge_index_list

[tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  9,
           9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18,
          18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27,
          27, 28, 28, 29, 29, 30, 30, 31, 31, 32, 32, 33, 33, 34, 34, 35, 35, 36,
          36, 37, 37, 38, 38, 39, 39, 40, 40, 41, 41, 42, 42, 43, 43, 44, 44, 45,
          45, 46, 46, 47, 47, 48, 48, 49, 49, 50, 50, 51, 51, 52, 52, 53, 53, 54,
          54, 55, 55, 56, 56, 57, 57, 58, 58, 59, 59, 60, 60, 61, 61, 62, 62, 63],
         [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  7,  9,  8,
          10,  9, 11, 10, 12, 11, 13, 12, 14, 13, 15, 14, 16, 15, 17, 16, 18, 17,
          19, 18, 20, 19, 21, 20, 22, 21, 23, 22, 24, 23, 25, 24, 26, 25, 27, 26,
          28, 27, 29, 28, 30, 29, 31, 30, 32, 31, 33, 32, 34, 33, 35, 34, 36, 35,
          37, 36, 38, 37, 39, 38, 40, 39, 41, 40, 42, 41, 43, 42, 44, 43, 45, 44,
          46, 4