In [1]:
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.optim as optim

from matplotlib import pyplot as plt

from numpy import exp,arange
from pylab import meshgrid,cm,imshow,contour,clabel,colorbar,axis,title,show

In [39]:
# ==========
# helping methods
# ==========

# generate a list from lower and upper bound
def gen_list(p0, pn, delta, dig=5):
    ret = []
    i = p0
    while i < pn:
        ret.append(float(i))
        i += delta
        i = round(i, dig)
    return ret

# padding and zero padding
def padding(origin, a_list, b_list):
    return np.hstack((a_list, origin, b_list))

def zero_padding(origin, num):
    zero_list = [0 for i in range(num)]
    return padding(origin, zero_list, zero_list)

# trainning pairs
def gen_pair(u, x, t, length=3, num=1000):
    pairs = []
    for i in range(num):
        r = random.randint(0, t-2)
        current_t = u[r]
        next_t = u[r+1]
        p = random.randint(length, x-1-length)
        train = current_t[p-length:p+length+1]
        solu = next_t[p]
        pair = {'input': train, 'solu': solu}
        pairs.append(pair)
    return pairs

In [3]:
# the analytical representation of exact solution
def heat_equ_analytical_solu(x, t):
    return np.sin(np.pi * x) * np.exp(-np.power(np.pi, 2) * t)

In [4]:
x = arange(-5.0,5.0,0.01)
t = arange(-0.1,0.1,0.001)
X,T = meshgrid(x, t) # grid of point
Z = heat_equ_analytical_solu(X, T) # evaluation of the function on the grid

In [5]:
print(Z)

[[-1.64290454e-15 -8.42771624e-02 -1.68471153e-01 ...  2.52498884e-01
   1.68471153e-01  8.42771624e-02]
 [-1.62676947e-15 -8.34494714e-02 -1.66816588e-01 ...  2.50019077e-01
   1.66816588e-01  8.34494714e-02]
 [-1.61079287e-15 -8.26299092e-02 -1.65178273e-01 ...  2.47563625e-01
   1.65178273e-01  8.26299092e-02]
 ...
 [-2.35076019e-16 -1.20588503e-02 -2.41057999e-02 ...  3.61289601e-02
   2.41057999e-02  1.20588503e-02]
 [-2.32767323e-16 -1.19404196e-02 -2.38690554e-02 ...  3.57741354e-02
   2.38690554e-02  1.19404196e-02]
 [-2.30481302e-16 -1.18231520e-02 -2.36346360e-02 ...  3.54227955e-02
   2.36346360e-02  1.18231520e-02]]


In [6]:
# preparing the trainning set
data_pairs = gen_pair(Z, len(Z[0]), len(Z), length=3, num=2000)
print(data_pairs)

