In [31]:
import torch
import torch.nn as nn
bsz = 50
n = 2000
in_channel = 200
phy_in_channel = 3
phy_out_channel = 100
out_channel = 128
liftpts = torch.rand(phy_out_channel,phy_in_channel)
liftweight = torch.rand(phy_out_channel,phy_in_channel)
x = torch.rand(bsz,n,in_channel)
class Phylift(nn.Module):
    def __init__(self,phy_in_channel,in_channel,out_channel,liftpts,liftweight):
        super(Phylift, self).__init__()
        self.phy_in_channel = phy_in_channel
        self.phy_out_channel = liftpts.shape[0]
        self.dim = liftpts.shape[1]
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.liftpts = liftpts  #shape: phy_out_channel, dim
        self.liftweight = liftweight  #shape: phy_out_channel, dim

        self.fc = nn.Linear(self.phy_out_channel - phy_in_channel + in_channel, out_channel)
    
    def forward(self,x):
        x_phy_in = x[:,:,:self.phy_in_channel]
        x_phy_out = self.compute_bases(x_phy_in)
        x = torch.cat((x_phy_out,x[:,:,self.phy_in_channel:]),dim=2)
        x = self.fc(x) 
        return x
    def compute_bases(self,x_phy_in):
        #x_phy_in.shape:  bsz,n,phy_in_channel
        x_phy_in = x_phy_in.unsqueeze(2) #bsz,n,1,phy_in_channel
        liftpts = self.liftpts.unsqueeze(0).unsqueeze(0) # 1,1,phy_out_channel,phy_in_channel
        liftweight = self.liftweight.unsqueeze(0).unsqueeze(0)  #1,1,phy_out_channel,phy_in_channel
        x_phy_out = torch.exp(-1*torch.sum(liftweight*(x_phy_in-liftpts)**2,dim=3))  #bsz,n,phy_out_channel,phy_in_channel-->bsz,n,phy_out_channel
        return x_phy_out
Phy = Phylift(phy_in_channel,in_channel,out_channel,liftpts,liftweight)
linear = nn.Linear(in_channel,out_channel)
x = Phy(x)
# x = linear(x)
print(x[:,:,-0:].shape)

torch.Size([50, 2000, 128])


In [None]:
import random
import torch
import sys
import numpy as np
import math
import matplotlib.pyplot as plt
from timeit import default_timer
from scipy.io import loadmat
import yaml
import gc

sys.path.append("../")


from models import  PhyHGkNN_train, compute_2dFourier_bases, compute_2dpca_bases, compute_2dFourier_cbases, count_params

from models.PhyHGkNN import PhyHGkNN

torch.set_printoptions(precision=16)


torch.manual_seed(0)
np.random.seed(0)


###################################
# load configs
###################################
with open('config.yml', 'r', encoding='utf-8') as f:
    config = yaml.full_load(f)

config = config["Darcy_HGkNN"]
config = dict(config)
config_data, config_model, config_train = (
    config["data"],
    config["model"],
    config["train"],
)
downsample_ratio = config_data["downsample_ratio"]
L = config_data["L"]
n_train = config_data["n_train"]
n_test = config_data["n_test"]
device = torch.device(config["train"]["device"])


###################################
# load data
###################################
# data_path = "../data/darcy_2d/piececonst_r421_N1024_smooth1"
# data1 = loadmat(data_path)
# coeff1 = data1["coeff"]
# sol1 = data1["sol"]
# del data1
# data_path = "../data/darcy_2d/piececonst_r421_N1024_smooth2"
# data2 = loadmat(data_path)
# coeff2 = data2["coeff"][:300,:,:]
# sol2 = data2["sol"][:300,:,:]
# del data2
# gc.collect()

# data_in = np.vstack((coeff1, coeff2))  # shape: 2048,421,421
# data_out = np.vstack((sol1, sol2))     # shape: 2048,421,421

data_path = "../data/darcy_2d/piececonst_r421_N1024_smooth1"
data1 = loadmat(data_path)

data_in = data1["coeff"]
data_out = data1["sol"]

print("data_in.shape:" , data_in.shape)
print("data_out.shape", data_out.shape)

Np_ref = data_in.shape[1]
grid_1d = np.linspace(0, L, Np_ref)
grid_x, grid_y = np.meshgrid(grid_1d, grid_1d)

data_in_ds = data_in[0:n_train, 0::downsample_ratio, 0::downsample_ratio]
grid_x_ds = grid_x[0::downsample_ratio, 0::downsample_ratio]
grid_y_ds = grid_y[0::downsample_ratio, 0::downsample_ratio]
data_out_ds = data_out[0:n_train, 0::downsample_ratio, 0::downsample_ratio]

# x_train, y_train are [n_data, n_x, n_channel] arrays
x_train = torch.from_numpy(
    np.stack(
        (
            data_in_ds,
            np.tile(grid_x_ds, (n_train, 1, 1)),
            np.tile(grid_y_ds, (n_train, 1, 1)),
        ),
        axis=-1,
    ).astype(np.float32)
)
y_train = torch.from_numpy(data_out_ds[:, :, :, np.newaxis].astype(np.float32))
# x_test, y_test are [n_data, n_x, n_channel] arrays
x_test = torch.from_numpy(
    np.stack(
        (
            data_in[-n_test:, 0::downsample_ratio, 0::downsample_ratio],
            np.tile(grid_x[0::downsample_ratio, 0::downsample_ratio], (n_test, 1, 1)),
            np.tile(grid_y[0::downsample_ratio, 0::downsample_ratio], (n_test, 1, 1)),
        ),
        axis=-1,
    ).astype(np.float32)
)
y_test = torch.from_numpy(
    data_out[-n_test:, 0::downsample_ratio, 0::downsample_ratio, np.newaxis].astype(
        np.float32
    )
)

