In [None]:
!sudo apt install curl git libgl1-mesa-dev libgl1-mesa-glx libglew-dev \
        libosmesa6-dev software-properties-common net-tools unzip vim \
        virtualenv wget xpra xserver-xorg-dev libglfw3-dev patchelf

In [None]:
!pip install robosuite

In [None]:
import numpy as np
import robosuite as suite
import os
import time
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

In [None]:
import os
import sys
import shutil
import os.path as osp
import json
import time
import datetime
import tempfile
from collections import defaultdict
from contextlib import contextmanager

DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40

DISABLED = 50

class KVWriter(object):
    def writekvs(self, kvs):
        raise NotImplementedError

class SeqWriter(object):
    def writeseq(self, seq):
        raise NotImplementedError

class HumanOutputFormat(KVWriter, SeqWriter):
    def __init__(self, filename_or_file):
        if isinstance(filename_or_file, str):
            self.file = open(filename_or_file, 'wt')
            self.own_file = True
        else:
            assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s'%filename_or_file
            self.file = filename_or_file
            self.own_file = False

    def writekvs(self, kvs):
        # Create strings for printing
        key2str = {}
        for (key, val) in sorted(kvs.items()):
            if hasattr(val, '__float__'):
                valstr = '%-8.3g' % val
            else:
                valstr = str(val)
            key2str[self._truncate(key)] = self._truncate(valstr)

        # Find max widths
        if len(key2str) == 0:
            print('WARNING: tried to write empty key-value dict')
            return
        else:
            keywidth = max(map(len, key2str.keys()))
            valwidth = max(map(len, key2str.values()))

        # Write out the data
        dashes = '-' * (keywidth + valwidth + 7)
        lines = [dashes]
        for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
            lines.append('| %s%s | %s%s |' % (
                key,
                ' ' * (keywidth - len(key)),
                val,
                ' ' * (valwidth - len(val)),
            ))
        lines.append(dashes)
        self.file.write('\n'.join(lines) + '\n')

        # Flush the output to the file
        self.file.flush()

    def _truncate(self, s):
        maxlen = 30
        return s[:maxlen-3] + '...' if len(s) > maxlen else s

    def writeseq(self, seq):
        seq = list(seq)
        for (i, elem) in enumerate(seq):
            self.file.write(elem)
            if i < len(seq) - 1: # add space unless this is the last one
                self.file.write(' ')
        self.file.write('\n')
        self.file.flush()

    def close(self):
        if self.own_file:
            self.file.close()

class JSONOutputFormat(KVWriter):
    def __init__(self, filename):
        self.file = open(filename, 'wt')

    def writekvs(self, kvs):
        for k, v in sorted(kvs.items()):
            if hasattr(v, 'dtype'):
                kvs[k] = float(v)
        self.file.write(json.dumps(kvs) + '\n')
        self.file.flush()

    def close(self):
        self.file.close()

class CSVOutputFormat(KVWriter):
    def __init__(self, filename):
        self.file = open(filename, 'w+t')
        self.keys = []
        self.sep = ','

    def writekvs(self, kvs):
        # Add our current row to the history
        extra_keys = list(kvs.keys() - self.keys)
        extra_keys.sort()
        if extra_keys:
            self.keys.extend(extra_keys)
            self.file.seek(0)
            lines = self.file.readlines()
            self.file.seek(0)
            for (i, k) in enumerate(self.keys):
                if i > 0:
                    self.file.write(',')
                self.file.write(k)
            self.file.write('\n')
            for line in lines[1:]:
                self.file.write(line[:-1])
                self.file.write(self.sep * len(extra_keys))
                self.file.write('\n')
        for (i, k) in enumerate(self.keys):
            if i > 0:
                self.file.write(',')
            v = kvs.get(k)
            if v is not None:
                self.file.write(str(v))
        self.file.write('\n')
        self.file.flush()

    def close(self):
        self.file.close()


class TensorBoardOutputFormat(KVWriter):
    """
    Dumps key/value pairs into TensorBoard's numeric format.
    """
    def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = 'events'
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat
        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))

    def writekvs(self, kvs):
        def summary_val(k, v):
            kwargs = {'tag': k, 'simple_value': float(v)}
            return self.tf.Summary.Value(**kwargs)
        summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
        event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
        event.step = self.step # is there any reason why you'd want to specify the step?
        self.writer.WriteEvent(event)
        self.writer.Flush()
        self.step += 1

    def close(self):
        if self.writer:
            self.writer.Close()
            self.writer = None

def make_output_format(format, ev_dir, log_suffix=''):
    os.makedirs(ev_dir, exist_ok=True)
    if format == 'stdout':
        return HumanOutputFormat(sys.stdout)
    elif format == 'log':
        return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix))
    elif format == 'json':
        return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix))
    elif format == 'csv':
        return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix))
    elif format == 'tensorboard':
        return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix))
    else:
        raise ValueError('Unknown format specified: %s' % (format,))

# ================================================================
# API
# ================================================================

def logkv(key, val):
    """
    Log a value of some diagnostic
    Call this once for each diagnostic quantity, each iteration
    If called many times, last value will be used.
    """
    get_current().logkv(key, val)

def logkv_mean(key, val):
    """
    The same as logkv(), but if called many times, values averaged.
    """
    get_current().logkv_mean(key, val)

def logkvs(d):
    """
    Log a dictionary of key-value pairs
    """
    for (k, v) in d.items():
        logkv(k, v)

def dumpkvs():
    """
    Write all of the diagnostics from the current iteration
    """
    return get_current().dumpkvs()

def getkvs():
    return get_current().name2val


def log(*args, level=INFO):
    """
    Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
    """
    get_current().log(*args, level=level)

def debug(*args):
    log(*args, level=DEBUG)

def info(*args):
    log(*args, level=INFO)

def warn(*args):
    log(*args, level=WARN)

def error(*args):
    log(*args, level=ERROR)


def set_level(level):
    """
    Set logging threshold on current logger.
    """
    get_current().set_level(level)

def set_comm(comm):
    get_current().set_comm(comm)

def get_dir():
    """
    Get directory that log files are being written to.
    will be None if there is no output directory (i.e., if you didn't call start)
    """
    return get_current().get_dir()

record_tabular = logkv
dump_tabular = dumpkvs

@contextmanager
def profile_kv(scopename):
    logkey = 'wait_' + scopename
    tstart = time.time()
    try:
        yield
    finally:
        get_current().name2val[logkey] += time.time() - tstart

def profile(n):
    """
    Usage:
    @profile("my_func")
    def my_func(): code
    """
    def decorator_with_name(func):
        def func_wrapper(*args, **kwargs):
            with profile_kv(n):
                return func(*args, **kwargs)
        return func_wrapper
    return decorator_with_name


# ================================================================
# Backend
# ================================================================

