# InImNet Solution of Projectile motion

## Imports and directories

In [None]:
# Quick parameters
npoints = 4 #32  # Number of p-layers
nsamples = 4 #100  # = No. training samples = No. validation samples
nlayers = 2  # Dense layers in the Phi-network
inflate_width = 4  # Inflate the hidden network width
gravity = 9.81  # Vertical decelaration constant
hmax = 10.  # Max height fixed for plotting
load_samples = False  # Requires existance of files else computes samples
save_samples = True  # Only when loading samples fails

# Import modules
import os
import sys
from matplotlib import rcParams, cycler
import matplotlib.pyplot as plt  # %matplotlib notebook
import numpy as np
import random
from tqdm import tqdm
from tqdm.notebook import tqdm_notebook

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp  # !pip install tensorflow_probability

# pgrid defining InIm layers: pgrid[-1] = 1 = q
pgrid = tf.linspace(0., 1., npoints + 1)


In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

pmotion_path = '/content/gdrive/MyDrive/Colab/pmotion'
if os.getcwd() != pmotion_path: 
    os.chdir(pmotion_path)
print('Current working directory: ', pmotion_path)
sys.path.insert(0, pmotion_path)
data_path = 'data/p' + str(npoints) + 'b' + str(nsamples)
graphics_path = 'graphics/p' + str(npoints) + 'b' + str(nsamples)
if not os.path.exists(data_path):
    os.makedirs(data_path)
    print('New data directory.')
if not os.path.exists(graphics_path):
    os.makedirs(graphics_path)
    print('New graphics directory.')

ptrain_path = data_path + '/ptrain{p_i:03d}'
pvalid_path = data_path + '/pvalid{p_i:03d}'
vtrain_path = data_path + '/vtrain'
vvalid_path = data_path + '/vvalid'
inimsave_path = data_path + '/inimsol'


In [None]:
%%script false
# TPU initialisation
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))

In [None]:
# GPU info
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
   print(gpu_info)


## Generate Samples

### Projectile ODE

In [None]:
class ProjectileODE(tfp.math.ode.DormandPrince):
    """Subclass of tfp.math.ode.Solve facilitates self.solve method."""

    def __init__(self,
                 batch_dim,
                 gravity=9.81,
                 c=0,
                 ):
        super(ProjectileODE, self).__init__()

        # ODE matrix multiplier
        self.A = tf.constant([[0, 1],
                              [0, c]], dtype=tf.float32)
        self.A = tf.expand_dims(self.A, axis=0)
        self.A = tf.repeat(self.A, batch_dim, axis=0)

        # ODE shift 
        self.b = tf.constant([0, -gravity], dtype=tf.float32)
        self.b = tf.expand_dims(self.b, axis=0)
        self.b = tf.repeat(self.b, batch_dim, axis=0)

    def dynamics(self, t, z):
        """Implements (A@z + b)/c for A=[[0, 1], [0, 0]] and b=[0, -gravity]."""
        return tf.linalg.matvec(self.A, z) + self.b

    def inimdynamics(self, z, p, x, gtape):
        """Invariant Imbedding p-gradient of z(0; p, x)."""
        dzdx = gtape.batch_jacobian(z, x)
        phi = self.dynamics(t=p, z=x)
        return -tf.linalg.matvec(dzdx, phi)

    def inimsolve(self, pgrid, x, gtape):
        """Euler-step ODE solver to solve dzdp = inimdynamics() over pgrid."""
        z = [x]
        q = pgrid[-1]
        for p in tqdm_notebook(pgrid[-2::-1]):  # reverse pgrid, ommit pgrid[-1]
            delta = p - q
            z.append(z[-1] + delta * self.inimdynamics(z[-1], p, x, gtape))
            q = p
        z.reverse()  # Reorder solution according to pgrid
        return tf.stack(z)


### Generate samples

