In [None]:
# Actor critic agent
# Continuous 2D
%reset -f

import torch as tor
import matplotlib.pyplot as plt
# Problem
tor.manual_seed(3)
LB = tor.tensor([[-1., -1.]]); UB = tor.tensor([[1., 1.]])
ALB = 0.1*tor.tensor([[-.1, -.1]]); AUB = tor.tensor([[.1, .1]])
dt = 1
# Agent
nhid = 10
alpha = 0.0003
actor_body = tor.nn.Sequential(tor.nn.Linear(4, nhid), tor.nn.ReLU(),
                               tor.nn.Linear(nhid, nhid), tor.nn.ReLU()
                               )
actor_mean = tor.nn.Sequential(tor.nn.Linear(nhid, 2))
actor_mean[-1].weight.data[:] = 0; actor_mean[-1].bias.data[:] = 0
actor_lsigma = tor.nn.Sequential(tor.nn.Linear(nhid, 2))
actor_lsigma[-1].weight.data[:] = 0; actor_lsigma[-1].bias.data[:] = 0
critic = tor.nn.Sequential(tor.nn.Linear(4, nhid), tor.nn.ReLU(),
                           tor.nn.Linear(nhid, nhid), tor.nn.ReLU(),
                          tor.nn.Linear(nhid, 1))
popt = tor.optim.Adam(list(actor_body.parameters())+list(actor_mean.parameters())+list(actor_lsigma.parameters()),lr=alpha)
copt = tor.optim.Adam(critic.parameters(), lr=10*alpha)
# Experiment
EP = 2000
rets = []
Slogs = []
i = 0
for ep in range(EP):
    Slogs.append([])
    pos = tor.rand((1, 2))*(UB-LB) + LB
    vel = tor.zeros((1, 2))
    S = tor.cat((pos, vel), 1)
    Slogs[-1].append(S)
    ret = 0
    while True:
        # Take action
        feat = actor_body(S)
        mu = actor_mean(feat)
        lsigma = actor_lsigma(feat)
        try:
            pol = tor.distributions.MultivariateNormal(mu, 0.01*tor.diag(tor.exp(lsigma[0])))
        except:
            print("A")
        A = pol.sample()
        tor.clamp(A, ALB, AUB)
        # Receive reward and next state
        pos = pos + vel*dt + 0.5*A*dt**2
        vel[pos < LB] = -0.1*vel[pos < LB]; vel[pos > UB] = -0.1*vel[pos > UB]
        pos = tor.clamp(pos, LB, UB)
        vel += A*dt
        SP = tor.cat((pos, vel), 1)
        R = -0.01
        done = tor.allclose(pos, tor.zeros(2), atol=0.25) and tor.allclose(vel, tor.zeros(2), atol=0.1)
        # Learning
        vs = critic(S); vsp = critic(SP)
        pobj = pol.log_prob(A)*(R + (1-done)*vsp - vs).detach()
        ploss = -pobj
        closs = (R + (1-done)*vsp.detach() - vs)**2
        popt.zero_grad()
        ploss.backward()
        popt.step()
        copt.zero_grad()
        closs.backward()
        copt.step()
        # Log
        Slogs[-1].append(SP)
        ret += R
        # Termination
        if done:
            rets.append(ret)
            i += 1
            print(i, len(Slogs[-1]))
            break
        S = SP
# Plotting
plt.plot(-100*tor.tensor(rets))
plt.figure()
colors = ["tab:blue", "tab:green", "tab:orange", "tab:purple", "tab:red", "tab:brown"]
for i in range(-min(30, EP), 0):
    color = colors[i%len(colors)]
    Slog = tor.cat(Slogs[i])
    for i in range(Slog.shape[0]-1):
        plt.plot(Slog[i:i+2,0], Slog[i:i+2,1], alpha=(i+1)/Slog.shape[0], color=color, marker='.')
plt.xlim([LB[0, 0], UB[0, 0]])
plt.ylim([LB[0, 1], UB[0, 1]])
plt.gca().set_aspect('equal', adjustable='box')
plt.grid()
plt.show()

In [None]:
# Action Value Gradient agent
# Continuous 2D
%reset -f

import torch as tor
import matplotlib.pyplot as plt

# Problem
tor.manual_seed(3)
LB = tor.tensor([[-1., -1.]]); UB = tor.tensor([[1., 1.]])
ALB = 0.1*tor.tensor([[-.1, -.1]]); AUB = tor.tensor([[.1, .1]])
dt = 1

