# Direct prediction based on MscaleFNO

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
from neuralop.models import TFNO,FNO
import torch
import torch.nn.functional as F
import numpy as np
torch.manual_seed(0)
np.random.seed(0)
import logging
import datetime
import glob
from utilities3 import *

In [None]:
def read_data(file_paths, inlen=16, outlen=16):
    eta = []
    scatter = []
    len = inlen + outlen
    for f in file_paths:
        reader = MatReader(f, to_torch=False)
        eta_tmp = reader.read_field('eta')
        scatter_tmp = reader.read_field('scatter')
        eta_tmp = eta_tmp[:,:,:,117:437]
        scatter_tmp = scatter_tmp[:,:,:,117:437]
        #print(eta_tmp.shape,scatter_tmp.shape)
        slice_eta = []
        slice_scatter = []
        for i in range(10):
            slice_eta.append(eta_tmp[:,:,:,32*i:32*(i+1)])
            slice_scatter.append(scatter_tmp[:,:,:,32*i:32*(i+1)])
        eta_tmp = np.concatenate(slice_eta,axis=0)
        scatter_tmp = np.concatenate(slice_scatter,axis=0)
        #print(eta_tmp.shape,scatter_tmp.shape)
        x = reader.read_field('x').flatten()
        y = reader.read_field('y').flatten()
        t = reader.read_field('t').flatten()
        t = t[0:32] 
        if eta == []:
            eta = eta_tmp
            scatter = scatter_tmp
        else:
            eta = np.concatenate([eta,eta_tmp])
            scatter = np.concatenate([scatter,scatter_tmp])
    return eta, scatter, x, y, t

def split_data(eta, scatter, train_test_split=[0.6,0.2,0.2], inlen=16, outlen=16, sub=2, to_torch=False):
    train_size = int(eta.shape[0]*train_test_split[0])
    test_size = int(eta.shape[0]*train_test_split[2])
    train_a = scatter[:train_size,::sub,::sub,:inlen]
    train_e = eta[:train_size,::sub,::sub,:inlen]
    train_u = eta[:train_size,::sub,::sub,inlen:inlen+outlen]
    val_a = scatter[train_size:train_size+test_size,::sub,::sub,:inlen]
    val_e = eta[train_size:train_size+test_size,::sub,::sub,:inlen]
    val_u = eta[train_size:train_size+test_size,::sub,::sub,inlen:inlen+outlen]
    test_a = scatter[train_size+test_size:,::sub,::sub,:inlen]
    test_e = eta[train_size+test_size:,::sub,::sub,:inlen]
    test_u = eta[train_size+test_size:,::sub,::sub,inlen:inlen+outlen]
    if to_torch:
        train_a = torch.from_numpy(train_a)
        train_e = torch.from_numpy(train_e)
        train_u = torch.from_numpy(train_u)
        val_a = torch.from_numpy(val_a)
        val_e = torch.from_numpy(val_e)
        val_u = torch.from_numpy(val_u)
        test_a = torch.from_numpy(test_a)
        test_e = torch.from_numpy(test_e)
        test_u = torch.from_numpy(test_u)
    return train_a, train_e, train_u, val_a, val_e, val_u, test_a, test_e, test_u

