In [33]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.init as init
from torch.distributions import Categorical, Normal
import numpy as np
import gym
from torch.autograd import Variable
import torch.nn.functional as F
import random
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


# Policy Network

In [34]:
class Policy(nn.Module):
    def __init__(self, img_dim, additional_feature_dim, action_dim):
        super(Policy, self).__init__()
        self.im = nn.Sequential(
            nn.Conv2d(img_dim[2], 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Flatten()
        )

        self.feature_dim = 128 * (img_dim[0]//8 + 1)*(img_dim[1] // 8 + 1)
        total_fc_input_dim = self.feature_dim + additional_feature_dim


        self.fc = nn.Sequential(
            nn.Linear(total_fc_input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )

        self.action_mean = nn.Linear(64, action_dim)
        init.kaiming_uniform_(self.action_mean.weight, nonlinearity='leaky_relu')
        self.action_mean.bias.data.fill_(0.1)

        self.action_log_std = nn.Parameter(torch.zeros(1, action_dim))

    def forward(self, img, combined_features):
        x = self.im(img)
        x = torch.cat([x, combined_features], dim=-1)
        x = self.fc(x)
        action_mean = self.action_mean(x)
        return action_mean

    def log_std_and_std(self, action_means):
        action_log_std = self.action_log_std.expand_as(action_means)
        return action_log_std, torch.exp(action_log_std)

# Discriminator Network

In [36]:
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(query, key)
        attention = F.softmax(energy, dim=-1)
        value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        return self.gamma * out + x

class Discriminator(nn.Module):
    def __init__(self, img_dim, combine_dim, action_dim):
        super(Discriminator, self).__init__()
        self.im = nn.Sequential(
            nn.Conv2d(img_dim[2], 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            SelfAttention(128),
            nn.Flatten()
        )

        self.feature_dim = 128 * (img_dim[0]//8 + 1)*(img_dim[1] // 8 + 1)

        self.fc = nn.Sequential(
            nn.Linear(self.feature_dim + combine_dim + action_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, img, combined_features, action):
        x = self.im(img)
        x = torch.cat([x, combined_features, action], dim=-1)
        return self.fc(x)

# Main Training Function

In [37]:
def train_info_gail(policy, discriminator, lambda1, lambda2, expert, demo, epochs=2):

    discriminator_optimizer = optim.AdamW(discriminator.parameters(), lr=0.0001, weight_decay=1e-2)
    policy_optimizer = optim.AdamW(policy.parameters(), lr=0.001, weight_decay=1e-2)

    adv_loss = nn.BCELoss().to(device)
    mse_loss = nn.MSELoss().to(device)
    for epoch in range(epochs):
        random.shuffle(demos)
        progress_bar = tqdm(enumerate(demos), total=len(demos), desc=f"Epoch {epoch}")



        for i, demo_key in progress_bar:
            demo_grp = expert["data/{}".format(demo_key)]

            # images, objects, position, quaternion, and joint position
            e_imgs = demo_grp['obs']['agentview_image'][:]
            e_obj = demo_grp['obs']['object'][:]
            e_eef_pos = demo_grp['obs']['robot0_eef_pos'][:]
            e_eef_quat = demo_grp['obs']['robot0_eef_quat'][:]
            e_joint_pos = demo_grp['obs']['robot0_joint_pos'][:]


            robot_eye_in_hand_image = demo_grp['obs']['robot0_eye_in_hand_image'][:]


# 'robot0_eef_vel_ang', 'robot0_eef_vel_lin', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos', 'robot0_gripper_qvel', 'robot0_joint_pos', 'robot0_joint_pos_cos', 'robot0_joint_pos_sin', 'robot0_joint_vel'


            # actions and state
            e_actions = demo_grp['actions'][:]
            e_states = demo_grp['states'][:]

            # convert to tensors
            e_imgs = torch.tensor(e_imgs, dtype=torch.float32).permute(0, 3, 1, 2).to(device)


            e_obj = torch.tensor(e_obj, dtype=torch.float32).to(device)
            e_eef_pos = torch.tensor(e_eef_pos, dtype=torch.float32).to(device)
            e_eef_quat = torch.tensor(e_eef_quat, dtype=torch.float32).to(device)
            e_joint_pos = torch.tensor(e_joint_pos, dtype=torch.float32).to(device)
            e_states = torch.tensor(e_states, dtype=torch.float32).to(device)
            e_actions = torch.tensor(e_actions, dtype=torch.float32).to(device)

            combined_features = torch.cat([e_obj, e_eef_pos, e_eef_quat, e_joint_pos, e_states], dim=1)

            real_labels = torch.ones(e_actions.shape[0], 1, device=device)
            fake_labels = torch.zeros(e_actions.shape[0], 1, device=device)

            # Discriminator Update
            discriminator_optimizer.zero_grad()


            real_predictions = discriminator(e_imgs, combined_features, e_actions)
            fake_actions = policy(e_imgs, combined_features).detach()
            fake_predictions = discriminator(e_imgs, combined_features, fake_actions)

            d_loss = adv_loss(real_predictions, real_labels) + adv_loss(fake_predictions, fake_labels)
            d_loss.backward()
            discriminator_optimizer.step()

            # Generator Update
            policy_optimizer.zero_grad()
            fake_actions = policy(e_imgs, combined_features)
            fake_predictions = discriminator(e_imgs, combined_features, fake_actions)

            g_loss = lambda1 * adv_loss(fake_predictions, real_labels)
            bc_loss = lambda2 * mse_loss(fake_actions, e_actions)
            total_policy_loss = g_loss + bc_loss
            total_policy_loss.backward()
            policy_optimizer.step()


            progress_bar.set_description(f"""Epoch {epoch+1} | Iter {i} | Discriminator Loss: {d_loss.item():.4f} | Policy Loss: {total_policy_loss.item():.4f}""")


# Feeding data

In [5]:
import os
# First, we need to decide where to host the runtime storage
USE_GDRIVE_STORAGE = True

if not USE_GDRIVE_STORAGE:
    # Option 1: use the colab runtime storage. All trained model and downloaded
    # will disappear after you disconnect from the runtime.
    WS_DIR = "/content/"
else:
    # Option 2: use your google drive as the runtime storage. You need to grant
    # permission for the colab runtime to access your google drive. You also
    # need to decide on a workspace for robomimic. In this case, we've created a
    # folder called "colab_ws" in Google Drive.
    from google.colab import drive
    drive.mount('/content/drive')
    WS_DIR = "/content/drive/MyDrive/colab_ws/" # this should be the absolute path, e.g., "/content/drive/MyDrive/my-ws/"
    assert os.path.exists(WS_DIR)

%cd $WS_DIR

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/colab_ws


In [None]:
# Install the basic requirements
%cd $WS_DIR
!pip install -e robosuite/
!pip install -e robomimic/
!pip install -e mimicgen_environments/
!pip install mujoco

import sys
import os
sys.path.append('./robosuite/')
sys.path.append('./robomimic/')
sys.path.append('./mimicgen_environments/')

In [7]:
DATA_DIR = WS_DIR + "mimicgen_data/"

In [8]:
import json
import h5py
import numpy as np

# enforce that the dataset exists
dataset_path = os.path.join(DATA_DIR, "stack_d0_100.hdf5")
assert os.path.exists(dataset_path)

In [9]:
# open file
f = h5py.File(dataset_path, "r")

# each demonstration is a group under "data"
demos = list(f["data"].keys())
num_demos = len(demos)

print("hdf5 file {} has {} demonstrations".format(dataset_path, num_demos))

hdf5 file /content/drive/MyDrive/colab_ws/mimicgen_data/stack_d0_100.hdf5 has 100 demonstrations


In [None]:
# each demonstration is named "demo_#" where # is a number.
# Let's put the demonstration list in increasing episode order
inds = np.argsort([int(elem[5:]) for elem in demos])
demos = [demos[i] for i in inds]

for ep in demos:
    num_actions = f["data/{}/actions".format(ep)].shape[0]
    print("{} has {} samples".format(ep, num_actions))

In [11]:
# look at first demonstration
demo_key = demos[0]
demo_grp = f["data/{}".format(demo_key)]

# Each observation is a dictionary that maps modalities to numpy arrays, and
# each action is a numpy array. Let's print the observation modalities and look at
# the action taken in the first 5 timesteps of this trajectory.

# print("observation modalities:")
# print(demo_grp["obs"].keys())
# for t in range(5):
#   print("timestep" + str(t) + ":")
#   print(demo_grp["obs"]["robot0_eef_pos"][t])
#   print(demo_grp["obs"]["robot0_joint_pos"][t])
#   print(demo_grp["actions"][t])

In [12]:
env_meta = json.loads(f["data"].attrs["env_args"])

In [13]:
import robomimic.utils.obs_utils as ObsUtils

# We normally need to make sure robomimic knows which observations are images (for the
# data processing pipeline). This is usually inferred from your training config, but
# since we are just playing back demonstrations, we just need to initialize robomimic
# with a dummy spec.
dummy_spec = dict(
    obs=dict(
            low_dim=["robot0_eef_pos"],
            rgb=[],
        ),
)
ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)



using obs modality: low_dim with keys: ['robot0_eef_pos']
using obs modality: rgb with keys: []


In [14]:
import robomimic.utils.env_utils as EnvUtils

# create simulation environment from environment metedata
env = EnvUtils.create_env_from_metadata(
    env_meta=env_meta,
    render=False,            # no on-screen rendering
    render_offscreen=True,   # off-screen rendering to support rendering video frames
)

  ROBOSUITE_DEFAULT_LOGGER.warn("No private macro file found!")


Created environment with name Stack_D0
Action size is 7


In [15]:
init_state = f["data/{}/states".format(demo_key)][0]
print(init_state.shape)


(45,)


  and should_run_async(code)


In [16]:
demo_grp.keys()

<KeysViewHDF5 ['actions', 'dones', 'obs', 'rewards', 'states']>

In [17]:
demo_grp = f["data/{}".format(demo_key)]
print(demo_grp)

<HDF5 group "/data/demo_13" (5 members)>


In [18]:
print(env.get_state()['states'])

[ 0.00000000e+00 -4.03752416e-04  1.93674889e-01  3.07942070e-03
 -2.61380476e+00  9.22365158e-03  2.93857091e+00  7.89112670e-01
  2.08330000e-02 -2.08330000e-02  4.78243633e-02 -1.29777625e-02
  8.30000000e-01  9.53902539e-02  0.00000000e+00  0.00000000e+00
  9.95439953e-01  6.50255014e-02  6.26146972e-02  8.35000000e-01
  4.33069982e-01  0.00000000e+00  0.00000000e+00  9.01360300e-01
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00]


In [38]:
# policy flag
use_existing_policy = False

if not use_existing_policy:
  policy = Policy((84, 84, 3), 82, 7)
else:
  PATH = WS_DIR + 'policy.pt'
  policy = torch.load(PATH)

discriminator = Discriminator((84, 84, 3), 82, 7)

# use GPU
policy.to(device)
discriminator.to(device)

# weight values
lambda1 = 0.15
lambda2 = 0.5

train_info_gail(policy, discriminator, lambda1, lambda2, f, demos, epochs=50)


Epoch 1 | Iter 99 | Discriminator Loss: 1.3997 | Policy Loss: 0.2027: 100%|██████████| 100/100 [00:07<00:00, 13.20it/s]
Epoch 2 | Iter 99 | Discriminator Loss: 1.3987 | Policy Loss: 0.1803: 100%|██████████| 100/100 [00:07<00:00, 13.25it/s]
Epoch 3 | Iter 99 | Discriminator Loss: 1.3927 | Policy Loss: 0.1832: 100%|██████████| 100/100 [00:07<00:00, 13.39it/s]
Epoch 4 | Iter 99 | Discriminator Loss: 1.3851 | Policy Loss: 0.1564: 100%|██████████| 100/100 [00:07<00:00, 13.27it/s]
Epoch 5 | Iter 99 | Discriminator Loss: 1.3894 | Policy Loss: 0.1838: 100%|██████████| 100/100 [00:07<00:00, 13.34it/s]
Epoch 6 | Iter 99 | Discriminator Loss: 1.3886 | Policy Loss: 0.1533: 100%|██████████| 100/100 [00:07<00:00, 13.47it/s]
Epoch 7 | Iter 99 | Discriminator Loss: 1.3860 | Policy Loss: 0.1476: 100%|██████████| 100/100 [00:07<00:00, 13.34it/s]
Epoch 8 | Iter 99 | Discriminator Loss: 1.3894 | Policy Loss: 0.1353: 100%|██████████| 100/100 [00:07<00:00, 13.49it/s]
Epoch 9 | Iter 99 | Discriminator Loss: 

In [39]:
import imageio

# prepare to write playback trajectories to video
video_path = os.path.join(DATA_DIR, "playback.mp4")
video_writer = imageio.get_writer(video_path, fps=20)

In [40]:
def playback_trajectory(demo_key):
    init_state = f["data/{}/states".format(demo_key)][0]
    model_xml = f["data/{}".format(demo_key)].attrs["model_file"]
    initial_state_dict = dict(states=init_state, model=model_xml)
    env.reset_to(initial_state_dict)
    e_actions = f["data/{}/actions".format(demo_key)][:]
    state = env.get_state()['states']
    ob = env.get_observation()

    img = env.render(mode="rgb_array", height=84, width=84, camera_name="agentview")
    obj = ob['object']

    eef_pos = ob['robot0_eef_pos']
    eef_quat = ob['robot0_eef_quat']
    joint_pos = ob['robot0_joint_pos']

    for t in tqdm(range(1000)):
        img_tensor = torch.tensor(img.copy(), dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2).to(device)
        # robot_eye_in_hand_image_tensor = torch.tensor(robot_eye_in_hand_image.copy(), dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2).to(device)

        obj_tensor = torch.tensor(obj, dtype=torch.float32).unsqueeze(0).to(device)
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        eef_pos_tensor = torch.tensor(eef_pos, dtype=torch.float32).unsqueeze(0).to(device)
        eef_quat_tensor = torch.tensor(eef_quat, dtype=torch.float32).unsqueeze(0).to(device)
        joint_pos_tensor = torch.tensor(joint_pos, dtype=torch.float32).unsqueeze(0).to(device)


        combined_features = torch.cat([obj_tensor, eef_pos_tensor, eef_quat_tensor, joint_pos_tensor, state_tensor], dim=1)


        action_mean = policy(img_tensor, combined_features)
        action_log_std, action_std = policy.log_std_and_std(action_mean)
        action = action_mean

        env.step(action.detach().cpu().numpy()[0])

        state = env.get_state()['states']
        ob = env.get_observation()
        img = env.render(mode="rgb_array", height=84, width=84, camera_name="agentview")
        obj = ob['object']
        video_img = env.render(mode="rgb_array", height=512, width=512, camera_name="agentview")
        video_writer.append_data(video_img)

In [41]:
# playback the first 3 demos and record them to a video file
for ep in demos[:1]:
    print("Playing back demo key: {}".format(ep))
    playback_trajectory(ep)

# done writing video
video_writer.close()

Playing back demo key: demo_397


100%|██████████| 1000/1000 [02:59<00:00,  5.56it/s]


In [42]:
# view the trajectories!
from IPython.display import Video
Video(video_path, embed=True)

In [43]:
demo_key = demos[0]
demo_grp = f["data/{}".format(demo_key)]

In [44]:
print(demo_grp['states'][0])
print(demo_grp['obs']['agentview_image'][:].shape)
print(demo_grp['obs']['object'][0].shape)

[ 0.00000000e+00  2.06476214e-02  1.67461556e-01 -2.34746790e-03
 -2.61381616e+00 -2.18726640e-02  2.93920923e+00  7.55390266e-01
  2.08330000e-02 -2.08330000e-02 -1.82395484e-02  2.28485912e-03
  8.30000000e-01 -3.85554035e-01  0.00000000e+00  0.00000000e+00
  9.22685258e-01  6.48805023e-02 -2.86982265e-02  8.35000000e-01
 -9.36981545e-01  0.00000000e+00  0.00000000e+00  3.49378856e-01
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00]
(123, 84, 84, 3)
(23,)


In [45]:
string = WS_DIR + 'policy_stack.pt'
torch.save(policy, string)