In [7]:
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.optim as optim

from numpy import exp,arange
from pylab import meshgrid,cm,imshow,contour,clabel,colorbar,axis,title,show

In [8]:
# the analytical representation of exact solution
def heat_equ_analytical_solu(x, t):
    return np.sin(np.pi * x) * np.exp(-np.power(np.pi, 2) * t)

In [9]:
x = arange(-5.0,5.0,0.01)
t = arange(-0.1,0.1,0.001)
X,T = meshgrid(x, t) # grid of point
Z = heat_equ_analytical_solu(X, T) # evaluation of the function on the grid

In [10]:
print(Z)

[[-1.64290454e-15 -8.42771624e-02 -1.68471153e-01 ...  2.52498884e-01
   1.68471153e-01  8.42771624e-02]
 [-1.62676947e-15 -8.34494714e-02 -1.66816588e-01 ...  2.50019077e-01
   1.66816588e-01  8.34494714e-02]
 [-1.61079287e-15 -8.26299092e-02 -1.65178273e-01 ...  2.47563625e-01
   1.65178273e-01  8.26299092e-02]
 ...
 [-2.35076019e-16 -1.20588503e-02 -2.41057999e-02 ...  3.61289601e-02
   2.41057999e-02  1.20588503e-02]
 [-2.32767323e-16 -1.19404196e-02 -2.38690554e-02 ...  3.57741354e-02
   2.38690554e-02  1.19404196e-02]
 [-2.30481302e-16 -1.18231520e-02 -2.36346360e-02 ...  3.54227955e-02
   2.36346360e-02  1.18231520e-02]]


In [11]:
print(Z[0])

