In [None]:
import numpy as np
import nidaqmx
from tqdm import tqdm
import torch as T
from mbpo_elite import SAC_Agent

In [None]:
config = {'alpha':3e-4,'beta':3e-4,'tau':5e-3,'model_lr':1e-3,'input_dims':(100,2),'env_id':'nested_loop_'+str(0.5).replace(".", "_"),'n_actions':3,'ac_batch_size':256,\
          'model_batch_size':512, 'n_models':10,'rollout_len':1,'fake_ratio':0.6,'gamma':0.99,'model_weight_decay':1e-4,'layer1_size':256,'layer2_size':256,\
          'use_model':False,'use_bn':True,'n_elites':8,'n_games':6000,'model_train_freq':250,'num_grad_updates':500,'n_ac_steps':20,'random_steps':500,'act_dist':'normal',\
          'noise_scale':0.25,'eps_len':1000,'act_dist':'beta'}

agent = SAC_Agent(alpha=config['alpha'],beta=config['beta'],tau=config['tau'],input_dims=config['input_dims'],env_id=config['env_id'],\
                  n_actions=config['n_actions'],ac_batch_size=config['ac_batch_size'],model_batch_size=config['model_batch_size'],n_models=config['n_models'],\
                  n_elites=config['n_elites'],rollout_len=config['rollout_len'],layer1_size=config['layer1_size'],layer2_size=config['layer2_size'],weight_decay=config['model_weight_decay'],\
                  no_bad_state=True,use_model=config['use_model'],use_bn=config['use_bn'],model_lr=config['model_lr'],act_dist=config['act_dist'])

In [None]:
# Define the input and output channels
n_channels = 3
input_channels = ["PXI1Slot2/ai{_i}".format(_i=i) for i in range(n_channels)]
output_channels = ["PXI1Slot3/ao{_i}".format(_i=i) for i in range(n_channels)]

In [None]:
def check_range(v,scale):
    idx = np.where(np.abs(v) + scale > 9.5)
    if idx[0].size > 0:
        v[idx[0]] = 5
    return v

In [None]:
with nidaqmx.Task(new_task_name='input_task') as input_task, nidaqmx.Task(new_task_name='output_task') as output_task:
    log={}
    # Configure the input task
    input_task.ai_channels.add_ai_voltage_chan(",".join(input_channels))

    # Configure the output task
    output_task.ao_channels.add_ao_voltage_chan(",".join(output_channels))
    state = np.array([])
    state_ = np.array([])
    
    for _ in tqdm(range(1000),ncols=120):
        if len(state)<100:
            input_data = np.array(input_task.read())
            if len(state) == 0:
                state = input_data[0:2].reshape(1,2)
            else:
                state = np.append(state, input_data[0:2].reshape(1,2),axis=0)
            print(state.shape)
        else:
            input_data = np.array(input_task.read())
            state = np.append(state[1:], input_data[0:2].reshape(1,2),axis=0)
            action = agent.choose_action(state.reshape(1,*state.shape))
            output_task.write(action,auto_start=True)
            input_data = np.array(input_task.read())
            state_ = np.append(state[1:], input_data[0:2].reshape(1,2),axis=0)
            reward = input_data[-1]
            agent.remember(state,action,reward,state_,True)
            if agent.real_mem_ready():
                log['actor loss'],log['critic loss'],log['entropy loss'],log['entropy coeff'] = agent.train_ac()

In [None]:
# with nidaqmx.Task() as input_task, nidaqmx.Task() as output_task:
    
#     # Configure the input task
#     input_task.ai_channels.add_ai_voltage_chan(input_channel)

#     # Configure the output task
#     output_task.ao_channels.add_ao_voltage_chan(",".join(output_channels))
#     v = 9.7*np.ones(4)

#     # for _ in tqdm(range(100),ncols=120):
#     while True:
#         delta_v = 1*np.random.rand(4)

#         v = check_range(v)

#         # print('reading..')
#         output_task.write(v,auto_start=True)
#         output_task.wait_until_done()
#         input_signal_v = input_task.read()

#         # print('positive perturbation')
#         pos_v =check_range(v+delta_v)
#         output_task.write(pos_v,auto_start=True)
#         output_task.wait_until_done()
#         input_signal_plus = input_task.read()

#         # print('negative perturbation')
#         neg_v = check_range(v-delta_v)
#         output_task.write(neg_v,auto_start=True)
#         output_task.wait_until_done()
#         input_signal_minus = 10*input_task.read()

#         gradient = (input_signal_plus - input_signal_minus) * delta_v

#         v += learning_rate*gradient
#         print('grad:',gradient, 'v:',v)
