In [1]:
import time
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math

np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f1fdd7717f0>

Definition of the architecture as seen in : https://arxiv.org/pdf/1811.08782.pdf

In [2]:
class DGM_layer(nn.Module):
    
    def __init__(self, in_features, out_feature, residual = False):
        super(DGM_layer, self).__init__()
        self.residual = residual
        
        self.Z = nn.Linear(out_feature,out_feature) ; self.UZ = nn.Linear(in_features,out_feature, bias=False)
        self.G = nn.Linear(out_feature,out_feature) ; self.UG = nn.Linear(in_features,out_feature, bias=False)
        self.R = nn.Linear(out_feature,out_feature) ; self.UR = nn.Linear(in_features,out_feature, bias=False)
        self.H = nn.Linear(out_feature,out_feature) ; self.UH = nn.Linear(in_features,out_feature, bias=False)
    

    def forward(self, x, s):
        z = torch.relu(self.UZ(x)+self.Z(s))
        g = torch.relu(self.UG(x)+self.G(s))
        r = torch.relu(self.UR(x)+self.R(s))
        h = torch.relu(self.UH(x)+self.H(s*r))
        return (1 - g) * h + z*s
        
    


In [3]:
class DGM_net(nn.Module):
    def __init__(self, in_dim,out_dim, n_layers, n_neurons, residual = False):
        """ in_dim is number of cordinates + 1 
            out_dim is the number of output
            n_layers and n_neurons are pretty self explanatory
            make residual = true for identity between each DGM layers
        """
        super(DGM_net, self).__init__()
        self.in_dim = in_dim ; self.out_dim = out_dim
        self.n_layers = n_layers
        self.n_neurons = n_neurons
        self.residual = residual

        self.first_layer = nn.Linear(in_dim, n_neurons)
        
        self.dgm_layers = nn.ModuleList([DGM_layer(self.in_dim, self.n_neurons,
                                                       self.residual) for i in range(self.n_layers)])
        self.final_layer = nn.Linear(n_neurons,out_dim)
    
    def forward(self,x):
        s = torch.relu(self.first_layer(x))
        for i,dgm_layer in enumerate(self.dgm_layers):
            s = dgm_layer(x, s)
        
        return  self.final_layer(s)

In [11]:
# Time limits
T0 = 0.0 + 1e-10    # Initial time
T  = 1.0            # Terminal time

# Space limits
S_1 = 0.0 + 1e-10    # Low boundary
S_2 = 1              # High boundary

#viscosity limits
V1 = 1e-2
V2 = 1e-1

#alpha ??
al1 = 1e-2
al2 = 1

#Boundary condition for x = 0
a1 = -1
a2 = 1

#Boundary condition for x = 1
b1 = -1
b2 = 1

#initial condition
def g(x, a, b) : return a + x *(b-a)


Definition of domain sampling and loss function from https://github.com/adolfocorreia/DGM/, adapted to pytorch

In [12]:
def sampler_space_time(N1, N2, N3):
    # Sampler #1: PDE domain
    t1 = np.random.uniform(low=T0 - 0.5*(T - T0),
                           high=T,
                           size=[N1,1])
    s1 = np.random.uniform(low=S_1 - (S_2 - S_1)*0.5,
                           high=S_2 + (S_2 - S_1)*0.5,
                           size=[N1,1])

    # Sampler #2: boundary condition x=0
    t2 = np.random.uniform(low=T0 - 0.5*(T - T0),
                           high=T,
                           size=[N2,1])
    s2 = np.zeros(shape=(N2, 1))
    
    # Sampler #3: boundary x=1
    t3 = np.random.uniform(low=T0 - 0.5*(T - T0),
                           high=T,
                           size=[N2,1])
    s3 = np.ones(shape=(N2, 1))
    
    #sampler 4 : initial condition 
    t4 = np.zeros(shape=(N3, 1))
    s4 = np.random.uniform(low=S_1 - (S_2 - S_1)*0.5,
                           high=S_2 + (S_2 - S_1)*0.5,
                           size=[N3,1])
    
    t1=torch.tensor(t1, dtype=torch.float32, requires_grad=True).cuda()
    s1=torch.tensor(s1, dtype=torch.float32, requires_grad=True).cuda()
    t2=torch.tensor(t2, dtype=torch.float32, requires_grad=True).cuda()
    s2=torch.tensor(s2, dtype=torch.float32, requires_grad=True).cuda()
    t3=torch.tensor(t3, dtype=torch.float32, requires_grad=True).cuda()
    s3=torch.tensor(s3, dtype=torch.float32, requires_grad=True).cuda()
    t4=torch.tensor(t4, dtype=torch.float32, requires_grad=True).cuda()
    s4=torch.tensor(s4, dtype=torch.float32, requires_grad=True).cuda()
    return (t1, s1, t2, s2, t3, s3, t4, s4)

