In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import numpy as np

In [2]:
class LowLevelPolicy(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, action_dim=2, hidden_dim=128):
        super(LowLevelPolicy, self).__init__()
        self.rnn = nn.LSTM(state_dim + goal_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, action_dim)

    def forward(self, state, goal):
        input_seq = torch.cat((state, goal), dim=-1)
        out, _ = self.rnn(torch.unsqueeze(input_seq, 0))
        actions = self.fc(out)
        return actions

In [3]:
class GoalProposalVAE(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, latent_dim=20):
        super(GoalProposalVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + goal_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, goal_dim)
        )

    def encode(self, x, c):
        h = self.encoder(torch.cat((x, c), dim=0))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z,c)
        inputs = torch.cat([z, c], 0)
        h3 = self.decoder(inputs)
        return h3

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z, c)
        return decoded, mu, logvar

    # def sample(self, num_samples, y):
    #     with torch.no_grad():
    #         z = torch.randn(num_samples, self.num_hidden)
    #         samples = self.decoder(self.condition_on_label(z, y))
    #     return samples

In [4]:
class ValueNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state, action):
        return self.fc(torch.cat((state, action), dim=-1))

In [5]:
def cvaeLoss(sg, D, mu, logvar, beta=0.0001):
    recon_loss = torch.nn.functional.mse_loss(D, sg)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss

In [6]:
class ActionProposalVAE(nn.Module):
    def __init__(self, state_dim=4, action_dim=2, latent_dim=20):
        super(ActionProposalVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )
        # self.label_projector = nn.Sequential(
        #     nn.Linear(state_dim, latent_dim), nn.ReLU())

    def encode(self, x, c):
        h = self.encoder(torch.cat((x, c), dim=0))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z,c)
        inputs = torch.cat([z, c], 0)
        h3 = self.decoder(inputs)
        return h3

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z, c)
        return decoded, mu, logvar

In [71]:
import numpy as np

import torch
from torch.utils.data import DataLoader

import robomimic
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.test_utils as TestUtils
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.train_utils as TrainUtils
from robomimic.utils.dataset import SequenceDataset


def get_data_loader(dataset_path, trajectory_length=50, batch_size=64):
    """
    Get a data loader to sample batches of data.
    Args:
        dataset_path (str): path to the dataset hdf5
        trajectory_length (int): number of timesteps in a trajectory
        batch_size (int): batch size for the DataLoader
    """
    dataset = SequenceDataset(
        hdf5_path=dataset_path,
        obs_keys=(                      # observations we want to appear in batches
            "robot0_eef_pos",
            "robot0_eef_quat",
            "robot0_gripper_qpos",
            "object",
            "robot0_gripper_qvel",       # Additional gripper velocity
            "robot0_joint_pos",          # 7DOF joint positions
            "robot0_joint_vel",          # 7DOF joint velocities
        ),
        dataset_keys=(                  # can optionally specify more keys here if they should appear in batches
            "actions",
            "rewards",
            "dones",
            "horizon",                  # Additional metadata
            "episode_id",               # Episode information
        ),
        load_next_obs=True,             # load the next observation in each sequence
        frame_stack=1,
        seq_length=trajectory_length,   # length-10 temporal sequences
        # pad last obs per trajectory to ensure all sequences are sampled
        pad_frame_stack=True,
        pad_seq_length=True,            # pad last observation sequence
        get_pad_mask=False,             # do not return padding masks
        goal_mode=None,
        hdf5_cache_mode="all",          # cache dataset in memory to avoid repeated file i/o
        hdf5_use_swmr=True,
        hdf5_normalize_obs=False,       # normalize observations
        filter_by_attribute=None,       # can optionally provide a filter key here
    )

    print("\n============= Created Dataset =============")
    print(dataset)
    # Print additional dataset information
    print(f"Dataset length: {len(dataset)}")
    print(f"Observation keys: {dataset.obs_keys}")
    print(f"Dataset keys: {dataset.dataset_keys}")
    print("")

    # Set batch size to a manageable number
    data_loader = DataLoader(
        dataset=dataset,
        # no custom sampling logic (uniform sampling)
        sampler=None,
        batch_size=batch_size,          # batches of size 32
        shuffle=True,                   # shuffle the dataset
        # use 0 workers (could be set to more if using multiprocessing)
        num_workers=0,
        drop_last=True                  # drop the last incomplete batch
    )

    return data_loader