def get_current():
    if Logger.CURRENT is None:
        _configure_default_logger()

    return Logger.CURRENT


class Logger(object):
    DEFAULT = None  # A logger with no output files. (See right below class definition)
                    # So that you can still log to the terminal without setting up any output files
    CURRENT = None  # Current logger being used by the free functions above

    def __init__(self, dir, output_formats, comm=None):
        self.name2val = defaultdict(float)  # values this iteration
        self.name2cnt = defaultdict(int)
        self.level = INFO
        self.dir = dir
        self.output_formats = output_formats
        self.comm = comm

    # Logging API, forwarded
    # ----------------------------------------
    def logkv(self, key, val):
        self.name2val[key] = val

    def logkv_mean(self, key, val):
        oldval, cnt = self.name2val[key], self.name2cnt[key]
        self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
        self.name2cnt[key] = cnt + 1

    def dumpkvs(self):
        if self.comm is None:
            d = self.name2val
        else:
            from baselines.common import mpi_util
            d = mpi_util.mpi_weighted_mean(self.comm,
                {name : (val, self.name2cnt.get(name, 1))
                    for (name, val) in self.name2val.items()})
            if self.comm.rank != 0:
                d['dummy'] = 1 # so we don't get a warning about empty dict
        out = d.copy() # Return the dict for unit testing purposes
        for fmt in self.output_formats:
            if isinstance(fmt, KVWriter):
                fmt.writekvs(d)
        self.name2val.clear()
        self.name2cnt.clear()
        return out

    def log(self, *args, level=INFO):
        if self.level <= level:
            self._do_log(args)

    # Configuration
    # ----------------------------------------
    def set_level(self, level):
        self.level = level

    def set_comm(self, comm):
        self.comm = comm

    def get_dir(self):
        return self.dir

    def close(self):
        for fmt in self.output_formats:
            fmt.close()

    # Misc
    # ----------------------------------------
    def _do_log(self, args):
        for fmt in self.output_formats:
            if isinstance(fmt, SeqWriter):
                fmt.writeseq(map(str, args))

def get_rank_without_mpi_import():
    # check environment variables here instead of importing mpi4py
    # to avoid calling MPI_Init() when this module is imported
    for varname in ['PMI_RANK', 'OMPI_COMM_WORLD_RANK']:
        if varname in os.environ:
            return int(os.environ[varname])
    return 0


def baselines_configure(dir=None, format_strs=None, comm=None, log_suffix=''):
    """
    If comm is provided, average all numerical stats across that comm
    """
    if dir is None:
        dir = os.getenv('OPENAI_LOGDIR')
    if dir is None:
        dir = osp.join(tempfile.gettempdir(),
            datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"))
    assert isinstance(dir, str)
    dir = os.path.expanduser(dir)
    os.makedirs(os.path.expanduser(dir), exist_ok=True)

    rank = get_rank_without_mpi_import()
    if rank > 0:
        log_suffix = log_suffix + "-rank%03i" % rank

    if format_strs is None:
        if rank == 0:
            format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',')
        else:
            format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',')
    format_strs = filter(None, format_strs)
    output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]

    Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
    if output_formats:
        log('Logging to %s'%dir)

def _configure_default_logger():
    baselines_configure()
    Logger.DEFAULT = Logger.CURRENT

def reset():
    if Logger.CURRENT is not Logger.DEFAULT:
        Logger.CURRENT.close()
        Logger.CURRENT = Logger.DEFAULT
        log('Reset logger')

@contextmanager
def scoped_configure(dir=None, format_strs=None, comm=None):
    prevlogger = Logger.CURRENT
    baselines_configure(dir=dir, format_strs=format_strs, comm=comm)
    try:
        yield
    finally:
        Logger.CURRENT.close()
        Logger.CURRENT = prevlogger

# ================================================================

def _demo():
    info("hi")
    debug("shouldn't appear")
    set_level(DEBUG)
    debug("should appear")
    dir = "/tmp/testlogging"
    if os.path.exists(dir):
        shutil.rmtree(dir)
    baselines_configure(dir=dir)
    logkv("a", 3)
    logkv("b", 2.5)
    dumpkvs()
    logkv("b", -2.5)
    logkv("a", 5.5)
    dumpkvs()
    info("^^^ should see a = 5.5")
    logkv_mean("b", -22.5)
    logkv_mean("b", -44.4)
    logkv("a", 5.5)
    dumpkvs()
    info("^^^ should see b = -33.3")

    logkv("b", -2.5)
    dumpkvs()

    logkv("a", "longasslongasslongasslongasslongasslongassvalue")
    dumpkvs()


# ================================================================
# Readers
# ================================================================

def read_json(fname):
    import pandas
    ds = []
    with open(fname, 'rt') as fh:
        for line in fh:
            ds.append(json.loads(line))
    return pandas.DataFrame(ds)

def read_csv(fname):
    import pandas
    return pandas.read_csv(fname, index_col=None, comment='#')

def read_tb(path):
    """
    path : a tensorboard file OR a directory, where we will find all TB files
           of the form events.*
    """
    import pandas
    import numpy as np
    from glob import glob
    import tensorflow as tf
    if osp.isdir(path):
        fnames = glob(osp.join(path, "events.*"))
    elif osp.basename(path).startswith("events."):
        fnames = [path]
    else:
        raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s"%path)
    tag2pairs = defaultdict(list)
    maxstep = 0
    for fname in fnames:
        for summary in tf.train.summary_iterator(fname):
            if summary.step > 0:
                for v in summary.summary.value:
                    pair = (summary.step, v.simple_value)
                    tag2pairs[v.tag].append(pair)
                maxstep = max(summary.step, maxstep)
    data = np.empty((maxstep, len(tag2pairs)))
    data[:] = np.nan
    tags = sorted(tag2pairs.keys())
    for (colidx,tag) in enumerate(tags):
        pairs = tag2pairs[tag]
        for (step, value) in pairs:
            data[step-1, colidx] = value
    return pandas.DataFrame(data, columns=tags)

if __name__ == "__main__":
    _demo()

In [None]:
def append_human_init(self, filename_or_file):
    if isinstance(filename_or_file, str):
        self.file = open(filename_or_file, 'at')
        self.own_file = True
    else:
        assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s'%filename_or_file
        self.file = filename_or_file
        self.own_file = False

def append_json_init(self, filename):
    self.file = open(filename, 'at')

def append_csv_init(self, filename):
    self.file = open(filename, 'a+t')
    self.keys = []
    self.sep = ','

HumanOutputFormat.__init__ = append_human_init
JSONOutputFormat.__init__ = append_json_init
CSVOutputFormat.__init__ = append_csv_init

