In [1]:
import os 
import numpy as np
from stable_baselines3 import PPO, SAC, DDPG
from environment.utils import generate_w0_with_locus
from aDBS_RL.evaluate_HF_DBS import make_env, evaluate_hf_dbs

# We import env version 
from environment.env_configs.env1 import (
    n_neurons, grid_size, coord_modif,
    eval_envs_list, checking
)
np.random.seed(228)

### 1. Load aDBS agent

In [None]:
DIR = '/home/jovyan/ekuzmina/DBS-Gym/adbs_checkpoints'
model_name = 'env0_R1_ppo_1000000_steps.zip'
model_path = os.path.join(DIR, model_name)

# Load the model
model = PPO.load(model_path)

### 2. Set testing environments

In [7]:
NUMBER_OF_ENV = 2
each_env_run_episodes = 2

# Instantiate environments for evaluation
eval_envs_list_new = []
for n_env in range(NUMBER_OF_ENV):
    eval_d = eval_envs_list[n_env]

    (w0_eval, ncoords, ngrid,
        w0_temp_eval, w_locus_eval, lmask_eval) = generate_w0_with_locus(
                n_neurons, grid_size,
                coord_modif,
                locus_center=eval_d['locus_center'],
                locus_size=eval_d['locus_size'],
                wmuL=17, wsdL=1, 
                show=False, vertical_layer=4)
    
    eval_d['reward_func'] = 'bbpow_action'
    eval_d['neur_coords'] = ncoords
    eval_d['neur_grid'] = ngrid

    eval_d['w0'] = w0_eval
    eval_d['w0_without_locus'] = w0_temp_eval
    eval_d['locus_without_w0'] = w_locus_eval
    eval_d['locus_mask'] = lmask_eval

    eval_d['dbs_action_bounds'] = [-5, 5]    # NOTE: IMPORTANT!!!!
    
    eval_envs_list_new.append(make_env(eval_d))
envs_cpu = eval_envs_list_new

No temporal drift events!
DBS affects 512 neurons, min=0.307 & max=1.0
No temporal drift events!
DBS affects 512 neurons, min=0.307 & max=1.0


In [8]:
# Print all info about testing environments
print('^^'*30)
print('THE ENVIROMENT IS:', checking)
print('NUMBER OF ENVS:', NUMBER_OF_ENV)
print('Each env with run for: ', each_env_run_episodes, ' times')

# print('MODEL PREDICT: ', model.predict(np.ones(5)))
print('SCALE TO INSIDE env', envs_cpu[0].params_dict['dbs_action_bounds'])
print('Episode len:', envs_cpu[0].params_dict['total_episode_len'])

print('Temporal drift:', envs_cpu[0].params_dict['temporal_drift'])
print('Spatial features:', envs_cpu[0].params_dict['spatial_feature'])
print('^^'*30)
r = envs_cpu[0].params_dict['dbs_action_bounds'][1]
print('Energy is = ',
    f'e * {r} / {each_env_run_episodes}')

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
THE ENVIROMENT IS: env1
NUMBER OF ENVS: 2
Each env with run for:  2  times
SCALE TO INSIDE env [-5, 5]
Episode len: 1000
Temporal drift: False
Spatial features: False
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Energy is =  e * 5 / 2


### 3. Run evaluation 

In [9]:
bbpow_mean, bbpow_sd, e_mean, e_sd = evaluate_hf_dbs(model, envs_cpu,
                
                # NOTE: IMPORTANT: EACH environment run this num of episodes
                n_eval_episodes=each_env_run_episodes, 

                render=False, deterministic=True,
                warn=False, callbacks_=None)
true_energy = e_mean * envs_cpu[0].params_dict['dbs_action_bounds'][1] / each_env_run_episodes
true_energy_sd = e_sd * envs_cpu[0].params_dict['dbs_action_bounds'][1] / each_env_run_episodes
print('Energy is = ', true_energy, ' sd= ', true_energy_sd)

# Save all to file
with open('data/eval_results.json', 'a') as f:
    res = {'env':checking,
            'agent': 'ppo_R1', 
            'bbpow mean':bbpow_mean,
            'bbpow sd':bbpow_sd, 
            'energy mean':true_energy,
            'energy sd':true_energy_sd,
            }
    f.write(str(res) + '\n')
    f.close()

DBS affects 512 neurons, min=0.307 & max=1.0
DBS affects 512 neurons, min=0.307 & max=1.0
DBS affects 512 neurons, min=0.307 & max=1.0
DBS affects 512 neurons, min=0.307 & max=1.0
DBS affects 512 neurons, min=0.307 & max=1.0
DBS affects 512 neurons, min=0.307 & max=1.0
Reward mean=-16882.938537359238, std=2161.241585925245
BBpow mean=0.008164462616245996, std=0.001090899773646822
Energy mean=1.7402392625808716, std=0.08317410945892334
Energy is =  4.3505983  sd=  0.20793527
