In [1]:
%matplotlib notebook
import torch
import matplotlib.pyplot as plt
import schemes as sc

# Data type and device setup

In [2]:
dtype = torch.float
device = torch.device("cpu")

# Generate signal

In [3]:
freq = 1
t_true = torch.linspace(0, 10, 1001)
x_true = torch.sin(2 * torch.pi * freq * t_true)

fig = plt.figure()
axs = fig.add_subplot(1, 1, 1)
axs.plot(t_true, x_true)
#plt.show()

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x260bae73850>]

In [4]:
dt = t_true[1] - t_true[0]
Nc = 100

<img src="./Figures/Model_1.png" alt="Model Outline" width="1000" height="1000">

# Model 

In [5]:
def model(t, xx, w0):
    return torch.mm(xx, w0)

# Initialization

In [6]:
τ = torch.tensor([0, 0.2], device=device, dtype=dtype, requires_grad=True)
w0 = torch.zeros((2, 2), device=device, dtype=dtype, requires_grad=True)

lr_tau = 1e-4
lr_w0 = 1
batch_size = 20
batch_time = 10  # (x_true)-Nc

loss_arr   = []
τ_arr      = []
τ_grad_arr = []

In [None]:
fig = plt.figure()
axs = fig.add_subplot(1, 1, 1)
for kk in range(2000):

    batch_time = 30 

    if kk % 100 == 0:
        print('lr_w0=', lr_w0)
        print('batch_time=', batch_time)

    z_true = sc.interp_linear(t_true,x_true,Nc,τ)

    if kk % 100 == 0:
        st_id = torch.randint(0, len(z_true) - batch_time, (1,)).item()

    id_sel = torch.randint(0, z_true.shape[0] - batch_time, (batch_size,))
    z_true_stack = torch.stack([z_true[id_sel + i, :] for i in range(batch_time)], dim=0)
    t_true_stack = torch.stack([t_true[id_sel + i] for i in range(batch_time)], dim=0)
    # print(z_true_stack.shape)
    # print(t_true_stack.shape)

    for i in range(0, batch_time):
        fun = lambda t, x: model(t, x, w0)
        if i == 0:
            z_pred = z_true_stack[i, :, :].reshape(1, z_true_stack.shape[1], z_true_stack.shape[2])
        else:
            z_next = sc.rk4(fun, t_true[i], z_pred[i - 1, :, :], dt)
            z_pred = torch.cat([z_pred, z_next.reshape(1, z_true_stack.shape[1], z_true_stack.shape[2])], 0)

    # print("z_pred.shape=", z_pred.shape)
    # print("z_true.shape=",z_true[0:batch_time,:].shape)
    # loss = torch.abs((z_true[0:batch_time,:]-z_pred)).sum() + torch.abs(w0).sum()

    # loss = (z_true_stack-z_pred).pow(2).mean() #+ 1e-5* torch.abs(w0).sum()
    loss = torch.abs((z_true_stack - z_pred)).mean() + 1e-2 * torch.abs(w0).sum()
    loss.backward()

    with torch.no_grad():
        w0_old = w0.detach().numpy()
        τ_old = τ.detach().numpy()

        # print(τ_old)
        τ_arr.append(τ_old)
        τ_grad_arr.append(τ.grad.detach().numpy())
        
        # print(w0.grad)
        w0 -= lr_w0 * w0.grad
        if kk > 300:
            τ -= lr_tau * τ.grad * (10 * (kk % 10 == 0) + (kk % 10 != 0))

        loss_arr.append(loss.item())
        τ[0] = 0
        
        # print(loss.item())
        if kk % 10 == 0:
            # Visualize
            plt.cla()
            for p_id in range(batch_size):
                axs.plot(t_true_stack.detach().numpy()[:, p_id], z_pred[:, p_id, 0].detach().numpy(), 'ro')
                axs.plot(t_true_stack.detach().numpy()[:, p_id], z_pred[:, p_id, 1].detach().numpy(), 'bo')

            axs.plot(t_true[0:len(z_true)], z_true[:, 0].detach().numpy(), 'r-')
            axs.plot(t_true[0:len(z_true)], z_true[:, 1].detach().numpy(), 'b-')
            
            plt.show(block=False)

            fig.canvas.draw()
            plt.pause(0.0001)

        if kk % 25 == 0:
            print("iter=", kk)
            print("loss=", loss.detach().numpy())
            print("w0=", w0.detach().numpy())
            print("tau=", τ.detach().numpy())

        w0.grad = None
        τ.grad = None

<IPython.core.display.Javascript object>

lr_w0= 1
batch_time= 30
iter= 0
loss= 0.5267721
w0= [[-0.01950833 -0.04459416]
 [ 0.02563957 -0.03573093]]
tau= [0.  0.2]
iter= 25
loss= 0.48870623
w0= [[-0.31826845 -0.8598776 ]
 [ 0.46749744 -0.38607228]]
tau= [0.  0.2]
iter= 50
loss= 0.45101094
w0= [[-0.512699   -1.6779048 ]
 [ 0.95200104 -0.5370561 ]]
tau= [0.  0.2]
iter= 75
loss= 0.40181217
w0= [[-0.5520624 -2.471312 ]
 [ 1.5137907 -0.600597 ]]
tau= [0.  0.2]
lr_w0= 1
batch_time= 30
iter= 100
loss= 0.39749643
w0= [[-0.55641264 -3.338529  ]
 [ 2.0452642  -0.51403195]]
tau= [0.  0.2]
iter= 125
loss= 0.33142382
w0= [[-0.43636283 -4.1446886 ]
 [ 2.690012   -0.30776572]]
tau= [0.  0.2]
iter= 150
loss= 0.2730572
w0= [[-0.30307317 -4.93682   ]
 [ 3.3777556  -0.00751369]]
tau= [0.  0.2]
iter= 175
loss= 0.25300813
w0= [[-0.24057658 -5.6418095 ]
 [ 3.9986584   0.04731238]]
tau= [0.  0.2]
lr_w0= 1
batch_time= 30
iter= 200
loss= 0.221997
w0= [[-0.43189684 -5.8514752 ]
 [ 4.435146    0.27093416]]
tau= [0.  0.2]
iter= 225
loss= 0.23145436
w0= [