# create global tensorboardX summary writer.
WRITER = None
def configure(log_dir, format_strs=None, tbX=False, **kwargs):
    global WRITER
    if tbX:
        tb_dir = os.path.join(log_dir, 'tensorboard')
        WRITER = SummaryWriter(tb_dir, **kwargs)
    else:
        WRITER = None
    baselines_configure(log_dir, format_strs)

def get_summary_writer():
    return WRITER

def add_scalar(tag, scalar_value, global_step=None, walltime=None):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_scalar(tag, scalar_value, global_step, walltime)
    # change interface so both add_scalar and add_scalars adds to the scalar dict.
    #WRITER._SummaryWriter__append_to_scalar_dict(tag, scalar_value, global_step, walltime)

"""
def add_scalars(main_tag, tag_scalar_dict, global_step=None, walltime=None):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_scalars(main_tag, tag_scalar_dict, global_step, walltime)
"""

def add_histogram(tag, values, global_step=None, bins='tensorflow'):#, walltime=None, max_bins=None):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_histogram(tag, values, global_step, bins)#, walltime, max_bins)

def add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_image(tag, img_tensor, global_step, walltime, dataformats)

def add_images(tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_images(tag, img_tensor, global_step, walltime, dataformats)

def add_image_with_boxes(tag, img_tensor, box_tensor, global_step=None,
                             walltime=None, dataformats='CHW', **kwargs):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_image_with_boxes(tag, img_tensor, box_tensor, global_step,
                                 walltime, dataformats, **kwargs)

def add_figure(tag, figure, global_step=None, close=True, walltime=None):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_figure(tag, figure, global_step, close, walltime)

def add_video(tag, vid_tensor, global_step=None, fps=4, walltime=None):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_video(tag, vid_tensor, global_step, fps, walltime)

def add_audio(tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_audio(tag, snd_tensor, global_step, sample_rate, walltime)

def add_text(tag, text_string, global_step=None, walltime=None):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_text(tag, text_string, global_step, walltime)

def add_graph(model, input_to_model=None, verbose=False, **kwargs):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    WRITER.add_graph(model, input_to_model, verbose, **kwargs)


def export_scalars(fname, overwrite=False):
    assert WRITER is not None, "call configure to initialize SummaryWriter"
    os.makedirs(os.path.join(WRITER.log_dir, 'scalar_data'), exist_ok=True)
    if fname[-4:] != 'json':
        fname += '.json'
    fname = os.path.join(WRITER.log_dir, 'scalar_data', fname)
    if not os.path.exists(fname) or overwrite:
        WRITER.export_scalars_to_json(fname)

In [None]:
class Normal(D.Normal):
    def mode(self):
        return self.mean

    def log_probs(self, action):
        return super().log_prob(action).sum(-1, keepdim=True)

    def entropy(self):
        return super().entropy().sum(-1)

In [None]:
class AddBias(nn.Module):
    def __init__(self, bias):
        super(AddBias, self).__init__()
        self._bias = nn.Parameter(bias.unsqueeze(1))

    def forward(self, x):
        if x.dim() == 2:
            bias = self._bias.t().view(1, -1)
        else:
            bias = self._bias.t().view(1, -1, 1, 1)

        return x + bias

In [None]:
def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module

In [None]:
def update_linear_schedule(optimizer, update, total_num_updates, initial_lr):
    """Decreases the learning rate linearly"""
    lr = initial_lr - (initial_lr * (update / float(total_num_updates)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
class DiagGaussian(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(DiagGaussian, self).__init__()

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0))

        self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
        self.logstd = AddBias(torch.zeros(num_outputs))

    def forward(self, x):
        action_mean = self.fc_mean(x)

        #  An ugly hack for my KFAC implementation.
        zeros = torch.zeros(action_mean.size())
        if x.is_cuda:
            zeros = zeros.cuda()

        action_logstd = self.logstd(zeros)
        return Normal(action_mean, action_logstd.exp())

In [None]:
class Rollouts(object):
    def __init__(self,
                 num_steps,
                 num_processes,
                 obs_shape,
                 action_space,
                 device=None,
                 use_gae=False):

        self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape)
        self.masks = torch.ones(num_steps + 1, num_processes, 1)
        self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
        self.returns = torch.zeros(num_steps + 1, num_processes, 1)

        self.rewards = torch.zeros(num_steps, num_processes, 1)
        self.intrinsic_rewards = torch.zeros(num_steps, num_processes, 1)
        self.selfish_extrinsic_rewards = torch.zeros(num_steps, num_processes, 1)
        self.agentHolding = torch.zeros(num_steps, num_processes, 1)
        self.action_log_probs = torch.zeros(num_steps, num_processes, 1)

        if action_space.__class__.__name__ == 'Discrete':
            action_shape = 1
        else:
            action_shape = action_space.shape[0]
        self.actions = torch.zeros(num_steps, num_processes, action_shape)
        if action_space.__class__.__name__ == 'Discrete':
            self.actions = self.actions.long()

        if device is not None:
            self.device = device
        else:
            self.device = 'cpu'
        self.to(self.device)

        self.num_steps = num_steps
        self.step = 0

    def to(self, device=None):
        if device is None:
            device = self.device
        self.obs = self.obs.to(device)
        self.rewards = self.rewards.to(device)
        self.intrinsic_rewards = self.intrinsic_rewards.to(device)
        self.selfish_extrinsic_rewards = self.selfish_extrinsic_rewards.to(device)
        self.value_preds = self.value_preds.to(device)
        self.returns = self.returns.to(device)
        self.action_log_probs = self.action_log_probs.to(device)
        self.actions = self.actions.to(device)
        self.agentHlding = self.agentHolding.to(device)
        self.masks = self.masks.to(device)


    def insert(self, obs, actions, action_log_probs, value_preds, rewards, intrinsic_rewards, selfish_extrinsic_rewards, masks, agentHolding):
        self.obs[self.step + 1].copy_(obs)
        self.actions[self.step].copy_(actions)
        self.action_log_probs[self.step].copy_(action_log_probs)
        self.value_preds[self.step].copy_(value_preds)
        self.rewards[self.step].copy_(rewards)
        self.intrinsic_rewards[self.step].copy_(intrinsic_rewards)
        self.selfish_extrinsic_rewards[self.step].copy_(selfish_extrinsic_rewards)
        self.agentHolding[self.step].copy_(agentHolding)
        self.masks[self.step + 1].copy_(masks)

        self.step = (self.step + 1) % self.num_steps

    def after_update(self):
        """
        After updating move the last observation and mask
        to the begining of the rollout storage
        """
        self.obs[0].copy_(self.obs[-1])
        self.masks[0].copy_(self.masks[-1])

    def compute_returns(self, next_value, gamma=0.99, use_gae=True, gae_lambda=0.95):

        if use_gae:
            self.value_preds[-1] = next_value
            gae = 0
            for step in reversed(range(self.rewards.size(0))):
                delta = self.rewards[step] + gamma * self.value_preds[
                    step + 1] * self.masks[step +
                                           1] - self.value_preds[step]
                gae = delta + gamma * gae_lambda * self.masks[step +
                                                              1] * gae
                self.returns[step] = gae + self.value_preds[step]
        else:
            self.returns[-1] = next_value
            for step in reversed(range(self.rewards.size(0))):
                self.returns[step] = self.returns[step + 1] * \
                    gamma * self.masks[step + 1] + self.rewards[step]


    def feed_forward_generator(self, advantages, num_mini_batch):
        # get number of steps and number of processes
        num_steps, num_processes = self.rewards.size()[0:2]
        batch_size = num_steps * num_processes
        # make sure the size of the batch is greater than the number of mini batches
        assert batch_size >= num_mini_batch
        # size of minibatch is size of big batch / number of minibatches
        mini_batch_size = batch_size // num_mini_batch
        # This will randomly partition indices will keep the last partition even
        # if it isn't the same size as mini_batch_size
        sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)

        for indices in sampler:
            obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
            actions_batch = self.actions.view(-1, *self.actions.size()[2:])[indices]
            next_obs_batch = self.obs[1:].view(-1, *self.obs.size()[2:])[indices]
            value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices]
            return_batch = self.returns[:-1].view(-1, 1)[indices]
            masks_batch = self.masks[:-1].view(-1, 1)[indices]
            old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices]
            adv_target = advantages.view(-1, 1)[indices]

            yield obs_batch, actions_batch, next_obs_batch, value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_target

    def curiosity_generator(self, num_mini_batch):
        # get number of steps and number of processes
        num_steps, num_processes = self.rewards.size()[0:2]
        batch_size = num_steps * num_processes
        # make sure the size of the batch is greater than the number of mini batches
        assert batch_size >= num_mini_batch
        # size of minibatch is size of big batch / number of minibatches
        mini_batch_size = batch_size // num_mini_batch
        # This will randomly partition indices will keep the last partition even
        # if it isn't the same size as mini_batch_size
        sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=False)

        for indices in sampler:
            obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
            next_obs_batch = self.obs[1:].view(-1, *self.obs.size()[2:])[indices]
            actions_batch = self.actions.view(-1, *self.actions.size()[2:])[indices]

            yield obs_batch, actions_batch, next_obs_batch

