In [1]:
%run ./FNO.py

In [2]:
%run ./Data_Generator/KFRDG.ipynb

# Configuration

In [3]:
################################################################
#  configurations for Data Generation
################################################################

Alpha = 4.85
InitialSolve = 350
Ndt = 20 #the labels are 20 dt after the input
TrainingSamples = 1500

In [4]:
################################################################
#  configurations for FNO
################################################################
ntrain = 1400
ntest = 100

#sub = 2**3 #subsampling rate
#h = 2**13 // sub #total grid size divided by the subsampling rate
S = 207

batch_size = 20
learning_rate = 0.001

epochs = 500
step_size = 50
gamma = 0.5

modes = 32
width = 85

# Data Generation

In [46]:
Udict = DataGenerator_Alpha(Alpha, TrainingSamples, Ndt, InitialSolve)

amp_0alpha_4.85 =-=-=-=-=-= 23.88s
Completed


In [5]:
################################################################
# read data
################################################################

path_x = 'Data_Generator/' + f'input_{Alpha}_{InitialSolve}.npy'
path_y = 'Data_Generator/' + f'u_results_{Alpha}_{InitialSolve}.npy'

x_data = torch.tensor(np.load(path_x, allow_pickle=True))
y_data = torch.tensor(np.load(path_y, allow_pickle=True))


x_train = x_data[:ntrain,:]
y_train = y_data[:ntrain,:]
x_test = x_data[-ntest:,:]
y_test = y_data[-ntest:,:]

x_train = x_train.reshape(ntrain,S,1)
x_test = x_test.reshape(ntest,S,1)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)

# FNO Model

In [6]:
# model
model = FNO1d(modes, width).cuda()
print(count_params(model))

1890232


In [7]:
################################################################
# training and evaluation
################################################################
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model(x)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()

            out = model(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

    train_mse /= len(train_loader)
    train_l2 /= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    print(ep, t2-t1, train_mse, train_l2, test_l2)

# torch.save(model, 'model/ns_fourier_burgers')
pred = torch.zeros(y_test.shape)
index = 0
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False)
with torch.no_grad():
    for x, y in test_loader:
        test_l2 = 0
        x, y = x.cuda(), y.cuda()

        out = model(x).view(-1)
        pred[index] = out

        test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
        print(index, test_l2)
        index = index + 1
        
print(pred.shape)
print(y_test.shape)
print(x_train.shape)
#scipy.io.savemat('pred/burger_test.mat', mdict={'pred': pred.cpu().numpy()})

0 1.1272477423772216 0.031298035506292114 0.1534573997770037 0.04905508518218994
1 0.7603605482727289 0.0008403324264301253 0.04036103395479066 0.04879530310630798
2 0.7440196797251701 0.000489117051516327 0.02977736209120069 0.019376173615455627
3 0.7269809236750007 0.0004033459542759894 0.02743099895971162 0.019024231135845185
4 0.7273609824478626 0.00016833948133613115 0.018104669536863054 0.01722657769918442
5 0.7373598450794816 0.0002574764794969399 0.021913211132798874 0.018290875554084776
6 0.7344430815428495 9.510350604874215e-05 0.013918937199882098 0.010766306966543198
7 0.7336311750113964 0.0001722032796221486 0.018074611074158125 0.010226675420999526
8 0.7448282958939672 0.00010519168892122771 0.014241629711219242 0.013946914374828338
9 0.7511806385591626 0.00013926779772321295 0.016558726845043045 0.023791267573833465
10 0.7711474699899554 0.00010751838599389884 0.0140356586341347 0.010214814692735672
11 0.7487950511276722 6.109329575078196e-05 0.011152500941285065 0.03131

96 0.7405505646020174 5.434452287188054e-06 0.0033755273265498025 0.0026485083252191545
97 0.7506527062505484 6.348387561112239e-06 0.0035965445930404324 0.0032441243901848793
98 0.7237686477601528 6.830381877469855e-06 0.003776874164385455 0.004860323891043663
99 0.7401657029986382 3.7948327108746783e-06 0.002691631025767752 0.0035767780989408495
100 0.7495083110406995 1.9874309446419957e-06 0.0020152723536427533 0.0013368580117821694
101 0.7547809462994337 1.4338275812113872e-06 0.0017006466191794192 0.002145988941192627
102 0.7515579489991069 1.4999004715069271e-06 0.0017596051788755826 0.0010494880378246307
103 0.7449469994753599 1.4715595174915767e-06 0.001746747690652098 0.0024864623323082925
104 0.7499002814292908 1.9504841532125284e-06 0.0020112960253443036 0.0013664612732827664
105 0.7431626692414284 1.197680686816836e-06 0.0015998926772070783 0.0024547486007213595
106 0.7434294046834111 1.6459307468202107e-06 0.0018007216043770314 0.00230409424751997
107 0.7420287523418665 1.

