In [1]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader

from models.fno import FNO3d
from models.lploss import LpLoss

In [2]:
data_path = 'data/l63'

In [3]:
# Data path
mesh_path = os.path.join(data_path, 'mesh.npy')
priorpdfn_path = os.path.join(data_path, 'prior.npy')

# Read numpy matrices
mesh = torch.from_numpy(np.load(mesh_path).reshape((1, 40, 40, 40, 3)))
priorpdfn = torch.from_numpy(np.load(priorpdfn_path).reshape((-1, 40, 40, 40, 1)))

In [4]:
priorpdfn.shape, mesh.shape

(torch.Size([499, 40, 40, 40, 1]), torch.Size([1, 40, 40, 40, 3]))

In [5]:
p = torch.concat([priorpdfn, mesh.repeat(499, 1, 1, 1, 1)], dim=-1)
p = F.pad(p, (0, 0, 0, 0, 0, 0, 0, 0, 1, 0), 'constant', 0)
p.shape

torch.Size([500, 40, 40, 40, 4])

In [6]:
mode1 = 10
mode2 = 10
mode3 = 10
width = 36
model = FNO3d(mode1, mode2, mode3, width)

In [7]:
def get_dataset(p, n_steps=10):
    xs = []
    ys = []
    for i in range(0, p.shape[0]-n_steps):
        xs.append(p[i])
        ys.append(p[i+n_steps, ..., 0].unsqueeze(-1))
    x = torch.stack(xs, dim=0)
    y = torch.stack(ys, dim=0)
    return TensorDataset(x, y)

In [8]:
full_ds = get_dataset(p, n_steps=5)
train_size = int(0.8 * len(full_ds))
test_size = len(full_ds) - train_size
train_ds, test_ds = torch.utils.data.random_split(full_ds, 
                                                  [train_size, test_size])


In [9]:
batch_size = 8
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=True)

In [10]:
epochs = 50
e_start = 0
learning_rate = 0.001
scheduler_step = 4
scheduler_gamma = 0.85
learning_rate

0.001

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
myloss = LpLoss(size_average=False)

In [12]:
train_l2 = 0.0
for ep in range(1,epochs+1):
    model.train()
    train_l2 = 0
    counter = 0
    for x, y in train_dl:
        
        optimizer.zero_grad()
    
        pred = model(x)
        ori_loss = 0

        num_examples = x.shape[0]
        
        # original loss
        # for i in range(batch_size):
        ori_loss += myloss(pred.reshape(num_examples, -1), 
                           y.reshape(num_examples, -1))

        loss = ori_loss
        
        loss.backward()
        optimizer.step()
        train_l2 += loss.item()

        counter += 1
        if counter % 10 == 0:
            print(f'epoch: {ep}, batch: {counter}/{len(train_dl)}, train loss: {loss.item()/batch_size:.4f}')
        
    scheduler.step()

    print(f'epoch: {ep}, train loss: {train_l2/len(train_size):.4f}')
    
    # lr_ = optimizer.param_groups[0]['lr']
    # if ep % 2 == 0:
    #     PATH = f'saved_models/dP_UFNO_{ep}ep_{width}width_{mode1}m1_{mode2}m2_{train_a.shape[0]}train_{lr_:.2e}lr'
    #     torch.save(model, PATH)

epoch: 1, batch: 10/50, train loss: 0.9962
epoch: 1, batch: 20/50, train loss: 0.9883
epoch: 1, batch: 30/50, train loss: 0.9832
epoch: 1, batch: 40/50, train loss: 0.9736
epoch: 1, batch: 50/50, train loss: 0.4801
epoch: 1, train loss: 7.7964
epoch: 2, batch: 10/50, train loss: 0.9580
epoch: 2, batch: 20/50, train loss: 0.9524
epoch: 2, batch: 30/50, train loss: 0.9006
epoch: 2, batch: 40/50, train loss: 0.9096
epoch: 2, batch: 50/50, train loss: 0.4031
epoch: 2, train loss: 7.1181
epoch: 3, batch: 10/50, train loss: 0.7348
epoch: 3, batch: 20/50, train loss: 0.7332
epoch: 3, batch: 30/50, train loss: 0.7152
epoch: 3, batch: 40/50, train loss: 0.5432
epoch: 3, batch: 50/50, train loss: 0.3024
epoch: 3, train loss: 5.1304
epoch: 4, batch: 10/50, train loss: 0.3083
epoch: 4, batch: 20/50, train loss: 0.4711
epoch: 4, batch: 30/50, train loss: 0.4791
epoch: 4, batch: 40/50, train loss: 0.4380
epoch: 4, batch: 50/50, train loss: 0.3351
epoch: 4, train loss: 4.1128
epoch: 5, batch: 10/50, 

KeyboardInterrupt: 