class MultimodalRollouts(Rollouts):
    def __init__(self, num_steps, num_processes,
                 obs_shape, action_space, im_shape, depth_shape, contact_shape,
                 device=None, use_gae=False):

        self.image1 = torch.zeros(num_steps + 1, num_processes, *im_shape)
        self.image2 = torch.zeros(num_steps + 1, num_processes, *im_shape)
        self.depth1 = torch.zeros(num_steps + 1, num_processes, *depth_shape)
        self.depth2 = torch.zeros(num_steps + 1, num_processes, *depth_shape)
        self.contact = torch.zeros(num_steps + 1, num_processes, *contact_shape)
        super(MultimodalRollouts, self).__init__(num_steps, num_processes, obs_shape, action_space, device, use_gae)

        self.to()

    def insert(self, obs, actions, action_log_probs,
               value_preds, rewards, masks,
               image1, image2, depth1, depth2, contact):

        self.image1[self.step + 1].copy_(image1)
        self.image2[self.step + 1].copy_(image2)
        self.depth1[self.step + 1].copy_(depth1)
        self.depth2[self.step + 1].copy_(depth2)
        self.contact[self.step + 1].copy_(contact)
        super(MultimodalRollouts, self).insert(obs, actions, action_log_probs,
                                               alue_preds, rewards, masks)

    def to(self):
        self.image1 = self.image1.to(self.device)
        self.image2 = self.image2.to(self.device)
        self.depth1 = self.depth1.to(self.device)
        self.depth2 = self.depth2.to(self.device)
        self.contact = self.contact.to(self.device)
        super(MultimodalRollouts, self).to(self.device)

    def after_update(self):
        """
        After updating move the last observation and mask
        to the begining of the rollout storage
        """
        self.obs[0].copy_(self.obs[-1])
        self.masks[0].copy_(self.masks[-1])
        self.image1[0].copy_(self.image1[-1])
        self.image2[0].copy_(self.image2[-1])
        self.depth1[0].copy_(self.depth1[-1])
        self.depth2[0].copy_(self.depth2[-1])
        self.contact[0].copy_(self.contact[-1])
        super(MultimodalRollouts, self).after_update()

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, num_inputs, hidden_size, num_outputs):
        super(ActorCritic, self).__init__()
        self.policy = Policy(num_inputs, hidden_size, num_outputs)
        self.value_fn = ValueFn(num_inputs, hidden_size)

    def forward(self):
        raise NotImplementedError

    def select_action(self, obs):
        value, actor_features = self.value_fn(obs), self.policy(obs)
        dist = self.policy.dist(actor_features)
        action = dist.sample()
        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy().mean()
        return value, action, action_log_probs

    def evaluate_action(self, obs, action):
        #print("Obs: ")
        #print(obs)
        value, actor_features = self.value_fn(obs), self.policy(obs)
        #print("Value")
        #print(value)
        #print("Actor features")
        #print(actor_features)
        dist = self.policy.dist(actor_features)
        action_log_probs = dist.log_probs(action)
        entropy = dist.entropy().mean()

        return value, action_log_probs, entropy

    def get_value(self, obs):
        return self.value_fn(obs)

class Policy(nn.Module):
    def __init__(self, num_inputs, hidden_size, num_outputs):
        super(Policy, self).__init__()

        """
        self.base = nn.Sequential(
            nn.Linear(num_inputs, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh())
        """

        self.base1 = nn.Linear(num_inputs, hidden_size)
        self.relu = nn.ReLU()
        self.base2 = nn.Linear(hidden_size, hidden_size)
        self.tanh = nn.Tanh()

        self.dist = DiagGaussian(hidden_size, num_outputs)

    def forward(self, x):
        return self.tanh(self.base2(self.relu(self.base1(x))))
        #return self.base(x)

class ValueFn(nn.Module):
    def __init__(self, num_inputs, hidden_size):
        super(ValueFn, self).__init__()

        """
        self.base = nn.Sequential(
            nn.Linear(num_inputs, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU())
        """
        self.base1 = nn.Linear(num_inputs, hidden_size)
        self.relu = nn.ReLU()
        self.base2 = nn.Linear(hidden_size, hidden_size)
        self.tanh = nn.Tanh()

        self.head = nn.Linear(hidden_size, 1)

    def forward(self, x):
        return self.head(self.tanh(self.base2(self.relu(self.base1(x)))))
        #return self.head(self.base(x))