In [13]:
def sampler_parameters(N1, N2, N3):
    alpha1 = np.random.uniform(low=al1, high = al2, size=[N1,1])
    alpha2 = np.random.uniform(low=al1, high = al2, size=[N2,1])
    alpha3 = np.random.uniform(low=al1, high = al2, size=[N3,1])
    
    nu1 = np.random.uniform(low=V1, high = V2, size=[N1,1])
    nu2 = np.random.uniform(low=V1, high = V2, size=[N2,1])
    nu3 = np.random.uniform(low=V1, high = V2, size=[N3,1])
    
    a_1 = np.random.uniform(low=a1, high = a2, size=[N1,1])
    a_2 = np.random.uniform(low=a1, high = a2, size=[N2,1])
    a_3 = np.random.uniform(low=a1, high = a2, size=[N3,1])
    
    b_1 = np.random.uniform(low=b1, high = b2, size=[N1,1])
    b_2 = np.random.uniform(low=b1, high = b2, size=[N2,1])
    b_3 = np.random.uniform(low=b1, high = b2, size=[N3,1])
    
    alpha1 = torch.tensor(alpha1, dtype=torch.float32, requires_grad=True).cuda()
    alpha2 = torch.tensor(alpha2, dtype=torch.float32, requires_grad=True).cuda()
    alpha3 = torch.tensor(alpha3, dtype=torch.float32, requires_grad=True).cuda()
    
    nu1 = torch.tensor(nu1, dtype=torch.float32, requires_grad=True).cuda()
    nu2 = torch.tensor(nu2, dtype=torch.float32, requires_grad=True).cuda()
    nu3 = torch.tensor(nu3, dtype=torch.float32, requires_grad=True).cuda()

    a_1 = torch.tensor(a_1, dtype=torch.float32, requires_grad=True).cuda()
    a_2 = torch.tensor(a_2, dtype=torch.float32, requires_grad=True).cuda()
    a_3 = torch.tensor(a_3, dtype=torch.float32, requires_grad=True).cuda()


    b_1 = torch.tensor(b_1, dtype=torch.float32, requires_grad=True).cuda()
    b_2 = torch.tensor(b_2, dtype=torch.float32, requires_grad=True).cuda()
    b_3 = torch.tensor(b_3, dtype=torch.float32, requires_grad=True).cuda()
    
    return (alpha1, alpha2, alpha3, nu1, nu2, nu3, a_1, a_2, a_3, b_1, b_2, b_3)
    