[{'input': array([-0.72736182, -0.66929801, -0.61057368, -0.55124679, -0.49137589,
       -0.43102006, -0.37023886]), 'solu': -0.5458329642145223}, {'input': array([-0.36699635, -0.34017647, -0.31302087, -0.28555637, -0.25781005,
       -0.2298093 , -0.20158177]), 'solu': -0.2827518997350829}, {'input': array([-2.13024226, -2.15568366, -2.17899765, -2.20016125, -2.21915354,
       -2.2359558 , -2.25055144]), 'solu': -2.1785533303110944}, {'input': array([2.52295465, 2.50136231, 2.47730144, 2.45079576, 2.42187145,
       2.39055704, 2.35688344]), 'solu': 2.4267263507717534}, {'input': array([-0.85559288, -0.85177562, -0.84711776, -0.84162389, -0.83529945,
       -0.82815066, -0.82018459]), 'solu': -0.8333582547657723}, {'input': array([-0.01426176, -0.02850945, -0.04272901, -0.05690639, -0.07102762,
       -0.08507875, -0.09904592]), 'solu': -0.05634751286257549}, {'input': array([-0.09693641, -0.12909969, -0.16113556, -0.19301241, -0.22469878,
       -0.25616339, -0.28737521]), 'solu':

In [7]:
# restnet
class ResNet(nn.Module):
    def __init__(self, i, a, b, c, d, o):
        super(ResNet, self).__init__()
        self.linear1 = nn.Linear(i, a)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(a, b)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(b, c)
        self.relu3 = nn.ReLU()
        self.linear4 = nn.Linear(c, d)
        self.relu4 = nn.ReLU()
        self.linear5 = nn.Linear(d, o)
        self.relu5 = nn.ReLU()
        
    def forward(self, x):
        out = self.linear1(x)
        out = self.relu1(out)
        out = self.linear2(out)
        out = self.relu2(out)
        out = self.linear3(out)
        out = self.relu3(out)
        out = self.linear4(out)
        out = self.relu4(out)
        out = self.linear5(out)
        out = self.relu5(out)
        return out + torch.mean(x)

    def load_model(self, save_path):
        self.load_state_dict(torch.load(save_path))

    def save_model(self, save_path):
        torch.save(self.state_dict(), save_path)

In [8]:
model = ResNet(7, 6, 6, 6, 6, 1)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
criterion = nn.MSELoss()
model.train()

ResNet(
  (linear1): Linear(in_features=7, out_features=6, bias=True)
  (relu1): ReLU()
  (linear2): Linear(in_features=6, out_features=6, bias=True)
  (relu2): ReLU()
  (linear3): Linear(in_features=6, out_features=6, bias=True)
  (relu3): ReLU()
  (linear4): Linear(in_features=6, out_features=6, bias=True)
  (relu4): ReLU()
  (linear5): Linear(in_features=6, out_features=1, bias=True)
  (relu5): ReLU()
)

In [9]:
num_epochs = 4
list_of_loss = []
counter = 0
for epoch in range(num_epochs):
    for data in data_pairs:
        output = model(torch.FloatTensor(data["input"]))
        loss = criterion(output, torch.FloatTensor([data["solu"]]))
        list_of_loss.append(loss.item())
        model.zero_grad()
        loss.backward()
        optimizer.step()
        counter += 1
print(list_of_loss)

[0.04227515682578087, 0.042122501879930496, 0.03450505807995796, 0.04867125302553177, 0.03725232183933258, 0.038837384432554245, 0.03736836835741997, 0.04180014878511429, 0.04073458909988403, 0.033674456179142, 0.035494379699230194, 0.029817277565598488, 0.031448978930711746, 0.03379759564995766, 0.02761015109717846, 0.032654277980327606, 0.027483930811285973, 0.030031081289052963, 0.02819783426821232, 0.0244743749499321, 0.028086695820093155, 0.024027273058891296, 0.026283102110028267, 0.020436014980077744, 0.01959952898323536, 0.019319657236337662, 0.021920014172792435, 0.018743930384516716, 0.020334351807832718, 0.018080100417137146, 0.016916368156671524, 0.019207479432225227, 0.019914235919713974, 0.01566602662205696, 0.016796419396996498, 0.017631307244300842, 0.011849958449602127, 0.01426957082003355, 0.014216546900570393, 0.011330414563417435, 0.013449237681925297, 0.01199087779968977, 0.01115249004215002, 0.012428799644112587, 0.011370420455932617, 0.010556136257946491, 0.00992

In [10]:
# px = [i for i in range(len(list_of_loss))]
# print(len(list_of_loss))
# print(len(px))
# plt.plot(px, list_of_loss)

In [11]:
# preparing the test set
data_pairs = gen_pair(Z, len(Z[0]), len(Z), length=3, num=500)
print(data_pairs)

[{'input': array([0.97001613, 0.97943744, 0.98789217, 0.99537197, 1.00186946,
       1.00737822, 1.01189282]), 'solu': 0.985596362648952}, {'input': array([-0.41177806, -0.39690384, -0.38163791, -0.36599536, -0.34999161,
       -0.33364246, -0.31696405]), 'solu': -0.3624008966435004}, {'input': array([-0.09837294, -0.11783369, -0.13717816, -0.15638725, -0.175442  ,
       -0.19432361, -0.21301345]), 'solu': -0.1548513579665046}, {'input': array([0.47651407, 0.48062745, 0.48426651, 0.48742765, 0.49010776,
       0.4923042 , 0.49401479]), 'solu': 0.4826405939303831}, {'input': array([0.30598212, 0.33541115, 0.36450918, 0.39324748, 0.42159769,
       0.44953183, 0.47702234]), 'solu': 0.38938537054313793}, {'input': array([0.30482405, 0.31876246, 0.33238629, 0.3456821 , 0.35863676,
       0.37123749, 0.38347185]), 'solu': 0.34228713501679414}, {'input': array([-1.41996482, -1.45166953, -1.48194161, -1.5107512 , -1.53806986,
       -1.56387063, -1.58812805]), 'solu': -1.4959140232437436}, {

In [25]:
errs = []
for data in data_pairs:
    output = model(torch.FloatTensor(data["input"]))
    err = output-torch.FloatTensor([data["solu"]])
    err = err.item()
    errs.append(err)
    print(output, "compare to", torch.FloatTensor([data["solu"]]), "giving...", err)

tensor([0.9934], grad_fn=<AddBackward0>) compare to tensor([0.9856]) giving... 0.0078119635581970215
tensor([-0.3627], grad_fn=<AddBackward0>) compare to tensor([-0.3624]) giving... -0.00032708048820495605
tensor([-0.1550], grad_fn=<AddBackward0>) compare to tensor([-0.1549]) giving... -0.00013880431652069092
tensor([0.4865], grad_fn=<AddBackward0>) compare to tensor([0.4826]) giving... 0.0038254857063293457
tensor([0.3925], grad_fn=<AddBackward0>) compare to tensor([0.3894]) giving... 0.0030863583087921143
tensor([0.3450], grad_fn=<AddBackward0>) compare to tensor([0.3423]) giving... 0.0027129948139190674
tensor([-1.4961], grad_fn=<AddBackward0>) compare to tensor([-1.4959]) giving... -0.0001703500747680664
tensor([-0.3841], grad_fn=<AddBackward0>) compare to tensor([-0.3838]) giving... -0.00026866793632507324
tensor([2.4149], grad_fn=<AddBackward0>) compare to tensor([2.3959]) giving... 0.018990039825439453
tensor([-0.2856], grad_fn=<AddBackward0>) compare to tensor([-0.2855]) giving

tensor([0.2984], grad_fn=<AddBackward0>) compare to tensor([0.2960]) giving... 0.002346217632293701
tensor([0.6632], grad_fn=<AddBackward0>) compare to tensor([0.6580]) giving... 0.005215108394622803
tensor([-0.5398], grad_fn=<AddBackward0>) compare to tensor([-0.5396]) giving... -0.00017088651657104492
tensor([-0.1231], grad_fn=<AddBackward0>) compare to tensor([-0.1229]) giving... -0.00024928897619247437
tensor([-1.5641], grad_fn=<AddBackward0>) compare to tensor([-1.5639]) giving... -0.0002014636993408203
tensor([2.4879], grad_fn=<AddBackward0>) compare to tensor([2.4684]) giving... 0.01956462860107422
tensor([1.7612], grad_fn=<AddBackward0>) compare to tensor([1.7474]) giving... 0.013850092887878418
tensor([-0.6584], grad_fn=<AddBackward0>) compare to tensor([-0.6582]) giving... -0.00020968914031982422
tensor([0.4817], grad_fn=<AddBackward0>) compare to tensor([0.4779]) giving... 0.003788203001022339
tensor([-0.3577], grad_fn=<AddBackward0>) compare to tensor([-0.3574]) giving... -

tensor([-0.1117], grad_fn=<AddBackward0>) compare to tensor([-0.1113]) giving... -0.00036346912384033203
tensor([-0.4697], grad_fn=<AddBackward0>) compare to tensor([-0.4695]) giving... -0.00018128752708435059
tensor([-0.1162], grad_fn=<AddBackward0>) compare to tensor([-0.1158]) giving... -0.0004029273986816406
tensor([0.4618], grad_fn=<AddBackward0>) compare to tensor([0.4582]) giving... 0.0036314427852630615
tensor([0.8597], grad_fn=<AddBackward0>) compare to tensor([0.8530]) giving... 0.006760597229003906
tensor([0.8518], grad_fn=<AddBackward0>) compare to tensor([0.8451]) giving... 0.006698548793792725
tensor([0.7598], grad_fn=<AddBackward0>) compare to tensor([0.7539]) giving... 0.005975186824798584
tensor([-1.7093], grad_fn=<AddBackward0>) compare to tensor([-1.7091]) giving... -0.00016939640045166016
tensor([2.2609], grad_fn=<AddBackward0>) compare to tensor([2.2431]) giving... 0.01777935028076172
tensor([1.2369], grad_fn=<AddBackward0>) compare to tensor([1.2272]) giving... 0.

tensor([1.8187], grad_fn=<AddBackward0>) compare to tensor([1.8044]) giving... 0.014301419258117676
tensor([1.5530], grad_fn=<AddBackward0>) compare to tensor([1.5408]) giving... 0.012212395668029785
tensor([1.0777], grad_fn=<AddBackward0>) compare to tensor([1.0692]) giving... 0.00847470760345459
tensor([-0.2738], grad_fn=<AddBackward0>) compare to tensor([-0.2737]) giving... -6.347894668579102e-05
tensor([0.5891], grad_fn=<AddBackward0>) compare to tensor([0.5845]) giving... 0.004632532596588135
tensor([0.5247], grad_fn=<AddBackward0>) compare to tensor([0.5206]) giving... 0.004126071929931641
tensor([-0.9934], grad_fn=<AddBackward0>) compare to tensor([-0.9932]) giving... -0.00018018484115600586
tensor([2.2241], grad_fn=<AddBackward0>) compare to tensor([2.2066]) giving... 0.01748967170715332
tensor([0.6276], grad_fn=<AddBackward0>) compare to tensor([0.6227]) giving... 0.004935503005981445
tensor([-0.4885], grad_fn=<AddBackward0>) compare to tensor([-0.4883]) giving... -0.000203549

tensor([-0.4188], grad_fn=<AddBackward0>) compare to tensor([-0.4187]) giving... -0.0001806318759918213
tensor([-0.3366], grad_fn=<AddBackward0>) compare to tensor([-0.3365]) giving... -0.00011309981346130371
tensor([-1.7831], grad_fn=<AddBackward0>) compare to tensor([-1.7829]) giving... -0.00021266937255859375
tensor([0.1273], grad_fn=<AddBackward0>) compare to tensor([0.1263]) giving... 0.0010007768869400024
tensor([1.3602], grad_fn=<AddBackward0>) compare to tensor([1.3495]) giving... 0.010696768760681152
tensor([0.8335], grad_fn=<AddBackward0>) compare to tensor([0.8269]) giving... 0.00655442476272583
tensor([1.5390], grad_fn=<AddBackward0>) compare to tensor([1.5269]) giving... 0.012102723121643066
tensor([-0.8119], grad_fn=<AddBackward0>) compare to tensor([-0.8117]) giving... -0.0002040266990661621
tensor([-0.6201], grad_fn=<AddBackward0>) compare to tensor([-0.6199]) giving... -0.0001906752586364746
tensor([-0.3503], grad_fn=<AddBackward0>) compare to tensor([-0.3500]) giving.

tensor([-0.3743], grad_fn=<AddBackward0>) compare to tensor([-0.3741]) giving... -0.00025147199630737305
tensor([0.4946], grad_fn=<AddBackward0>) compare to tensor([0.4907]) giving... 0.003889322280883789
tensor([-1.9220], grad_fn=<AddBackward0>) compare to tensor([-1.9218]) giving... -0.00017750263214111328
tensor([1.3538], grad_fn=<AddBackward0>) compare to tensor([1.3432]) giving... 0.010646343231201172
tensor([0.0480], grad_fn=<AddBackward0>) compare to tensor([0.0476]) giving... 0.00037752091884613037
tensor([-1.6398], grad_fn=<AddBackward0>) compare to tensor([-1.6396]) giving... -0.0001512765884399414
tensor([0.5234], grad_fn=<AddBackward0>) compare to tensor([0.5193]) giving... 0.004115760326385498
tensor([-0.9022], grad_fn=<AddBackward0>) compare to tensor([-0.9020]) giving... -0.00019174814224243164
tensor([-0.5319], grad_fn=<AddBackward0>) compare to tensor([-0.5317]) giving... -0.0002116560935974121
tensor([1.1710], grad_fn=<AddBackward0>) compare to tensor([1.1618]) giving

In [26]:
print(errs)

[0.0078119635581970215, -0.00032708048820495605, -0.00013880431652069092, 0.0038254857063293457, 0.0030863583087921143, 0.0027129948139190674, -0.0001703500747680664, -0.00026866793632507324, 0.018990039825439453, -5.7250261306762695e-05, 9.94289293885231e-05, 0.011719703674316406, -0.0001552104949951172, -0.00015032291412353516, -0.00020945072174072266, -0.00017881393432617188, -0.00020182132720947266, -0.00017940998077392578, 0.0035988986492156982, 0.0009844079613685608, 0.0037342607975006104, -0.00018206238746643066, 0.011926412582397461, -0.00021076202392578125, -0.0002561211585998535, 0.004953503608703613, 0.0023156404495239258, 0.004495799541473389, 0.003919720649719238, 0.00716322660446167, -0.00019121170043945312, -0.00021839141845703125, -0.0003916025161743164, 0.0003771185874938965, -0.00017070770263671875, -0.00018846988677978516, -0.00026404857635498047, 9.371060878038406e-05, -0.0004456862807273865, 2.2023916244506836e-05, 0.008160233497619629, 0.01615309715270996, -0.0001

In [36]:
errs_file = open('errs.txt', 'w+')
for value in errs:
    errs_file.write(str(value)+" ")
errs_file.close()

In [37]:
list_of_loss_file = open('list_of_loss.txt', 'w+')
for value in errs:
    list_of_loss_file.write(str(value)+" ")
list_of_loss_file.close()

In [38]:
model.save_model("ResNet + PDE model 1")