class FwdDyn(nn.Module):
    def __init__(self, num_inputs, hidden_size, num_outputs):
        super(FwdDyn, self).__init__()
        self.base = nn.Sequential(
            nn.Linear(num_inputs, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_outputs),
            nn.Tanh())

    def forward(self, state, action):
        feature = torch.cat((state, action), -1)
        return self.base(feature)

In [None]:
class PPO():
    def __init__(self,
                 log_dir,
                 observation_space,
                 action_space,
                 actor_critic=ActorCritic,
                 dynamics_model=FwdDyn,
                 optimizer=optim.Adam,
                 hidden_size=64,
                 num_steps=2048,
                 num_processes=1,
                 ppo_epochs=10,
                 num_mini_batch=32,
                 pi_lr=1e-4,
                 v_lr=3e-4,
                 dyn_lr=3e-4,
                 clip_param=0.2,
                 value_coef=0.5,
                 entropy_coef=0.01,
                 dyn_coef=0.5,
                 grad_norm_max=0.5,
                 use_clipped_value_loss=True,
                 use_tensorboard=True,
                 add_intrinsic_reward=False,
                 add_selfish_extrinsic_reward=False,
                 predict_delta_obs=False,
                 device='cpu',
                 share_optim=False,
                 debug=False):

        # setup logging
        self.checkpoint_path = os.path.join(log_dir, 'checkpoint.pth')
        self.checkpoint_path2 = os.path.join(log_dir, 'checkpoint2.pth')

        # ppo hyperparameters
        self.clip_param = clip_param
        self.ppo_epochs = ppo_epochs
        self.num_mini_batch = num_mini_batch

        # loss hyperparameters
        self.pi_lr = pi_lr
        self.v_lr = v_lr
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.dyn_coef = dyn_coef

        # clip values
        self.grad_norm_max = grad_norm_max
        self.use_clipped_value_loss = use_clipped_value_loss
        self.add_intrinsic_reward = add_intrinsic_reward
        self.add_selfish_extrinsic_reward = add_selfish_extrinsic_reward
        self.predict_delta_obs = predict_delta_obs

        # data normalization
        self.obs_mean = None
        self.obs_var = None

        # setup actor critic
        self.actor_critic = actor_critic(
            num_inputs=observation_space.shape[0],
            hidden_size=hidden_size,
            num_outputs=action_space.shape[0])

        # setup dynamics model
        if self.add_intrinsic_reward:
            dynamics_dim = observation_space.shape[0] + action_space.shape[0]
            self.dynamics_model = dynamics_model(num_inputs=dynamics_dim,
                                                 hidden_size=hidden_size,
                                                 num_outputs=observation_space.shape[0])

        # setup optimizers
        self.share_optim = share_optim
        if self.share_optim:
            if self.add_intrinsic_reward:
                self.optimizer = optimizer(list(self.actor_critic.parameters()) + list(self.dynamics_model.parameters()), lr=pi_lr)
            else:
                self.optimizer = optimizer(self.actor_critic.parameters(), lr=pi_lr)
        else:
            self.policy_optimizer = optimizer(self.actor_critic.policy.parameters(), lr=pi_lr)
            self.value_fn_optimizer = optimizer(self.actor_critic.value_fn.parameters(), lr=v_lr)
            if self.add_intrinsic_reward:
                self.dynamics_optimizer = optimizer(self.dynamics_model.parameters(), lr=dyn_lr)

        # create rollout storage
        self.num_processes = num_processes
        self.rollouts = Rollouts(num_steps, num_processes,
                                 observation_space.shape,
                                 action_space,
                                 device)
        
        # teammate is set to None by default
        self.teammate = None
    
    def setTeammate(self, tm):
      self.teammate = tm

    
    def train(self):
        self.actor_critic.train()
        if self.add_intrinsic_reward:
            self.dynamics_model.train()

    def eval(self):
        self.actor_critic.eval()
        if self.add_intrinsic_reward:
            self.dynamics_model.eval()

    def select_action(self, step):
        with torch.no_grad():
            return self.actor_critic.select_action(self.rollouts.obs[step])

    def evaluate_action(self, obs, action):
        return self.actor_critic.evaluate_action(obs, action)

    def get_value(self, obs):
        with torch.no_grad():
            return self.actor_critic.get_value(obs)

    def store_rollout(self, obs, action, action_log_probs, value, reward, intrinsic_reward, selfish_extrinsic_reward, done, agentHolding):
        masks = torch.tensor(1.0 - done.astype(np.float32)).view(-1, 1)
        self.rollouts.insert(obs, action, action_log_probs, value, reward, intrinsic_reward, selfish_extrinsic_reward, masks, agentHolding)

    def compute_returns(self, gamma, use_gae=True, gae_lambda=0.95):
        with torch.no_grad():
            next_value = self.actor_critic.get_value(self.rollouts.obs[-1]).detach()
        if self.add_intrinsic_reward:
            self.rollouts.rewards += self.rollouts.intrinsic_rewards
        
        # adding selfish extrinsic reward here
        if self.add_selfish_extrinsic_reward:
            self.rollouts.rewards += self.rollouts.selfish_extrinsic_rewards

        self.rollouts.compute_returns(next_value, gamma, use_gae, gae_lambda)

    def compute_intrinsic_reward(self, step):
        with torch.no_grad():

            if self.teammate is None:
              return 0
            else:
              obs = self.rollouts.obs[step]
              action = self.rollouts.actions[step]
              teammateAction = self.teammate.rollouts.actions[step]
              next_obs = self.rollouts.obs[step + 1]
              if self.predict_delta_obs:
                  next_obs = (next_obs - obs)
              next_obs_preds = self.teammate.dynamics_model(self.dynamics_model(obs, action), teammateAction)

              # change to be l2 norm
              return 0.5 * (next_obs_preds - next_obs).pow(2).sum(-1).unsqueeze(-1)
        
    def compute_selfish_extrinsic_reward(self, step):
        with torch.no_grad():
            if self.teammate is None:
              return 0
            else:
              lastTouched = self.rollouts.agentHolding[step] # boolean if this agent was the last one to touch the hammer
              beforeHeld = self.rollouts.agentHolding[step] or self.teammate.rollouts.agentHolding[step] # boolean if there is any arm holding the hammer
              afterHeld = self.rollouts.agentHolding[step + 1] or self.teammate.rollouts.agentHolding[step + 1] # boolean if there is any arm holding the hammer

              # reward if suddenly there is no arm holding hammer and this agent as the one that last touched
              if lastTouched and (not beforeHeld) and afterHeld:
                return 1
              else:
                return 0

            




            

    def update(self, obs_mean, obs_var):
        self.obs_mean = obs_mean
        self.obs_var = obs_var
        tot_loss, pi_loss, v_loss, dyn_loss, ent, kl, delta_p, delta_v = self._update()

        self.rollouts.after_update()
        return tot_loss, pi_loss, v_loss, dyn_loss, ent, kl, delta_p, delta_v

    def compute_loss(self, sample):
        # get sample batch
        obs_batch, actions_batch, next_obs_batch, value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_target = sample

        '''print("Obs batch:")
        print(obs_batch)
        print("Actions batch:")
        print(actions_batch)'''
        # evaluate actions
        values, action_log_probs, entropy = self.actor_critic.evaluate_action(obs_batch, actions_batch)

        # compute policy loss
        ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
        sur1 = ratio * adv_target
        sur2 = torch.clamp(ratio, 1 - self.clip_param, 1 + self.clip_param) * adv_target
        policy_loss = -torch.min(sur1, sur2).mean()

        # compute value loss
        if self.use_clipped_value_loss:
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param, self.clip_param)
            value_losses = (return_batch - values).pow(2).mean()
            value_losses_clipped = (return_batch - value_pred_clipped).pow(2).mean()
            value_loss = 0.5 * torch.max(value_losses, value_losses_clipped).mean()
        else:
            value_loss = 0.5 * (return_batch - values).pow(2).mean()

        # compute dynamics loss
        if self.add_intrinsic_reward:
            dynamics_loss = self.compute_dynamics_loss(obs_batch, actions_batch, next_obs_batch, masks_batch)
        else:
            dynamics_loss = 0

        # compute total loss
        total_loss =  self.value_coef * value_loss + self.dyn_coef * dynamics_loss \
                    + (policy_loss - self.entropy_coef * entropy)

        # compute kl divergence
        kl = (old_action_log_probs_batch - action_log_probs).mean().detach()

        return total_loss, policy_loss, value_loss, dynamics_loss, entropy, kl

    def compute_dynamics_loss(self, obs, action, next_obs, masks):
        if self.predict_delta_obs:
            next_obs = (next_obs - obs)
        next_obs_preds = self.dynamics_model(obs, action)
        return 0.5 * (next_obs_preds - next_obs).pow(2).sum(-1).unsqueeze(-1).mean()

    def _update(self):
        # compute and normalize advantages
        advantages = self.rollouts.returns[:-1] - self.rollouts.value_preds[:-1]
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)

        # policy and value losses before gradient update
        with torch.no_grad():
            # Get whole batch of data
            update_generator = self.rollouts.feed_forward_generator(advantages, num_mini_batch=1)
            for update_sample in update_generator:
                _, policy_loss_old, value_loss_old, _, _, _ = self.compute_loss(update_sample)

        total_loss_epoch = 0
        policy_loss_epoch = 0
        value_loss_epoch = 0
        dynamics_loss_epoch = 0
        entropy_epoch = 0
        kl_epoch = 0

        for epoch in range(self.ppo_epochs):
            data_generator = self.rollouts.feed_forward_generator(advantages, self.num_mini_batch)

            for sample in data_generator:
                total_loss, policy_loss, value_loss, dynamics_loss, entropy, kl = self.compute_loss(sample)

                if self.share_optim:
                    self.optimizer.zero_grad()
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.grad_norm_max)
                    self.optimizer.step()

                    if not self.add_intrinsic_reward:
                        dynamics_loss = torch.tensor(0).view(1, 1)
                else:
                    self.policy_optimizer.zero_grad()
                    (policy_loss - self.entropy_coef * entropy).backward()
                    torch.nn.utils.clip_grad_norm_(self.actor_critic.policy.parameters(), self.grad_norm_max)
                    self.policy_optimizer.step()

                    self.value_fn_optimizer.zero_grad()
                    value_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.actor_critic.value_fn.parameters(), self.grad_norm_max)
                    self.value_fn_optimizer.step()

                    if self.add_intrinsic_reward:
                        self.dynamics_optimizer.zero_grad()
                        dynamics_loss.backward()
                        torch.nn.utils.clip_grad_norm_(self.dynamics_model.parameters(), self.grad_norm_max)
                        self.dynamics_optimizer.step()
                    else:
                        dynamics_loss = torch.tensor(0).view(1, 1)

                total_loss_epoch += total_loss.item()
                policy_loss_epoch += policy_loss.item()
                value_loss_epoch += value_loss.item()
                dynamics_loss_epoch += dynamics_loss.item()
                entropy_epoch += entropy.item()
                kl_epoch += kl.item()

        num_updates = (self.ppo_epochs + 1) * self.num_mini_batch
        total_loss_epoch /= num_updates
        policy_loss_epoch /= num_updates
        value_loss_epoch /= num_updates
        dynamics_loss_epoch /= num_updates
        entropy_epoch /= num_updates
        kl_epoch /= num_updates


        # policy and value losses after gradient update
        with torch.no_grad():
            _, policy_loss_new, value_loss_new, _, _, _ = self.compute_loss(update_sample)
            delta_p = policy_loss_new - policy_loss_old
            delta_v = value_loss_new - value_loss_old

        return total_loss_epoch, policy_loss_epoch, value_loss_epoch, dynamics_loss_epoch, entropy_epoch, kl_epoch, delta_p.item(), delta_v.item()
    
    def save_checkpoint(self, path=None):
        # create checkpoint dict
        checkpoint = {
            'share_optim': self.share_optim,
            'add_intrinsic_reward': self.add_intrinsic_reward,
            'obs_mean': self.obs_mean,
            'obs_var': self.obs_var}

        # save models
        checkpoint['actor_critic'] = self.actor_critic.state_dict()
        if self.add_intrinsic_reward:
            checkpoint['dynamics_model'] = self.dynamics_model.state_dict()

        # save optimizer(s)
        if self.share_optim:
            checkpoint['optimizer'] = self.optimizer.state_dict()
        else:
            checkpoint['policy_optimizer'] = self.policy_optimizer.state_dict()
            checkpoint['value_fn_optimizer'] = self.value_fn_optimizer.state_dict()
            if self.add_intrinsic_reward:
                checkpoint['dynamics_model'] = self.dynamics_model.state_dict()
                checkpoint['dynamics_optimizer'] = self.dynamics_optimizer.state_dict()

        if path is None:
            torch.save(checkpoint, self.checkpoint_path)
            torch.save(self.actor_critic, self.checkpoint_path2)
        else:
            torch.save(checkpoint, path)
    
    def load_checkpoint(self, path):
        # load checkpoint
        checkpoint = torch.load(path)
        self.share_optim = checkpoint['share_optim']
        self.add_intrinsic_reward = checkpoint['add_intrinsic_reward']
        self.obs_mean = checkpoint['obs_mean']
        self.obs_var = checkpoint['obs_var']

        # load models
        self.actor_critic.load_state_dict(checkpoint['actor_critic'])
        if self.add_intrinsic_reward:
            self.dynamics_model.load_state_dict(checkpoint['dynamics_model'])

        # load optimizer(s)
        if self.share_optim:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            self.policy_optimizer.load_state_dict(checkpoint['policy_optimizer'])
            self.value_fn_optimizer.load_state_dict(checkpoint['value_fn_optimizer'])
            if self.add_intrinsic_reward:
                self.dynamics_optimizer.load_state_dict(checkpoint['dynamics_optimizer'])

    def load_models(self, path):
        checkpoint = torch.load(path)
        # load models
        self.actor_critic.load_state_dict(checkpoint['actor_critic'])
        if self.add_intrinsic_reward:
            self.dynamics_model.load_state_dict(checkpoint['dynamics_model'])
        del checkpoint