# read and split data, for each datafile, spilit it in to train,val and test then concatenate them, add hs and spread tag in the filename, length the same as the test array
def read_and_split(file_paths, train_test_split=[0.6,0.2,0.2], inlen=16, outlen=16, sub=2, to_torch=False):
    eta, scatter, x, y, t = read_data([file_paths[0]], inlen=inlen, outlen=outlen)
    train_a, train_e, train_u, val_a, val_e, val_u, test_a, test_e, test_u = split_data(eta, scatter, train_test_split=train_test_split, inlen=inlen, outlen=outlen, sub=sub, to_torch=to_torch)
    hs = float(file_paths[0].split('angle')[1].split('h')[1].split('.mat')[0]) * np.ones(test_a.shape[0])   
    spread = float(file_paths[0].split('angle')[1].split('h')[0]) * np.ones(test_a.shape[0]) 

    for i,f in enumerate(file_paths):
        if i>0:
            eta, scatter, x, y, t = read_data([f], inlen=inlen, outlen=outlen)
            train_a_tmp, train_e_tmp, train_u_tmp, val_a_tmp, val_e_tmp, val_u_tmp, test_a_tmp, test_e_tmp, test_u_tmp = split_data(eta, scatter, train_test_split=train_test_split, inlen=inlen, outlen=outlen, sub=sub, to_torch=to_torch)
            train_a = np.concatenate([train_a,train_a_tmp])
            train_e = np.concatenate([train_e,train_e_tmp])
            train_u = np.concatenate([train_u,train_u_tmp])
            val_a = np.concatenate([val_a,val_a_tmp])
            val_e = np.concatenate([val_e,val_e_tmp])
            val_u = np.concatenate([val_u,val_u_tmp])
            test_a = np.concatenate([test_a,test_a_tmp])
            test_e = np.concatenate([test_e,test_e_tmp])
            test_u = np.concatenate([test_u,test_u_tmp])
            hs = np.concatenate([hs,float(f.split('angle')[1].split('h')[1].split('.mat')[0]) * np.ones(test_a_tmp.shape[0])])
            spread = np.concatenate([spread,float(f.split('angle')[1].split('h')[0]) * np.ones(test_a_tmp.shape[0])])
    if to_torch:
        train_a = torch.from_numpy(train_a)
        train_e = torch.from_numpy(train_e)
        train_u = torch.from_numpy(train_u)
        val_a = torch.from_numpy(val_a)
        val_e = torch.from_numpy(val_e)
        val_u = torch.from_numpy(val_u)
        test_a = torch.from_numpy(test_a)
        test_e = torch.from_numpy(test_e)
        test_u = torch.from_numpy(test_u)
        hs = torch.from_numpy(hs)
        spread = torch.from_numpy(spread)

    return train_a, train_e, train_u, val_a, val_e, val_u, test_a, test_e, test_u, x, y, t, hs, spread

dirs = os.listdir("./mixed_data")
dirs = [os.path.join("./mixed_data",d) for d in dirs]
# add another dir
dirs2 = os.listdir("./mixed_data2")
dirs2 = [os.path.join("./mixed_data2",d) for d in dirs2]
dirs = dirs + dirs2
train_a, train_e, train_u, val_a, val_e, val_u, test_a, test_e, test_u,x_domain,y_domain,t_domain,hs, spread = read_and_split(dirs,inlen=16,outlen=16,sub=2,to_torch=True)

In [None]:
# parameter and logging setup
modesxy = 32
modest = 8
width = 64
n_net = 1
cnnfusion = False
learning_rate = 0.001
epochs = 500
batch_size = 20
iterations = epochs*(train_u.shape[0]//batch_size)
prefix = "dirpred"
path = 'exp/'+prefix+ '_m' + str(modesxy) + str(modest) + '_w' + str(width)+ '_n' + str(n_net)+ '_cnnfusion'+ str(cnnfusion) 
if not os.path.exists(path):
    os.mkdir(path)
path_model = path+'/model'
path_train_err = path+'/train.txt'
path_test_err = path+'/test.txt'
path_image = path+'/image'
if not os.path.exists(path_image):
    os.mkdir(path_image)
logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
sh = logging.StreamHandler()
log_file = path+'/train'+ datetime.datetime.now().strftime('%Y%m%d_%H%M%S')+'.log'
fh = logging.FileHandler(log_file,encoding="UTF-8")
formator = logging.Formatter(fmt = '%(asctime)s : %(message)s')
sh.setFormatter(formator)
fh.setFormatter(formator)
logger.addHandler(sh)
logger.addHandler(fh)

In [None]:
sub = 2  
S = 128 // sub
T_in = 16
T = 16 

ntrain = train_a.shape[0]
nval = val_a.shape[0]
ntest = test_a.shape[0]
train_a = train_a.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1])
val_a = val_a.reshape(nval,S,S,1,T_in).repeat([1,1,1,T,1])
test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1])

