# PPO
---

### 1. Import the Necessary Packages

In [1]:
import torch
%matplotlib inline

from model.ppo_parallel import PPO
from model.network import ActorCritic
from model.environments import LunarContinuous
from logger import WandbSummaryWritter 

### 2. Instantiate the Model

Setup the hyperparameters in the code cell below.

In [2]:
# hyperparameters = {
#     'timesteps_per_batch': 1024 ,                # Number of timesteps to run per batch
#     'max_timesteps_per_episode': 1200,           # Max number of timesteps per episode
#     'n_updates_per_iteration': 5,                # Number of times to update actor/critic per iteration
#     'lr': 2.5e-4 ,                                # Learning rate of actor optimizer
#     'gamma': 0.95,                               # Discount factor to be applied when calculating Rewards-To-Go
#     'clip': 0.2                                 # Recommended 0.2, helps define the threshold to clip the ratio during SGA
# }
hyperparameters = {'gamma': 0.999, 'lr_gamma': 0.995,
                   'max_timesteps_per_episode': 1600,
							'clip_range': 0.2, 'lr': 0.005 }

misc_hyperparameters = {
    'num_workers': 2  ,
    'seed': None 
}

Initialise wandb session in the code cell below.

In [3]:
logger = WandbSummaryWritter(project='lunar', config =hyperparameters)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpmsaraiva2712[0m ([33mpmsaraiva2712-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


Initialise the model fo the desired timestamps. Alternatively can specify a checkpoint to continue training.

In [4]:
checkpoint = 'ppo_parallel_checkpoints/charmed-armadillo-108/ppo_policy_960.pth'
LOAD_MODEL = False

ppo = PPO(logger, **hyperparameters, **misc_hyperparameters)

if LOAD_MODEL:
    
    env = LunarContinuous().make_environment()
    model = ActorCritic(env.observation_space.shape[0], env.action_space.shape[0])
    model.load_state_dict(torch.load(checkpoint))



### 3. Train the Model

Train model for specified amount of timestamps.

In [5]:
total_timesteps_to_train =  1_000_000

ppo.train(total_timesteps_to_train)


-------------------- Iteration #1 --------------------
Average Episodic Length: 110.8
Average Episodic Return: -266.33
Average Loss: -1e-04
Timesteps So Far: 4875
Iteration took: 6.91 secs
Current learning rate: 0.004876243765609375
------------------------------------------------------


-------------------- Iteration #2 --------------------
Average Episodic Length: 109.55
Average Episodic Return: -295.6
Average Loss: -0.00114
Timesteps So Far: 9695
Iteration took: 6.37 secs
Current learning rate: 0.004755550652328859
------------------------------------------------------


-------------------- Iteration #3 --------------------
Average Episodic Length: 113.63
Average Episodic Return: -220.16
Average Loss: -0.0016
Timesteps So Far: 14581
Iteration took: 6.41 secs
Current learning rate: 0.00463784484409164
------------------------------------------------------


-------------------- Iteration #4 --------------------
Average Episodic Length: 103.19
Average Episodic Return: -162.9
Averag

KeyboardInterrupt: 

### 4. Evaluate the Model

Run multiple episodes from pretrained model.

In [None]:
ppo.test()

KeyboardInterrupt: 