In [1]:
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.optim as optim

from matplotlib import pyplot as plt

from numpy import exp,arange
from pylab import meshgrid,cm,imshow,contour,clabel,colorbar,axis,title,show

In [2]:
# ==========
# helping methods
# ==========

# generate a list from lower and upper bound
def gen_list(p0, pn, delta, dig=5):
    ret = []
    i = p0
    while i < pn:
        ret.append(float(i))
        i += delta
        i = round(i, dig)
    return ret

# padding and zero padding
def padding(origin, a_list, b_list):
    return np.hstack((a_list, origin, b_list))

def zero_padding(origin, num):
    zero_list = [0 for i in range(num)]
    return padding(origin, zero_list, zero_list)

# trainning pairs
def gen_pair(u, x, t, length=3, num=1000):
    pairs = []
    for i in range(num):
        r = random.randint(0, t-2)
        current_t = u[r]
        next_t = u[r+1]
        p = random.randint(length, x-1-length)
        train = current_t[p-length:p+length+1]
        solu = next_t[p]
        pair = {'input': train, 'solu': solu}
        pairs.append(pair)
    return pairs


In [3]:
# the analytical representation of exact solution
def heat_equ_analytical_solu(x, t):
    return np.sin(np.pi * x) * np.exp(-np.power(np.pi, 2) * t)

In [4]:
# restnet
class ResNet(nn.Module):
    def __init__(self, i, h, o):
        super(ResNet, self).__init__()
        self.linear1 = nn.Linear(i, h)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(h, o)
        self.relu2 = nn.ReLU()
        
    def forward(self, x):
        out = self.linear1(x)
        out = self.relu1(out)
        out = self.linear2(out)
        out = self.relu2(out)
        return out + torch.mean(x)

    def load_model(self, save_path):
        self.load_state_dict(torch.load(save_path))

    def save_model(self, save_path):
        torch.save(self.state_dict(), save_path)

In [5]:
def calc_next_time(current, model, length=3):
    ret = []
    p = np.hstack(([0 for i in range(length)], current, [0 for i in range(length)])).tolist()
    for index in range(len(current)):
        i = index + 3
        seg = p[i-length:i+length+1]
        tensor_seg = torch.FloatTensor(seg)
        out = model(tensor_seg)
        ret.append(out.item())
    return ret

## $\Delta x=\frac{1}{20}$

In [6]:
x_1 = arange(0, 2 * np.pi, 1/20)
t_1 = arange(0, 2 * np.pi, 1/20)
X_1,T_1 = meshgrid(x_1, t_1) # grid of point
Z_1 = heat_equ_analytical_solu(X_1, T_1) # evaluation of the function on the grid

In [7]:
padding_1 = []
for z in Z_1:
    p = np.hstack(([0,0,0], z, [0,0,0]))
    padding_1.append(p.tolist())

