# [  ]

In [None]:
import fatbot as fb
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import os

# Setup

### RL Algorithm

In [None]:
sbalgo =                fb.PPO         #<----- model args depend on this
model_name =            'ppo_model'    #<----- stores data in this folder
os.makedirs(model_name, exist_ok=True)

### Reward Scheme

In [None]:
reward_scheme = dict( 
                dis_target_point=   1.0, 
                #dis_neighbour = 1.0,
                dis_target_radius=  1.0, 
                all_unsafe=         2.0, 
                all_neighbour=      1.0, 
                occluded_neighbour= 2.0, 
                )
delta_reward = True

### Hyperparams

In [None]:
gamma =                 0.99
horizon =               300
total_timesteps =       100_000 #<--- for training
model_version =         'baseZ'
model_path =            os.path.join(model_name, model_version)

# learning rate scheduling
start_lr, end_lr = 0.00050, 0.000040
lr_mapper=fb.REMAP((-0.2,1), (start_lr, end_lr)) # set learn rate schedluer
def lr_schedule(progress):
  #progress_precent = 100*(1-progress)
  #lr = lr_mapper.in2map(1-progress)
  #if int(progress_precent) % 10 == 0:
  #  print(f'Progress: {progress} ~~> {progress_precent:.3f} %,  {lr = }')  
  return lr_mapper.in2map(1-progress) #lr

# Training

### prepare

In [None]:
# initial state distribution - uniformly sample from all listed states
initial_state_keys =  db.all_states() # [db.isd[db.isd_keys[0]]] #[v for k,v in db.isd.items()] 
permute_states =        True
print(f'Total Initial States: {len(initial_state_keys)}')

# build training env
training_env = db.envF(False, horizon, reward_scheme, delta_reward, 
                        permute_states, *initial_state_keys)

#<---- optinally check
fb.check_env(training_env) 

In [None]:
training_env.reset()

### perform training

In [None]:
# start training
training_start_time = fb.common.now()
print(f'Training @ [{model_path}]')
model = sbalgo(policy=      'MlpPolicy', 
        env=                training_env, 
        learning_rate =     lr_schedule,
        n_steps=            1024*10,
        batch_size =        1024,
        n_epochs =          20,
        gamma =             gamma,
        gae_lambda=         0.95,
        clip_range=         0.20, 
        clip_range_vf=      None, 
        normalize_advantage=True, 
        ent_coef=           0.0, 
        vf_coef=            0.5, 
        max_grad_norm=      0.5, 
        use_sde=            False, 
        sde_sample_freq=    -1, 
        target_kl=          None, 
        tensorboard_log=    None, 
        create_eval_env=    False, 
        verbose=            0, 
        seed=               None, 
        device=             'cpu', 
        _init_setup_model=  True,
        policy_kwargs=dict(
                        activation_fn=  nn.LeakyReLU, 
                        net_arch=[dict(
                            pi=[512, 512, 300], 
                            vf=[512, 512, 300])])) #256, 256, 256, 128, 128

model.learn(total_timesteps=total_timesteps,log_interval=int(0.1*total_timesteps))
model.save(model_path)
training_end_time = fb.common.now()
print(f'Finished!, Time-Elapsed:[{training_end_time-training_start_time}]')

# Testing

In [None]:
model = sbalgo.load(model_path)
model, model_path

### prepare

In [None]:
# initial state distribution - uniformly sample from all listed states
initial_state_keys =    db.all_states() # [db.isd[db.isd_keys[0]]]  #[v for k,v in db.isd.items()] 
permute_states =        False
nos_initial_states=len(initial_state_keys)
print(f'Total Initial States: {nos_initial_states}')

horizon =               500
# build testing_env
testing_env = db.envF(True, horizon, reward_scheme, delta_reward, 
                        permute_states, *initial_state_keys)

# save initial state
#testing_env.reset()
#testing_env.save_state(f'{model_path}_initial.npy')
#fig=testing_env.render() # do a default render
#fig.savefig(f'{model_path}_initial.png')
#del fig

#<---- optinally check
#fb.check_env(training_env) 

### perform testing

In [None]:
print(f'Testing @ [{model_path}]')
average_return, total_steps = fb.TEST(
    env=            testing_env, 
    model=          sbalgo.load(model_path), #<---- use None for random
    episodes=       11,
    steps=          0,
    deterministic=  True,
    render_as=      None, # use None for no plots, use '' (empty string) to plot inline
    save_dpi=       'figure',
    make_video=     False,
    video_fps=      2,
    render_kwargs=dict(local_sensors=True, reward_signal=True),
    starting_state=lambda ep: ep, # either none or lambda episode: initial_state_index (int)
    plot_results=0,
    start_n=0, # for naming render pngs
    save_state_info=model_path, # call plt.show() if true
    save_both_states=True,
)
print(f'{average_return=}, {total_steps=}')


