In [6]:
import os
import matplotlib.pyplot as plt
import argparse
import scipy.io as scio
import numpy as np
import torch
from tqdm import *
from utils.testloss import TestLoss
from model import Transolver_Structured_Mesh_2D
from phi.torch.flow import *

In [7]:
layer = 8
n_hidden = 64
droupout = 0.0
heads = 4
mlp_ratio = 1
lr = 0.001
slice_num = 32
unified_pos = 0
ref = 8
weight_decay = 1e-5
max_grad_norm = None
batch_size = 2
epochs = 50

In [8]:
data_path = r"C:\\Users\\onurb\\master\\PRJ_4ID22_TP\\Transolver\\PDE-Solving-StandardBenchmark\\data\\ns_20_20.npy"
save_name = "buff"
# data_path = args.data_path + '/NavierStokes_V1e-5_N1200_T20.mat'
ntrain = 16
ntest = 4
T_in = 10
T = 10
step = 2 #step is 2 since we have velx, vely

r = 1
h = int(((64 - 1) / r) + 1)

data = np.load(data_path)
print(data.shape)
#a is the frames until the time T u is the frames after the time T
train_a = data[:ntrain, ::r, ::r, :T_in][:, :h, :h, :]
train_a = train_a.reshape(train_a.shape[0], -1, train_a.shape[-1])
train_a = torch.from_numpy(train_a)
train_u = data[:ntrain, ::r, ::r, T_in:T + T_in][:, :h, :h, :]
train_u = train_u.reshape(train_u.shape[0], -1, train_u.shape[-1])
train_u = torch.from_numpy(train_u)

test_a = data[-ntest:, ::r, ::r, :T_in][:, :h, :h, :]
test_a = test_a.reshape(test_a.shape[0], -1, test_a.shape[-1])
test_a = torch.from_numpy(test_a)
test_u = data[-ntest:, ::r, ::r, T_in:T + T_in][:, :h, :h, :]
test_u = test_u.reshape(test_u.shape[0], -1, test_u.shape[-1])
test_u = torch.from_numpy(test_u)

print(train_a.shape)
print(train_u.shape)

print(test_a.shape)
print(test_u.shape)

x = np.linspace(0, 1, h)
y = np.linspace(0, 1, h)
x, y = np.meshgrid(x, y)
pos = np.c_[x.ravel(), y.ravel()]
pos = torch.tensor(pos, dtype=torch.float).unsqueeze(0)
pos_train = pos.repeat(ntrain, 1, 1)
pos_test = pos.repeat(ntest, 1, 1)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_train, train_a, train_u),
                                            batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(pos_test, test_a, test_u),
                                            batch_size=batch_size, shuffle=False)

(20, 64, 64, 40)
torch.Size([16, 4096, 10])
torch.Size([16, 4096, 10])
torch.Size([4, 4096, 10])
torch.Size([4, 4096, 10])


In [9]:
model = Transolver_Structured_Mesh_2D.Model(space_dim=2,
                                  n_layers=layer,
                                  n_hidden=n_hidden,
                                  dropout=droupout,
                                  n_head=heads,
                                  Time_Input=False,
                                  mlp_ratio=mlp_ratio,
                                  fun_dim=T_in,
                                  out_dim=2,                #!!!!Output dimenstion is 2 since we calculate a velocity field
                                  slice_num=slice_num,
                                  ref=ref,
                                  unified_pos=unified_pos,
                                  H=h, W=h).cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

print(model)

def count_parameters(model):
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        total_params += params
    print(f"Total Trainable Params: {total_params}")
    return total_params

count_parameters(model)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, epochs=epochs,
                                                steps_per_epoch=len(train_loader))
myloss = TestLoss(size_average=False)


