In [None]:
import sys
# custom utilies for displaying animation, collecting rollouts and more
sys.path.append("~/Adventures/CustomA2C")
sys.path.append("~/Adventures/")

%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from parallelEnv import parallelEnv
from helpers import sample_trajectories
from a2c_agent import a2c_agent

agent = a2c_agent(8, 4)
envs = parallelEnv('LunarLander-v2', n=4, seed=1234)
device = torch.device("cpu")
print(np.arange(envs.action_space.n))
states, actions, rewards = sample_trajectories(envs, agent, tmax=200)
print(device)

## Training

In [None]:
print(agent.named_parameters)


## training params ##
ROLLOUT_LENGTH = 1000
NUM_EPOCHS = 60
DISCOUNT = 0.99
BETA = 0.005 #entropy term
CRITIC_LOSS_COEFF = 1.0 #critic term
GRADIENT_CLIP = 0.5
LEARNING_RATE = 0.001
GAE_TAU = 0.95
optimizer = torch.optim.RMSprop(agent.parameters(),lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                        lr_lambda=lambda epoch: 0.99 ** epoch,
                                        last_epoch=-1,
                                        verbose=False)


## LOGGING SETUP ##s
import logging
import os
#logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s: %(message)s')
logger = logging.getLogger() 
lognum = 1
#logger.setLevel(logging.INFO)
fh = logging.FileHandler('./log/%s-%s.txt' % ('train', lognum))
fh.setFormatter(logging.Formatter('%(asctime)s - %(name)s: %(message)s'))
fh.setLevel(logging.INFO)
logger.addHandler(fh)


def learn():
    envs.reset()
    #take initial random steps
    for i in range(5):
        s, r, done, _ = envs.step(np.random.choice(np.arange(envs.action_space.n), envs.num_envs))
        s, r, done, _ = envs.step([0]* envs.num_envs) #wait a frame
    
    total_timesteps = 0
    for ep in range(NUM_EPOCHS):
        storage = {"states": [], "values": [], "log_probs": [],
                    "actions":[], "entropy":[], "rewards":[], "dones_mask":[]}
        total_rewards = 0.0
        
        for t in range(ROLLOUT_LENGTH + 1): #one more step for ending!
            # step throguh agent
            s_tensor = torch.from_numpy(s).float().to(device)
            preds = agent.forward(s_tensor)
            actions = preds["actions"].cpu().detach().numpy()

            storage["states"].append(s_tensor)
            storage["values"].append(preds["values"])
            storage["log_probs"].append(preds["log_probs"])
            storage["actions"].append(preds["actions"])
            storage["entropy"].append(preds["entropy"])

            # step through env
            s, r, done, _ = envs.step(actions)
            storage["rewards"].append(torch.tensor(r).to(device))
            storage["dones_mask"].append(torch.tensor(1-done).to(device))
            total_rewards += r.mean()
            total_timesteps += envs.num_envs

        total_rewards /= envs.num_envs

        storage["advantages"] = [0]*ROLLOUT_LENGTH
        advantage = torch.zeros((envs.num_envs, 1)).to(device)
        ret = preds["values"].detach()

        for i in reversed(range(ROLLOUT_LENGTH)):
            td_error = storage["rewards"][i] + storage["dones_mask"][i]* DISCOUNT * storage["values"][i+1] - storage["values"][i]
            td_error.detach()
            advantage = advantage * GAE_TAU * DISCOUNT * storage["dones_mask"][i] + td_error

            storage["advantages"][i] = td_error
        
        # collect
        log_probs = torch.stack(storage["log_probs"][:-1])
        advantages = torch.stack(storage["advantages"])
        entropy = torch.stack(storage["entropy"][:-1])
        
        # calculate losses
        actor_loss = -(log_probs * advantages).mean() #this part is ascent!
        critic_loss = 0.5 * (advantages).pow(2).mean() #this is descent
        entropy_loss = entropy.mean()
        L = actor_loss + CRITIC_LOSS_COEFF * critic_loss - BETA * entropy_loss

        # backprop
        optimizer.zero_grad()
        L.backward()
        nn.utils.clip_grad_norm_(agent.parameters(), GRADIENT_CLIP)
        optimizer.step()
        scheduler.step()
        print("loss:", L)
        logger.info("Episode: {0:d}, loss: {1:f}, rewards: {2:f}, lr: {3:f}".format(ep+1, L, total_rewards, optimizer.param_groups[0]['lr']))
        if (ep+1)%20 ==0 :
            print("Episode: {0:d}, loss: {1:f}, rewards: {2:f}, lr: {3:f}".format(ep+1, L, total_rewards, optimizer.param_groups[0]['lr']))