In [None]:
# multi-scale FNO which is a combination of N FNO
class multiscaleFNO(nn.Module):
    def __init__(self,modes1,modes2,modes3,width,n_net):
        super(multiscaleFNO, self).__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.n_net = n_net
        self.fno_subnets = nn.ModuleList([FNO(n_modes=(self.modes1,self.modes2,self.modes3),
                          n_layers=4,hidden_channels=self.width, 
                          lifting_channels = 2*self.width, 
                          projection_channels=2*self.width,in_channels=T_in+3, out_channels=1) for _ in range(n_net)])
        #self.cs = nn.Parameter(torch.randn(n_net), requires_grad=True)
        # give initial value to cs
        self.cs = torch.nn.Parameter(torch.pow(2,torch.arange(-1,n_net-1,dtype=torch.float32)),requires_grad=True) 
        #self.cs = torch.nn.Parameter(torch.arange(1,n_net+1,dtype=torch.float32),requires_grad=True) 
        #self.cs = torch.nn.Parameter(torch.tensor([1.0,80.0,160.0,200.0,240.0,280.0,360.0,400.0]),requires_grad=True) 
        self.gammas = torch.nn.Parameter(torch.tensor([1.0/n_net]*n_net),requires_grad=True)

    def forward(self,x):
        results = []
        for i in range(self.n_net):
            scaled_x = self.cs[i] * x
            subnet_output = self.fno_subnets[i](scaled_x)
            weighted_output = self.gammas[i] * subnet_output
            results.append(weighted_output)
        u = sum(results)
        return u

# testcase, input shape [batch, chanel, S, S, T_in]
#input = torch.randn(10, 19, 32, 32, 16)
#model = multiscaleFNO(8, 8, 8, 32, n_net=4)
#output = model(input)
#print("input shape: ", input.shape)
#print("output shape: ", output.shape)

In [None]:
class multiscaleFNOcnnFusion(nn.Module):
    def __init__(self, modes1, modes2, modes3, width, n_net):
        super(multiscaleFNOcnnFusion, self).__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.n_net = n_net
        
        self.fno_subnets = nn.ModuleList([
            FNO(n_modes=(self.modes1, self.modes2, self.modes3),
                n_layers=4,
                hidden_channels=self.width, 
                lifting_channels=2*self.width,
                projection_channels=2*self.width,
                in_channels=T_in+3,
                out_channels=1) 
            for _ in range(n_net)
        ])
        
        self.cs = torch.nn.Parameter(
            torch.pow(2, torch.arange(-1, n_net-1, dtype=torch.float32)),
            requires_grad=True
        )
        
        self.fusion_net = nn.Sequential(
            nn.Conv3d(n_net, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),

            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),

            nn.Conv3d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),

            nn.Conv3d(32, 1, kernel_size=1),
        )

    def forward(self, x):
        # 获取每个子网络的输出
        subnet_outputs = []
        for i in range(self.n_net):
            scaled_x = self.cs[i] * x
            subnet_output = self.fno_subnets[i](scaled_x)
            subnet_outputs.append(subnet_output)
        
        # 将所有输出堆叠在一起 [batch, n_net, x, y, t]
        stacked_outputs = torch.stack(subnet_outputs, dim=1).squeeze(2)
        # stacked_outputs.shape = [batch, n_net, x, y, t]
        #print("stacked_outputs shape: ", stacked_outputs.shape)
        # 通过融合网络 
        fused_output = self.fusion_net(stacked_outputs)
        
        # 移除多余的维度 [batch, 1, x, y, t] -> [batch, x, y, t]
        #output = fused_output.squeeze(1)
        
        return fused_output

# testcase, input shape [batch, chanel, S, S, T_in]
#input = torch.randn(10, 19, 64, 64, 16)
#model = multiscaleFNOcnnFusion(8, 8, 8, 32, n_net=8)
#output = model(input)
#print("input shape: ", input.shape)
#print("output shape: ", output.shape)

