In [None]:
import os
import argparse
import time
import numpy as np
import matplotlib.pyplot as plt
from os import makedirs
import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint_adjoint as odeint

parser = argparse.ArgumentParser('ODE demo')
method = 'dopri5'
data_size = 1000
batch_time = 2
batch_size = 30
niters = 2000
test_freq = 20
viz = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
s=10.
r=28.
b=8/3
ty1 = 0.
ty2 = 1.
ty3 = 1.05
true_y0 = torch.tensor([ty1, ty2, ty3]).to(device)
t = torch.linspace(0., 30., data_size).to(device)
s=10.
r=28.
b= 8./3
class Lambda(nn.Module):
    def forward(self, t, xyz, s=s, r=r, b=b):
        # return torch.matmul(y, true_A)
        x, y, z = xyz[...,:1], xyz[...,1:2], xyz[...,2:]
        x_dot = s*(y - x)
        y_dot = r*x - y - x*z
        z_dot = x*y - b*z
        return torch.tensor([x_dot, y_dot, z_dot]).to(device)

def to_np(x):
    return x.detach().cpu().numpy()

with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')

# define a grid of points at which we will show arrows
x0=np.linspace(-20,20,100)
x1=np.linspace(-20,20,100)
 
# create a grid
X0,X1=np.meshgrid(x0,x1)
# projections of the trajectory tangent vector 
dX0=np.zeros(X0.shape)
dX1=np.zeros(X1.shape)
shape1,shape2=X1.shape
 
for indexShape1 in range(shape1):
    for indexShape2 in range(shape2):
        dxdtAtX=Lambda().forward(0,torch.tensor([X0[indexShape1,indexShape2],X1[indexShape1,indexShape2], 0], dtype=torch.float).to(device))
        dX0[indexShape1,indexShape2]=dxdtAtX[0]
        dX1[indexShape1,indexShape2]=dxdtAtX[1]

plt.figure(figsize=(8, 8))
# plot the phase portrait
plt.quiver(X0,X1,dX0,dX1,color='b')
plt.xlim(-20,20)
plt.ylim(-20,20)
plt.title('Phase Portrait')
plt.xlabel('x1')
plt.ylabel('x2')
plt.tick_params(axis='both', which='major')
plt.plot(true_y.cpu().numpy()[:, 0], true_y.cpu().numpy()[:, 1], 'g-')

In [4]:
class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(3, 64),
            nn.LeakyReLU(),
            nn.Linear(64,64),
            nn.LeakyReLU(),
            nn.Linear(64,64),
            nn.LeakyReLU(),
            nn.Linear(64, 3)
        )

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)

    def forward(self, t, y):
        return self.net(y)


def get_batch():
    s = torch.from_numpy(np.random.choice(np.arange(data_size - batch_time, dtype=np.int64), batch_size, replace=False))
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(batch_time)], dim=0)  # (T, M, D)
    return batch_y0.to(device), batch_t.to(device), batch_y.to(device)


folder = f'lorenz/lorenz_niters{niters}_s{s}_r{r}_b{b}_{ty1}_{ty2}_{ty3}'
makedirs(folder, exist_ok=True)

def visualize(true_y, pred_y, odefunc, itr):

    fig = plt.figure(figsize=(12, 4), facecolor='white')
    ax_traj = fig.add_subplot(131, frameon=False)
    ax_phase = fig.add_subplot(132, frameon=False)
    # ax_vecfield = fig.add_subplot(133, frameon=False)
    
    ax_traj.cla()
    ax_traj.set_title('Trajectories')
    ax_traj.set_xlabel('t')
    ax_traj.set_ylabel('x,y')
    ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 1], 'g-')
    ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 1], 'b--')
    ax_traj.set_xlim(t.cpu().min(), t.cpu().max())
    ax_traj.set_ylim(-20, 20)
    # ax_traj.legend()

    ax_phase.cla()
    ax_phase.set_title('Phase Portrait')
    ax_phase.set_xlabel('x')
    ax_phase.set_ylabel('y')
    ax_phase.plot(true_y.cpu().numpy()[:, 0], true_y.cpu().numpy()[:, 1], 'g-')
    ax_phase.plot(pred_y.cpu().numpy()[:, 0], pred_y.cpu().numpy()[:, 1], 'b--')
    ax_phase.set_xlim(-20, 20)
    ax_phase.set_ylim(-20, 20)

    # ax_vecfield.cla()
    # ax_vecfield.set_title('Learned Vector Field')
    # ax_vecfield.set_xlabel('x')
    # ax_vecfield.set_ylabel('y')

    # y, x = np.mgrid[-3:3:21j, -3:3:21j]
    # dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()
    # mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
    # dydt = (dydt / mag)
    # dydt = dydt.reshape(21, 21, 2)

    # ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
    # ax_vecfield.set_xlim(-20, 20)
    # ax_vecfield.set_ylim(-20, 20)

    fig.tight_layout()
    plt.savefig('{}/{:03d}'.format(folder, itr))
    plt.draw()
    plt.pause(0.001)

class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val

In [None]:
ii = 0

func = ODEFunc().to(device)

optimizer = optim.RMSprop(func.parameters(), lr=1e-3)
end = time.time()

time_meter = RunningAverageMeter(0.97)

loss_meter = RunningAverageMeter(0.97)

for itr in range(1, niters + 1):
    optimizer.zero_grad()
    batch_y0, batch_t, batch_y = get_batch()
    pred_y = odeint(func, batch_y0, batch_t).to(device)
    loss = torch.mean(torch.abs(pred_y - batch_y))
    loss.backward()
    optimizer.step()

    time_meter.update(time.time() - end)
    loss_meter.update(loss.item())

    if itr % test_freq == 0:
        with torch.no_grad():
            pred_y = odeint(func, true_y0, t)
            loss = torch.mean(torch.abs(pred_y - true_y))
            print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
            visualize(true_y, pred_y, func, ii)
            ii += 1

    end = time.time()

In [6]:
import glob
import contextlib
from PIL import Image

# filepaths
fp_in = f'./{folder}/*.png'
fp_out = f'./{folder}/results.gif'

# use exit stack to automatically close opened images
with contextlib.ExitStack() as stack:

    # lazily load images
    imgs = (stack.enter_context(Image.open(f))
            for f in sorted(glob.glob(fp_in)))

    # extract  first image from iterator
    img = next(imgs)

    # https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
    img.save(fp=fp_out, format='GIF', append_images=imgs,
             save_all=True, duration=300, loop=0)