In [None]:
def return_observation_values(base):
    count = 1e-4
    mean = np.zeros(base.shape, 'float64')
    var = np.ones(base.shape, 'float64')
    batch_mean = np.mean(base, axis=0)
    batch_var = np.var(base, axis=0)
    batch_count = base.shape[0]

    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    base = torch.tensor(base.reshape((1, len(base))))
    base_clip = torch.tensor(np.clip((base - new_mean) / np.sqrt(new_var + new_count), -10., 10.))
    base_mean = torch.tensor(new_mean.reshape((1, len(new_mean))))
    base_var = torch.tensor(new_var.reshape((1, len(new_var))))

    return np.array([base, base_clip, base_mean, base_var])

In [None]:
def form_observations(obs):
    agent1_observation_space = np.empty(87)
    agent2_observation_space = np.empty(87)

    start1 = 0
    start2 = 0
    for key in obs:
        if "robot0" in key:
            agent1_observation_space[start1:start1 + len(obs[key])] = obs[key]
            start1 += len(obs[key])
        elif "robot1" in key:
            agent2_observation_space[start2:start2 + len(obs[key])] = obs[key]
            start2 += len(obs[key])
        elif "hammer" in key:
            agent1_observation_space[start1:start1 + len(obs[key])] = obs[key]
            start1 += len(obs[key])
            agent2_observation_space[start2:start2 + len(obs[key])] = obs[key]
            start2 += len(obs[key])
        elif "object" in key:
            agent1_observation_space[start1:start1 + len(obs[key])] = obs[key]
            start1 += len(obs[key])
            agent2_observation_space[start2:start2 + len(obs[key])] = obs[key]
            start2 += len(obs[key])

    agent1_complete_space = return_observation_values(agent1_observation_space)
    agent2_complete_space = return_observation_values(agent2_observation_space)

    return agent1_observation_space, agent1_complete_space, agent2_observation_space, agent2_complete_space