In [72]:
data_loader = get_data_loader(
    dataset_path="/home/keerthi/IRIS/lift/ph/low_dim_v141.hdf5")

SequenceDataset: loading dataset into memory...
100%|██████████| 200/200 [00:00<00:00, 451.60it/s]
SequenceDataset: caching get_item calls...
100%|██████████| 9666/9666 [00:02<00:00, 4683.54it/s]

SequenceDataset (
	path=/home/keerthi/IRIS/lift/ph/low_dim_v141.hdf5
	obs_keys=('robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos', 'object', 'robot0_gripper_qvel', 'robot0_joint_pos', 'robot0_joint_vel')
	seq_length=50
	filter_key=none
	frame_stack=1
	pad_seq_length=True
	pad_frame_stack=True
	goal_mode=none
	cache_mode=all
	num_demos=200
	num_sequences=9666
)
Dataset length: 9666
Observation keys: ('robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos', 'object', 'robot0_gripper_qvel', 'robot0_joint_pos', 'robot0_joint_vel')
Dataset keys: ('actions', 'rewards', 'dones', 'horizon', 'episode_id')



In [73]:
from tqdm import tqdm


def train_IRIS_full_trajectory(low_level_policy, goal_proposal_vae: GoalProposalVAE, action_vae: ActionProposalVAE, value_network, data_loader, M=15, gamma=0.99):
    """
    Train the IRIS algorithm over full trajectories using a DataLoader that provides batches of trajectories.

    Args:
        low_level_policy (nn.Module): Low-level policy network.
        goal_proposal_vae (nn.Module): Goal proposal VAE network.
        action_vae (nn.Module): Action proposal VAE network.
        value_network (nn.Module): Value network.
        data_loader (DataLoader): DataLoader that provides batches of trajectories.
        M (int): Number of samples for action sampling.
        gamma (float): Discount factor for future rewards.
    """

    # Move models to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    low_level_policy = low_level_policy.to(device)
    goal_proposal_vae = goal_proposal_vae.to(device)
    action_vae = action_vae.to(device)
    value_network = value_network.to(device)

    # Optimizers
    policy_optimizer = optim.Adam(low_level_policy.parameters(), lr=0.0001)
    vae_optimizer = optim.Adam(goal_proposal_vae.parameters(), lr=0.0001)
    value_optimizer = optim.Adam(value_network.parameters(), lr=0.0001)
    action_optimizer = optim.Adam(action_vae.parameters(), lr=0.0001)

    # Statistics for monitoring
    cumulative_rewards = []
    cumulative_policy_loss = []
    cumulative_vae_loss = []
    cumulative_value_loss = []

    # Iterate through batches of trajectories
    for epoch in range(1):
        for batch_idx, trajectory_batch in enumerate(tqdm(data_loader)):
            # Each element in trajectory_batch is a batch of sequences
            batch_size = trajectory_batch['obs']['robot0_eef_pos'].shape[0]

            for i in range(batch_size):
                # Extract the i-th trajectory from the batch
                states = torch.cat([
                    # End-effector position (3,)
                    trajectory_batch['obs']['robot0_eef_pos'][i],
                    # Object-related observation (10,)
                    trajectory_batch['obs']['object'][i],
                    # End-effector orientation (4,)
                    trajectory_batch['obs']['robot0_eef_quat'][i],
                    # Gripper position (2,)
                    trajectory_batch['obs']['robot0_gripper_qpos'][i],
                    # Gripper velocity (2,)
                    trajectory_batch['obs']['robot0_gripper_qvel'][i],
                    # 7DOF joint positions (7,)
                    trajectory_batch['obs']['robot0_joint_pos'][i],
                    # 7DOF joint velocities (7,)
                    trajectory_batch['obs']['robot0_joint_vel'][i]
                ], axis=1).to(device)  # Concatenate along the feature axis

                states = torch.squeeze(states.type(torch.float32))

                actions = trajectory_batch['actions'][i].to(device)
                actions = torch.squeeze(actions.type(torch.float32))

                rewards = trajectory_batch['rewards'][i].to(device)
                rewards = torch.squeeze(rewards.type(torch.float32))

                # Remove last timestep for actions (to match states length)
                actions = actions[:-1]

                sg = states[-1]
                s_start = states[0]
                reward_sg = rewards[-2]
                action_last = actions[-2]
                state_second_last = states[-2]

                # Train Low-Level Policy
                policy_actions = [torch.squeeze(low_level_policy(
                    state, sg)) for state in states[:-1]]
                policy_actions = torch.squeeze(torch.stack(policy_actions))
                policy_loss = nn.MSELoss()(policy_actions, actions)

                policy_optimizer.zero_grad()
                policy_loss.backward()
                policy_optimizer.step()

                # VAE updates (Goal and Action VAEs)
                mu, logvar = goal_proposal_vae.encode(sg, s_start)
                z = goal_proposal_vae.reparameterize(mu, logvar)
                vae_loss = nn.MSELoss()(sg, goal_proposal_vae.decode(z, s_start)) + \
                    0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1)

                mu_a, logvar_a = action_vae.encode(
                    action_last, state_second_last)
                z_a = action_vae.reparameterize(mu_a, logvar_a)
                action_vae_loss = nn.MSELoss()(action_last, action_vae.decode(z_a, state_second_last)) + \
                    0.5 * torch.sum(mu_a.pow(2) +
                                    logvar_a.exp() - logvar_a - 1)

                vae_optimizer.zero_grad()
                action_optimizer.zero_grad()
                (vae_loss + action_vae_loss).backward()
                vae_optimizer.step()
                action_optimizer.step()

                # Sample actions and update value network
                sampled_actions = [action_vae.decode(
                    z_a, sg) for _ in range(M)]
                sampled_actions = torch.stack(sampled_actions)

                values = torch.stack([value_network(sg, a)
                                      for a in sampled_actions])
                max_value = torch.max(values)

                Vbar = torch.squeeze(reward_sg + gamma * max_value.detach())
                value_loss = nn.MSELoss()(Vbar, torch.squeeze(
                    value_network(state_second_last, action_last)))

                value_optimizer.zero_grad()
                value_loss.backward()
                value_optimizer.step()

                # Store losses and rewards for monitoring
                cumulative_rewards.append(reward_sg.item())
                cumulative_policy_loss.append(policy_loss.item())
                cumulative_vae_loss.append(
                    vae_loss.item() + action_vae_loss.item())
                cumulative_value_loss.append(value_loss.item())

            # Print statistics every N batches (you can adjust this interval)
            if (batch_idx+1) % 100 == 0:
                avg_reward = np.mean(cumulative_rewards)
                avg_policy_loss = np.mean(cumulative_policy_loss)
                avg_vae_loss = np.mean(cumulative_vae_loss)
                avg_value_loss = np.mean(cumulative_value_loss)

                print(f"Batch {batch_idx + 1}: Avg Reward: {avg_reward:.4f}, "
                      f"Avg Policy Loss: {avg_policy_loss:.4f}, "
                      f"Avg VAE Loss: {avg_vae_loss:.4f}, "
                      f"Avg Value Loss: {avg_value_loss:.4f}")

                # Reset cumulative stats after printing
                cumulative_rewards = []
                cumulative_policy_loss = []
                cumulative_vae_loss = []
                cumulative_value_loss = []


