In [1]:
import torch
import torch.nn as nn
import numpy as np
import sympy
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(),'../..')))
from Rmse_loss import rmse_loss 
from Select import select
from All import all0
from Train_fun import train_fun
from Select import select


torch.set_default_dtype(torch.float32)
device = torch.device('cuda:0')
from warnings import filterwarnings
filterwarnings('ignore')
np.set_printoptions(suppress=True)

In [2]:
class FiveNet_sin(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.linear1 = nn.Linear(ni, 128)
        self.linear2 = nn.Linear(128,128)
        self.linear3 = nn.Linear(128, 64)
        self.linear4 = nn.Linear(64,64)
        self.linear5 = nn.Linear(64,1)

    def forward(self, x):
        x = torch.sin(self.linear1(x))
        x = torch.sin(self.linear2(x))
        x = torch.sin(self.linear3(x))
        x = torch.sin(self.linear4(x))
        x = self.linear5(x)
        return x

In [14]:
def diff_KdV(net,cor_train):
    cache = np.zeros((1,6))
    pts = cor_train.clone().detach().to(device)
    pts.requires_grad_(True)
    n_iter = len(pts)//10000
    for i in range(n_iter+1):
        pts0 = pts[10000*i:10000*(i+1)]
        outs = net(pts0)
        grad = torch.autograd.grad(outs, pts0, grad_outputs=torch.ones_like(outs), create_graph=True)[0]
        ut_t = (grad[:,0]).reshape(-1,1)
        ux_t = (grad[:,1]).reshape(-1,1)
        uxx_t  = ((torch.autograd.grad(ux_t, pts0, grad_outputs=torch.ones_like(outs), 
                create_graph=True)[0])[:,1]).reshape(-1,1)
        uxxx_t  = ((torch.autograd.grad(uxx_t, pts0, grad_outputs=torch.ones_like(outs), 
                create_graph=True)[0])[:,1]).reshape(-1,1)
        uxxxx_t  = ((torch.autograd.grad(uxxx_t, pts0, grad_outputs=torch.ones_like(outs), 
                create_graph=True)[0])[:,1]).reshape(-1,1)

        u_pred = outs.cpu().detach().numpy()
        ut = ut_t.cpu().detach().numpy()
        ux = ux_t.cpu().detach().numpy()
        uxx = uxx_t.cpu().detach().numpy()
        uxxx = uxxx_t.cpu().detach().numpy()
        uxxxx = uxxxx_t.cpu().detach().numpy()
        del ut_t, ux_t, uxx_t, uxxx_t, uxxxx_t, outs
        torch.cuda.empty_cache()
        cache = np.vstack((cache,np.hstack((u_pred,ux,uxx,uxxx,uxxxx,ut))))
    return cache[1:]

In [15]:
def KdV_data(x_min,x_max,t_min,t_max,n,m,c,x0):
    x = np.linspace(x_min,x_max,n)
    t = np.linspace(t_min,t_max,m)
    X,T = np.meshgrid(x,t)
    data = np.hstack((T.reshape(-1,1),X.reshape(-1,1)))
    t,x = data[:,0].reshape(-1,1),data[:,1].reshape(-1,1)
    u = (c/2)*(sech((c**0.5/2)*(x-c*t-x0)))**2
    data = np.hstack((data,u))
    return data

In [33]:
per = 0.05
for c in [2,6]:
    data = np.loadtxt('data/c_{}'.format(c))
    data = data[data[:,0]//0.01%4==0]
    data = data[data[:,1]//0.008%5==0]
    data = data.astype('float32')
    u_clean = data[:,-1].reshape(-1,1)
    u_noise = u_clean + per*np.std(u_clean)*np.random.randn(u_clean.shape[0],u_clean.shape[1])
    u_noise = u_noise.astype('float32')
    cor_train = torch.from_numpy(data[:,:-1])
    u_train = torch.from_numpy(u_noise).reshape(-1,1)
    net = FiveNet_sin(2).to(device)
    try:
        net.load_state_dict(torch.load('nets/net_{}_{}'.format(c,per)))
    except:
        train_fun(net,cor_train,u_train,N_red_lr=4,epochs=5000,lr=0.001)
        torch.save(net.state_dict(),'net/net_{}_{}'.format(c,per))
    cache = diff_KdV(net,cor_train)
    u_pred = cache[:,0].reshape(-1,1)
    ux = cache[:,1].reshape(-1,1)
    uxx = cache[:,2].reshape(-1,1)
    uxxx = cache[:,3].reshape(-1,1)
    uxxxx = cache[:,4].reshape(-1,1)
    ut = cache[:,5].reshape(-1,1)
    feature = np.hstack((u_pred,ux,uxx,uxxx,uxxxx,ut))
    np.savetxt('feature/feature_{}_{}'.format(c,per),feature)

0 0.8661051392555237
1 0.660590410232544
2 0.5449854731559753
3 0.5625259876251221
4 0.5942551493644714
5 0.5733059644699097
6 0.5325704216957092
7 0.5069750547409058
8 0.5078938603401184
9 0.5188020467758179
10 0.5200080275535583
11 0.5067975521087646
12 0.4852065443992615
13 0.4653133153915405
14 0.45516419410705566
15 0.453602135181427
16 0.45016685128211975
17 0.43649572134017944
18 0.4144200384616852
19 0.39311715960502625
20 0.37955212593078613
21 0.3695392310619354
22 0.35201919078826904
23 0.32345977425575256
24 0.2941672205924988
25 0.27564090490341187
26 0.2549646198749542
27 0.2218383252620697
28 0.20355655252933502
29 0.19414657354354858
30 0.17432892322540283
31 0.18256542086601257
32 0.17627814412117004
33 0.17935681343078613
34 0.1780015230178833
35 0.16561108827590942
36 0.16160430014133453
37 0.14360570907592773
38 0.13554047048091888
39 0.12242042273283005
40 0.10913180559873581
41 0.10078977048397064
42 0.08884096145629883
43 0.08718415349721909
44 0.0810517147183418

337 0.027290502563118935
338 0.02738097310066223
339 0.027471790090203285
340 0.027494750916957855
341 0.027478598058223724
342 0.027391508221626282
343 0.02729952335357666
344 0.027196353301405907
345 0.027129942551255226
346 0.027086710557341576
347 0.02709353156387806
348 0.027136148884892464
349 0.027239292860031128
0.02770893048495054 0.027536463253200055 0.006224247155412854
350 0.027386916801333427
351 0.027588387951254845
352 0.02779768779873848
353 0.02796655334532261
354 0.028055690228939056
355 0.02803955413401127
356 0.027990568429231644
357 0.027935786172747612
358 0.02789672650396824
359 0.027856510132551193
360 0.027765346691012383
361 0.027627527713775635
362 0.02745378203690052
363 0.027283329516649246
364 0.027143001556396484
365 0.027044691145420074
366 0.0269928015768528
367 0.02698832005262375
368 0.027034003287553787
369 0.02713942900300026
370 0.027305468916893005
371 0.027529027312994003
372 0.027760842815041542
373 0.02793375588953495
374 0.02798384241759777
37

59 0.02636807970702648
60 0.026359369978308678
61 0.02635882794857025
62 0.026364710181951523
63 0.026364225894212723
64 0.02635813131928444
65 0.026358118280768394
66 0.02636215090751648
67 0.02636108733713627
68 0.026357199996709824
69 0.026357604190707207
70 0.02636026218533516
71 0.026358867064118385
72 0.026356376707553864
73 0.026357287541031837
74 0.026358669623732567
75 0.026357248425483704
76 0.026355823501944542
77 0.026356879621744156
78 0.026357347145676613
79 0.026356035843491554
80 0.026355504989624023
81 0.026356343179941177
82 0.026356210932135582
83 0.026355193927884102
84 0.02635524421930313
85 0.026355702430009842
86 0.026355238631367683
87 0.026354651898145676
88 0.026354921981692314
89 0.026354992762207985
90 0.026354465633630753
91 0.026354271918535233
92 0.026354487985372543
93 0.02635428123176098
94 0.026353897526860237
95 0.026353923603892326
96 0.02635394223034382
97 0.026353638619184494
98 0.026353463530540466
99 0.026353515684604645
0.026927793100476265 0.02

391 0.02632095292210579
392 0.02632085792720318
393 0.026320764794945717
394 0.026320669800043106
395 0.026320574805140495
396 0.026320479810237885
397 0.026320386677980423
398 0.026320289820432663
399 0.026320194825530052
0.026327606439590454 0.026322560757398604 0.0001916498639338051
400 0.02632010169327259
401 0.02632000483572483
402 0.026319915428757668
403 0.026319824159145355
404 0.026319732889533043
405 0.02631964720785618
406 0.026319561526179314
407 0.02631950192153454
408 0.0263194739818573
409 0.02631952054798603
410 0.02631974220275879
411 0.02632036805152893
412 0.026321904733777046
413 0.026325540617108345
414 0.026333920657634735
415 0.026352424174547195
416 0.026389464735984802
417 0.026447826996445656
418 0.02649584971368313
419 0.026463542133569717
420 0.02636164240539074
421 0.02631942369043827
422 0.02637512981891632
423 0.02641456574201584
424 0.02636408805847168
425 0.02631799317896366
426 0.02635037526488304
427 0.026379043236374855
428 0.026342932134866714
429 0

332 0.04563288763165474
333 0.04582715407013893
334 0.0459643229842186
335 0.0459873303771019
336 0.04589565098285675
337 0.045725490897893906
338 0.04554363340139389
339 0.045394886285066605
340 0.04530949518084526
341 0.045292384922504425
342 0.04534393921494484
343 0.04544910788536072
344 0.04558815807104111
345 0.045721057802438736
346 0.04581083729863167
347 0.04582232981920242
348 0.04575889930129051
349 0.045641686767339706
0.04618081904947758 0.04567670308053493 0.010916133133163877
350 0.0455147922039032
351 0.04540805146098137
352 0.04534636810421944
353 0.045334286987781525
354 0.04537346959114075
355 0.04545000568032265
356 0.04554867744445801
357 0.04563913121819496
358 0.04569855332374573
359 0.04570412635803223
360 0.045660410076379776
361 0.04557955265045166
362 0.045491620898246765
363 0.0454159751534462
364 0.0453723669052124
365 0.045363910496234894
366 0.04539294168353081
367 0.045447416603565216
368 0.04551653936505318
369 0.045577555894851685
370 0.045616805553436

42 0.043640125542879105
43 0.04356832057237625
44 0.04354096204042435
45 0.043588124215602875
46 0.04361224174499512
47 0.04356914758682251
48 0.0435371994972229
49 0.04356352612376213
50 0.04358715936541557
51 0.04356387257575989
52 0.04353802651166916
53 0.04355018213391304
54 0.04356919974088669
55 0.043557438999414444
56 0.043537694960832596
57 0.04354435205459595
58 0.04355837404727936
59 0.043550170958042145
60 0.043536849319934845
61 0.043541986495256424
62 0.04355078563094139
63 0.04354487732052803
64 0.04353640228509903
65 0.043540261685848236
66 0.043545886874198914
67 0.04354094713926315
68 0.04353576898574829
69 0.04353964328765869
70 0.04354234039783478
71 0.04353797808289528
72 0.043535713106393814
73 0.04353884607553482
74 0.043539613485336304
75 0.043536294251680374
76 0.04353569075465202
77 0.043538082391023636
78 0.04353751242160797
79 0.04353523254394531
80 0.04353587329387665
81 0.04353712499141693
82 0.04353593289852142
83 0.043534863740205765
84 0.0435357838869094

27 0.043519191443920135
28 0.04351918399333954
29 0.043519165366888046
30 0.04351913556456566
31 0.043519098311662674
32 0.04351907968521118
33 0.043519072234630585
34 0.0435190349817276
35 0.043518997728824615
36 0.043518971651792526
37 0.04351896420121193
38 0.043518926948308945
39 0.04351889342069626
40 0.04351887106895447
41 0.04351884871721268
42 0.04351881518959999
43 0.0435187891125679
44 0.04351876303553581
45 0.04351874068379402
46 0.043518707156181335
47 0.04351867735385895
48 0.04351865500211716
49 0.04351862892508507
50 0.04351859912276268
51 0.0435185581445694
52 0.043518539518117905
53 0.04351850971579552
54 0.04351847991347313
55 0.04351845011115074
56 0.043518420308828354
57 0.04351838305592537
58 0.04351835697889328
59 0.04351833835244179
60 0.0435183010995388
61 0.043518275022506714
62 0.043518245220184326
63 0.04351821169257164
64 0.04351818934082985
65 0.04351814463734627
66 0.04351812228560448
67 0.04351808875799179
68 0.043518055230379105
69 0.043518029153347015
7

In [34]:
feature = np.zeros((1,6))
for c in [2,6]:
    cache = np.loadtxt('feature/feature_{}_{}'.format(c,per))
    feature = np.vstack((feature,cache))
feature = feature[1:]
u_pred = feature[:,0].reshape(-1,1)
ux = feature[:,1].reshape(-1,1)
uxx = feature[:,2].reshape(-1,1)
uxxx = feature[:,3].reshape(-1,1)
uxxxx = feature[:,4].reshape(-1,1)
ut = feature[:,5].reshape(-1,1)

In [35]:
rmse_loss(-6*u_pred*ux-uxxx,ut)

tensor(0.0684, device='cuda:0', dtype=torch.float64)

In [36]:
LC,exp_list,exp = all0('result_{}.txt'.format(per),feature,
                             tol=0.05,alpha=10**-5,state=1,
                             data_points=1000,niterations=200)

Started!

Expressions evaluated per second: 5.190e+05
Head worker occupation: 9.3%
Progress: 1142 / 3000 total iterations (38.067%)
Hall of Fame:
---------------------------------------------------------------------------------------------------
Complexity  Loss       Score     Equation
1           8.240e+00  1.594e+01  y = x₃
3           9.129e-01  1.100e+00  y = (-5.433 * x₁)
5           6.659e-01  1.577e-01  y = ((-3.8252 - x₀) * x₁)
7           6.636e-01  1.714e-03  y = (-0.04916 - ((3.8253 + x₀) * x₁))
8           5.522e-01  1.837e-01  y = ((-5.5848 * x₁) - sin(x₃))
9           5.484e-02  2.310e+00  y = ((-0.96839 * x₃) - ((5.9237 * x₀) * x₁))
11          5.159e-02  3.051e-02  y = ((-0.9147 * x₃) - (((5.5934 * x₀) + 0.32451) * x₁))
13          4.987e-02  1.692e-02  y = ((-0.96839 * x₃) - (((5.9237 * x₀) - (0.003902 * x₄)) * x₁...
                                  ))
15          4.816e-02  1.748e-02  y = (((-0.96839 * x₃) - (((5.9237 * x₀) - (0.003902 * x₄)) * x...
                

Started!

Expressions evaluated per second: 5.510e+05
Head worker occupation: 11.6%
Progress: 1183 / 3000 total iterations (39.433%)
Hall of Fame:
---------------------------------------------------------------------------------------------------
Complexity  Loss       Score     Equation
1           1.956e+01  1.594e+01  y = -0.052042
3           1.941e+01  3.737e-03  y = (0.048697 * x₁)
5           1.916e+01  6.733e-03  y = (0.49677 / (x₀ + 1.44))
7           1.883e+01  8.433e-03  y = (x₀ / ((-1.5509 + -1.1948) - x₀))
9           1.824e+01  1.592e-02  y = ((-0.1305 / ((-1.563 + x₀) * -1.7521)) * x₂)
11          1.801e+01  6.508e-03  y = ((-0.1305 / ((x₀ + -1.5843) * 1.8478)) * (1.8478 - x₂))
13          1.792e+01  2.433e-03  y = ((-0.1305 / ((-1.563 + x₀) * 1.8478)) * (1.8478 - (x₂ - x₀...
                                  )))
16          1.790e+01  4.138e-04  y = ((-0.1305 / ((-1.563 + x₀) * 1.8179)) * ((-0.19428 - x₂) +...
                                   exp(x₀)))
17          1.7

In [37]:
k = select(LC)
exp_list[k]

[ 1  1  1  1 -1 -1 -1 -1 -1]
[ 1  1  1  1 -1 -1 -1 -1]
[-1 -1 -1 -1 -1 -1 -1]


-6.0574417*x0 - x2