In [None]:
def get_projectile_samples(pgrid,
                           nsamples,
                           saving=False,
                           gravity=9.81,
                           ):
    """Integrates state z(t; p, x) = [h, v] from p to pgrid[-1].

    - Completes this for each p in pgrid[:-1] with t-solns at pgrid[i:].
    - Also computes the InIm integral analytically over pgrid[::-1].
    """

    if not all(p < q for p, q in zip(pgrid, pgrid[1:])):
        raise ValueError("Tensor `pgrid` should be strictly increasing.")

    # Adjust sample parameters (vertical speed) for suitable plots
    vmin = gravity * (pgrid[-1] - pgrid[0]) / 2
    vmax = tf.math.sqrt(2 * gravity * hmax)
    print('vmax: ', vmax)
    print('vmin: ', vmin)

    # Initial input values for z(p; p, x) = x = [h=0, v=vinit]
    vinit = tf.random.uniform(shape=[2 * nsamples, 1],
                              minval=vmin,
                              maxval=vmax,
                              dtype=tf.dtypes.float32,
                              )
    hinit = tf.zeros_like(vinit)
    x = tf.concat([hinit, vinit], axis=-1)

    # ODE dynamics objects
    ode = ProjectileODE(2*nsamples, gravity)

    # Iterate t-integral for inital values p
    ptrain = []
    pvalid = []
    for i in tqdm_notebook(range(len(pgrid) - 1)):
        # Integrate t-solution
        tsol = ode.solve(ode.dynamics, pgrid[i], x, pgrid[i:])
        # i-th tstates[v, t, j] = z(t; p_i, x=[0, v])_j
        tstates = tf.transpose(
            tsol.states, perm=[1, 0, 2], name = 'z(t; p_{}, x)'.format(i+1),
            )   
        # Split training/validation data 50:50
        ttrain = tf.data.Dataset.from_tensors(tstates[:nsamples])
        tvalid = tf.data.Dataset.from_tensors(tstates[nsamples:])
        if saving:
            # Save p-th t-state
            tf.data.experimental.save(ttrain, ptrain_path.format(p_i=i+1))
            tf.data.experimental.save(tvalid, pvalid_path.format(p_i=i+1))
        # Accumulate as list
        ptrain.append(ttrain.get_single_element())
        pvalid.append(tvalid.get_single_element())

    # Save initial values
    vtrain = tf.data.Dataset.from_tensors(vinit[:nsamples])
    vvalid = tf.data.Dataset.from_tensors(vinit[nsamples:])
    if saving:
        tf.data.experimental.save(vtrain, vtrain_path)
        tf.data.experimental.save(vvalid, vvalid_path)

    # Compute InIm dynamics
    with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
        g.watch(x)
        # Integrate InImODE: inimsoln[i, v, j] = z(q; p_i, x=[0, v])_j
        inimsoln = ode.inimsolve(pgrid, x, g)
    if saving:
        tf.data.experimental.save(tf.data.Dataset.from_tensors(inimsoln),
                                  inimsave_path,
                                  )
    return (ptrain,
            pvalid,
            tf.squeeze(vtrain.get_single_element()),
            tf.squeeze(vvalid.get_single_element()),
            inimsoln,
            )

# Load training data
if load_samples and os.path.exists(ptrain_path.format(p_i=len(pgrid) - 1)):
    print('\nLoading samples from file...')
    ptrain = []
    pvalid = []
    for i in tqdm_notebook(range(len(pgrid) - 1)):
        ttrain = tf.data.experimental.load(ptrain_path.format(p_i=i+1))
        tvalid = tf.data.experimental.load(pvalid_path.format(p_i=i+1))
        ptrain.append(ttrain.get_single_element())
        pvalid.append(tvalid.get_single_element())
    vtrain = tf.data.experimental.load(vtrain_path).get_single_element()
    vvalid = tf.data.experimental.load(vvalid_path).get_single_element()
    inimsoln = tf.data.experimental.load(inimsave_path).get_single_element()
    print('Complete.')
else:
    print('\nComputing samples...')
    samples = get_projectile_samples(pgrid,
                                     nsamples,
                                     saving=save_samples,
                                     gravity=gravity,
                                     )
    ptrain, pvalid, vtrain, vvalid, inimsoln = samples
    print('\nComplete.')

# inimsoln[i, v, j] = z(q; p_i, [0, v])_j with len(v) = 2*nsamples
inimtrain = inimsoln[:, :nsamples, :]


