In [1]:
import sys
# custom utilies for displaying animation, collecting rollouts and more
sys.path.append("~/Adventures/CustomA2C")

%load_ext autoreload
%autoreload 2

In [13]:
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from parallelEnv import parallelEnv
from helpers import sample_trajectories
from a2c_agent import a2c_agent

agent = a2c_agent(8, 4)
envs = parallelEnv('LunarLander-v2', n=4)
device = torch.device("cpu")
print(np.arange(envs.action_space.n))
#states, actions, rewards = sample_trajectories(envs, agent, tmax=200)
#print(rewards)

started processes:  [2883, 2884, 2885, 2886]
[0 1 2 3]


## Training

In [24]:
print(agent.named_parameters)


## training params ##
ROLLOUT_LENGTH = 300
NUM_EPOCHS = 60
DISCOUNT = 0.99
BETA = 0.01 #entropy term
CRITIC_LOSS_COEFF = 1.0 #critic term
GRADIENT_CLIP = 0.5
LEARNING_RATE = 0.0003
optimizer = torch.optim.RMSprop(agent.parameters(),lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                        lr_lambda=lambda epoch: 0.98 ** epoch,
                                        last_epoch=-1,
                                        verbose=False)


## LOGGING SETUP ##s
import logging
import os
#logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s: %(message)s')
logger = logging.getLogger() 
lognum = 1
#logger.setLevel(logging.INFO)
fh = logging.FileHandler('./log/%s-%s.txt' % ('train', lognum))
fh.setFormatter(logging.Formatter('%(asctime)s - %(name)s: %(message)s'))
fh.setLevel(logging.INFO)
logger.addHandler(fh)


def step():
    envs.reset()
    #take initial random steps
    for i in range(5):
        s, r, done, _ = envs.step(np.random.choice(np.arange(envs.action_space.n), envs.num_envs))
        s, r, done, _ = envs.step([0]* envs.num_envs) #wait a frame
    
    for ep in range(NUM_EPOCHS):
        storage = {"states": [], "values": [], "log_probs": [],
                    "actions":[], "entropy":[], "rewards":[], "dones":[]}
        total_rewards = 0.0
        
        for t in range(ROLLOUT_LENGTH + 1): #one more step for ending!
            # step throguh agent
            s_tensor = torch.from_numpy(s).float().to(device)
            preds = agent.forward(s_tensor)
            actions = preds["actions"].detach().numpy()

            storage["states"].append(s_tensor)
            storage["values"].append(preds["values"])
            storage["log_probs"].append(preds["log_probs"])
            storage["actions"].append(preds["actions"])
            storage["entropy"].append(preds["entropy"])

            # step through env
            s, r, done, _ = envs.step(actions)
            storage["rewards"].append(torch.tensor(r))
            storage["dones"].append(torch.tensor(1-done))
            total_rewards += r.mean()
        
        total_rewards /= ROLLOUT_LENGTH

        storage["advantages"] = [torch.zeros((envs.num_envs, 1)).to(device)]*(ROLLOUT_LENGTH + 1)
        for i in reversed(range(ROLLOUT_LENGTH)):
            td_error = storage["rewards"][i] + DISCOUNT * storage["values"][i+1] - storage["values"][i]
            td_error.detach()
            storage["advantages"][i] = td_error
        
        # collect
        log_probs = torch.stack(storage["log_probs"][:-1])
        advantages = torch.stack(storage["advantages"][:-1])
        entropy = torch.stack(storage["entropy"][:-1])
        
        # calculate losses
        actor_loss = -(log_probs * advantages).mean() #this part is ascent!
        critic_loss = 0.5 * (advantages).pow(2).mean() #this is descent
        entropy_loss = entropy.mean()
        L = actor_loss + CRITIC_LOSS_COEFF * critic_loss - BETA * entropy_loss

        # backprop
        optimizer.zero_grad()
        L.backward()
        nn.utils.clip_grad_norm_(agent.parameters(), GRADIENT_CLIP)
        optimizer.step()
        scheduler.step()
        print("loss:", L)
        logger.info("Episode: {0:d}, loss: {1:f}, rewards: {2:f}, lr: {3:f}".format(ep+1, L, total_rewards, optimizer.param_groups[0]['lr']))
        if (ep+1)%20 ==0 :
            print("Episode: {0:d}, loss: {1:f}, rewards: {2:f}, lr: {3:f}".format(ep+1, L, total_rewards, optimizer.param_groups[0]['lr']))