In [8]:
model_1 = ResNet(7, 6, 1)
optimizer_1 = optim.Adam(model_1.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
criterion_1 = nn.MSELoss()
model_1.train()

ResNet(
  (linear1): Linear(in_features=7, out_features=6, bias=True)
  (relu1): ReLU()
  (linear2): Linear(in_features=6, out_features=1, bias=True)
  (relu2): ReLU()
)

In [9]:
model_1.load_model("model delta x=1 20")

In [10]:
prediction_1 = []
prediction_1.append(torch.FloatTensor(Z_1[0]).tolist())
for i in range(len(Z_1)-1):
    prediction_1.append(calc_next_time(Z_1[i], model_1))
    print("time", i, "is done", '", an example is:"', prediction_1[i+1][0])

time 0 is done ", an example is:" 0.1313488483428955
time 1 is done ", an example is:" 0.08018821477890015
time 2 is done ", an example is:" 0.04895474761724472
time 3 is done ", an example is:" 0.029886778444051743
time 4 is done ", an example is:" 0.018245818093419075
time 5 is done ", an example is:" 0.01113903522491455
time 6 is done ", an example is:" 0.006800359580665827
time 7 is done ", an example is:" 0.004151606000959873
time 8 is done ", an example is:" 0.0025345473550260067
time 9 is done ", an example is:" 0.0015473359962925315
time 10 is done ", an example is:" 0.0009446456097066402
time 11 is done ", an example is:" 0.0005767042748630047
time 12 is done ", an example is:" 0.00035207680775783956
time 13 is done ", an example is:" 0.00021494222164619714
time 14 is done ", an example is:" 0.0001312217937083915
time 15 is done ", an example is:" 8.011063619051129e-05
time 16 is done ", an example is:" 4.8907390009844676e-05
time 17 is done ", an example is:" 2.98578652291325

In [12]:
diff_1 = (prediction_1-Z_1).tolist()

In [16]:
import csv

with open("prediction_1.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(prediction_1)

In [17]:
import csv

with open("diff_1.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(diff_1)

## $\Delta x=\frac{1}{40}$

In [18]:
x_2 = arange(0, 2 * np.pi, 1/40)
t_2 = arange(0, 2 * np.pi, 1/40)
X_2,T_2 = meshgrid(x_2, t_2) # grid of point
Z_2 = heat_equ_analytical_solu(X_2, T_2) # evaluation of the function on the grid

In [19]:
padding_2 = []
for z in Z_2:
    p = np.hstack(([0,0,0], z, [0,0,0]))
    padding_2.append(p.tolist())

In [20]:
model_2 = ResNet(7, 6, 1)
optimizer_2 = optim.Adam(model_2.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
criterion_2 = nn.MSELoss()
model_2.train()

ResNet(
  (linear1): Linear(in_features=7, out_features=6, bias=True)
  (relu1): ReLU()
  (linear2): Linear(in_features=6, out_features=1, bias=True)
  (relu2): ReLU()
)

In [21]:
model_2.load_model("model delta x=1 40")

In [22]:
prediction_2 = []
prediction_2.append(torch.FloatTensor(Z_2[0]).tolist())
for i in range(len(Z_2)-1):
    prediction_2.append(calc_next_time(Z_2[i], model_2))
    print("time", i, "is done", '", an example is:"', prediction_2[i+1][0])

time 0 is done ", an example is:" 0.06690555810928345
time 1 is done ", an example is:" 0.05227624252438545
time 2 is done ", an example is:" 0.04084571450948715
time 3 is done ", an example is:" 0.03191453963518143
time 4 is done ", an example is:" 0.02493622712790966
time 5 is done ", an example is:" 0.019483763724565506
time 6 is done ", an example is:" 0.015223517082631588
time 7 is done ", an example is:" 0.011894799768924713
time 8 is done ", an example is:" 0.009293926879763603
time 9 is done ", an example is:" 0.007261752150952816
time 10 is done ", an example is:" 0.005673923995345831
time 11 is done ", an example is:" 0.004433284979313612
time 12 is done ", an example is:" 0.0034639197401702404
time 13 is done ", an example is:" 0.0027065116446465254
time 14 is done ", an example is:" 0.0021147162187844515
time 15 is done ", an example is:" 0.0016523200320079923
time 16 is done ", an example is:" 0.0012910299701616168
time 17 is done ", an example is:" 0.001008738181553781
ti

time 144 is done ", an example is:" 2.4815850164870583e-17
time 145 is done ", an example is:" 1.9389710782645232e-17
time 146 is done ", an example is:" 1.5150027985867966e-17
time 147 is done ", an example is:" 1.1837379736468678e-17
time 148 is done ", an example is:" 9.249062943799413e-18
time 149 is done ", an example is:" 7.226697335974641e-18
time 150 is done ", an example is:" 5.646534211814592e-18
time 151 is done ", an example is:" 4.411884429517943e-18
time 152 is done ", an example is:" 3.4471980682226223e-18
time 153 is done ", an example is:" 2.693446692236233e-18
time 154 is done ", an example is:" 2.1045077445785846e-18
time 155 is done ", an example is:" 1.644343861092885e-18
time 156 is done ", an example is:" 1.284797710330253e-18
time 157 is done ", an example is:" 1.003868666141039e-18
time 158 is done ", an example is:" 7.843664678303189e-19
time 159 is done ", an example is:" 6.128598219005516e-19
time 160 is done ", an example is:" 4.788541845227e-19
time 161 is

In [23]:
diff_2 = (prediction_2-Z_2).tolist()

In [24]:
import csv

with open("prediction_2.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(prediction_2)

In [25]:
import csv

with open("diff_2.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(diff_2)

## $\Delta x=\frac{1}{80}$

In [26]:
x_3 = arange(0, 2 * np.pi, 1/80)
t_3 = arange(0, 2 * np.pi, 1/80)
X_3,T_3 = meshgrid(x_3, t_3) # grid of point
Z_3 = heat_equ_analytical_solu(X_3, T_3) # evaluation of the function on the grid

In [27]:
padding_3 = []
for z in Z_3:
    p = np.hstack(([0,0,0], z, [0,0,0]))
    padding_3.append(p.tolist())

In [28]:
model_3 = ResNet(7, 6, 1)
optimizer_3 = optim.Adam(model_3.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
criterion_3 = nn.MSELoss()
model_3.train()

ResNet(
  (linear1): Linear(in_features=7, out_features=6, bias=True)
  (relu1): ReLU()
  (linear2): Linear(in_features=6, out_features=1, bias=True)
  (relu2): ReLU()
)

In [29]:
model_3.load_model("model delta x=1 80")

In [31]:
prediction_3 = []
prediction_3.append(torch.FloatTensor(Z_3[0]).tolist())
for i in range(len(Z_3)-1):
    prediction_3.append(calc_next_time(Z_3[i], model_3))
    print("time", i, "is done", '", an example is:"', prediction_3[i+1][0])

time 0 is done ", an example is:" 0.03360804542899132
time 1 is done ", an example is:" 0.029707375913858414
time 2 is done ", an example is:" 0.02625943534076214
time 3 is done ", an example is:" 0.02321167290210724
time 4 is done ", an example is:" 0.02051764354109764
time 5 is done ", an example is:" 0.018136294558644295
time 6 is done ", an example is:" 0.016031334176659584
time 7 is done ", an example is:" 0.01417068112641573
time 8 is done ", an example is:" 0.012525982223451138
time 9 is done ", an example is:" 0.011072171851992607
time 10 is done ", an example is:" 0.009787097573280334
time 11 is done ", an example is:" 0.008651172742247581
time 12 is done ", an example is:" 0.00764708686619997
time 13 is done ", an example is:" 0.006759539246559143
time 14 is done ", an example is:" 0.005975003819912672
time 15 is done ", an example is:" 0.005281523335725069
time 16 is done ", an example is:" 0.004668531473726034
time 17 is done ", an example is:" 0.004126685671508312
time 18 

time 144 is done ", an example is:" 6.472569702431485e-10
time 145 is done ", an example is:" 5.721340623487947e-10
time 146 is done ", an example is:" 5.057301799560321e-10
time 147 is done ", an example is:" 4.470333270223392e-10
time 148 is done ", an example is:" 3.951491078790781e-10
time 149 is done ", an example is:" 3.4928668291023257e-10
time 150 is done ", an example is:" 3.08747249988528e-10
time 151 is done ", an example is:" 2.7291299775633604e-10
time 152 is done ", an example is:" 2.412377242411168e-10
time 153 is done ", an example is:" 2.1323885712742907e-10
time 154 is done ", an example is:" 1.884895989290314e-10
time 155 is done ", an example is:" 1.6661283464003418e-10
time 156 is done ", an example is:" 1.4727517816393032e-10
time 157 is done ", an example is:" 1.3018189304325745e-10
time 158 is done ", an example is:" 1.150725209564385e-10
time 159 is done ", an example is:" 1.0171680858706011e-10
time 160 is done ", an example is:" 8.991119654355728e-11
time 161

time 286 is done ", an example is:" 1.5953944970873552e-17
time 287 is done ", an example is:" 1.4102274699893998e-17
time 288 is done ", an example is:" 1.2465513396681233e-17
time 289 is done ", an example is:" 1.101872322012801e-17
time 290 is done ", an example is:" 9.739852362284276e-18
time 291 is done ", an example is:" 8.609409966376444e-18
time 292 is done ", an example is:" 7.610172477689937e-18
time 293 is done ", an example is:" 6.7269086060155075e-18
time 294 is done ", an example is:" 5.946160122786919e-18
time 295 is done ", an example is:" 5.256028034892596e-18
time 296 is done ", an example is:" 4.645994740843926e-18
time 297 is done ", an example is:" 4.106764384917039e-18
time 298 is done ", an example is:" 3.6301189277262206e-18
time 299 is done ", an example is:" 3.208794689517489e-18
time 300 is done ", an example is:" 2.8363706807859017e-18
time 301 is done ", an example is:" 2.5071716153487323e-18
time 302 is done ", an example is:" 2.2161804359956935e-18
time 3

time 427 is done ", an example is:" 4.448753957900279e-25
time 428 is done ", an example is:" 3.9324158992253126e-25
time 429 is done ", an example is:" 3.4760056313677233e-25
time 430 is done ", an example is:" 3.0725686058908894e-25
time 431 is done ", an example is:" 2.71595540551958e-25
time 432 is done ", an example is:" 2.4007320112119873e-25
time 433 is done ", an example is:" 2.1220947530933805e-25
time 434 is done ", an example is:" 1.8757969710438255e-25
time 435 is done ", an example is:" 1.658085289528184e-25
time 436 is done ", an example is:" 1.4656421786614524e-25
time 437 is done ", an example is:" 1.2955346782499223e-25
time 438 is done ", an example is:" 1.145170394143811e-25
time 439 is done ", an example is:" 1.0122578703027262e-25
time 440 is done ", an example is:" 8.947717063525129e-26
time 441 is done ", an example is:" 7.909213112978359e-26
time 442 is done ", an example is:" 6.991242178368703e-26
time 443 is done ", an example is:" 6.179813688282548e-26
time 4

In [32]:
diff_3 = (prediction_3-Z_3).tolist()

In [33]:
import csv

with open("prediction_3.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(prediction_3)

In [34]:
import csv

with open("diff_3.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(diff_3)

## $\Delta x=\frac{1}{160}$

In [35]:
x_4 = arange(0, 2 * np.pi, 1/160)
t_4 = arange(0, 2 * np.pi, 1/160)
X_4,T_4 = meshgrid(x_4, t_4) # grid of point
Z_4 = heat_equ_analytical_solu(X_4, T_4) # evaluation of the function on the grid

In [36]:
padding_4 = []
for z in Z_4:
    p = np.hstack(([0,0,0], z, [0,0,0]))
    padding_4.append(p.tolist())

In [37]:
model_4 = ResNet(7, 6, 1)
optimizer_4 = optim.Adam(model_4.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
criterion_4 = nn.MSELoss()
model_4.train()

ResNet(
  (linear1): Linear(in_features=7, out_features=6, bias=True)
  (relu1): ReLU()
  (linear2): Linear(in_features=6, out_features=1, bias=True)
  (relu2): ReLU()
)

In [38]:
model_4.load_model("model delta x=1 160")

In [None]:
prediction_4 = []
prediction_4.append(torch.FloatTensor(Z_4[0]).tolist())
for i in range(len(Z_4)-1):
    prediction_4.append(calc_next_time(Z_4[i], model_4))
    print("time", i, "is done", '", an example is:"', prediction_4[i+1][0])

time 0 is done ", an example is:" 0.01682347245514393
time 1 is done ", an example is:" 0.015817075967788696
time 2 is done ", an example is:" 0.014870882034301758
time 3 is done ", an example is:" 0.013981291092932224
time 4 is done ", an example is:" 0.01314491592347622
time 5 is done ", an example is:" 0.012358573265373707
time 6 is done ", an example is:" 0.011619269847869873
time 7 is done ", an example is:" 0.010924193076789379
time 8 is done ", an example is:" 0.010270697996020317
time 9 is done ", an example is:" 0.00965629331767559
time 10 is done ", an example is:" 0.009078643284738064
time 11 is done ", an example is:" 0.008535549975931644
time 12 is done ", an example is:" 0.008024944923818111
time 13 is done ", an example is:" 0.007544883992522955
time 14 is done ", an example is:" 0.007093541324138641
time 15 is done ", an example is:" 0.006669198628515005
time 16 is done ", an example is:" 0.006270240526646376
time 17 is done ", an example is:" 0.0058951484970748425
time

time 144 is done ", an example is:" 2.3347072328760987e-06
time 145 is done ", an example is:" 2.1950424979877425e-06
time 146 is done ", an example is:" 2.0637328361772234e-06
time 147 is done ", an example is:" 1.9402782527322415e-06
time 148 is done ", an example is:" 1.8242087662656559e-06
time 149 is done ", an example is:" 1.715082703412918e-06
time 150 is done ", an example is:" 1.6124847661558306e-06
time 151 is done ", an example is:" 1.5160242128331447e-06
time 152 is done ", an example is:" 1.4253340623326949e-06
time 153 is done ", an example is:" 1.340069275101996e-06
time 154 is done ", an example is:" 1.259904820472002e-06
time 155 is done ", an example is:" 1.184536017717619e-06
time 156 is done ", an example is:" 1.113675693886762e-06
time 157 is done ", an example is:" 1.0470546385477064e-06
time 158 is done ", an example is:" 9.844187616181443e-07
time 159 is done ", an example is:" 9.255298323296302e-07
time 160 is done ", an example is:" 8.70163603394758e-07
time 1

time 285 is done ", an example is:" 3.898679434843899e-10
time 286 is done ", an example is:" 3.6654565493954294e-10
time 287 is done ", an example is:" 3.446185281585912e-10
time 288 is done ", an example is:" 3.240030743700828e-10
time 289 is done ", an example is:" 3.046209118284793e-10
time 290 is done ", an example is:" 2.863981551914918e-10
time 291 is done ", an example is:" 2.6926549878680817e-10
time 292 is done ", an example is:" 2.5315777252288285e-10
time 293 is done ", an example is:" 2.380136088220297e-10
time 294 is done ", an example is:" 2.237754148648463e-10
time 295 is done ", an example is:" 2.103889146232163e-10
time 296 is done ", an example is:" 1.978032321270362e-10
time 297 is done ", an example is:" 1.8597044737500568e-10
time 298 is done ", an example is:" 1.7484549919011272e-10
time 299 is done ", an example is:" 1.643860603195435e-10
time 300 is done ", an example is:" 1.5455232926786522e-10
time 301 is done ", an example is:" 1.4530684988578457e-10
time 30

time 427 is done ", an example is:" 6.120870818656851e-14
time 428 is done ", an example is:" 5.754714060965047e-14
time 429 is done ", an example is:" 5.4104608976628804e-14
time 430 is done ", an example is:" 5.0868021546270764e-14
time 431 is done ", an example is:" 4.7825042130732534e-14
time 432 is done ", an example is:" 4.496409687182283e-14
time 433 is done ", an example is:" 4.2274296313971754e-14
time 434 is done ", an example is:" 3.974540829917647e-14
time 435 is done ", an example is:" 3.736779359249723e-14
time 436 is done ", an example is:" 3.5132416046452716e-14
time 437 is done ", an example is:" 3.3030761285857135e-14
time 438 is done ", an example is:" 3.1054826543424827e-14
time 439 is done ", an example is:" 2.919709355471596e-14
time 440 is done ", an example is:" 2.7450492982752754e-14
time 441 is done ", an example is:" 2.580837561889926e-14
time 442 is done ", an example is:" 2.426449374813653e-14
time 443 is done ", an example is:" 2.281296557367883e-14
time 4

time 569 is done ", an example is:" 9.609679776467418e-18
time 570 is done ", an example is:" 9.034818202064503e-18
time 571 is done ", an example is:" 8.49434666162848e-18
time 572 is done ", an example is:" 7.986205475434093e-18
time 573 is done ", an example is:" 7.508463176751029e-18
time 574 is done ", an example is:" 7.05929914105106e-18
time 575 is done ", an example is:" 6.637003999598344e-18
time 576 is done ", an example is:" 6.2399717812336104e-18
time 577 is done ", an example is:" 5.866690399797112e-18
time 578 is done ", an example is:" 5.515738758996482e-18
time 579 is done ", an example is:" 5.185781375732755e-18
time 580 is done ", an example is:" 4.875563417016687e-18
time 581 is done ", an example is:" 4.583901187391714e-18
time 582 is done ", an example is:" 4.309687505607935e-18
time 583 is done ", an example is:" 4.051878056142002e-18
time 584 is done ", an example is:" 3.809490562016507e-18
time 585 is done ", an example is:" 3.581602716848454e-18
time 586 is don

time 710 is done ", an example is:" 1.6047007019392884e-21
time 711 is done ", an example is:" 1.5087058465918519e-21
time 712 is done ", an example is:" 1.418453393763445e-21
time 713 is done ", an example is:" 1.3336001321623113e-21
time 714 is done ", an example is:" 1.2538227424132807e-21
time 715 is done ", an example is:" 1.178817797057769e-21
time 716 is done ", an example is:" 1.1082997410698606e-21
time 717 is done ", an example is:" 1.0420001850369376e-21
time 718 is done ", an example is:" 9.796666934693292e-22
time 719 is done ", an example is:" 9.210618760325492e-22
time 720 is done ", an example is:" 8.65963084624708e-22
time 721 is done ", an example is:" 8.141602424313014e-22
time 722 is done ", an example is:" 7.654562983090927e-22
time 723 is done ", an example is:" 7.1966601509576145e-22
time 724 is done ", an example is:" 6.766148084066513e-22
time 725 is done ", an example is:" 6.361390495573578e-22
time 726 is done ", an example is:" 5.980846014378879e-22
time 727

In [None]:
diff_4 = (prediction_4-Z_4).tolist()

In [None]:
import csv

with open("prediction_4.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(prediction_4)

In [None]:
import csv

with open("diff_4.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(diff_4)