n_timeout = 5000

# Agent
nhid = 10
alpha = 0.0003
actor_body = tor.nn.Sequential(tor.nn.Linear(4, nhid), tor.nn.ReLU(),
                               tor.nn.Linear(nhid, nhid), tor.nn.ReLU()
                               )
actor_mean = tor.nn.Sequential(tor.nn.Linear(nhid, 2))
actor_mean[-1].weight.data[:] = 0; actor_mean[-1].bias.data[:] = 0
actor_lsigma = tor.nn.Sequential(tor.nn.Linear(nhid, 2))
actor_lsigma[-1].weight.data[:] = 0; actor_lsigma[-1].bias.data[:] = 0
q_net = tor.nn.Sequential(tor.nn.Linear(6, nhid), tor.nn.ReLU(),
                           tor.nn.Linear(nhid, nhid), tor.nn.ReLU(),
                          tor.nn.Linear(nhid, 1))
popt = tor.optim.Adam(list(actor_body.parameters())+list(actor_mean.parameters())+list(actor_lsigma.parameters()),lr=alpha)
qopt = tor.optim.Adam(q_net.parameters(), lr=10*alpha)

# Experiment
EP = 2000
rets = []
Slogs = []
i = 0
for ep in range(EP):
    Slogs.append([])
    pos = tor.rand((1, 2))*(UB-LB) + LB
    vel = tor.zeros((1, 2))
    S = tor.cat((pos, vel), 1)
    Slogs[-1].append(S)
    ret = 0
    step = 0
    while True:
        # Take action
        feat = actor_body(S)
        mu = actor_mean(feat)
        lsigma = actor_lsigma(feat)
        try:
            pol = tor.distributions.MultivariateNormal(mu, 0.01*tor.diag(tor.exp(lsigma[0])))
        except:
            print("A")
        A = pol.sample() # Don't use rsample() here
        tor.clamp(A, ALB, AUB)

        # Receive reward and next state
        pos = pos + vel*dt + 0.5*A*dt**2
        vel[pos < LB] = -0.1*vel[pos < LB]; vel[pos > UB] = -0.1*vel[pos > UB]
        pos = tor.clamp(pos, LB, UB)
        vel += A*dt
        SP = tor.cat((pos, vel), 1)
        R = -0.01
        done = (tor.allclose(pos, tor.zeros(2), atol=0.25) and tor.allclose(vel, tor.zeros(2), atol=0.1)) #or step + 1 == n_timeout

        # print("Step: {}, ".format(step))

        # Learning
        q = q_net(tor.cat((S, A), 1))
        with tor.no_grad():
          featP = actor_body(SP)
          muP = actor_mean(featP)
          lsigmaP = actor_lsigma(featP)
          polP = tor.distributions.MultivariateNormal(muP, 0.01*tor.diag(tor.exp(lsigmaP[0])))
          A2 = polP.sample()
          q2 = q_net(tor.cat((SP, A2), 1));

        # A.requires_grad = False
        ## Q loss
        qloss = (R + (1-done)*q2 - q)**2

        # Policy loss
        feat_pi = actor_body(S)
        mu_pi = actor_mean(feat_pi)
        lsigma_pi = actor_lsigma(feat_pi)
        pol_pi = tor.distributions.MultivariateNormal(mu_pi, 0.01*tor.diag(tor.exp(lsigma_pi[0])))
        A_pi = pol_pi.rsample()   # Requires rsample()
        q_pi = q_net(tor.cat((S, A_pi), 1))
        pobj = q_pi
        ploss = -pobj


        # A.requires_grad = True
        popt.zero_grad()
        ploss.backward()
        popt.step()

        qopt.zero_grad()
        qloss.backward()
        qopt.step()

        # Log
        Slogs[-1].append(SP)
        ret += R
        step += 1

        # Termination
        if done:
            rets.append(ret)
            i += 1
            print(i, len(Slogs[-1]))
            break
        S = SP

# Plotting
plt.plot(-100*tor.tensor(rets))
plt.figure()
colors = ["tab:blue", "tab:green", "tab:orange", "tab:purple", "tab:red", "tab:brown"]
for i in range(-min(30, EP), 0):
    color = colors[i%len(colors)]
    Slog = tor.cat(Slogs[i])
    for i in range(Slog.shape[0]-1):
        plt.plot(Slog[i:i+2,0], Slog[i:i+2,1], alpha=(i+1)/Slog.shape[0], color=color, marker='.')
