In [1]:
%matplotlib notebook
import torch
import matplotlib.pyplot as plt
import schemes as sc

# Data type and device setup

In [2]:
dtype = torch.float
device = torch.device("cpu")

# Generate signal

In [3]:
freq = 1
t_true = torch.linspace(0, 10, 1001)
x_true = torch.sin(2 * torch.pi * freq * t_true)

fig = plt.figure()
axs = fig.add_subplot(1, 1, 1)
axs.plot(t_true, x_true)
#plt.show()

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x10f21eb00>]

In [4]:
dt = t_true[1] - t_true[0]
Nc = 100

<img src="./Figures/Model_1.png" alt="Model Outline" width="1000" height="1000">

# Model 

In [5]:
def model(t, xx, w0):
    return torch.mm(xx, w0)

# Initialization

In [6]:
τ = torch.tensor([0, 0.2], device=device, dtype=dtype, requires_grad=True)
w0 = torch.zeros((2, 2), device=device, dtype=dtype, requires_grad=True)

lr_tau = 1e-4
lr_w0 = 1
batch_size = 20
batch_time = 10  # (x_true)-Nc
lr_pow_w0 = torch.linspace(0, -2, 100)
lr_pow_tau = torch.linspace(0, -4, 100)

loss_arr   = []
τ_arr      = []
τ_grad_arr = []

In [7]:
fig = plt.figure()
axs = fig.add_subplot(1, 1, 1)
for kk in range(2000):

    # lr_w0 = 1e-4 + 1e-1 * torch.tanh(torch.tensor(kk/500))
    # lr_tau = 1e-1 * torch.tensor(10).pow(lr_pow_tau[kk//100])
    batch_time = 10 + ((kk // 100) * 10);

    if kk % 100 == 0:
        print('lr_w0=', lr_w0)
        print('batch_time=', batch_time)

    z_true = sc.interp_linear(t_true,x_true,Nc,τ)

    if kk % 100 == 0:
        st_id = torch.randint(0, len(z_true) - batch_time, (1,)).item()

    id_sel = torch.randint(0, z_true.shape[0] - batch_time, (batch_size,))
    z_true_stack = torch.stack([z_true[id_sel + i, :] for i in range(batch_time)], dim=0)
    t_true_stack = torch.stack([t_true[id_sel + i] for i in range(batch_time)], dim=0)
    # print(z_true_stack.shape)
    # print(t_true_stack.shape)

    for i in range(0, batch_time):
        fun = lambda t, x: model(t, x, w0)
        if i == 0:
            z_pred = z_true_stack[i, :, :].reshape(1, z_true_stack.shape[1], z_true_stack.shape[2])
        else:
            z_next = sc.rk4(fun, t_true[i], z_pred[i - 1, :, :], dt)
            z_pred = torch.cat([z_pred, z_next.reshape(1, z_true_stack.shape[1], z_true_stack.shape[2])], 0)

    # print("z_pred.shape=", z_pred.shape)
    # print("z_true.shape=",z_true[0:batch_time,:].shape)
    # loss = torch.abs((z_true[0:batch_time,:]-z_pred)).sum() + torch.abs(w0).sum()

    # loss = (z_true_stack-z_pred).pow(2).mean() #+ 1e-5* torch.abs(w0).sum()
    loss = torch.abs((z_true_stack - z_pred)).mean() + 1e-2 * torch.abs(w0).sum()
    loss.backward()

    with torch.no_grad():
        w0_old = w0.detach().numpy()
        τ_old = τ.detach().numpy()

        # print(τ_old)
        τ_arr.append(τ_old)
        τ_grad_arr.append(τ.grad.detach().numpy())
        # print(w0.grad)
        w0 -= lr_w0 * w0.grad
        if kk > 300:
            τ -= lr_tau * τ.grad * (10 * (kk % 10 == 0) + (kk % 10 != 0))

        loss_arr.append(loss.item())
        # print(loss.item())
        if kk % 10 == 0:
            # Visualize
            plt.cla()
            for p_id in range(batch_size):
                axs.plot(t_true_stack.detach().numpy()[:, p_id], z_pred[:, p_id, 0].detach().numpy(), 'ro')
                axs.plot(t_true_stack.detach().numpy()[:, p_id], z_pred[:, p_id, 1].detach().numpy(), 'bo')

            axs.plot(t_true[0:len(z_true)], z_true[:, 0].detach().numpy(), 'r-')
            axs.plot(t_true[0:len(z_true)], z_true[:, 1].detach().numpy(), 'b-')
            
            plt.show(block=False)

            fig.canvas.draw()
            plt.pause(0.0001)

        if kk % 25 == 0:
            print("iter=", kk)
            print("loss=", loss.detach().numpy())
            print("w0=", w0.detach().numpy())
            print("tau=", τ.detach().numpy())

        w0.grad = None
        τ.grad = None

<IPython.core.display.Javascript object>

lr_w0= 1
batch_time= 10
iter= 0
loss= 0.18492426
w0= [[-0.0060216  -0.01458016]
 [ 0.01112439 -0.0006384 ]]
tau= [0.  0.2]
iter= 25
loss= 0.1789139
w0= [[-0.00806733 -0.12953576]
 [ 0.06507432 -0.01174788]]
tau= [0.  0.2]
iter= 50
loss= 0.18466237
w0= [[-0.00253714 -0.2409582 ]
 [ 0.11347534 -0.00577792]]
tau= [0.  0.2]
iter= 75
loss= 0.18302855
w0= [[-0.00730244 -0.3569132 ]
 [ 0.1690358   0.00731208]]
tau= [0.  0.2]
lr_w0= 1
batch_time= 20
iter= 100
loss= 0.35173088
w0= [[-0.01234107 -0.48152155]
 [ 0.22860104 -0.0055436 ]]
tau= [0.  0.2]
iter= 125
loss= 0.34976333
w0= [[-0.03722303 -0.987465  ]
 [ 0.54035294 -0.04228577]]
tau= [0.  0.2]
iter= 150
loss= 0.35077345
w0= [[-0.06042376 -1.531937  ]
 [ 0.8239603  -0.03157965]]
tau= [0.  0.2]
iter= 175
loss= 0.29204857
w0= [[-0.04241876 -2.072119  ]
 [ 1.1480131  -0.02017419]]
tau= [0.  0.2]
lr_w0= 1
batch_time= 30
iter= 200
loss= 0.3994241
w0= [[-0.0051326  -2.6201746 ]
 [ 1.5095352  -0.01669903]]
tau= [0.  0.2]
iter= 225
loss= 0.3593861


KeyboardInterrupt: 