In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
torch.set_default_dtype(torch.float64)

In [3]:
x = np.load('saves/try0_traj.npy')
potential = np.load('saves/try0_pot.npy')

In [4]:
x.shape

(1000, 50000)

In [5]:
potential.shape

(5001,)

In [6]:
nt = x.shape[0]
N = x.shape[1]

In [7]:
batch_size = 50
# make a random batch

def batch(x):
    batchy = np.zeros([batch_size, nt, 3])
    true_next = np.zeros([batch_size, nt, 1])
    for i in range(batch_size):
        k = np.random.randint(0, N-3)
        batchy[i, :, :] = x[:, k:k+3]
        true_next[i, :, 0] = x[:, k+3]
    return torch.from_numpy(batchy), torch.from_numpy(true_next)

In [8]:
batch(x)[0].shape

torch.Size([50, 1000, 3])

In [9]:
class TNet(nn.Module):

    def __init__(self):
        super(TNet, self).__init__()
        self.pad1 = nn.ReplicationPad2d((1, 1, 0, 0))
        self.pad2 = nn.ReplicationPad1d(1)
        self.con1 = nn.Conv2d(1, 9, kernel_size = (3,3))
        self.con2 = nn.Conv1d(9, 18, kernel_size = 3)
        self.con3 = nn.Conv1d(18, 27, kernel_size = 3)
        self.nonlin = nn.SELU()
        self.fct = nn.Linear(27, 1)


    def forward(self, x):
        x = self.con1(x)
        x = torch.squeeze(x)
        x = self.pad2(x)
        x = self.nonlin(x)
        x = self.con2(x)
        x = self.pad2(x)
        x = self.nonlin(x)
        x = self.con3(x)
        x = self.pad2(x)
        x = self.nonlin(x)
        x = x.transpose(1,2)
        x = self.fct(x)
        return x

In [10]:
net = TNet()
net = net.cuda()
# net.load_state_dict(torch.load('saves/tnet'))
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [11]:
for epoch in range(0, 20_000+1):
    optimizer.zero_grad()
    
    data, true_next = batch(x)
    data = data.cuda()
    data = torch.unsqueeze(data, 1)
    true_next = true_next.cuda()
        
    output = net(data)

    loss = criterion(output, true_next)
    loss.backward()
    optimizer.step()
    if epoch % 1000 == 0:
        print("Epoch "+str(epoch)+": "+("%.6f" % loss.data.item()))

    
    if epoch % 10000 == 0:
        for g in optimizer.param_groups:
            g['lr'] = 0.0002 / 2
        torch.save(net.state_dict(), "saves/tnet")

Epoch 0: 159.811501
Epoch 1000: 0.357593
Epoch 2000: 0.031242
Epoch 3000: 0.021961
Epoch 4000: 0.011433
Epoch 5000: 0.006985
Epoch 6000: 0.006059
Epoch 7000: 0.003803
Epoch 8000: 0.002233
Epoch 9000: 0.004266
Epoch 10000: 0.001601
Epoch 11000: 0.001678
Epoch 12000: 0.001016
Epoch 13000: 0.000947
Epoch 14000: 0.000789
Epoch 15000: 0.000504
Epoch 16000: 0.000211
Epoch 17000: 0.000247
Epoch 18000: 0.000210
Epoch 19000: 0.000214
Epoch 20000: 0.000144