<bound method Module.named_parameters of a2c_agent(
  (shared_network): SharedNetwork(
    (layers): ModuleList(
      (0): Linear(in_features=8, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (critic_network): CriticNetwork(
    (layers): ModuleList(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): Linear(in_features=32, out_features=1, bias=True)
    )
  )
  (actor_network): ActorNetwork(
    (layers): ModuleList(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): Linear(in_features=32, out_features=4, bias=True)
    )
  )
)>


In [25]:
step()
agent.save('pretraied-model-1')

2023-07-10 19:27:41,010 - root - INFO: Episode: 1, loss: 16.953126, rewards: -1.045049, lr: 0.000294


loss: tensor(16.9531, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:41,531 - root - INFO: Episode: 2, loss: 16.028826, rewards: -0.112916, lr: 0.000288


loss: tensor(16.0288, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:41,938 - root - INFO: Episode: 3, loss: 20.024970, rewards: -0.579440, lr: 0.000282


loss: tensor(20.0250, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:42,340 - root - INFO: Episode: 4, loss: 16.308397, rewards: -0.510872, lr: 0.000277


loss: tensor(16.3084, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:42,755 - root - INFO: Episode: 5, loss: 16.505023, rewards: -0.584254, lr: 0.000271


loss: tensor(16.5050, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:43,197 - root - INFO: Episode: 6, loss: 9.881204, rewards: -0.208693, lr: 0.000266


loss: tensor(9.8812, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:43,768 - root - INFO: Episode: 7, loss: 13.725723, rewards: -0.555545, lr: 0.000260


loss: tensor(13.7257, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:44,433 - root - INFO: Episode: 8, loss: 12.393806, rewards: -0.641312, lr: 0.000255


loss: tensor(12.3938, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:45,146 - root - INFO: Episode: 9, loss: 12.258665, rewards: -0.234679, lr: 0.000250


loss: tensor(12.2587, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:45,950 - root - INFO: Episode: 10, loss: 15.046856, rewards: -0.583693, lr: 0.000245


loss: tensor(15.0469, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:46,791 - root - INFO: Episode: 11, loss: 20.694020, rewards: -0.563769, lr: 0.000240


loss: tensor(20.6940, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:47,620 - root - INFO: Episode: 12, loss: 16.319870, rewards: -0.738675, lr: 0.000235


loss: tensor(16.3199, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:48,587 - root - INFO: Episode: 13, loss: 14.637452, rewards: -0.029526, lr: 0.000231


loss: tensor(14.6375, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:49,616 - root - INFO: Episode: 14, loss: 21.707489, rewards: -0.748607, lr: 0.000226


loss: tensor(21.7075, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:50,700 - root - INFO: Episode: 15, loss: 17.269249, rewards: -0.482460, lr: 0.000222


loss: tensor(17.2692, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:51,852 - root - INFO: Episode: 16, loss: 18.407103, rewards: -0.485883, lr: 0.000217


loss: tensor(18.4071, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:52,758 - root - INFO: Episode: 17, loss: 22.382445, rewards: -0.630581, lr: 0.000213


loss: tensor(22.3824, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:53,226 - root - INFO: Episode: 18, loss: 14.905860, rewards: -0.269377, lr: 0.000209


loss: tensor(14.9059, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:53,918 - root - INFO: Episode: 19, loss: 13.459603, rewards: -0.195161, lr: 0.000204


loss: tensor(13.4596, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:54,891 - root - INFO: Episode: 20, loss: 9.214065, rewards: -0.178008, lr: 0.000200


loss: tensor(9.2141, dtype=torch.float64, grad_fn=<SubBackward0>)
Episode: 20, loss: 9.214065, rewards: -0.178008, lr: 0.000200


2023-07-10 19:27:56,128 - root - INFO: Episode: 21, loss: 12.049984, rewards: -0.473490, lr: 0.000196


loss: tensor(12.0500, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:57,221 - root - INFO: Episode: 22, loss: 12.629518, rewards: -0.656392, lr: 0.000192


loss: tensor(12.6295, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:57,639 - root - INFO: Episode: 23, loss: 17.461531, rewards: -0.427632, lr: 0.000189


loss: tensor(17.4615, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:58,068 - root - INFO: Episode: 24, loss: 10.871183, rewards: -0.381474, lr: 0.000185


loss: tensor(10.8712, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:58,563 - root - INFO: Episode: 25, loss: 11.996970, rewards: -0.538415, lr: 0.000181


loss: tensor(11.9970, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:59,026 - root - INFO: Episode: 26, loss: 21.070440, rewards: -0.528004, lr: 0.000177


loss: tensor(21.0704, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:27:59,516 - root - INFO: Episode: 27, loss: 9.909063, rewards: -0.196016, lr: 0.000174


loss: tensor(9.9091, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:00,162 - root - INFO: Episode: 28, loss: 12.198311, rewards: -0.225826, lr: 0.000170


loss: tensor(12.1983, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:00,885 - root - INFO: Episode: 29, loss: 23.428896, rewards: -0.151679, lr: 0.000167


loss: tensor(23.4289, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:01,682 - root - INFO: Episode: 30, loss: 15.593479, rewards: -0.456903, lr: 0.000164


loss: tensor(15.5935, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:02,519 - root - INFO: Episode: 31, loss: 11.528826, rewards: -0.455597, lr: 0.000160


loss: tensor(11.5288, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:03,483 - root - INFO: Episode: 32, loss: 9.112686, rewards: -0.121457, lr: 0.000157


loss: tensor(9.1127, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:04,513 - root - INFO: Episode: 33, loss: 13.427043, rewards: -0.161151, lr: 0.000154


loss: tensor(13.4270, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:05,632 - root - INFO: Episode: 34, loss: 11.354318, rewards: -0.285097, lr: 0.000151


loss: tensor(11.3543, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:06,769 - root - INFO: Episode: 35, loss: 7.327786, rewards: -0.318638, lr: 0.000148


loss: tensor(7.3278, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:07,936 - root - INFO: Episode: 36, loss: 12.348327, rewards: -0.113983, lr: 0.000145


loss: tensor(12.3483, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:08,955 - root - INFO: Episode: 37, loss: 16.136877, rewards: -0.490043, lr: 0.000142


loss: tensor(16.1369, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:09,380 - root - INFO: Episode: 38, loss: 13.590434, rewards: -0.553626, lr: 0.000139


loss: tensor(13.5904, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:09,787 - root - INFO: Episode: 39, loss: 14.076657, rewards: -0.646150, lr: 0.000136


loss: tensor(14.0767, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:10,216 - root - INFO: Episode: 40, loss: 16.925334, rewards: -0.935697, lr: 0.000134


loss: tensor(16.9253, dtype=torch.float64, grad_fn=<SubBackward0>)
Episode: 40, loss: 16.925334, rewards: -0.935697, lr: 0.000134


2023-07-10 19:28:10,614 - root - INFO: Episode: 41, loss: 17.840535, rewards: -0.541776, lr: 0.000131


loss: tensor(17.8405, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:11,053 - root - INFO: Episode: 42, loss: 12.850689, rewards: -0.493318, lr: 0.000128


loss: tensor(12.8507, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:11,448 - root - INFO: Episode: 43, loss: 15.127500, rewards: -0.504334, lr: 0.000126


loss: tensor(15.1275, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:11,892 - root - INFO: Episode: 44, loss: 18.601999, rewards: -0.499382, lr: 0.000123


loss: tensor(18.6020, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:12,318 - root - INFO: Episode: 45, loss: 20.928064, rewards: -0.562150, lr: 0.000121


loss: tensor(20.9281, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:12,741 - root - INFO: Episode: 46, loss: 11.301533, rewards: -0.287782, lr: 0.000118


loss: tensor(11.3015, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:13,209 - root - INFO: Episode: 47, loss: 12.691306, rewards: -0.246004, lr: 0.000116


loss: tensor(12.6913, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:13,615 - root - INFO: Episode: 48, loss: 23.494307, rewards: -0.757366, lr: 0.000114


loss: tensor(23.4943, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:14,043 - root - INFO: Episode: 49, loss: 11.618883, rewards: -0.464588, lr: 0.000111


loss: tensor(11.6189, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:14,489 - root - INFO: Episode: 50, loss: 16.895260, rewards: 0.176222, lr: 0.000109


loss: tensor(16.8953, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:15,039 - root - INFO: Episode: 51, loss: 13.425269, rewards: -0.026693, lr: 0.000107


loss: tensor(13.4253, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:15,738 - root - INFO: Episode: 52, loss: 10.630493, rewards: -0.152219, lr: 0.000105


loss: tensor(10.6305, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:16,569 - root - INFO: Episode: 53, loss: 12.093577, rewards: -0.336201, lr: 0.000103


loss: tensor(12.0936, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:17,491 - root - INFO: Episode: 54, loss: 10.519511, rewards: -0.088571, lr: 0.000101


loss: tensor(10.5195, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:18,502 - root - INFO: Episode: 55, loss: 10.984829, rewards: -0.387403, lr: 0.000099


loss: tensor(10.9848, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:19,605 - root - INFO: Episode: 56, loss: 11.844933, rewards: -0.315621, lr: 0.000097


loss: tensor(11.8449, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:20,797 - root - INFO: Episode: 57, loss: 10.029444, rewards: -0.357690, lr: 0.000095


loss: tensor(10.0294, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:22,146 - root - INFO: Episode: 58, loss: 10.591726, rewards: -0.145302, lr: 0.000093


loss: tensor(10.5917, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:23,549 - root - INFO: Episode: 59, loss: 14.020878, rewards: -0.334791, lr: 0.000091


loss: tensor(14.0209, dtype=torch.float64, grad_fn=<SubBackward0>)


2023-07-10 19:28:24,747 - root - INFO: Episode: 60, loss: 12.185579, rewards: -0.394047, lr: 0.000089


loss: tensor(12.1856, dtype=torch.float64, grad_fn=<SubBackward0>)
Episode: 60, loss: 12.185579, rewards: -0.394047, lr: 0.000089


In [30]:
import gymnasium
env = gymnasium.make('LunarLander-v2',render_mode="human")
print('State shape: ', env.observation_space.shape)
print('Number of actions: ', env.action_space.n)
env.reset()

agent.load('pretraied-model-1')

#take initial random steps
for i in range(5):
    s, r, done, _, info = env.step(np.random.choice(np.arange(envs.action_space.n)))
    s, r, done, _, info = env.step(0) #wait a frame

agent.eval()
with torch.no_grad():
    for t in range(350):
        # convert observations to torch
        s = np.asarray(s)
        s_tensor = torch.from_numpy(s).float().to(device)
        preds = agent.forward(s_tensor)
        actions = preds["actions"].detach().numpy()

        # step env
        s, r, done,  _, info= env.step(actions)
        if(done): break

agent.train()
env.close()


State shape:  (8,)
Number of actions:  4