189 0.7467381153255701 4.777647167283508e-07 0.0010021025960200599 0.0006729163788259029
190 0.7792867897078395 3.4682248265685953e-07 0.0008420256412188922 0.0006450760830193758
191 0.7805562978610396 2.6932337086853814e-07 0.000749064851552248 0.000610131761059165
192 0.7518145497888327 2.6023804267083506e-07 0.0007436263508030347 0.0006452040467411279
193 0.7339570540934801 4.92096683249851e-07 0.0010061017403911268 0.0009716536104679108
194 0.7354256538674235 5.564683529460775e-07 0.0010875600117391773 0.0011455397680401802
195 0.7528037810698152 3.529116202319723e-07 0.0008481031862486686 0.0009918206185102463
196 0.7528313109651208 4.3142158066952364e-07 0.000950241490001125 0.0010382961668074131
197 0.728713802061975 5.550397887077452e-07 0.0010950382938608528 0.001271953172981739
198 0.7427267385646701 5.510881452762208e-07 0.0010887308711452143 0.0009480978734791279
199 0.7309522433206439 3.4048972900840063e-07 0.0008268226530136806 0.0004933499544858932
200 0.744167105294764 

281 0.7363962633535266 7.19130640496652e-08 0.00038787353883630463 0.0002947291079908609
282 0.7451663073152304 6.846596442804704e-08 0.0003772030581187989 0.00039145576301962136
283 0.7446323530748487 6.351806032256491e-08 0.00036394637876323294 0.0003725876705721021
284 0.7435765406116843 6.922643233256817e-08 0.0003812681052035519 0.0003788090869784355
285 0.7394215753301978 6.415883326837957e-08 0.0003671574080362916 0.0002999794529750943
286 0.7473237765952945 6.675578921390621e-08 0.0003758028375783137 0.0005342818330973387
287 0.7473590830340981 6.479220136920308e-08 0.00036818806946809803 0.0003003793442621827
288 0.7368972105905414 5.6919337459199986e-08 0.00034659754551414934 0.0003437259187921882
289 0.7246648240834475 6.87101066075359e-08 0.0003829410229809582 0.0003501249244436622
290 0.7573105348274112 5.6687199670843064e-08 0.0003434314059891871 0.00031932236161082985
291 0.7568568382412195 5.7581034416744255e-08 0.00034848382962601526 0.0003381259087473154
292 0.7889793

373 0.7811820898205042 2.9409799877961567e-08 0.0002464495033824018 0.0002358058001846075
374 0.7585863694548607 3.103045609671134e-08 0.00025416016478889755 0.0002577746473252773
375 0.7151645524427295 3.033091654661543e-08 0.0002504755070965205 0.000270080016925931
376 0.729824255220592 2.8861349627667162e-08 0.0002447493986359664 0.0002523773675784469
377 0.7397495293989778 3.052771610104596e-08 0.0002511699494373586 0.00027105739340186117
378 0.7447861824184656 2.8894324545701368e-08 0.0002445282238269491 0.00024390653241425752
379 0.748251480050385 2.848474185412897e-08 0.00024282796574490412 0.0002396025415509939
380 0.7355793546885252 3.1304803473273884e-08 0.0002538147392416639 0.00028509462252259256
381 0.7323496825993061 3.1755459265322575e-08 0.00025668101900789354 0.0003030300745740533
382 0.7367129009217024 2.9308787067147282e-08 0.0002464608753299607 0.00024217858910560608
383 0.724913795478642 2.8593515821739857e-08 0.00024382384859823755 0.00024602035526186225
384 0.734

464 0.8023784989491105 2.2970430763840407e-08 0.00021642210444302432 0.00024057054426521063
465 0.78696358948946 2.2995407952411793e-08 0.00021690150573184448 0.0002210840815678239
466 0.7365239029750228 2.2981288072543295e-08 0.0002173657775191324 0.0002192973392084241
467 0.7392793716862798 2.311024895261328e-08 0.00021755932298089776 0.00022158892825245856
468 1.0596603564918041 2.2504166378009163e-08 0.00021500704660346465 0.0002193714864552021
469 1.1299679949879646 2.3595313791702212e-08 0.00022047028161718376 0.0002198617300018668
470 1.126612019725144 2.3027924456187066e-08 0.00021741541418513017 0.00021932173520326616
471 1.0621932530775666 2.3117323395232298e-08 0.00021778199289526257 0.00022160829044878482
472 0.9849836062639952 2.280506062086423e-08 0.00021656425586635513 0.0002195146307349205
473 0.9534960966557264 2.2757062558196952e-08 0.0002158209668206317 0.00022130475379526616
474 0.9140056427568197 2.2773031105316478e-08 0.00021644216213774468 0.00021854617632925512


In [8]:
torch.save(model, f'KFR_FNO_skiptype{Ndt}_alpha{Alpha}_trainingsamples{TrainingSamples}')