### Cultivate data

In [None]:
# Training data
vtrain = tf.expand_dims(vtrain, axis=-1)
htrain = tf.zeros_like(vtrain)
xtrain = tf.concat([htrain, vtrain], axis=-1)
ytrain = ptrain  # ytrain[i][v, t, j] = z(t; p_i, [0, v])_j & len(v) = nsamples

print('vtrain.shape: ', vtrain.shape)
print('htrain.shape: ', htrain.shape)
print('xtrain.shape: ', xtrain.shape)
print('ytrain[0].shape: ', ytrain[0].shape)

# Validation data
#vvalid = tf.zeros_like(tf.expand_dims(vvalid, axis=-1))
hvalid = tf.zeros_like(vvalid)
xvalid = tf.concat([hvalid, vvalid], axis=-1)
yvalid = pvalid  # yvalid[i][v, t, j] = z(t; p_i, [0, v])_j & len(v) = nsamples

In [None]:
print(vtrain)

### Plot $t$-varying projectiles

In [None]:
## Plot t-varying output
# ptrain[i][v, t, j] = z(t; p_i, [0, v])_j
print('No. of p-states: ', len(ptrain))
print('No. of v-samples: ', len(vtrain))
print('Shape of typical p_states: ', ptrain[0].shape)

print('First input x = [h, v] in batch:')
print('h(p) = ', ptrain[0][0][0, 0])
print('v(p) = ', ptrain[0][0][0, 1])
print('\n')

# Create figure
plt.style.use('seaborn-whitegrid')
cmap = plt.cm.autumn
rcParams['axes.prop_cycle'] = cycler(
    color=cmap(np.linspace(0, 1, len(pgrid) - 1))
    )
fig, ax = plt.subplots(figsize=[10, 8])
ax.set_xlim(pgrid[0], pgrid[-1])
ax.set_xlabel('p', fontsize=16)
ax.set_ylim(0, hmax)
ax.set_ylabel('h(t; p, x=[0, v(p)])', fontsize=16)
ax.set_title('Projectile path sample set.', fontsize=20)

# Plot (blue) lines at p = pgrid[0] for 5 random values of v
cstr = ['b', 'aquamarine', 'xkcd:sky blue', 'r']
knum = 4
ipt = 0
msize = 100
for k in range(knum):
    ax.plot(pgrid[ipt:],
            ptrain[ipt][k][:, 0],
            cstr[k],
            label='$v(p)$ = {:.2f}'.format(float(vtrain[k])),
            )
    ax.scatter(pgrid[-1],
               #inimtrain[ipt, k, 0],
               ptrain[0][k][-1, 0],  
               c=cstr[k],
               s=msize,
               marker=8,
               )

# Plot (red) lines at each p = pgrid for the final values of v=vtrain[vidx[-1]]
for i in range(0, len(pgrid) - 1):
    ax.plot(pgrid[i:], ptrain[i][knum-1][:, 0])
ax.legend(loc='upper left')
#plt.show()
fig.savefig(graphics_path + '/tsol.png')


## Define model