# Assuming dataset and models are already initialized
state_dim = 35
state_goal_dim = 35
action_dim = 7
latent_dim = 20

low_level_policy = LowLevelPolicy(state_dim, state_goal_dim, action_dim)
goal_proposal_vae = GoalProposalVAE(state_dim, state_goal_dim, latent_dim)
action_vae = ActionProposalVAE(state_dim, action_dim, latent_dim)
value_network = ValueNetwork(state_dim, action_dim)

# Train the IRIS algorithm using the D4RL dataset
train_IRIS_full_trajectory(low_level_policy, goal_proposal_vae,
                           action_vae, value_network, data_loader)

 11%|█         | 16/151 [00:49<06:55,  3.08s/it]


KeyboardInterrupt: 

In [33]:
import robosuite
from robosuite.controllers import load_controller_config

# load default controller parameters for Operational Space Control (OSC)
controller_config = load_controller_config(default_controller="OSC_POSE")



In [65]:

def extract_features(state):
    # Extract and concatenate the desired features from the state
    features = torch.cat([
        # End-effector position (3,)
        torch.tensor(state['robot0_eef_pos'], dtype=torch.float32),
        # Object-related observation (10,)
        torch.tensor(state['object-state'], dtype=torch.float32),
        # End-effector orientation (4,)
        torch.tensor(state['robot0_eef_quat'], dtype=torch.float32),
        # Gripper position (2,)
        torch.tensor(state['robot0_gripper_qpos'], dtype=torch.float32),
        # Gripper velocity (2,)
        torch.tensor(state['robot0_gripper_qvel'], dtype=torch.float32),
        # 7DOF joint positions (7,)
        torch.tensor(
            np.arccos(state['robot0_joint_pos_cos']), dtype=torch.float32),
        # 7DOF joint velocities (7,)
        torch.tensor(state['robot0_joint_vel'], dtype=torch.float32)
    ], dim=0)  # Concatenate into a single 1D tensor

    return features