In [9]:
class FNO_dir(nn.Module):
    def __init__(self, modes1, modes2, modes3, width,net,cnnfusion=False):
        super(FNO_dir, self).__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        '''
        self.fno_r = TFNO(n_modes=(self.modes1, self.modes2, self.modes3), hidden_channels=self.width,
                        in_channels=19,
                        out_channels=1,
                        factorization='tucker',
                        implementation='factorized',
                        n_layers=4,
                        rank=0.05)
        self.fno_p = TFNO(n_modes=(self.modes1, self.modes2, self.modes3), hidden_channels=self.width,
                        in_channels=19,
                        out_channels=1,
                        factorization='tucker',
                        implementation='factorized',
                        n_layers=4,
                        rank=0.05)
        '''
        #self.fno_r = FNO(n_modes=(self.modes1,self.modes2,self.modes3), n_layers=4,hidden_channels=self.width, lifting_channels = self.width, projection_channels=self.width,in_channels=19, out_channels=1)
        #self.fno_p = FNO(n_modes=(self.modes1,self.modes2,self.modes3), n_layers=4,hidden_channels=self.width, lifting_channels = self.width, projection_channels=self.width,in_channels=19, out_channels=1)
        if cnnfusion:
            self.fno_r = multiscaleFNOcnnFusion(self.modes1, self.modes2, self.modes3, self.width, n_net=net)
            #self.fno_p = multiscaleFNOcnnFusion(self.modes1, self.modes2, self.modes3, self.width, n_net=net)
        else:
            if net == 1:
                self.fno_r = FNO(n_modes=(self.modes1,self.modes2,self.modes3), n_layers=4,hidden_channels=self.width, lifting_channels = 2*self.width, projection_channels=2*self.width,in_channels=T_in+3, out_channels=1)
            else:
                self.fno_r = multiscaleFNO(self.modes1, self.modes2, self.modes3, self.width, n_net=net)
            #self.fno_p = multiscaleFNO(self.modes1, self.modes2, self.modes3, self.width, n_net=net)
    def forward(self,x):
        #x = x.permute(0,4,2,3,1).repeat([1,1,1,1,10])
        x = self.get_grid(x)
        x = self.fno_r(x)
        return x   
    def get_grid(self,x):
        # x[batch,x,y,t,c]
        # output[batch,c+3,x,y,t]
        shape,device = x.shape,x.device
        batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3]
        gridx = torch.tensor(np.linspace(0, x_domain[-1]-x_domain[0], size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
        gridy = torch.tensor(np.linspace(0, y_domain[-1]-y_domain[0], size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
        gridz = torch.tensor(np.linspace(0, t_domain[-1]-t_domain[0], size_z), dtype=torch.float)
        gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
        grid = torch.cat((gridx, gridy, gridz), dim=-1).to(device)
        x = torch.cat((x,grid),dim=-1)
        x = x.permute(0,4,1,2,3)
        return x

In [None]:
def ssp_loss(pred, true, ax=(1, 2)):
    fpred = torch.fft.fftn(pred, dim=ax)
    ftrue = torch.fft.fftn(true, dim=ax)
    norm_error = torch.norm(fpred-ftrue, dim=ax)
    norm_pred = torch.norm(fpred, dim=ax)
    norm_true = torch.norm(ftrue, dim=ax)
    ssps = norm_error / (norm_pred + norm_true)
    ssp = torch.sum(torch.mean(ssps, dim=-1))
    return ssp

device = torch.device('cuda')
model = FNO_dir(modesxy,modesxy,modest,width,n_net,cnnfusion).cuda() #change
#print(count_params(model))
logger.info("models params: %d" % count_params(model))
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations)
myloss = LpLoss(size_average=False)
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(val_a, val_u), batch_size=batch_size, shuffle=False)

2025-05-08 02:14:16,418 : models params: 134244097


In [None]:
from timeit import default_timer
gamma = 1.0
best_err = 1e10
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()
        optimizer.zero_grad()
        out = model(x)
        out = out.view(batch_size, S, S, T)

        mse = F.mse_loss(out, y, reduction='mean')
        # mse.backward()

        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward()

        optimizer.step()
        scheduler.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    model.eval()
    test_l2 = 0.0

    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.cuda(), y.cuda()

            out = model(x)
            out = out.view(batch_size, S, S, T)
#            out = y_normalizer.decode(out)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

    train_mse /= len(train_loader)
    train_l2 /= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    logger.info('epoch:%d, time:%f, train_mse:%f, train_l2:%f,  test_l2:%f'%
               (ep, t2-t1, train_mse, train_l2,test_l2))
    if test_l2<best_err:
        best_err = test_l2
        torch.save(model, path_model)
        logger.info('best model saved')

#torch.save(model, path_model)

2025-05-08 02:14:27,093 : epoch:0, time:10.659605, train_mse:0.653252, train_l2:1.167531,  test_l2:0.631481
2025-05-08 02:14:27,644 : best model saved
2025-05-08 02:14:37,069 : epoch:1, time:9.423625, train_mse:0.188966, train_l2:0.523786,  test_l2:0.458060
2025-05-08 02:14:37,706 : best model saved
2025-05-08 02:14:47,049 : epoch:2, time:9.342071, train_mse:0.116644, train_l2:0.425006,  test_l2:0.411647
2025-05-08 02:14:47,860 : best model saved
2025-05-08 02:14:57,055 : epoch:3, time:9.193546, train_mse:0.096686, train_l2:0.387420,  test_l2:0.400775
2025-05-08 02:14:57,786 : best model saved
2025-05-08 02:15:06,947 : epoch:4, time:9.157879, train_mse:0.085154, train_l2:0.366656,  test_l2:0.400742
2025-05-08 02:15:07,601 : best model saved
2025-05-08 02:15:16,767 : epoch:5, time:9.164533, train_mse:0.077212, train_l2:0.349248,  test_l2:0.391694
2025-05-08 02:15:17,519 : best model saved
2025-05-08 02:15:26,779 : epoch:6, time:9.257969, train_mse:0.067210, train_l2:0.324102,  test_l2:0

In [12]:
pred = torch.zeros(test_u.shape)
index = 0
model = torch.load(path_model).cuda()
#model = torch.load(model,path_model).cuda()
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False)
with torch.no_grad():
    for x, y in test_loader:
        test_l2 = 0
        x, y  = x.cuda(), y.cuda()

        out = model(x)
        out = out.view(1, S, S, T)
#        out = y_normalizer.decode(out)
        pred[index] = out

        test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
        #print(index, test_l2)
        logger.info('%d    %f' % (index,test_l2))
        index = index + 1

scipy.io.savemat(path+'/pred.mat', mdict={'pred': pred.cpu().numpy(), 'test_u': test_u.cpu().numpy()})

2025-05-08 03:40:19,554 : 0    0.359190
2025-05-08 03:40:19,571 : 1    0.265943
2025-05-08 03:40:19,581 : 2    0.371688
2025-05-08 03:40:19,590 : 3    0.541948
2025-05-08 03:40:19,599 : 4    0.321139
2025-05-08 03:40:19,608 : 5    0.306439
2025-05-08 03:40:19,673 : 6    0.539004
2025-05-08 03:40:19,687 : 7    0.382147
2025-05-08 03:40:19,697 : 8    0.332345
2025-05-08 03:40:19,713 : 9    0.364134
2025-05-08 03:40:19,722 : 10    0.438590
2025-05-08 03:40:19,739 : 11    0.295619
2025-05-08 03:40:19,749 : 12    0.417132
2025-05-08 03:40:19,758 : 13    0.403903
2025-05-08 03:40:19,767 : 14    0.430318
2025-05-08 03:40:19,776 : 15    0.403275
2025-05-08 03:40:19,805 : 16    0.413278
2025-05-08 03:40:19,848 : 17    0.404600
2025-05-08 03:40:19,857 : 18    0.461607
2025-05-08 03:40:19,874 : 19    0.338014
2025-05-08 03:40:19,883 : 20    0.331858
2025-05-08 03:40:19,899 : 21    0.340784
2025-05-08 03:40:19,908 : 22    0.480059
2025-05-08 03:40:19,917 : 23    0.297184
2025-05-08 03:40:19,926 : 