In [22]:
import os
import numpy as np

if os.getcwd().split(os.sep)[-1] == "examples":
    os.chdir('..')

# This will reload all imports as soon as the code changes
%load_ext autoreload
%autoreload 2

color_x = 'red'
color_y = 'blue'

zero_control = np.vstack( [np.zeros((1,201)), np.zeros((1,201))] )[np.newaxis,:,:]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
import numpy as np
import matplotlib.pyplot as plt
from neurolib.models.fhn import FHNModel
from neurolib.optimal_control import oc_fhn


target = np.zeros(( 1,2,2001 ))
zero_input = np.zeros((1,200))
zero_control = np.vstack( [zero_input, zero_input] )[np.newaxis,:,:]

model = FHNModel()
model.params["duration"] = 200.
model.run()

model.params["x_ext"] = zero_input
model.params["y_ext"] = zero_input

w2 = 1e-2
wp = 1.
model_controlled = oc_fhn.OcFhn(model, target, w_p=wp, w_2=w2, print_array=np.arange(0,1001,100), precision_cost_interval=[1000,None])

# We run 100 iterations of the optimal control gradient descent algorithm
model_controlled.optimize(1000)

state = model_controlled.get_xs()
control = model_controlled.control

plot_singlenode(state, target, control, zero_control, model_controlled.cost_history)



AssertionError: 

In [20]:
# define plot function for later convenience
def plot_singlenode(model, duration, dt, state, target, control, input, weight_array=[], M=1):
    duration = model.params.duration
    dt = model.params.dt
    fig, ax = plt.subplots( 3,1, figsize=(8,6), constrained_layout=True)

    # Plot the target (dashed line) and unperturbed activity
    t_array = np.arange(0, duration+dt, dt)

    ax[0].plot(t_array, state[0,0,:], label="x", color=color_x, linewidth=1)
    ax[0].plot(t_array, state[0,1,:], label="y", color=color_y, linewidth=1)
    ax[0].plot(t_array, target[0,0,:], linestyle='dashed', label="Target x", color=color_x)
    ax[0].plot(t_array, target[0,1,:], linestyle='dashed', label="Target y", color=color_y)
    ax[0].legend()
    ax[0].set_title("Activity without stimulation and target activity")

    for m in range(M-1):
        model.run()
        state = np.concatenate((np.concatenate( (model.params["xs_init"], model.params["ys_init"]), axis=1)[:,:, np.newaxis], np.stack( (model.x, model.y), axis=1)), axis=2)
        ax[0].plot(t_array, state[0,0,:], label="x", color=color_x, linewidth=1)
        ax[0].plot(t_array, state[0,1,:], label="y", color=color_y, linewidth=1)

    # Plot the target control signal (dashed line) and "initial" zero control signal
    ax[1].plot(t_array, control[0,0,:], label="stimulation x", color=color_x)
    ax[1].plot(t_array, control[0,1,:], label="stimulation y", color=color_y)
    ax[1].plot(t_array, input[0,0,:], linestyle='dashed', label="input x", color=color_x)
    ax[1].plot(t_array, input[0,1,:], linestyle='dashed', label="input y", color=color_y)
    ax[1].legend()
    ax[1].set_title("Active stimulation and input stimulation")

    ax[2].plot(weight_array)
    ax[2].set_title("Cost throughout optimization.")

    plt.show()