In [66]:
import numpy as np
import robosuite as suite
import cv2
# create environment instance
env = suite.make(
    env_name="Lift",  # try with other tasks like "Stack" and "Door"
    robots="Sawyer",  # try with other robots like "Sawyer" and "Jaco"
    has_renderer=True,
    has_offscreen_renderer=False,
    use_camera_obs=False,
)

# height, width, _ = env.render().shape
# fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for MP4 video
# video_writer = cv2.VideoWriter(
# "env_policy_video_lift.mp4", fourcc, 30, (width, height))
goal = extract_features(env.reset())
state = extract_features(env.reset())
value_function = value_network.to("cpu")
goal_proposal_vae.to("cpu")
action_vae.to("cpu")
low_level_policy.to("cpu")

for step in range(1000):
    # Render the environment and capture the frame
    frame = state
    # Write frame to video
    # video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    goal_tensor = torch.tensor(goal, dtype=torch.float32)
    # VAE update
    state_tensor = torch.tensor(state, dtype=torch.float32)
    mu, logvar = goal_proposal_vae.encode(goal_tensor, state_tensor)
    z = goal_proposal_vae.reparameterize(mu, logvar)
    goal = goal_proposal_vae.decode(z, state_tensor)
    goal_final = goal
    value = value_function(goal, torch.squeeze(
        low_level_policy(state_tensor, goal)))
    for k in range(5):
        z = goal_proposal_vae.reparameterize(mu, logvar)
        goal = goal_proposal_vae.decode(z, state_tensor)
        if value_function(goal, torch.squeeze(low_level_policy(state_tensor, goal))) > value:
            goal_final = goal
            value = value_function(
                goal, torch.squeeze(low_level_policy(state_tensor, goal)))
    goal = goal_final
    next_state = None
    # Get the action from the low-level policy
    reward = 0
    for _ in range(2):

        state_tensor = torch.tensor(state, dtype=torch.float32)
        action = low_level_policy(
            state_tensor, goal).detach().numpy()
        # video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        action = np.squeeze(action)
        next_state, reward, done, _ = env.step(
            np.concatenate((action, np.array([action[6]]))))
        # frame = env.render(mode="rgb_array")
        env.render()
        # video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        state = extract_features(next_state)
        if reward > 0:
            print(
                f"Episode {step+1}, Step {step+1}: Goal reached!")
            break
        if done:
            break  # Terminate the episode if done is True
    if done:
        break
    if reward > 0:
        break
# for i in range(10):     frame = env.render()
#     state = env.reset()
#     goal = state
#     state=extract_features(state)
#     action = np.random.randn(env.robots[0].dof)  # sample random action
#     # take action in the environment
#     obs, reward, done, info = env.step(action)
#     env.render()  # render on display

  goal_tensor = torch.tensor(goal, dtype=torch.float32)
  state_tensor = torch.tensor(state, dtype=torch.float32)
  state_tensor = torch.tensor(state, dtype=torch.float32)