[-1.64290454e-15 -8.42771624e-02 -1.68471153e-01 -2.52498884e-01
 -3.36277428e-01 -4.19724108e-01 -5.02756570e-01 -5.85292872e-01
 -6.67251561e-01 -7.48551753e-01 -8.29113215e-01 -9.08856442e-01
 -9.87702737e-01 -1.06557429e+00 -1.14239425e+00 -1.21808680e+00
 -1.29257725e+00 -1.36579208e+00 -1.43765904e+00 -1.50810721e+00
 -1.57706705e+00 -1.64447052e+00 -1.71025110e+00 -1.77434386e+00
 -1.83668556e+00 -1.89721467e+00 -1.95587147e+00 -2.01259805e+00
 -2.06733844e+00 -2.12003862e+00 -2.17064658e+00 -2.21911237e+00
 -2.26538816e+00 -2.30942829e+00 -2.35118929e+00 -2.39062996e+00
 -2.42771135e+00 -2.46239690e+00 -2.49465235e+00 -2.52444588e+00
 -2.55174809e+00 -2.57653203e+00 -2.59877325e+00 -2.61844979e+00
 -2.63554224e+00 -2.65003372e+00 -2.66190994e+00 -2.67115918e+00
 -2.67777230e+00 -2.68174279e+00 -2.68306672e+00 -2.68174279e+00
 -2.67777230e+00 -2.67115918e+00 -2.66190994e+00 -2.65003372e+00
 -2.63554224e+00 -2.61844979e+00 -2.59877325e+00 -2.57653203e+00
 -2.55174809e+00 -2.52444

In [63]:
print(len(Z[0]))

1000


In [12]:
type(Z[0])

numpy.ndarray

In [20]:
type(np.hstack(([0,0,0], Z[0], [0,0,0])))

numpy.ndarray

In [21]:
len(Z)

200

In [None]:
# trainning pairs
def gen_pair(u, x, t, length=3, num=1000):
    pairs = []
    for i in range(num):
        r = random.randint(0, t-2)
        current_t = u[r]
        next_t = u[r+1]
        p = random.randint(length, x-1-length)
        train = current_t[p-length:p+length+1]
        solu = next_t[p]
        pair = {'input': train, 'solu': solu}
        pairs.append(pair)
    return pairs

In [75]:
padding = []
for z in Z:
    p = np.hstack(([0,0,0], z, [0,0,0]))
    padding.append(p.tolist())

In [24]:
print(padding[0])

[0.0, 0.0, 0.0, -1.6429045373914556e-15, -0.08427716244111669, -0.16847115349780273, -0.25249888386576, -0.3362774283199015, -0.41972410755155726, -0.5027565697629605, -0.5852928719385486, -0.6672515607127982, -0.7485517527548815, -0.8291132145907395, -0.9088564417838525, -0.987702737396493, -1.0655742896541152, -1.1423942487361962, -1.2180868026177039, -1.2925772518864296, -1.3657920834622823, -1.4376590431458385, -1.5081072069244894, -1.5770670509658917, -1.644470520229607, -1.7102510956291879, -1.7743438596784995, -1.8366855605574355, -1.8972146745338399, -1.9558714666799826, -2.0125980498237364, -2.067338441676218, -2.1200386200795633, -2.1706465763202547, -2.2191123664554526, -2.265388160601644, -2.309428290136949, -2.351189292770546, -2.390629955434699, -2.4277113549570926, -2.4623968964732876, -2.4946523495414414, -2.5244458819236284, -2.551748091000402, -2.5765320327876378, -2.5987732485269937, -2.6184497888237512, -2.635542235308214, -2.6500337197992976, -2.661909940951387, -2

In [30]:
print(padding[0][3:-3])

[-1.6429045373914556e-15, -0.08427716244111669, -0.16847115349780273, -0.25249888386576, -0.3362774283199015, -0.41972410755155726, -0.5027565697629605, -0.5852928719385486, -0.6672515607127982, -0.7485517527548815, -0.8291132145907395, -0.9088564417838525, -0.987702737396493, -1.0655742896541152, -1.1423942487361962, -1.2180868026177039, -1.2925772518864296, -1.3657920834622823, -1.4376590431458385, -1.5081072069244894, -1.5770670509658917, -1.644470520229607, -1.7102510956291879, -1.7743438596784995, -1.8366855605574355, -1.8972146745338399, -1.9558714666799826, -2.0125980498237364, -2.067338441676218, -2.1200386200795633, -2.1706465763202547, -2.2191123664554526, -2.265388160601644, -2.309428290136949, -2.351189292770546, -2.390629955434699, -2.4277113549570926, -2.4623968964732876, -2.4946523495414414, -2.5244458819236284, -2.551748091000402, -2.5765320327876378, -2.5987732485269937, -2.6184497888237512, -2.635542235308214, -2.6500337197992976, -2.661909940951387, -2.67115917836804

In [55]:
len(padding[0][3:-3])

1000

In [25]:
# restnet
class ResNet(nn.Module):
    def __init__(self, i, a, b, c, d, o):
        super(ResNet, self).__init__()
        self.linear1 = nn.Linear(i, a)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(a, b)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(b, c)
        self.relu3 = nn.ReLU()
        self.linear4 = nn.Linear(c, d)
        self.relu4 = nn.ReLU()
        self.linear5 = nn.Linear(d, o)
        self.relu5 = nn.ReLU()
        
    def forward(self, x):
        out = self.linear1(x)
        out = self.relu1(out)
        out = self.linear2(out)
        out = self.relu2(out)
        out = self.linear3(out)
        out = self.relu3(out)
        out = self.linear4(out)
        out = self.relu4(out)
        out = self.linear5(out)
        out = self.relu5(out)
        return out + torch.mean(x)

    def load_model(self, save_path):
        self.load_state_dict(torch.load(save_path))

    def save_model(self, save_path):
        torch.save(self.state_dict(), save_path)

In [27]:
model = ResNet(7, 6, 6, 6, 6, 1)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
criterion = nn.MSELoss()
model.train()

ResNet(
  (linear1): Linear(in_features=7, out_features=6, bias=True)
  (relu1): ReLU()
  (linear2): Linear(in_features=6, out_features=6, bias=True)
  (relu2): ReLU()
  (linear3): Linear(in_features=6, out_features=6, bias=True)
  (relu3): ReLU()
  (linear4): Linear(in_features=6, out_features=6, bias=True)
  (relu4): ReLU()
  (linear5): Linear(in_features=6, out_features=1, bias=True)
  (relu5): ReLU()
)

In [28]:
model.load_model("ResNet + PDE model 1")

In [127]:
def calc_next_time(current, length=3):
    ret = []
    p = np.hstack(([0 for i in range(length)], current, [0 for i in range(length)])).tolist()
    for index in range(len(current)):
        i = index + 3
        seg = p[i-length:i+length+1]
        tensor_seg = torch.FloatTensor(seg)
        out = model(tensor_seg)
        ret.append(out.item())
    return ret

In [132]:
prediction = []
prediction.append(torch.FloatTensor(Z[0]).tolist())
for i in range(len(Z)-1):
    prediction.append(calc_next_time(Z[i]))
    print("time", i, "is done", '", an example is:"', prediction[i+1][0])

time 0 is done ", an example is:" -0.07185225933790207
time 1 is done ", an example is:" -0.07115238904953003
time 2 is done ", an example is:" -0.07045936584472656
time 3 is done ", an example is:" -0.06977317482233047
time 4 is done ", an example is:" -0.069093719124794
time 5 is done ", an example is:" -0.06842093914747238
time 6 is done ", an example is:" -0.06775476038455963
time 7 is done ", an example is:" -0.06709513068199158
time 8 is done ", an example is:" -0.06644195318222046
time 9 is done ", an example is:" -0.06579523533582687
time 10 is done ", an example is:" -0.06515484303236008
time 11 is done ", an example is:" -0.06452073901891708
time 12 is done ", an example is:" -0.06389286369085312
time 13 is done ", an example is:" -0.0632711723446846
time 14 is done ", an example is:" -0.06265556067228317
time 15 is done ", an example is:" -0.06204601377248764
time 16 is done ", an example is:" -0.06144243851304054
time 17 is done ", an example is:" -0.06084481254220009
time 

time 145 is done ", an example is:" -0.017254019156098366
time 146 is done ", an example is:" -0.01708456501364708
time 147 is done ", an example is:" -0.016916776075959206
time 148 is done ", an example is:" -0.0167506355792284
time 149 is done ", an example is:" -0.016586126759648323
time 150 is done ", an example is:" -0.016423232853412628
time 151 is done ", an example is:" -0.016261938959360123
time 152 is done ", an example is:" -0.016102230176329613
time 153 is done ", an example is:" -0.015944089740514755
time 154 is done ", an example is:" -0.015787500888109207
time 155 is done ", an example is:" -0.015632452443242073
time 156 is done ", an example is:" -0.015478923916816711
time 157 is done ", an example is:" -0.015326904132962227
time 158 is done ", an example is:" -0.01517637912184
time 159 is done ", an example is:" -0.015027331188321114
time 160 is done ", an example is:" -0.014879746362566948
time 161 is done ", an example is:" -0.014733610674738884
time 162 is done ", a

In [133]:
print(prediction[0])

[-1.642904575098849e-15, -0.08427716046571732, -0.16847115755081177, -0.25249889492988586, -0.33627742528915405, -0.41972410678863525, -0.5027565956115723, -0.5852928757667542, -0.6672515869140625, -0.7485517263412476, -0.8291131854057312, -0.9088564515113831, -0.9877027273178101, -1.065574288368225, -1.1423943042755127, -1.218086838722229, -1.2925772666931152, -1.3657920360565186, -1.4376590251922607, -1.5081071853637695, -1.5770670175552368, -1.6444705724716187, -1.7102510929107666, -1.7743438482284546, -1.8366855382919312, -1.897214651107788, -1.9558714628219604, -2.0125980377197266, -2.067338466644287, -2.1200385093688965, -2.1706466674804688, -2.2191123962402344, -2.265388250350952, -2.3094282150268555, -2.351189374923706, -2.390630006790161, -2.427711248397827, -2.4623968601226807, -2.494652271270752, -2.5244457721710205, -2.551748037338257, -2.5765321254730225, -2.598773241043091, -2.6184496879577637, -2.635542154312134, -2.650033712387085, -2.661910057067871, -2.671159267425537

In [134]:
print(prediction[1])

[-0.07185225933790207, -0.11939699202775955, -0.17880359292030334, -0.249739870429039, -0.3324926793575287, -0.41544970870018005, -0.4979628026485443, -0.579688310623169, -0.6608420014381409, -0.7413437366485596, -0.8211139440536499, -0.9000741243362427, -0.9781461358070374, -1.055253028869629, -1.1313188076019287, -1.206268310546875, -1.2800275087356567, -1.352523684501648, -1.423685073852539, -1.4934418201446533, -1.5617249011993408, -1.6284668445587158, -1.6936020851135254, -1.7570661306381226, -1.8187962770462036, -1.878731608390808, -1.936813235282898, -1.9929834604263306, -2.047186851501465, -2.0993709564208984, -2.1494827270507812, -2.1974732875823975, -2.243295669555664, -2.2869043350219727, -2.328256607055664, -2.3673110008239746, -2.404029369354248, -2.438375234603882, -2.4703152179718018, -2.4998178482055664, -2.526853084564209, -2.5513947010040283, -2.5734190940856934, -2.5929033756256104, -2.609829902648926, -2.624180555343628, -2.635941505432129, -2.645101308822632, -2.65

In [135]:
print(prediction[2])

[-0.07115238904953003, -0.11823411285877228, -0.17705726623535156, -0.24729685485363007, -0.32922497391700745, -0.4113576412200928, -0.4930741488933563, -0.5739971399307251, -0.6543537378311157, -0.7340648770332336, -0.8130515813827515, -0.8912363052368164, -0.9685415029525757, -1.044891357421875, -1.12021005153656, -1.1944234371185303, -1.267458200454712, -1.3392422199249268, -1.4097050428390503, -1.4787765741348267, -1.5463892221450806, -1.612475872039795, -1.676971197128296, -1.739811897277832, -1.8009356260299683, -1.8602824211120605, -1.9177935123443604, -1.9734121561050415, -2.0270836353302, -2.078754425048828, -2.1283740997314453, -2.175893783569336, -2.221266269683838, -2.264446258544922, -2.3053925037384033, -2.3440632820129395, -2.3804211616516113, -2.4144296646118164, -2.4460558891296387, -2.4752683639526367, -2.5020384788513184, -2.526339530944824, -2.548147201538086, -2.5674407482147217, -2.584200382232666, -2.59840989112854, -2.610055446624756, -2.6191253662109375, -2.625

In [136]:
type(prediction)

list

In [137]:
import csv

with open("prediction.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(prediction)