In [None]:
class InImNet(tf.keras.Model):
    """Time series regression with InImNet."""

    def __init__(self,
                 input_dim: int=2,
                 num_layers: int=2,
                 width_mult: int=2,
                 activation_in='relu',
                 activation_out=None,
                 cost: str='mse',
                 bias_in: bool=True,
                 bias_out: bool=False,
                 t_grad_on: bool=False,
                 name="InImNet"):
        super(InImNet, self).__init__(name=name)
        if cost == 'mse':
            self.cost_fn = tf.keras.losses.MeanSquaredError(
                                reduction=tf.keras.losses.Reduction.NONE)
        else:
            raise ValueError('Please set new cost_fn')
        self.t_grad_on = t_grad_on
        self.phi = tf.keras.Sequential()
        self.phi.add(tf.keras.Input(shape=(input_dim,)))
        for layer in range(num_layers - 1):
            self.phi.add(tf.keras.layers.Dense(input_dim * width_mult,
                                               activation=activation_in,
                                               use_bias=bias_in,
                                               name='HiddenLayer{}'
                                                    .format(layer),
                                               ))
        self.phi.add(tf.keras.layers.Dense(input_dim,
                                           activation=activation_out,
                                           use_bias=bias_out,
                                           name='OutputLayer',
                                           ))

    def inim_dynamics(self, z, p, x, gtape):
        """Invariant Imbedding p-gradient of z(0; p, x)."""
        return -tf.linalg.matvec(gtape.batch_jacobian(z, x), self.phi(x))

    def call(self, pgrid, x):
        """Fwd-Euler integration of p-gradient of z(q; p, x) over pgrid.

        Args:
            pgrid: Points at which to evaluate state z, all less than q.
            x: Tensor, input to the system for any given p in pgrid

        Returns: 3D Tensor `out`, such that out[i, v, j] = z(q; p_i, x=[0, v])_j
        """

        if not all(p < q for p, q in zip(pgrid, pgrid[1:])):
            raise ValueError("The list 'pgrid' should be strictly increasing.")

        with tf.GradientTape(
            persistent=True, watch_accessed_variables=False) as g:
            g.watch(x)
            z = [x]
            q = pgrid[-1]
            for p in tqdm_notebook(pgrid[-2::-1]):  # p = pgrid[-2], ..., [0]
                delta = p - q
                z.append(z[-1] + delta * self.inim_dynamics(z[-1], p, x, g))
                q = p
        z.reverse()
        return tf.stack(z)

    def aug_adjoint_dynamics(self, aa, p, x, gtape):
        """Implements the p-derivative of the augmented adjoint (aa) state.

        Args:
            aa: The Augmented Adjoint, a length 2 or 3 tuple of tensors with
                aa[0] = z derivative of the adjoint
                aa[1] = parameter derivative of the adjoint
                aa[2] = t derivative of the adjoint [optional]
            p: Scalar tensor
            x: Tensor, InIm input of shape (batch_size, 2)
            gx: GradientTape().watch(x) 
                
        """

        # Watch x in phi(x, var) for Jacobian computation
        with tf.GradientTape(
            persistent=True, watch_accessed_variables=False) as g:
            g.watch(x)
            for var in self.variables:
                g.watch(var)
            phi = self.phi(x)
        dphi = g.batch_jacobian(phi, x)

        # Gradient wrt. state z
        grad_z = - tf.linalg.matvec(gtape.batch_jacobian(aa[0], x), phi) \
                 - tf.linalg.matvec(tf.transpose(dphi, perm=[0, 2, 1]), aa[0])

        # Gradient wrt. params
        jacobians_wrt_params = []
        for var in self.variables:
            # vjp(z=self.phi(p, x), x=param, v_like_z=lam_aug[0]).view(-1))
            jacobians_wrt_params.append(tf.linalg.matvec(
                tf.transpose(tf.reshape(g.jacobian(phi, var),
                                        shape = phi.shape.as_list() + [-1,],
                                        ),
                             perm=[0, 2, 1],
                             ),
                aa[0],
                ))

        # grad_lam_th = - jvp(z=lam_aug[1], x=x, v_like_x=self.phi(p, x)) \
                    #   - torch.cat(vjps_wrt_params)
        grad_params = - tf.linalg.matvec(gtape.batch_jacobian(aa[1], x), phi) \
                      - tf.concat(jacobians_wrt_params, axis=-1)

        # Activate the t-component of the lambda derivative
        if self.t_grad_on:
            # grad_lam_t = - jvp(z=lam_aug[2], x=x, v_like_x=self.phi(p, x)) \
                        #  - torch.bmm(jvp(self.phi(p, x), x, self.phi(p, x)),
                                    #  lam_aug[0],
                                    #  )
            grad_t = - tf.linalg.matvec(gtape.batch_jacobian(aa[2], x), phi) \
                     - tf.reduce_sum(tf.multiply(tf.linalg.matvec(dphi, phi),
                                                 aa[0],
                                                 ),
                                     axes=1,
                                     keepdims=True,
                                     )
        else:
            grad_t = None

        return grad_z, grad_params, grad_t

    def call_aug_grads(self, pgrid, x, y):
        """Integrates the p-gradient of the augmented adjoint (aa) state.
        
        Args:
            pgrid: Points at which to evaluate state z, all less than q.
            x: Tensor, input to the system for any given p in pgrid
            y: Tensor, 

        Returns: 3D Tensor `out`, such that out[i, v, j] = z(q; p_i, x=[0, v])_j
        """


        if not all(p < q for p, q in zip(pgrid, pgrid[1:])):
            raise ValueError("The list 'pgrid' should be strictly increasing.")

        # List to store solutions tuples of solutions for each p in pgrid
        aa = []
        z = []

        with tf.GradientTape(
            persistent=True, watch_accessed_variables=False) as gx:
            gx.watch(x)

            # Initial value of InIm solution
            z.append(x)

            # Initial values of the augmented adjoint
            with tf.GradientTape(
                persistent=True, watch_accessed_variables=False) as gz:
                gz.watch(z[0])
                init_rloss = tf.expand_dims(self.cost_fn(z[0], x), -1)
            aa.append((tf.squeeze(
                tf.linalg.matmul(gz.batch_jacobian(init_rloss, z[0]),
                                 gx.batch_jacobian(z[0], x)))
            ,))          
            aa[0] = aa[0] + (tf.math.multiply(
                tf.zeros(shape=(x.shape[0], sum(tf.math.reduce_prod(var.shape)
                                                for var in self.variables))),
                tf.reduce_sum(x)),)  # Mult by x so that JVP wrt x is 0 not None
            if self.t_grad_on:
                aa[0] = aa[0] + (torch.bmm(aa[0][0], self.phi(p_points[0], x)),)
            
            # Integrate over pgrid
            q = pgrid[-1]
            ii = 0  # Count the y data points in reverse p-order
            for p in tqdm_notebook(pgrid[-2::-1]):  # p = pgrid[-2], ..., [0]
                delta = p - q
                z.append(z[-1] + delta * self.inim_dynamics(z[-1], p, x, gx))
                ii += 1
                yi = tf.transpose(y[-ii], perm=[1, 0, 2])[-1]
                with tf.GradientTape(
                    persistent=True, watch_accessed_variables=False) as gz:
                    gz.watch(z[-1])
                    r_loss = tf.expand_dims(self.cost_fn(z[-1], yi), -1)
                r_grad = tf.squeeze(
                    tf.linalg.matmul(gz.batch_jacobian(r_loss, z[-1]),
                                     gx.batch_jacobian(z[-1], x)))
                aa_dynamics = self.aug_adjoint_dynamics(aa[-1], p, x, gx)
                aa.append((
                    aa[-1][0] + r_grad + delta * aa_dynamics[0],
                    aa[-1][1] + delta * aa_dynamics[1],
                    ))
                if self.t_grad_on:
                    aa[-1] = aa[-1] + (aa[-1][2] + delta * aa_dynamics[2],)
                q = p
        z.reverse()
        aa.reverse()
        return z, aa


