In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir runs

In [None]:
import gym
import pybulletgym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

def Identity(x):
    return x

def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

class Normalizer():
    def __init__(self, nb_inputs):
        self.n = torch.zeros(nb_inputs)
        self.mean = torch.zeros(nb_inputs)
        self.mean_diff = torch.zeros(nb_inputs)
        self.var = torch.zeros(nb_inputs)

    def observe(self, x):
        self.n += 1.0
        last_mean = self.mean.clone().detach()
        self.mean += (x - self.mean) / self.n
        self.mean_diff += (x - last_mean) * (x - self.mean)
        self.var = (self.mean_diff / self.n).clamp(min = 1e-2)

    def normalize(self, inputs):
        obs_mean = self.mean
        obs_std = torch.sqrt(self.var)
        return (inputs - obs_mean) / obs_std
    
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, hidden_depth, output_size=None, use_output_layer=True):
        super().__init__()

        self.use_output_layer = use_output_layer

        self.hidden_layers = nn.ModuleList()
        in_size = input_size
        for _ in range(hidden_depth):
            fc = nn.Linear(in_size, hidden_size)
            in_size = hidden_size
            self.hidden_layers.append(fc)

        if use_output_layer:
            self.output_layer = nn.Linear(hidden_size, output_size)
        else:
            self.output_layer = Identity
        
        self.apply(weights_init_)

    def forward(self, x):
        for hidden_layer in self.hidden_layers:
            x = F.relu(hidden_layer(x))

        x = self.output_layer(x)

        return x

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6

class GaussianPolicy(nn.Module):
    def __init__(self, num_inputs, hidden_size, hidden_depth, num_actions, action_space):
        super(GaussianPolicy, self).__init__()
        
        self.net = MLP(num_inputs, hidden_size, hidden_depth, use_output_layer=False)

        self.mean_linear = nn.Linear(hidden_size, num_actions)
        self.log_std_linear = nn.Linear(hidden_size, num_actions)

        self.apply(weights_init_)

        if action_space is None:
            self.action_scale = torch.tensor(1.)
            self.action_bias = torch.tensor(0.)
        else:
            self.action_scale = torch.FloatTensor(
                (action_space.high - action_space.low) / 2.)
            self.action_bias = torch.FloatTensor(
                (action_space.high + action_space.low) / 2.)

    def sample(self, x):
        x = F.relu(self.net(x))

        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)

        return mean, log_std

    def forward(self, x):
        mean, log_std = self.sample(x)
        std = log_std.exp()
        normal = Normal(mean, std)
      
        x_t = normal.rsample()
        y_t = torch.tanh(x_t)

        action = y_t * self.action_scale + self.action_bias

        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
        log_prob = log_prob.sum(-1, keepdim=True)

        mean = torch.tanh(mean) * self.action_scale + self.action_bias

        return mean, action, log_prob.squeeze()