In [None]:
# create environment instance

configure('/content', tbX = True)

env = suite.make(
    env_name="TwoArmHandover", # try with other tasks like "Stack" and "Door"
    robots=["Panda", "Panda"],  # try with other robots like "Sawyer" and "Jaco"
    use_object_obs=True,
    ignore_done=True
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))

torch.set_num_threads(1)

# reset the environment
obs = env.reset()

# prepare observation format
agent1_observation_space, agent1_complete_space, agent2_observation_space, agent2_complete_space = form_observations(obs)

# experiment switches
use_intrinsic_reward = True
max_intrinsic_reward = None
intrinsic_coefficient = 1
selfish_extrinsic_coefficient = -1
use_selfish_extrinsic_reward = True

agent1 = PPO('/content', agent1_observation_space, env.action_spec[0], add_intrinsic_reward=use_intrinsic_reward, add_selfish_extrinsic_reward=use_selfish_extrinsic_reward)
agent2 = PPO('/content', agent2_observation_space, env.action_spec[1], add_intrinsic_reward=use_intrinsic_reward, add_selfish_extrinsic_reward=use_selfish_extrinsic_reward)

agent1.setTeammate(agent2)
agent2.setTeammate(agent1)

agents = [agent1, agent2]

num_env_steps = 1000 #1e07
num_steps = 2 #2048
num_updates = num_env_steps // agent1.num_processes // num_steps

gamma = 0.9
use_gae = True
gae_lambda = 0.95

# start training
log_interval = 1
checkpoint_interval = 20
use_tensorboard = True
start = time.time()

agent_index = -1
for agent in agents:
    agent.actor_critic = agent.actor_critic.to(device)
    if use_intrinsic_reward:
        agent.dynamics_model = agent.dynamics_model.to(device)
    agent_index += 1
    if agent_index == 0:
        obs = agent1_complete_space
    else:
        obs = agent2_complete_space
    agent.rollouts.obs[0].copy_(obs[1])
    agent.rollouts.to(device)
    agent.train()

