In [1]:
# normal imports
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

In [2]:
# import codipack
import sys
sys.path.append("codiswig")
from codi import RealForward, RealReverse

In [3]:
# define ODE solver in pure python
def mysolve(f, interv, x0, dt=0.000001):
    x = x0
    t = interv[0]
    while(t < interv[1]):
        thistime_dt = min(dt, interv[1]-t)
        x += thistime_dt * f(t,x)
        t += thistime_dt
    return x

Consider the parametric ODE $\dot y_c(t) = c y_c(t)$, $y_c(0) = 1$. The solution is $y_c(t) = \exp (ct)$. Analytically we compute $$ \tfrac\partial{\partial c} y_c(t = 2) = \tfrac\partial{\partial c} \exp(c\cdot 2) = 2 \exp(2c). $$ Let us do it with AD:

In [4]:
# activate taping and declare c as input
RealReverse.setActive()
c = RealReverse(7.0)
c.registerInput()

# define problem
def f(t,y):
    return c*y
y0 = 1.0

# perform the algorithm
sol = mysolve(f, [0,2], y0)

# output value of solution
print("Numerical value for c="+str(c)+", t=2: ", sol)
print("Analytical value would be: ", np.exp(c*2))

Numerical value for c=7.000000, t=2:  1202545.358899
Analytical value would be:  1202604.284165


In [5]:
# declare outputs and stop taping
sol.registerOutput()
RealReverse.setPassive()
sol.setGradient(1.0)
RealReverse.evaluate()

# output derivative of solution
print("AD derivative is: ", c.getGradient())
print("Analytical derivative would be:", 2*np.exp(2*c))

AD derivative is:  2405073.88226675
Analytical derivative would be: 2405208.568330


In [6]:
# Testing with numpy 

In [10]:
# define ODE solver in pure python
def mysolve(f, interv, x0, dt=0.000001):
    x = x0
    t = interv[0]
    while(t < interv[1]):
        thistime_dt = min(dt, interv[1]-t)
        x = x + thistime_dt * f(t,x)
        t += thistime_dt
    return x

# activate taping and declare c as input
RealReverse.setActive()
c = np.array([RealReverse(7.0),RealReverse(3.0)])
for var in c:
    var.registerInput()

# define problem
def f(t,y):
    return c * y
y0 = np.array((1.0,1.0))

# perform the algorithm
sol = mysolve(f, [0,2], y0)

In [11]:
# declare outputs and stop taping
sol[0].registerOutput()
sol[1].registerOutput()
RealReverse.setPassive()
sol[0].setGradient(1.0)
sol[1].setGradient(1.0)
RealReverse.evaluate()

# output derivative of solution
print("AD derivative is: ", c[0].getGradient())
# output derivative of solution
print("AD derivative is: ", c[1].getGradient())

AD derivative is:  2405073.88226675
AD derivative is:  806.847904957385
