In [None]:
import pandas as pd
import numpy as np
import jax.numpy as jnp
import numpyro
import matplotlib.pyplot as plt


from PacTimeOrig.controllers import simulator as sim
from PacTimeOrig.controllers import JaxMod as jm
from PacTimeOrig.controllers import models as mods

from PacTimeOrig.controllers import utils as ut
from PacTimeOrig.data import scripts



sess=0
cfgparams={'area':'dACC','session':sess+1,'subj':'H','typeswitch':1,'rbfs':[20,30,40,50,60,70,80],'restarts':6}

Xdsgn, kinematics, sessvars, psth = scripts.monkey_run(cfgparams)

trial=4
tdat=ut.trial_grab_kine(Xdsgn,trial)

In [None]:
mse=[]
corr=[]
L1_out=[]
L2_out=[]
for i in range(5):
    
    L1=np.array((1.6,0.4,0.1))
    L2=np.array((2,0.5,0.8))
    outputs=sim.controller_sim_pvi_slack(6,tdat['player_pos'],tdat['player_vel'],tdat['pry1_pos'],tdat['pry1_vel'],tdat['pry2_pos'],tdat['pry2_vel'],L1,L2,None,None,gpscaler=5,alpha=10)
    
    
    #Make time
    tmp=ut.make_timeline(outputs)
    
    #
    num_rbfs = 30
    
    
    A,B=ut.define_system_parameters(ctrltype='pvi')
    
    inputs=ut.prepare_inputs(A, B, outputs['x'], outputs['uout'], tdat['pry1_pos'], tdat['pry2_pos'], tmp, num_rbfs, outputs['x'][:,2:], tdat['pry1_vel'], tdat['pry2_vel'],pry_1_accel=tdat['pry1_accel'],pry_2_accel=tdat['pry2_accel'],ctrltype='pvi')
    
    loss_function = jm.create_loss_function_pvi_slack(ut.generate_rbf_basis,inputs['num_rbfs'],ut.generate_smoothing_penalty,lambda_reg=0)
    
    
    grad_loss = ut.compute_loss_gradient(loss_function)
    
    params=jm.initialize_parameters(inputs, ctrltype='pvi', slack_model=True)
    
    
    # Set up the optimizer
    optimizer, opt_state = jm.setup_optimizer(params, learning_rate=1e-2,slack_model=True,optimizer='adam')
    
    # Number of optimization steps
    num_steps = 20000
    
    # Optimization loop
    for step in range(num_steps):
        params, opt_state, loss = jm.optimization_step(params, opt_state, optimizer, loss_function, inputs,ctrltype='pvi',slack_model=True)
    
        if step % 100 == 0:
            print(f"Step {step}, Loss: {loss}")
            
            
            
    #transform paramteres to correct domain
    L1_fit=jnp.log(1+jnp.exp(params[2]))
    L2_fit=jnp.log(1+jnp.exp(params[3]))
    alpha=params[4]
    
    weights = params[0]
    widths = params[1]
    wtout=ut.generate_sim_switch(inputs, widths, weights,slack_model=True)
    
    shift=np.vstack((wtout[0],wtout[1],wtout[2]))
    pred_out=sim.controller_sim_pvi_post_slack(shift,outputs['x'][:,:2],outputs['x'][:,2:],tdat['pry1_pos'],tdat['pry1_vel'],tdat['pry2_pos'],tdat['pry2_vel'],L1_fit,L2_fit,alpha,A=None,B=None)
    
    plt.plot(tmp,outputs['x'][:,:2],linewidth=4)
    plt.plot(tmp,pred_out['x'][:,:2],'--',linewidth=4)
    plt.show()
    
    plt.plot(tmp,outputs['shift'].transpose(),linewidth=4)
    plt.plot(tmp,shift.transpose(),'--',linewidth=4)
    plt.show()
    
    
    corr.append(np.corrcoef(shift.flatten(),outputs['shift'].flatten())[0,1])
    mse.append(np.mean(np.power((shift.flatten()-outputs['shift'].flatten()),2)))
    L1_out.append(L1_opt)
    L2_out.append(L2_opt)

In [None]:


L1=np.array((1.6, 0.4, 2))
L2=np.array((2, 0.5, 0.1))
outputs=sim.controller_sim_pif_slack(6,tdat['player_pos'],tdat['player_vel'],tdat['pry1_pos'],tdat['pry1_vel'],tdat['pry1_accel'],tdat['pry2_pos'],tdat['pry2_vel'],tdat['pry2_accel'],L1,L2,None,None,gpscaler=5,alpha=10)


#Make time
tmp=ut.make_timeline(outputs)

#
num_rbfs = 30


A,B=ut.define_system_parameters(ctrltype='pif')

inputs=ut.prepare_inputs(A, B, outputs['x'], outputs['uout'], tdat['pry1_pos'], tdat['pry2_pos'], tmp, num_rbfs, outputs['x'][:,2:], tdat['pry1_vel'], tdat['pry2_vel'],tdat['pry1_accel'],tdat['pry2_accel'],ctrltype='pif',usingJax=True,slack_model=True)



loss_function = mods.create_loss_function_pif_slack(mods.generate_rbf_basis,inputs['num_rbfs'])
grad_loss = mods.compute_loss_gradient(loss_function)
(L1_opt, L2_opt,alpha), best_params, best_loss = mods.outer_optimization_slack(inputs, mods.inner_optimization_slack, loss_function, grad_loss,maxiter=5,tolerance=1e-3,opttype='global')



L1_fit=L1_opt
L2_fit=L2_opt
alpha=alpha

weights = best_params[0]
widths = best_params[1]
wtout=ut.generate_sim_switch(inputs, widths, weights,slack_model=True)

shift=np.vstack((wtout[0],wtout[1],wtout[2]))

pred_out=sim.controller_sim_pv_post_slack(shift,player_pos,player_vel,pry1_pos,pry1_vel,pry2_pos,pry2_vel,L1_fit,L2_fit,alpha,A=None,B=None)

plt.plot(tmp,outputs['x'][:,:2],linewidth=4)
plt.plot(tmp,pred_out['x'][:,:2],'--',linewidth=4)
plt.show()

plt.plot(tmp,outputs['shift'].transpose(),linewidth=4)
plt.plot(tmp,shift.transpose(),'--',linewidth=4)
plt.show()


np.corrcoef(shift.flatten(),outputs['shift'].flatten())[0,1]
np.mean(np.power((shift.flatten()-outputs['shift'].flatten()),2))

Test loop for pure adam # TODO: add loop over trials and variants of gains and gpscaler and alpha