class SAC():
    def __init__(self,
                obs_size,
                action_size,
                hidden_size,
                hidden_depth,
                action_space=None,
                tau=0.005,
                alpha=0.2,
                gamma=0.99,
                policy_lr=3e-4,
                critic_lr=3e-4,
                predict_lr=3e-4,
                batch_size=64
                ):
        super().__init__()

        self.tau = tau
        self.alpha = alpha
        self.gamma = gamma
        self.batch_size = batch_size
        self.action_size = action_size

        self.predict = MLP(obs_size+obs_size, hidden_size, hidden_depth, 1)

        self.policy = GaussianPolicy(obs_size+1, hidden_size, hidden_depth, action_size, action_space)

        self.critic1 = MLP(obs_size+action_size+1, hidden_size, hidden_depth, 1)
        self.critic2 = MLP(obs_size+action_size+1, hidden_size, hidden_depth, 1)

        self.critic1_target = MLP(obs_size+action_size+1, hidden_size, hidden_depth, 1)
        self.critic2_target = MLP(obs_size+action_size+1, hidden_size, hidden_depth, 1)

        self.critic1_target.load_state_dict(self.critic1.state_dict())
        self.critic2_target.load_state_dict(self.critic2.state_dict())

        self.critic_parameters = list(self.critic1.parameters()) + list(self.critic2.parameters())
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=policy_lr)
        self.critic_optimizer = optim.Adam(self.critic_parameters, lr=critic_lr)
        self.predict_optimizer = optim.Adam(self.predict.parameters(), lr=predict_lr)

        self.norm = Normalizer(obs_size)

        self.replay_buffer = []

        self.target_entropy = -np.prod((action_size,)).item()
        self.log_alpha = torch.zeros(1, requires_grad=True)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=policy_lr)

    def cat(self, list):
        return torch.cat(list, dim=-1)

    def store(self, state, skill, action, next_state, next_skill, done):
        self.replay_buffer.append({"state": state,
                                   "skill": skill,
                                   "action": torch.tensor(action),
                                   "next_state": next_state,
                                   "next_skill": next_skill,
                                   "done": torch.tensor(int(not done))})

    def sample(self, batch_size):
        rand_batch = random.choices(self.replay_buffer, k=batch_size)

        batch = {"states": [],
                 "skills": [],
                 "actions": [],
                 "next_states": [],
                 "next_skills": [],
                 "dones": []}

        for dict in rand_batch:
            batch["states"].append(dict["state"])
            batch["skills"].append(dict["skill"])
            batch["actions"].append(dict["action"])
            batch["next_states"].append(dict["next_state"])
            batch["next_skills"].append(dict["next_skill"])
            batch["dones"].append(dict["done"])

        batch = {key: torch.stack(value_list, dim=0) for key, value_list in batch.items()}

        return batch

    def skill_predict(self, states, skills, next_states, timestep):
        pred_skills = torch.tanh(self.predict(self.cat([states, next_states])))

        skill_loss = F.mse_loss(pred_skills, skills, reduction='none')

        rewards = torch.tanh(-torch.log(skill_loss.detach()))

        predict_loss = skill_loss.mean()

        writer.add_scalar("Avg Rewards", rewards.mean(), timestep)

        self.predict_optimizer.zero_grad()
        predict_loss.backward()
        self.predict_optimizer.step()

        return rewards


    def train(self, timestep):
        batch = self.sample(self.batch_size)

        states = batch['states']
        skills = batch['skills']
        actions = batch['actions']
        next_states = batch['next_states']
        next_skills = batch['next_skills']
        dones = batch['dones']

        states = self.norm.normalize(states)
        next_states = self.norm.normalize(next_states)

        rewards = self.skill_predict(states, skills, next_states, timestep)

        _, pi, log_pi = self.policy(self.cat([states, skills]))
        _, next_pi, next_log_pi = self.policy(self.cat([next_states, next_skills]))
        q1 = self.critic1(self.cat([states, actions, skills])).squeeze(1)
        q2 = self.critic2(self.cat([states, actions, skills])).squeeze(1)

        min_q_pi = torch.min(self.critic1(self.cat([states, pi, skills])), self.critic2(self.cat([states, pi, skills]))).squeeze(1)
        min_q_next_pi = torch.min(self.critic1_target(self.cat([next_states, next_pi, next_skills])),
                                  self.critic2_target(self.cat([next_states, next_pi, next_skills]))).squeeze(1)
  

        v_backup = min_q_next_pi - self.alpha*next_log_pi
        q_backup = rewards.squeeze() + self.gamma*dones*v_backup

        policy_loss = (self.alpha*log_pi - min_q_pi).mean()
        critic1_loss = F.mse_loss(q1, q_backup.detach())
        critic2_loss = F.mse_loss(q2, q_backup.detach())
        critic_loss = critic1_loss + critic2_loss

        writer.add_scalar("Policy loss", policy_loss.item(), timestep)
        writer.add_scalar("Critic loss", critic_loss.item(), timestep)

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()    

        for critic1_param, critic1_target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
            critic1_target_param.data.copy_(self.tau*critic1_param.data + (1.0-self.tau)*critic1_target_param.data)

        for critic2_param, critic2_target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
            critic2_target_param.data.copy_(self.tau*critic2_param.data + (1.0-self.tau)*critic2_target_param.data)    

        alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        self.alpha = self.log_alpha.exp()

        writer.add_scalar("Alpha", self.alpha, timestep)

env = gym.make("AntPyBulletEnv-v0")
observation = env.reset()

agent = SAC(env.observation_space.shape[0],
            env.action_space.shape[0],
            128,
            2,
            env.action_space)

episode_skill = torch.empty(1).uniform_(-1,1)
print(episode_skill)
train = 5000
explore = 5000

WalkerBase::__init__
tensor([0.5701])




In [None]:
for timestep in range(1000001, 1005001):
    state = torch.tensor(observation, dtype=torch.float32)

    agent.norm.observe(state)

    if timestep > explore:
        with torch.no_grad():
            obs = agent.norm.normalize(state)
            _, action, _ = agent.policy(torch.cat([obs, episode_skill], dim=-1))
            action = action.squeeze(0).numpy()

    else:
        action = env.action_space.sample()

    observation, _, done, _ = env.step(action)

    next_state = torch.tensor(observation, dtype=torch.float32)

    agent.store(state, episode_skill, action, next_state, episode_skill, done)

    episode_skill = torch.empty(1).uniform_(-1,1)

    if done:
        agent.norm.observe(state)
        observation = env.reset()
    

    if timestep > train:
        agent.train(timestep)
        torch.save({"policy": agent.policy.state_dict(),
                    "critic1": agent.critic1.state_dict(),
                    "critic2": agent.critic2.state_dict(),
                    "critic1_target": agent.critic1_target.state_dict(),
                    "critic2_target": agent.critic2_target.state_dict(),
                    "predictor": agent.predict.state_dict(),
                    "log_alpha": agent.log_alpha,
                    "alpha" : agent.alpha,
                    "policy_optim": agent.policy_optimizer.state_dict(),
                    "critic_optim": agent.critic_optimizer.state_dict(),
                    "predictor_optim": agent.predict_optimizer.state_dict(),
                    "alpha_optim" : agent.alpha_optimizer.state_dict(),
                    "n": agent.norm.n,
                    "mean": agent.norm.mean,
                    "mean_diff": agent.norm.mean_diff,
                    "var": agent.norm.var}, "drive/My Drive/models.tar")
    writer.flush()

env.close()
writer.close()