In [14]:
def Loss(model, S1, S2, S3, S4):
    #Each set contain [x,t,nu,alpha,a,b]
    
    # Loss term #1: PDE
    U = model(S1)
    DU = torch.autograd.grad(U.sum(), S1, create_graph=True, retain_graph=True)[0]
    U_t = DU[:,1]
    U_x = DU[:,0]
    U_xx = torch.autograd.grad(U_x.sum(), S1, create_graph = True, retain_graph=True)[0][:,0]
    
    f = U_t - S1[:,2]*U_xx + S1[:,3] * U * U_x
    L1 = torch.mean(torch.pow(f,2))

    # Loss term #2: boundary condition x=0
    Ub1 = model(S2)
    L2 = torch.mean(torch.pow(Ub1[:,0]-S2[:,4], 2))
    
    # Loss term #2: boundary condition x=1
    Ub2 = model(S3)
    L3 = torch.mean(torch.pow(Ub2[:,0]-S3[:,5], 2))
    
    # Loss term #3: initial/terminal condition
    CI = g(S4[:,0], S4[:,4], S4[:,5])
    L4 = torch.mean(torch.pow((model(S4)[:,0] - CI) ,2))

    return L1, L2, L3, L4


Model init and training

In [18]:
model = DGM_net(6,1,6,200)
model.cuda()
opt = torch.optim.Adam(model.parameters(),lr = 0.00001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(opt,gamma=0.99)

loss = nn.MSELoss()
# Number of samples
NS_1 = 1000  #samples on domain
NS_2 = 100  #samples for BC
NS_3 = 100  #samples for IC


# Training parameters
steps_per_sample = 10
sampling_stages = 1000

In [19]:
def train_model(model, optimizer, scheduler, num_epochs=100):
    since = time.time()
    model.train()
  # Set model to training mode
    
    for epoch in range(num_epochs):
        t1, x1, t2, x2, t3, x3, t4, x4 = sampler_space_time(NS_1, NS_2, NS_3) 
        alpha1, alpha2, alpha3, nu1, nu2, nu3, a_1, a_2, a_3, b_1, b_2, b_3 = sampler_parameters(NS_1, NS_2, NS_3)
        
        S1 = torch.cat((x1,t1,nu1, alpha1, a_1, b_1), 1)
        S2 = torch.cat((x2,t2,nu2, alpha2, a_2, b_2), 1)
        S3 = torch.cat((x3,t3,nu2, alpha2, a_2, b_2), 1)
        S4 = torch.cat((x4,t4,nu3, alpha3, a_3, b_3), 1)

        scheduler.step()

        for _ in range(steps_per_sample) :

            # zero the parameter gradients
            optimizer.zero_grad()
        
            # forward
            L1, L2, L3, L4 = Loss(model, S1, S2, S3, S4)
    
            loss = L1 + L2 + L3 + L4
            # backward + optimize
            loss.backward()
            optimizer.step()
            
        epoch +=1 
        if epoch % (num_epochs//num_epochs) == 0: print(f'epoch {epoch}, loss {loss.data}, L1 : {L1.data}')
    time_elapsed = time.time() - since
    print(f"Training finished in {time_elapsed:.2f} for {num_epochs}.")
    print(f"The final loss value is {loss.data}")

In [20]:
train_model(model, opt, scheduler, sampling_stages)

epoch 1, loss 1.0830315351486206, L1 : 0.00682768365368247
epoch 2, loss 1.059787631034851, L1 : 0.006554476451128721
epoch 3, loss 1.1586031913757324, L1 : 0.00684357201680541
epoch 4, loss 0.9778107404708862, L1 : 0.006685047876089811
epoch 5, loss 1.0076286792755127, L1 : 0.006252388935536146
epoch 6, loss 1.0590533018112183, L1 : 0.005900115240365267
epoch 7, loss 1.0388705730438232, L1 : 0.005843978375196457
epoch 8, loss 1.0774928331375122, L1 : 0.005365019664168358
epoch 9, loss 1.100426197052002, L1 : 0.005127577111124992
epoch 10, loss 1.0812653303146362, L1 : 0.00538421468809247
epoch 11, loss 0.9620516300201416, L1 : 0.004902665037661791
epoch 12, loss 1.0113743543624878, L1 : 0.005016155540943146
epoch 13, loss 0.872666597366333, L1 : 0.004897871054708958
epoch 14, loss 0.9823595285415649, L1 : 0.004773371387273073
epoch 15, loss 0.9100465774536133, L1 : 0.004677505698055029
epoch 16, loss 1.0025830268859863, L1 : 0.004732554778456688
epoch 17, loss 0.9774953126907349, L1 :

epoch 136, loss 0.15537381172180176, L1 : 0.041822001338005066
epoch 137, loss 0.18847277760505676, L1 : 0.03669014573097229
epoch 138, loss 0.1719084084033966, L1 : 0.03978312388062477
epoch 139, loss 0.17672927677631378, L1 : 0.0429484024643898
epoch 140, loss 0.1905364692211151, L1 : 0.044138599187135696
epoch 141, loss 0.18946939706802368, L1 : 0.044099386781454086
epoch 142, loss 0.17744134366512299, L1 : 0.04297901690006256
epoch 143, loss 0.17004844546318054, L1 : 0.0434664785861969
epoch 144, loss 0.16876475512981415, L1 : 0.04558437317609787
epoch 145, loss 0.18766725063323975, L1 : 0.038686223328113556
epoch 146, loss 0.17345505952835083, L1 : 0.03977752476930618
epoch 147, loss 0.1777910739183426, L1 : 0.03590425103902817
epoch 148, loss 0.14750927686691284, L1 : 0.03582204133272171
epoch 149, loss 0.14637112617492676, L1 : 0.044340286403894424
epoch 150, loss 0.1590231955051422, L1 : 0.04704568162560463
epoch 151, loss 0.15862420201301575, L1 : 0.03906445950269699
epoch 152

epoch 269, loss 0.11384491622447968, L1 : 0.03864597901701927
epoch 270, loss 0.1288173496723175, L1 : 0.03985544666647911
epoch 271, loss 0.12046679854393005, L1 : 0.04160219430923462
epoch 272, loss 0.10944846272468567, L1 : 0.04568585008382797
epoch 273, loss 0.11156962811946869, L1 : 0.03528079017996788
epoch 274, loss 0.11534976959228516, L1 : 0.0402890220284462
epoch 275, loss 0.12725789844989777, L1 : 0.03655015677213669
epoch 276, loss 0.11237985640764236, L1 : 0.04106797277927399
epoch 277, loss 0.1303921639919281, L1 : 0.04163757711648941
epoch 278, loss 0.11814787983894348, L1 : 0.038603633642196655
epoch 279, loss 0.12453020364046097, L1 : 0.038844529539346695
epoch 280, loss 0.1360427290201187, L1 : 0.04300164803862572
epoch 281, loss 0.12813782691955566, L1 : 0.040048740804195404
epoch 282, loss 0.10921104997396469, L1 : 0.038956545293331146
epoch 283, loss 0.1033731997013092, L1 : 0.03553619235754013
epoch 284, loss 0.12595394253730774, L1 : 0.042982738465070724
epoch 28

epoch 402, loss 0.10834585130214691, L1 : 0.04644748196005821
epoch 403, loss 0.11464323848485947, L1 : 0.040640540421009064
epoch 404, loss 0.10928664356470108, L1 : 0.0443536639213562
epoch 405, loss 0.1017591804265976, L1 : 0.03957458212971687
epoch 406, loss 0.11796721816062927, L1 : 0.0433506965637207
epoch 407, loss 0.11394514888525009, L1 : 0.04177413508296013
epoch 408, loss 0.10508839786052704, L1 : 0.04010109230875969
epoch 409, loss 0.10301774740219116, L1 : 0.03818459436297417
epoch 410, loss 0.11977110058069229, L1 : 0.043117210268974304
epoch 411, loss 0.10789680480957031, L1 : 0.045406319200992584
epoch 412, loss 0.1142573356628418, L1 : 0.03822159394621849
epoch 413, loss 0.11379235237836838, L1 : 0.04430948942899704
epoch 414, loss 0.1049923524260521, L1 : 0.034583680331707
epoch 415, loss 0.11682863533496857, L1 : 0.03969103842973709
epoch 416, loss 0.11799569427967072, L1 : 0.03781607002019882
epoch 417, loss 0.09964827448129654, L1 : 0.03807806968688965
epoch 418, l

epoch 534, loss 0.10548698157072067, L1 : 0.04035559296607971
epoch 535, loss 0.09253666549921036, L1 : 0.03782081976532936
epoch 536, loss 0.11286046355962753, L1 : 0.040523890405893326
epoch 537, loss 0.11768227815628052, L1 : 0.041816163808107376
epoch 538, loss 0.11071813851594925, L1 : 0.03852265328168869
epoch 539, loss 0.10973313450813293, L1 : 0.03978806361556053
epoch 540, loss 0.11818664520978928, L1 : 0.0432608425617218
epoch 541, loss 0.11002319306135178, L1 : 0.03809985890984535
epoch 542, loss 0.09968115389347076, L1 : 0.04086151346564293
epoch 543, loss 0.11655101180076599, L1 : 0.0382470041513443
epoch 544, loss 0.1180715560913086, L1 : 0.04143537953495979
epoch 545, loss 0.10814959555864334, L1 : 0.036543577909469604
epoch 546, loss 0.11082404851913452, L1 : 0.03975294530391693
epoch 547, loss 0.1092342957854271, L1 : 0.03993705287575722
epoch 548, loss 0.10782959312200546, L1 : 0.044530902057886124
epoch 549, loss 0.10590334981679916, L1 : 0.043766748160123825
epoch 5

epoch 666, loss 0.10217152535915375, L1 : 0.03666767477989197
epoch 667, loss 0.1118955984711647, L1 : 0.043215952813625336
epoch 668, loss 0.12316031754016876, L1 : 0.045393820852041245
epoch 669, loss 0.1080416589975357, L1 : 0.03777685761451721
epoch 670, loss 0.1157013550400734, L1 : 0.03969554603099823
epoch 671, loss 0.11800475418567657, L1 : 0.0438385009765625
epoch 672, loss 0.12274859845638275, L1 : 0.04138394445180893
epoch 673, loss 0.09798066318035126, L1 : 0.033840153366327286
epoch 674, loss 0.11543938517570496, L1 : 0.042530357837677
epoch 675, loss 0.1080549955368042, L1 : 0.04009093716740608
epoch 676, loss 0.11467138677835464, L1 : 0.04219095781445503
epoch 677, loss 0.10601630806922913, L1 : 0.0393136590719223
epoch 678, loss 0.10633032023906708, L1 : 0.03941129520535469
epoch 679, loss 0.11526603996753693, L1 : 0.04270939901471138
epoch 680, loss 0.1085158959031105, L1 : 0.040919069200754166
epoch 681, loss 0.11737793684005737, L1 : 0.04587242007255554
epoch 682, lo

epoch 799, loss 0.12591156363487244, L1 : 0.04279294237494469
epoch 800, loss 0.09893915057182312, L1 : 0.03258733078837395
epoch 801, loss 0.10843025147914886, L1 : 0.04294093698263168
epoch 802, loss 0.11176396906375885, L1 : 0.04141777381300926
epoch 803, loss 0.09863518178462982, L1 : 0.035752907395362854
epoch 804, loss 0.11901362240314484, L1 : 0.04101540893316269
epoch 805, loss 0.10946939885616302, L1 : 0.043755125254392624
epoch 806, loss 0.09729593247175217, L1 : 0.037458620965480804
epoch 807, loss 0.10401946306228638, L1 : 0.043266598135232925
epoch 808, loss 0.10122592747211456, L1 : 0.0405963733792305
epoch 809, loss 0.11365154385566711, L1 : 0.03979547694325447
epoch 810, loss 0.11340970546007156, L1 : 0.044300589710474014
epoch 811, loss 0.12084729969501495, L1 : 0.04850692301988602
epoch 812, loss 0.10661091655492783, L1 : 0.0432865247130394
epoch 813, loss 0.09974991530179977, L1 : 0.04099353030323982
epoch 814, loss 0.10790377110242844, L1 : 0.047545187175273895
epoc

epoch 931, loss 0.09901013970375061, L1 : 0.037492070347070694
epoch 932, loss 0.11689233779907227, L1 : 0.040932804346084595
epoch 933, loss 0.09781152009963989, L1 : 0.04131195321679115
epoch 934, loss 0.09440305083990097, L1 : 0.03638917952775955
epoch 935, loss 0.10433238744735718, L1 : 0.040520258247852325
epoch 936, loss 0.1041012853384018, L1 : 0.03728873282670975
epoch 937, loss 0.10688337683677673, L1 : 0.03808888420462608
epoch 938, loss 0.10627707093954086, L1 : 0.04027731344103813
epoch 939, loss 0.10453888028860092, L1 : 0.04154544696211815
epoch 940, loss 0.10983450710773468, L1 : 0.03843660280108452
epoch 941, loss 0.11065660417079926, L1 : 0.033649053424596786
epoch 942, loss 0.1070106104016304, L1 : 0.038988273590803146
epoch 943, loss 0.11717423796653748, L1 : 0.04182465746998787
epoch 944, loss 0.1026415005326271, L1 : 0.04622901603579521
epoch 945, loss 0.10141191631555557, L1 : 0.037308111786842346
epoch 946, loss 0.11301898956298828, L1 : 0.0383673831820488
epoch 

In [None]:
# Plot results
N = 41      # Points on plot grid

times_to_plot = [0*T, 0.33*T, 0.66*T, T]
tplot = np.linspace(T0, T, N)
xplot = np.linspace(S1, S2, N)

plt.figure(figsize=(8,7))
i = 1
for t in times_to_plot:
    solution_plot = analytical_solution(t, xplot)

    tt = t*np.ones_like(xplot.reshape(-1,1))
    
    tt_nn = torch.tensor(tt.reshape(-1,1), dtype=torch.float32).cuda()
    xplot_nn = torch.tensor(xplot.reshape(-1,1), dtype=torch.float32).cuda()
    nn_plot = model(tt_nn,xplot_nn).cpu()

    plt.subplot(2,2,i)
    plt.plot(xplot, solution_plot, 'b')
    plt.plot(xplot, nn_plot.data.numpy(), 'r')

    plt.ylim(-1.1, -0.2)
    plt.xlabel("S")
    plt.ylabel("V")
    plt.title("t = %.2f"%t, loc="left")
    i = i+1

plt.show()

In [None]:
U = model(S1)


In [None]:
        t1, x1, t2, x2, t3, x3, t4, x4 = sampler_space_time(NS_1, NS_2, NS_3) 
        alpha1, alpha2, alpha3, nu1, nu2, nu3, a_1, a_2, a_3, b_1, b_2, b_3 = sampler_parameters(NS_1, NS_2, NS_3)
        
        S1 = torch.cat((x1,t1,nu1, alpha1, a_1, b_1), 1)
        S2 = torch.cat((x2,t2,nu2, alpha2, a_2, b_2), 1)
        S3 = torch.cat((x3,t3,nu2, alpha2, a_2, b_2), 1)
        S4 = torch.cat((x4,t4,nu3, alpha3, a_3, b_3), 1)

In [None]:
    CI = g(S4[:,0], S4[:,4], S4[:,5])
    L4 = torch.mean(torch.pow((model(S4)[:,0] - CI) ,2))

In [None]:
(model(S4)[:,0] - CI)

In [None]:
U_t[:,1]

In [None]:
    # Loss term #1: PDE
    U = model(S1)
    DU = torch.autograd.grad(U.sum(), S1, create_graph=True, retain_graph=True)[0]
    U_t = DU[:,1]
    U_x = DU[:,0]
    U_xx = torch.autograd.grad(U_x.sum(), S1, create_graph = True, retain_graph=True)[0][:,0]
    
    f = U_t - S1[:,2]*U_xx + S1[:,3] * U * U_x

In [None]:
U_xx