In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import scipy.interpolate as interpol
import matplotlib.pyplot as plt

In [2]:
torch.set_default_dtype(torch.float64)

In [3]:
xi = -20. #left endpont
xf = 20. #right endpoint
J = 5001 #number of space nodes
Xgrid = np.linspace(xi, xf, J)

In [4]:
x = np.load('saves/try0.npy')

In [5]:
x.shape

(2, 1000, 50000)

In [6]:
nt = x.shape[1]
N = x.shape[2]

In [7]:
batch_size = 50
# make a random batch

def batch(x):
    batchy = np.zeros([batch_size, 2, nt, 3])
    true_next = np.zeros([batch_size, nt, 1])
    for i in range(batch_size):
        k = np.random.randint(0, N-3)
        batchy[i, :, :, :] = x[:, :, k:k+3]
        true_next[i, :, 0] = x[0, :, k+3]
    return batchy, true_next

In [8]:
baba = batch(x)

In [9]:
baba[1].shape

(50, 1000, 1)

In [10]:
class TNet(nn.Module):

    def __init__(self):
        super(TNet, self).__init__()
        self.pad1 = nn.ReplicationPad2d((1, 1, 0, 0))
        self.pad2 = nn.ReplicationPad1d(1)
        self.con1 = nn.Conv2d(2, 9, kernel_size = (3,3))
        self.con2 = nn.Conv1d(9, 18, kernel_size = 3)
        self.con3 = nn.Conv1d(18, 27, kernel_size = 3)
        self.nonlin = nn.SELU()
        self.fct = nn.Linear(27, 1)


    def forward(self, x):
        x = self.con1(x)
        x = torch.squeeze(x)
        x = self.pad2(x)
        x = self.nonlin(x)
        x = self.con2(x)
        x = self.pad2(x)
        x = self.nonlin(x)
        x = self.con3(x)
        x = self.pad2(x)
        x = self.nonlin(x)
        x = x.transpose(1,2)
        x = self.fct(x)
        return x

In [11]:
net = TNet()
net = net.cuda()
net.load_state_dict(torch.load('saves/tnet'))
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.0002)

In [12]:
for epoch in range(0, 50_000+1):
    optimizer.zero_grad()
    
    data, true_next = batch(x)
    data = torch.from_numpy(data).cuda()
    true_next = torch.from_numpy(true_next).cuda()
        
    output = net(data)

    loss = criterion(output, true_next)
    loss.backward()
    optimizer.step()
    if epoch % 1000 == 0:
        print("Epoch "+str(epoch)+": "+("%.6f" % loss.data.item()))

    
    if epoch % 10000 == 0:
        for g in optimizer.param_groups:
            g['lr'] = 0.0002 / 2
        torch.save(net.state_dict(), "saves/tnet")

Epoch 0: 0.000171
Epoch 1000: 0.000136
Epoch 2000: 0.000142
Epoch 3000: 0.000163
Epoch 4000: 0.000141
Epoch 5000: 0.000142
Epoch 6000: 0.000145
Epoch 7000: 0.000146
Epoch 8000: 0.000111
Epoch 9000: 0.000117
Epoch 10000: 0.000090
Epoch 11000: 0.000101
Epoch 12000: 0.000095
Epoch 13000: 0.000103
Epoch 14000: 0.000080
Epoch 15000: 0.000119
Epoch 16000: 0.000084
Epoch 17000: 0.000187
Epoch 18000: 0.000292
Epoch 19000: 0.000089
Epoch 20000: 0.000073
Epoch 21000: 0.000086
Epoch 22000: 0.000078
Epoch 23000: 0.000072
Epoch 24000: 0.000078
Epoch 25000: 0.000066
Epoch 26000: 0.000072
Epoch 27000: 0.000053
Epoch 28000: 0.000036
Epoch 29000: 0.000061
Epoch 30000: 0.000071
Epoch 31000: 0.000051
Epoch 32000: 0.000108
Epoch 33000: 0.000051
Epoch 34000: 0.000058
Epoch 35000: 0.000053
Epoch 36000: 0.000105
Epoch 37000: 0.000051
Epoch 38000: 0.000047
Epoch 39000: 0.000052
Epoch 40000: 0.000035
Epoch 41000: 0.000049
Epoch 42000: 0.000043
Epoch 43000: 0.000056
Epoch 44000: 0.000040
Epoch 45000: 0.000051
E