plt.xlim([LB[0, 0], UB[0, 0]])
plt.ylim([LB[0, 1], UB[0, 1]])
plt.gca().set_aspect('equal', adjustable='box')
plt.grid()
plt.show()

**AVG on Challenging Mujoco Benchmark Tasks**

In [None]:
import torch, time
import argparse, os, traceback

import numpy as np
import torch.nn as nn
import gymnasium as gym
import torch.nn.functional as F

from torch.distributions import MultivariateNormal
from gymnasium.wrappers import NormalizeObservation
from datetime import datetime


def orthogonal_weight_init(m):
    """ Orthogonal weight initialization for neural networks """
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        m.bias.data.fill_(0.0)

def human_format_numbers(num, use_float=False):
    # Make human readable short-forms for large numbers
    magnitude = 0
    while abs(num) >= 1000:
        magnitude += 1
        num /= 1000.0
    # add more suffixes if you need them
    if use_float:
        return '%.2f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude])
    return '%d%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude])

def set_one_thread():
    '''
    N.B: Pytorch over-allocates resources and hogs CPU, which makes experiments very slow!
    Set number of threads for pytorch to 1 to avoid this issue. This is a temporary workaround.
    '''
    os.environ['OMP_NUM_THREADS'] = '1'
    os.environ['MKL_NUM_THREADS'] = '1'
    torch.set_num_threads(1)


class Actor(nn.Module):
    """ Squashed Normal MLP """
    def __init__(self, obs_dim, action_dim, device, n_hid):
        super(Actor, self).__init__()
        self.device = device
        self.LOG_STD_MAX = 2
        self.LOG_STD_MIN = -20

        # Two hidden layers
        self.phi = nn.Sequential(
            nn.Linear(obs_dim, n_hid),
            nn.LeakyReLU(),
            nn.Linear(n_hid, n_hid),
            nn.LeakyReLU(),
        )

        self.mu = nn.Linear(n_hid, action_dim)
        self.log_std = nn.Linear(n_hid, action_dim)

        self.apply(orthogonal_weight_init)
        self.to(device=device)

    def forward(self, obs):
        phi = self.phi(obs.to(self.device))
        phi = phi / torch.norm(phi, dim=1).view((-1, 1))
        mu = self.mu(phi)
        log_std = self.log_std(phi)
        log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)

        dist = MultivariateNormal(mu, torch.diag_embed(log_std.exp()))
        action_pre = dist.rsample()
        lprob = dist.log_prob(action_pre)
        lprob -= (2 * (np.log(2) - action_pre - F.softplus(-2 * action_pre))).sum(axis=1)

        # N.B: Tanh must be applied _only_ after lprob estimation of dist sampled action!!
        #   A mistake here can break learning :/
        action = torch.tanh(action_pre)
        action_info = {'mu': mu, 'log_std': log_std, 'dist': dist, 'lprob': lprob, 'action_pre': action_pre}

        return action, action_info


class Q(nn.Module):
    def __init__(self, obs_dim, action_dim, device, n_hid):
        super(Q, self).__init__()
        self.device = device

        # Two hidden layers
        self.phi = nn.Sequential(
            nn.Linear(obs_dim + action_dim, n_hid),
            nn.LeakyReLU(),
            nn.Linear(n_hid, n_hid),
            nn.LeakyReLU(),
        )
        self.q = nn.Linear(n_hid, 1)
        self.apply(orthogonal_weight_init)
        self.to(device=device)

    def forward(self, obs, action):
        x = torch.cat((obs, action), -1).to(self.device)
        phi = self.phi(x)
        phi = phi / torch.norm(phi, dim=1).view((-1, 1))
        return self.q(phi).view(-1)