### Instantiation

In [None]:
# Instantiate model
inimodel = InImNet(input_dim = 2,
                   num_layers=nlayers,
                   width_mult=inflate_width)

## Training Phase

In [None]:
epochs = 10
learning_rate = 0.001

# Performance record
record_training_loss = np.zeros((epochs,))

# Train
for ep in tqdm_notebook(range(epochs)):
    print('Epoch ', ep)

    # Backward pass model
    ztrain, aatrain = inimodel.call_aug_grads(pgrid, xtrain, ytrain)
    print('Check ztrain epoch {}: '.format(ep), ztrain)
    param_grads = tf.math.reduce_sum(aatrain[0][1], axis=0)
    print('param_grads.shape: ', param_grads.shape)
    new_weights = []
    start = 0
    for w in inimodel.get_weights():
        end = start + tf.math.reduce_prod(w.shape)
        wloss = tf.reshape(param_grads[start:end], shape=w.shape)
        new_weights.append(w - learning_rate * wloss)
        start = end
    inimodel.set_weights(new_weights)

    # Record loss
    num = 0
    losses = []
    mse = tf.keras.losses.MeanSquaredError()
    for z, y in zip(ztrain[:-1], ytrain):
        num += 1
        losses.append(mse(z, tf.transpose(y, perm=[1, 0, 2])[-1]))
    record_training_loss[ep] = sum(losses)/num