Model(
  (preprocess): MLP(
    (linear_pre): Sequential(
      (0): Linear(in_features=12, out_features=128, bias=True)
      (1): GELU(approximate='none')
    )
    (linear_post): Linear(in_features=128, out_features=64, bias=True)
    (linears): ModuleList()
  )
  (blocks): ModuleList(
    (0-6): 7 x Transolver_block(
      (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (Attn): Physics_Attention_Structured_Mesh_2D(
        (softmax): Softmax(dim=-1)
        (dropout): Dropout(p=0.0, inplace=False)
        (in_project_x): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (in_project_fx): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (in_project_slice): Linear(in_features=16, out_features=32, bias=True)
        (to_q): Linear(in_features=16, out_features=16, bias=False)
        (to_k): Linear(in_features=16, out_features=16, bias=False)
        (to_v): Linear(in_features=16, out_features=16, bias=False)
        (t

In [10]:
test_losses = []
for ep in range(epochs):
    model.train()
    train_l2_step = 0
    train_l2_full = 0

    print(len(train_loader))
    for i, (x, fx, yy) in enumerate(train_loader):
        print(f"training data {i}")
        #print(f"fx shape {fx.shape}, yy shape {yy.shape}")
        #print(f"x {x.size()}, fx {fx.size()}, yy {yy.size()}")
        loss = 0
        x, fx, yy = x.cuda(), fx.cuda(), yy.cuda()  # x: B,4096,2    fx: B,4096,T   y: B,4096,T
        bsz = x.shape[0]

        for t in range(0, T, step):
            #print(f"t is {t}")
            y = yy[..., t:t + step]
            #print(f"x {x.shape}, fx {fx.shape}")
            im = model(x, fx=fx)  # B , 4096 , 1
            loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1))
            #print(torch.sum(torch.pow(im.reshape(1, -1) - y.reshape(1, -1),2)))
            if t == 0:
                pred = im
            else:
                pred = torch.cat((pred, im), -1)
            #we add the ground truth to the fx not th+e prediction
            #from frame (t: t+T) to (t+1: t+T+1) 
            fx = torch.cat((fx[..., step:], y), dim=-1)  # detach() & groundtruth
            #print(f"fx shape {fx.shape} im shape {im.shape} y shape {y.shape} x shape {x.shape}")

        train_l2_step += loss.item()
        train_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item()
        optimizer.zero_grad()
        loss.backward()
        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        optimizer.step()
        scheduler.step()

    test_l2_step = 0
    test_l2_full = 0

    model.eval()

    with torch.no_grad():
        for x, fx, yy in test_loader:
            loss = 0
            x, fx, yy = x.cuda(), fx.cuda(), yy.cuda()  # x : B, 4096, 2  fx : B, 4096  y : B, 4096, T
            bsz = x.shape[0]
            for t in range(0, T, step):
                y = yy[..., t:t + step]
                im = model(x, fx=fx)
                loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1))
                if t == 0:
                    pred = im
                else:
                    pred = torch.cat((pred, im), -1)
                fx = torch.cat((fx[..., step:], im), dim=-1)

            test_l2_step += loss.item()
            test_l2_full += myloss(pred.reshape(bsz, -1), yy.reshape(bsz, -1)).item()
        
        test_losses.append(test_l2_step)

    print(
        "Epoch {} , train_step_loss:{:.5f} , train_full_loss:{:.5f} , test_step_loss:{:.5f} , test_full_loss:{:.5f}".format(
            ep, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step),
                test_l2_full / ntest))

    if ep % 100 == 0:
        if not os.path.exists('./checkpoints'):
            os.makedirs('./checkpoints')
        print('save model')
        torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))

print(test_losses)
if not os.path.exists('./checkpoints'):
    os.makedirs('./checkpoints')
print('save model')
torch.save(model.state_dict(), os.path.join('./checkpoints', save_name + '.pt'))

8
training data 0
training data 1
training data 2
training data 3
training data 4
training data 5
training data 6
training data 7
Epoch 0 , train_step_loss:0.72347 , train_full_loss:0.72365 , test_step_loss:0.57462 , test_full_loss:0.57708
save model
8
training data 0
training data 1
training data 2
training data 3
training data 4
training data 5
training data 6
training data 7
Epoch 1 , train_step_loss:0.34407 , train_full_loss:0.34459 , test_step_loss:0.39500 , test_full_loss:0.39781
8
training data 0
training data 1
training data 2
training data 3
training data 4
training data 5
training data 6
training data 7
Epoch 2 , train_step_loss:0.24314 , train_full_loss:0.24370 , test_step_loss:0.33578 , test_full_loss:0.34028
8
training data 0
training data 1
training data 2
training data 3
training data 4
training data 5
training data 6
training data 7
Epoch 3 , train_step_loss:0.21378 , train_full_loss:0.21415 , test_step_loss:0.27619 , test_full_loss:0.27980
8
training data 0
training da