class AVG:
    def __init__(self, cfg):
        self.cfg = cfg
        self.steps = 0

        self.actor = Actor(obs_dim=cfg.obs_dim, action_dim=cfg.action_dim, device=cfg.device, n_hid=cfg.nhid_actor)
        self.Q = Q(obs_dim=cfg.obs_dim, action_dim=cfg.action_dim, device=cfg.device, n_hid=cfg.nhid_critic)

        self.popt = torch.optim.Adam(self.actor.parameters(), lr=cfg.actor_lr, betas=cfg.betas)
        self.qopt = torch.optim.Adam(self.Q.parameters(), lr=cfg.critic_lr, betas=cfg.betas)

        self.alpha, self.gamma, self.device = cfg.alpha_lr, cfg.gamma, cfg.device

    def compute_action(self, obs):
        obs = torch.Tensor(obs.astype(np.float32)).unsqueeze(0).to(self.device)
        action, action_info = self.actor(obs)
        return action, action_info

    def update(self, obs, action, next_obs, reward, done, **kwargs):
        obs = torch.Tensor(obs.astype(np.float32)).unsqueeze(0).to(self.device)
        next_obs = torch.Tensor(next_obs.astype(np.float32)).unsqueeze(0).to(self.device)
        action, lprob = action.to(self.device), kwargs['lprob']

        #### Q loss
        q = self.Q(obs, action.detach())    # N.B: Gradient should NOT pass through action here
        with torch.no_grad():
            next_action, action_info = self.actor(next_obs)
            next_lprob = action_info['lprob']
            q2 = self.Q(next_obs, next_action)
            target_V = q2 - self.alpha * next_lprob

        delta = reward + (1 - done) *  self.gamma * target_V - q
        qloss = delta ** 2
        ####

        # Policy loss
        ploss = self.alpha * lprob - self.Q(obs, action) # N.B: USE reparametrized action
        self.popt.zero_grad()
        ploss.backward()
        self.popt.step()

        self.qopt.zero_grad()
        qloss.backward()
        self.qopt.step()

        self.steps += 1


def main(args):
    tic = time.time()
    run_id = datetime.now().strftime("%Y%m%d_%H%M%S") + f"-{args.algo}-{args.env}_seed-{args.seed}"

    # Env
    env = gym.make(args.env)
    env = NormalizeObservation(env)

    #### Reproducibility
    env.reset(seed=args.seed)
    env.action_space.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    ####

    # Learner
    args.obs_dim =  env.observation_space.shape[0]
    args.action_dim = env.action_space.shape[0]
    agent = AVG(args)

    # Interaction
    rets, ep_steps = [], []
    ret, step = 0, 0
    terminated, truncated = False, False
    obs, _ = env.reset()
    ep_tic = time.time()
    try:
        for t in range(args.N):
            # N.B: Action is a torch.Tensor
            action, action_info = agent.compute_action(obs)
            sim_action = action.detach().cpu().view(-1).numpy()

            # Receive reward and next state
            next_obs, reward, terminated, truncated, _ = env.step(sim_action)
            agent.update(obs, action, next_obs, reward, terminated, **action_info)
            ret += reward
            step += 1

            obs = next_obs

            # Termination
            if terminated or truncated:
                rets.append(ret)
                ep_steps.append(step)
                print("E: {}| D: {:.3f}| S: {}| R: {:.2f}| T: {}".format(len(rets), time.time() - ep_tic, step, ret, t))

                ep_tic = time.time()
                obs, _ = env.reset()
                ret, step = 0, 0
    except Exception as e:
        print(e)
        print("Exiting this run, storing partial logs in the database for future debugging...")
        traceback.print_exc()

    if not (terminated or truncated):
        # N.B: We're adding a partial episode just to make plotting easier. But this data point shouldn't be used
        print("Appending partial episode #{}, length: {}, Total Steps: {}".format(len(rets), step, t+1))
        rets.append(ret)
        ep_steps.append(step)

    # Save returns and args before exiting run
    if args.save_model:
        agent.save(model_dir=args.results_dir, unique_str=f"{run_id}_model")


    print("Run with id: {} took {:.3f}s!".format(run_id, time.time()-tic))
    return ep_steps, rets


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    args = parser.parse_args(args=[])


    args.env = "Hopper-v4"
    args.seed = 42
    args.N = 10001000
    args.actor_lr = 0.00006
    args.critic_lr = 0.00087
    args.gamma = 0.99
    args.alpha_lr = 0.6
    args.nhid_actor = 256
    args.nhid_critic = 256
    # Miscellaneous
    args.results_dir = "./results"
    parser.add_argument('--save_model', action='store_true', default=False)

    # Adam
    args.betas = [0, 0.999]

    args.device = torch.device("cpu")
    args.algo = "AVG"

    # Start experiment
    set_one_thread()
    main(args)