In [18]:
from sympy              import symbols,sin,pi,Matrix,lambdify,flatten,eye
from scipy.integrate    import odeint,ode
import numpy as np

x,t                 =   symbols("x,t")
u0                  =   symbols("u0")
k_min,k_max         =   symbols("k_min,k_max",positive=True)
omega_k             =   symbols("omega_k")

k   =   (k_max+k_min)/2+sin(omega_k*t)*(k_max-k_min)/2
A   =   Matrix([[-k*(1+x)]])
u   =   Matrix([[1+0.1*sin(t)]])

X=Matrix([[x]])
rhs =   A*X+u

pars={
    k_min:0.1,
    k_max:0.2,
    omega_k:2*pi}
time_symbol=t

ts=0;te=1;ns=11
times=np.linspace(ts,te,ns)
cacheTimes=np.linspace(ts,te,1000*(ns-1)+1)

new_start_values=np.array([1]) 
rhs_par=rhs.subs(pars)
state_variables=X


tup     = tuple(state_variables) + (time_symbol,)
F_num   = lambdify(tup, rhs_par, modules="numpy")
innerA  =lambdify(tup,A.subs(pars),modules="numpy")

def A_num(X,t):
    t_max=times[-1]
    if t>t_max:
        return innerA(X,t_max)
    return innerA(X,t)

u_num   = lambdify((time_symbol,),u.subs(pars),modules="numpy")
#print(u_num(3))
def num_rhs(X,t):
    Xt = tuple(X) + (t,)
    Aval=A_num(*Xt)
    #Fval = F_num(*Xt)
    Fval=Aval*Matrix(1,1,X)+u_num(t)
    #print(Fval)
    return flatten(Fval.tolist())
def num_rhs_ode(t,X):
    return num_rhs(X,t)

soln = odeint(num_rhs, new_start_values, times, mxstep=500)
def find_nearest_index(array,value):
    idx = (np.abs(array-value)).argmin()
    return idx

class OdeCache:
    def __init__(self,num_rhs,osv,times):
        self.times=times
        self.num_rhs=num_rhs
        solver=ode(num_rhs_ode)
        solver.set_integrator('lsoda')
        solver.set_initial_value(osv,times[0])
        ti=0
        ts=[times[0]]
        ys=[osv]
        while solver.successful() and solver.t<times[-1]:
            ti +=1
            solver.integrate(times[ti])
            ts.append(solver.t)
            ys.append(solver.y)
        sol_t= np.vstack(ts)
        soln = np.vstack(ys)
        self.values=soln
        self.sol_t=sol_t
        #self.values =odeint(num_rhs, new_start_values, times)

    def new_val(self,t):
        idx=find_nearest_index(self.times,t)
        new_start=self.values[idx]
        values =odeint(num_rhs, new_start, [self.times[idx],t] )
        return values[-1,:]

sc=OdeCache(num_rhs,new_start_values,cacheTimes)
print(sc.values)
print(sc.sol_t)


[[ 1.        ]
 [ 1.00007   ]
 [ 1.00013998]
 ..., 
 [ 1.60964896]
 [ 1.60969451]
 [ 1.60974006]]
[[  0.00000000e+00]
 [  1.00000000e-04]
 [  2.00000000e-04]
 ..., 
 [  9.99800000e-01]
 [  9.99900000e-01]
 [  1.00000000e+00]]


In [22]:
def A_of_t(tv):
    #Xs,info = odeint(num_rhs, new_start_values, [0,tv/2,tv/2,tv*3/4,tv], full_output=1 )
   # Xt = tuple(Xs[-1,:]) + (tv,)
    Xt = tuple(sc.new_val(tv)) + (tv,)
    val = A_num(*Xt)
    return val
    
def lin_num_rhs(X,t):
    # use the linearized version
    #Aval=-.1*eye(1)
    Aval=A_of_t(t)
    Xt = tuple(X) + (t,)
    #Aval=A_num(*Xt)
    Fval=Aval*Matrix(1,1,X)+u_num(t)
    return flatten(Fval.tolist())

lin_soln = odeint(lin_num_rhs, new_start_values, times,  mxstep=500)
    
print(np.max(lin_soln-soln))
lin_soln-soln      

1.24015277603e-07


array([[  0.00000000e+00],
       [ -2.34297826e-09],
       [  2.19839125e-09],
       [  2.59088042e-08],
       [  7.29958969e-08],
       [  1.06278483e-07],
       [  1.24015278e-07],
       [  1.22904930e-07],
       [  1.00746805e-07],
       [  1.07134037e-08],
       [ -2.00076788e-07]])

In [12]:
solver=ode


In [13]:
?ode.set_f_params

In [14]:
?ode.set_solout