In [None]:
learn()
agent.save('pretraied-model-1')

In [None]:
import gymnasium
env = gymnasium.make('LunarLander-v2',render_mode="human")
print('State shape: ', env.observation_space.shape)
print('Number of actions: ', env.action_space.n)
env.reset()

agent.load('pretraied-model-1')

#take initial random steps
for i in range(5):
    s, r, done, _, info = env.step(np.random.choice(np.arange(envs.action_space.n)))
    s, r, done, _, info = env.step(0) #wait a frame

agent.eval()
with torch.no_grad():
    for t in range(350):
        # convert observations to torch
        s = np.asarray(s)
        s_tensor = torch.from_numpy(s).float().to(device)
        preds = agent.forward(s_tensor)
        actions = preds["actions"].cpu().detach().numpy()

        # step env
        s, r, done,  _, info= env.step(actions)
        if(done): break

agent.train()
env.close()


## SAC


In [59]:
from SAC_agent import SAC
import torch
import numpy as np

import gymnasium as gym
from gymnasium.wrappers import TimeLimit
import sys
import metaworld
from parallelEnv import parallelEnv, ParallelCollector
import random
from mtEnv import MtEnv, GymWrapper, ClsOneHotWrapper

SEED=11
random.seed(SEED)
ENV_NAME = 'button-press-v2'

def create_env_fn(env_name):
    return TimeLimit(MtEnv(ENV_NAME), max_episode_steps=500)

def create_vec_env():
    return parallelEnv(env_fn = lambda : create_env_fn(ENV_NAME), n = 4, seed=SEED)



#collector = ParallelCollector(train_mt10, env_names, 2)
#obs, reward, terminated, truncated, info = collector.step(np.array([0.0, 0.0, 0.0, 0.0]))
#print(obs, reward, terminated, truncated, info)

In [40]:
train_env = TimeLimit(MtEnv(ENV_NAME), max_episode_steps=4)
obs, _ = train_env.reset()


In [60]:
SAC_args = {
    "obs_dim" : train_env.observation_space.shape[0],
    "action_dim" : train_env.action_space.shape[0],
    "hidden_sizes": [8, 8],
    "task_fn": lambda: TimeLimit(MtEnv(ENV_NAME), max_episode_steps=4),
    "max_ep_len": 8,
    "batch_size": 4,
    "learning_starts": 0,
    "update_every": 1,
    "steps_per_epoch":2,
    "epochs":3,
    
}
agent = SAC(**SAC_args)

In [61]:
frames = agent.collect_rollout()
print("collected %d steps" % frames)
data = agent.replay.sample_batch(4)
print(data)


collected 8 steps
{'obs': tensor([[0.0105, 0.3996, 0.1994, 0.9754, 0.0133, 0.6696, 0.1150, 1.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0097, 0.3997, 0.1983, 0.9885, 0.0133, 0.6696, 0.1150, 1.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0133, 0.7633, 0.1150],
        [0.0082, 0.4009, 0.2039, 0.9745, 0.0133, 0.6697, 0.1150, 1.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0097, 0.4002, 0.2017, 0.9699, 0.0133, 0.6696, 0.1150, 1.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0133, 0.7633, 0.1150],
        [0.0103, 0.3997, 0.2001, 0.9694, 0.0133, 0.6696, 0.1150, 1.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0105, 0.3996, 0.1994, 0.9754, 0.0133, 0.6696, 0.1150, 1.0000, 0.0000,
         0.0000

In [None]:
train_mt10 = metaworld.MT10()
env_names = list(train_mt10.train_classes.keys())
print(env_names)
envs=[]
mt_cls_cnt = len(env_names)
for i, env_name in enumerate(env_names[:4]):
    env_cls = train_mt10.train_classes[env_name]
    task_list = [task for task in train_mt10.train_tasks if task.env_name == env_name]

    env = env_cls()      
    task = random.choice(task_list)
    env.set_task(task)
    env = GymWrapper(env)
    env = TimeLimit(env, 500)
    env = ClsOneHotWrapper(env, i, 4)
    envs.append(env)


In [None]:

for i in range(4): 
    envs[i].reset()
    print(envs[i].step(np.array([1.0]*4)))