## Plot results

In [None]:
    # Model evalutated on training data, shape (npoints+1, batch_dim, dim=2)
    ztrain = tf.stack(ztrain)
    yptrain = [tf.transpose(y, perm=[1, 0, 2])[-1] for y in ytrain]
    yptrain.append(xtrain)
    yptrain = tf.stack(yptrain)

In [None]:
# Normalise loss by untrained parameters
training_loss = np.array(record_training_loss)
norm_loss = [loss / training_loss[0] for loss in training_loss]

# Plot loss and accuracy graphs on one plot
fig, ax = plt.subplots()
plt.title('Loss versus epoch')
loss_lines = ax.plot(norm_loss)
ax.set_ylabel('Training loss')
ax.set_xlabel('Number of epochs')

In [None]:
## Plot Learnt output from training data at t=q (for varying p)

print('First input x = [h, v] in batch:')
print('h(q; q, x=[0, v]) = ', ztrain[-1, 0, 0])
print('v(q; q, x=[0, v]) = ', ztrain[-1, 0, 1])
print('\n')

#print('h(q; q, x=[0, v]) = ', yptrain[-1, 0, 0])
#print('v(q; q, x=[0, v]) = ', yptrain[-1, 0, 1])
#print('\n')

# Create figure
plt.style.use('seaborn-whitegrid')
cmap = plt.cm.autumn
rcParams['axes.prop_cycle'] = cycler(
    color=cmap(np.linspace(0, 1, 5))
    )
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[20, 8])

## Axis 1
ipt = 0
ax1.set_xlim(pgrid[0], pgrid[-1])
ax1.set_xlabel('t', fontsize=24)
ax1.set_ylim(0, hmax)
ax1.set_ylabel('h(t; p, x=[0, v])', fontsize=24)
ax1.set_title('Fixed p = {:.2f}, varying initial velocity $v$'
              .format(0), fontsize=24)  # pgrid[ipt]
xlabs = [0, 0.25, 0.5, 0.75, 1]
ylabs = [0, 2, 4, 6, 8, 10]
ax1.set_xticklabels(xlabs, fontsize=24)
ax1.set_yticklabels(ylabs, fontsize=24)

# Plot (blue) lines at p = pgrid[0] for 5 random values of v
cstr = ['b', 'aquamarine', 'xkcd:sky blue', 'r']
knum = 4
msize = 100
for k in range(knum):
    ax1.plot(pgrid[ipt:],
             ytrain[ipt][k][:, 0],
             cstr[k],
             label='$v$ = {:.2f}'.format(float(vtrain[k])),
             )
    ax1.scatter(pgrid[-1],
                ztrain[ipt, k, 0],
                c=cstr[k],
                s=msize,
                marker=8,
                )
ax1.legend(loc='upper left', fontsize=20, framealpha=1)

## Axis 2
ax2.set_xlim(pgrid[0], pgrid[-1])
ax2.set_xlabel('t', fontsize=24)
ax2.set_ylim(0, hmax)
#ax2.set_ylabel('h(t; p, x=[0, v(p)])', fontsize=24)
ax2.set_title('Fixed initial velocity $v$, varying $p$', fontsize=24)
xlabs = ax1.get_xticklabels()
ax2.set_xticklabels(xlabs, fontsize=24)
ax2.set_yticklabels([])

# Plot (red) lines at each p = pgrid for the final values of v=vtrain[vidx[-1]]
for i in range(0, len(pgrid) - 1):
    if i == 0:
        ax2.plot(pgrid[i:], ytrain[i][knum-1][:, 0],
                 label='$v$ = {:.2f}'.format(float(vtrain[k])),
                 )
    else:
        ax2.plot(pgrid[i:], ptrain[i][knum-1][:, 0])
    ax2.scatter(pgrid[-1],
                ztrain[i, knum-1, 0],
                s=msize,
                marker=8,
                )
ax2.legend(loc='upper center', fontsize=20)
plt.show()
fig.savefig(graphics_path + '/psol.pdf')
