In [None]:
import sys
from pathlib import Path
project_root = Path.cwd().parent.absolute() # get project root path for loading modules in notebook
sys.path.insert(0, str(project_root))


import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt
from functools import reduce, partial
from timeit import default_timer

from src.utils.utils import *
from src.models.base import FNO3d
from src.models.multi_step import BOON_FNO3d

In [2]:
torch.manual_seed(0)
np.random.seed(0)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dirichlet

In [4]:
ntrain = 1000
ntest = 200

modes = 8
width = 20

batch_size = 10
batch_size2 = batch_size

epochs = 500
learning_rate = 0.001
scheduler_step = 100
scheduler_gamma = 0.5

sub = 2
N = 100 // sub #total grid size divided by the subsampling rate
S = N

T_in = 1
T = 25

In [17]:
downloader = DataDownloader()
# Use downloader.download(id = "NV_Dir_3D", tag = "Re_100") if only downloading a single file 
downloader.download(id = "NV_Dir_3D")
rw = downloader.locate(id = "NV_Dir_3D", tag = "Re_100")

train_a = rw['a'][:ntrain,::sub, ::sub,:T_in]
train_a = train_a.astype(np.float32)
train_a = torch.from_numpy(train_a)

train_u = rw['u'][:ntrain,::sub, ::sub,-T:]
train_u = train_u.astype(np.float32)
train_u = torch.from_numpy(train_u)

test_a = rw['a'][-ntest:,::sub, ::sub,:T_in]
test_a = test_a.astype(np.float32)
test_a = torch.from_numpy(test_a)

test_u = rw['u'][-ntest:,::sub, ::sub,-T:]
test_u = test_u.astype(np.float32)
test_u = torch.from_numpy(test_u)

In [20]:
train_a = train_a.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1])
test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1])

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False)

In [21]:
base_no = FNO3d(modes, modes, modes, width)
model = BOON_FNO3d(width,
                    base_no,
                    bdy_type = 'dirichlet').to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)

myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        bs, nx, ny, T, _ = x.shape
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        
        bdy_left  = y[:, 0, :, :].reshape(bs, 1, ny, T) # add extra dimension to take care of 
#                                                         model channel structure
        bdy_right = y[:,-1, :, :].reshape(bs, 1, ny, T)
        bdy_top   = y[:, :, 0, :].reshape(bs, 1, nx, T)
        bdy_down  = y[:, :,-1, :].reshape(bs, 1, nx, T)
        
        out = model(x, 
                    bdy_left = {'val':bdy_left}, 
                    bdy_right = {'val':bdy_right}, 
                    bdy_top = {'val':bdy_top}, 
                    bdy_down = {'val':bdy_down}
                ).view(bs, S, S, T)

        l2 = myloss(out.view(bs, -1), y.view(bs, -1))
        l2.backward()

        optimizer.step()
        train_l2 += l2.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            bs, nx, ny, T, _ = x.shape
            x, y = x.to(device), y.to(device)
            
            bdy_left  = y[:, 0, :, :].reshape(bs, 1, ny, T) # add extra dimension to take care of 
#                                                         model channel structure
            bdy_right = y[:,-1, :, :].reshape(bs, 1, ny, T)
            bdy_top   = y[:, :, 0, :].reshape(bs, 1, nx, T)
            bdy_down  = y[:, :,-1, :].reshape(bs, 1, nx, T)

            out = model(x,
                    bdy_left = {'val':bdy_left}, 
                    bdy_right = {'val':bdy_right}, 
                     bdy_top = {'val':bdy_top}, 
                    bdy_down = {'val':bdy_down}
                ).view(bs, S, S, T)
            test_l2 += myloss(out.view(bs, -1), y.view(bs, -1)).item()

    train_l2 /= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    print(ep, t2-t1, train_l2, test_l2)
# torch.save(model, path_model)