x_train = x_train.reshape(x_train.shape[0], -1, x_train.shape[-1])   # shape: 800,11236,3  (11236 = 106*106 , 106-1 = (421-1) /4)
x_test = x_test.reshape(x_test.shape[0], -1, x_test.shape[-1])
y_train = y_train.reshape(y_train.shape[0], -1, y_train.shape[-1])   # shape: 800,11236,1
y_test = y_test.reshape(y_test.shape[0], -1, y_test.shape[-1])
print("x_train.shape: ",x_train.shape)
print("y_train.shape: ",y_train.shape)




####################################
#compute pca bases
####################################
k_max = max(config_model["GkNN_mode_in"],config_model["GkNN_mode_out"])
Np = (Np_ref + downsample_ratio - 1) // downsample_ratio
pca_data_in = data_in_ds.reshape((data_in_ds.shape[0], -1))
pca_data_out = data_out_ds.reshape((data_out_ds.shape[0], -1))
# if config_model["pca_include_input"]:
#     pca_data = np.vstack(
#         (pca_data, data_in_ds.reshape((data_in_ds.shape[0], -1)))
#     )
# if config_model["pca_include_grid"]:
#     n_grid = 1
#     pca_data = np.vstack((pca_data, np.tile(grid_x_ds, (n_grid, 1))))
#     pca_data = np.vstack((pca_data, np.tile(grid_y_ds, (n_grid, 1))))

# percentage = 0.1
# mask1 = torch.rand(pca_data_in.shape) > percentage
# mask2 = torch.rand(pca_data_out.shape) > percentage
# pca_data_in = (torch.from_numpy(pca_data_in)*mask1).numpy()
# pca_data_out = (torch.from_numpy(pca_data_out)*mask2).numpy()


print("Start SVD with data shape: ", pca_data_out.shape, flush = True)

bases_pca_in, wbases_pca_in = compute_2dpca_bases(Np , k_max , L,  pca_data_in)
bases_pca_in, wbases_pca_in = bases_pca_in.to(device), wbases_pca_in.to(device)

bases_pca_out, wbases_pca_out = compute_2dpca_bases(Np , k_max , L,  pca_data_out)
bases_pca_out, wbases_pca_out = bases_pca_out.to(device), wbases_pca_out.to(device)




###################################
#compute kernel bases
###################################

H_in = 0
H_out = 0


bases_list = [ bases_pca_out, wbases_pca_out, 0,0]
###################################
#construct model and train
###################################
model = PhyHGkNN(bases_list, **config_model).to(device)





In [4]:
import torch
def uniform_points(num_pts,dim,range_pts):
    a = int(torch.pow(torch.tensor(num_pts),1/dim))
    index_tensors = []
    for k in range(dim):
        xmin,xmax = range_pts[k][0],range_pts[k][1]
        idx = xmin + (xmax-xmin)*torch.arange(a).float().add(0.5).div(a)
        idx = idx.view((1,) * k+ (-1,) + (1,) * (dim - k - 1))
        index_tensors.append(idx.expand(a, *([a] * (dim - 1))))
    num_pts1 = int(torch.pow(torch.tensor(a),dim))
    x = torch.stack(index_tensors, dim=dim).reshape(num_pts1,dim)
    return x
x = uniform_points(16,2,[[-1,1],[-1,1]])
print(x.shape,x)

torch.Size([16, 2]) tensor([[-0.7500, -0.7500],
        [-0.7500, -0.2500],
        [-0.7500,  0.2500],
        [-0.7500,  0.7500],
        [-0.2500, -0.7500],
        [-0.2500, -0.2500],
        [-0.2500,  0.2500],
        [-0.2500,  0.7500],
        [ 0.2500, -0.7500],
        [ 0.2500, -0.2500],
        [ 0.2500,  0.2500],
        [ 0.2500,  0.7500],
        [ 0.7500, -0.7500],
        [ 0.7500, -0.2500],
        [ 0.7500,  0.2500],
        [ 0.7500,  0.7500]])


In [67]:
import torch
from timeit import default_timer
# def fast_mul(x,wbases,base_mask):
#     bsz = x.shape[0]
#     channel_in = x.shape[1]
#     modes_in = wbases.shape[-1]
#     t4 = default_timer()
#     x_hat = torch.zeros(bsz,channel_in,modes_in).to(x.device)
#     t5 = default_timer()
#     k=0
#     x_mask = x[:,:,base_mask[:,k]]
#     t6 = default_timer()
#     wbases_mask = wbases[:,base_mask[:,k],k]
#     t7 = default_timer()
#     x_hat[:,:,k] = torch.einsum("bcx,bx->bc", x_mask, wbases_mask)
#     t8 = default_timer()
#     print(t5-t4,t6-t5,t7-t6,t8-t7)
    
#     return x_hat

b = 30
c = 100
k = 1000
n = 50000
# device = 'cuda'
# x = torch.rand(b,c,n).to(device)
# wbases = torch.rand(b,n,k).to(device)
# base_mask = torch.zeros(n,k).to(device)
# base_mask[:3,:] = 1
# base_mask = base_mask.bool()
# t1 = default_timer()
# x_hat = torch.einsum("bcx,bxk->bck", x, wbases)
# t2 = default_timer()
# x_hat = fast_mul(x,wbases,base_mask)
# t3 = default_timer()
# print(t2-t1)
# print(t3-t2)
t = default_timer()
x_hat = torch.zeros(b,c,k)
t_1 = default_timer()
x_hat = x_hat.to('cuda')
t_2 = default_timer()
print(t_1-t)
print(t_2-t_1)

0.001279299998714123
0.004285800001525786