for update in range(num_env_steps):
    print(update)

    for optimizer in [agent1.policy_optimizer, agent1.value_fn_optimizer]:
        update_linear_schedule(optimizer=optimizer,
                                  update=update,
                                  total_num_updates=num_updates,
                                  initial_lr=agent1.pi_lr)

    agent1_extrinsic_rewards = []
    agent2_extrinsic_rewards = []
    episode_length = []
    agent1_intrinsic_rewards = []
    agent2_intrinsic_rewards = []
    solved_episodes = []
    agent1_selfish_extrinsic_rewards = []
    agent2_selfish_extrinsic_rewards = []

    for step in range(num_steps):

        agent_index = -1
        for agent in agents:
            agent_index += 1
            # select action
            value, action, action_log_probs = agent.select_action(step)

            # take a step in the environment
            obs, reward, done, infos = env.step(action[0].cpu().detach().numpy())

            # prepare observations
            agent1_observation_space, agent1_complete_space, agent2_observation_space, agent2_complete_space = form_observations(obs)
            if agent_index == 0:
                obs = agent1_complete_space
            else:
                obs = agent2_complete_space

            # calculate intrinsic reward
            if use_intrinsic_reward:
                intrinsic_reward = intrinsic_coefficient * agent.compute_intrinsic_reward(step)
                if max_intrinsic_reward is not None:
                    intrinsic_reward = torch.clamp(agent.compute_intrinsic_reward(step), 0.0, max_intrinsic_reward)
            else:
                intrinsic_reward = torch.tensor(0).view(1, 1)
            
            # handle extrinsic reward
            reward = torch.tensor(reward)

            # calculate selfish extrinsic reward
            selfish_extrinsic_reward = 0 # simply declaring variable
            if use_selfish_extrinsic_reward:
                selfish_extrinsic_reward = selfish_extrinsic_coefficient * agent.compute_selfish_extrinsic_reward(step)
            
            # handle selfish extrinsic reward
            selfish_extrinsic_reward = torch.tensor(selfish_extrinsic_reward)


            # save both types of rewards
            if agent_index == 0:
                agent1_intrinsic_rewards.extend(list(intrinsic_reward.cpu().numpy().reshape(-1)))
                agent1_extrinsic_rewards.extend(list(reward.cpu().numpy().reshape(-1)))
                agent1_selfish_extrinsic_rewards.extend(list(selfish_extrinsic_reward.cpu().numpy().reshape(-1)))
            else:
                agent2_intrinsic_rewards.extend(list(intrinsic_reward.cpu().numpy().reshape(-1)))
                agent2_extrinsic_rewards.extend(list(reward.cpu().numpy().reshape(-1)))
                agent2_selfish_extrinsic_rewards.extend(list(selfish_extrinsic_reward.cpu().numpy().reshape(-1)))

            # store experience
            done_value = np.zeros(1)
            if done:
                done_value = np.ones(1)

            # Check if any Arm's gripper is grasping the hammer handle
            curGripper = env.robots[agent_index].gripper
            agentHolding = env._check_grasp(gripper=curGripper, object_geoms=env.hammer)
            agentHolding = torch.tensor(agentHolding)

            agent.store_rollout(obs[1], action, action_log_probs, value, reward, intrinsic_reward, selfish_extrinsic_reward, done_value, agentHolding)
    
    agent_index = -1
    for agent in agents:
        agent_index += 1
        # compute returns
        agent.compute_returns(gamma, use_gae, gae_lambda)
        # update policy and value_fn, reset rollout storage
        tot_loss, pi_loss, v_loss, dyn_loss, entropy, kl, delta_p, delta_v =  agent.update(obs_mean=obs[2], obs_var=obs[3])

    # new episode
    if done:
        env = suite.make(
            env_name="TwoArmHandover", # try with other tasks like "Stack" and "Door"
            robots=["Panda", "Panda"],  # try with other robots like "Sawyer" and "Jaco"
            use_object_obs=True,
            ignore_done=True
        )
        obs = env.reset()
        
    # checkpoint model
    if (update + 1) % checkpoint_interval == 0:
        agent.save_checkpoint()

    # log data
    if update % log_interval == 0:
        current = time.time()
        elapsed = current - start
        total_steps = (update + 1) * agent.num_processes * num_steps
        fps =int(total_steps / (current - start))

        logkv('Time/Updates', update)
        logkv('Time/Total Steps', total_steps)
        logkv('Time/FPS', fps)
        logkv('Time/Current', current)
        logkv('Time/Elapsed', elapsed)
        logkv('Time/Epoch', elapsed)
        if agent_index == 0:
            extrinsic_rewards = agent1_extrinsic_rewards
            intrinsic_rewards = agent1_intrinsic_rewards
            selfish_extrinsic_rewards = agent1_selfish_extrinsic_rewards
        else:
            extrinsic_rewards = agent2_extrinsic_rewards
            intrinsic_rewards = agent2_intrinsic_rewards
            selfish_extrinsic_rewards = agent2_selfish_extrinsic_rewards

        logkv('Extrinsic/Mean', np.mean(extrinsic_rewards))
        logkv('Extrinsic/Median', np.median(extrinsic_rewards))
        logkv('Extrinsic/Min', np.min(extrinsic_rewards))
        logkv('Extrinsic/Max', np.max(extrinsic_rewards))
        logkv('SelfishExtrinsic/Mean', np.mean(selfish_extrinsic_rewards))
        logkv('SelfishExtrinsic/Median', np.median(selfish_extrinsic_rewards))
        logkv('SelfishExtrinsic/Min', np.min(selfish_extrinsic_rewards))
        logkv('SelfishExtrinsic/Max', np.max(selfish_extrinsic_rewards))
        logkv('Intrinsic/Mean', np.mean(intrinsic_rewards))
        logkv('Intrinsic/Median', np.median(intrinsic_rewards))
        logkv('Intrinsic/Min', np.min(intrinsic_rewards))
        logkv('Intrinsic/Max', np.max(intrinsic_rewards))
        logkv('Loss/Total', tot_loss)
        logkv('Loss/Policy', pi_loss)
        logkv('Loss/Value', v_loss)
        logkv('Loss/Entropy', entropy)
        logkv('Loss/KL', kl)
        logkv('Loss/DeltaPi', delta_p)
        logkv('Loss/DeltaV', delta_v)
        logkv('Loss/Dynamics', dyn_loss)
        logkv('Value/Mean', np.mean(agent.rollouts.value_preds.cpu().data.numpy()))
        logkv('Value/Median', np.median(agent.rollouts.value_preds.cpu().data.numpy()))
        logkv('Value/Min', np.min(agent.rollouts.value_preds.cpu().data.numpy()))
        logkv('Value/Max', np.max(agent.rollouts.value_preds.cpu().data.numpy()))

        if use_tensorboard:
            add_scalar('reward/mean', np.mean(extrinsic_rewards), total_steps, elapsed)
            add_scalar('reward/median', np.median(extrinsic_rewards), total_steps, elapsed)
            add_scalar('reward/min', np.min(extrinsic_rewards), total_steps, elapsed)
            add_scalar('reward/max', np.max(extrinsic_rewards), total_steps, elapsed)
            add_scalar('reward/mean', np.mean(selfish_extrinsic_rewards), total_steps, elapsed)
            add_scalar('reward/median', np.median(selfish_extrinsic_rewards), total_steps, elapsed)
            add_scalar('reward/min', np.min(selfish_extrinsic_rewards), total_steps, elapsed)
            add_scalar('reward/max', np.max(selfish_extrinsic_rewards), total_steps, elapsed)
            add_scalar('intrinsic/mean', np.mean(intrinsic_rewards), total_steps, elapsed)
            add_scalar('intrinsic/median', np.median(intrinsic_rewards), total_steps, elapsed)
            add_scalar('intrinsic/min', np.min(intrinsic_rewards), total_steps, elapsed)
            add_scalar('intrinsic/max', np.max(intrinsic_rewards), total_steps, elapsed)
            add_scalar('loss/total', tot_loss, total_steps, elapsed)
            add_scalar('loss/policy', pi_loss, total_steps, elapsed)
            add_scalar('loss/value', v_loss, total_steps, elapsed)
            add_scalar('loss/entropy', entropy, total_steps, elapsed)
            add_scalar('loss/kl', kl, total_steps, elapsed)
            add_scalar('loss/delta_p', delta_p, total_steps, elapsed)
            add_scalar('loss/delta_v', delta_v, total_steps, elapsed)
            add_scalar('loss/dynamics', dyn_loss, total_steps, elapsed)
            add_scalar('value/mean', np.mean(agent.rollouts.value_preds.cpu().data.numpy()), total_steps, elapsed)
            add_scalar('value/median', np.median(agent.rollouts.value_preds.cpu().data.numpy()), total_steps, elapsed)
            add_scalar('value/min', np.min(agent.rollouts.value_preds.cpu().data.numpy()), total_steps, elapsed)
            add_scalar('value/max', np.max(agent.rollouts.value_preds.cpu().data.numpy()), total_steps, elapsed)

            dumpkvs()
    


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
drive.flush_and_unmount()

In [None]:
!tensorboard --logdir=content