In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import numpy as np

In [2]:
class LowLevelPolicy(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, action_dim=2, hidden_dim=512):
        super(LowLevelPolicy, self).__init__()
        self.rnn = nn.LSTM(state_dim + goal_dim, hidden_dim,
                           num_layers=4, batch_first=True)
        self.fc = nn.Linear(hidden_dim, action_dim)

    def forward(self, state, goal):
        input_seq = torch.cat((state, goal), dim=-1)
        out, _ = self.rnn(torch.unsqueeze(input_seq, 0))
        actions = self.fc(out)
        return actions

In [3]:
import torch
import torch.nn as nn


class ImprovedLowLevelPolicy(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, action_dim=2, hidden_dim=512):
        super(ImprovedLowLevelPolicy, self).__init__()

        # Bidirectional LSTM with layer normalization
        self.rnn = nn.LSTM(state_dim + goal_dim, hidden_dim,
                           num_layers=4, batch_first=True, bidirectional=True)
        self.fc = nn.Sequential(
            nn.LayerNorm(hidden_dim * 2),  # for bidirectional LSTM
            nn.Linear(hidden_dim * 2, action_dim),
            nn.Tanh()  # Tanh to keep action outputs bounded between -1 and 1
        )

    def forward(self, state, goal):
        input_seq = torch.cat((state, goal), dim=-1)

        # Pass through LSTM
        out, _ = self.rnn(input_seq.unsqueeze(0))

        # Pass through the fully connected layer with tanh activation
        actions = self.fc(out.squeeze(0))

        return actions

In [4]:
class GoalProposalVAE(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, latent_dim=32):
        super(GoalProposalVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + goal_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(1024, latent_dim)
        self.fc_logvar = nn.Linear(1024, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + state_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, goal_dim)
        )

    def encode(self, x, c):
        h = self.encoder(torch.cat((x, c), dim=0))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z,c)
        inputs = torch.cat([z, c], 0)
        h3 = self.decoder(inputs)
        return h3

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z, c)
        return decoded, mu, logvar

    # def sample(self, num_samples, y):
    #     with torch.no_grad():
    #         z = torch.randn(num_samples, self.num_hidden)
    #         samples = self.decoder(self.condition_on_label(z, y))
    #     return samples

In [5]:
import torch
import torch.nn as nn


class ValueNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)

        # Normalization layers
        self.norm1 = nn.LayerNorm(256)
        self.norm2 = nn.LayerNorm(256)
        self.norm3 = nn.LayerNorm(128)

        # Dropout layer
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, state, action):
        x = torch.cat((state, action), dim=-1)

        # First layer with normalization and activation
        x = nn.functional.relu(self.norm1(self.fc1(x)))

        # Second layer with dropout and normalization
        x = self.dropout(nn.functional.relu(self.norm2(self.fc2(x))))

        # Third layer with normalization and activation
        x = nn.functional.relu(self.norm3(self.fc3(x)))

        # Fourth layer without normalization
        x = nn.functional.relu(self.fc4(x))

        # Output layer
        return self.output(x)

In [6]:
def cvaeLoss(sg, D, mu, logvar, beta=0.001):
    recon_loss = torch.nn.functional.mse_loss(D, sg)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss

In [7]:
class ActionProposalVAE(nn.Module):
    def __init__(self, state_dim=4, action_dim=2, latent_dim=32):
        super(ActionProposalVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + action_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(1024, latent_dim)
        self.fc_logvar = nn.Linear(1024, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + state_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, action_dim)
        )
        # self.label_projector = nn.Sequential(
        #     nn.Linear(state_dim, latent_dim), nn.ReLU())

    def encode(self, x, c):
        h = self.encoder(torch.cat((x, c), dim=0))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z,c)
        inputs = torch.cat([z, c], 0)
        h3 = self.decoder(inputs)
        return h3

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z, c)
        return decoded, mu, logvar

In [8]:
import numpy as np

import torch
from torch.utils.data import DataLoader

import robomimic
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.test_utils as TestUtils
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.train_utils as TrainUtils
from robomimic.utils.dataset import SequenceDataset


def get_data_loader(dataset_path, trajectory_length=8, batch_size=1):
    """
    Get a data loader to sample batches of data.
    Args:
        dataset_path (str): path to the dataset hdf5
        trajectory_length (int): number of timesteps in a trajectory
        batch_size (int): batch size for the DataLoader
    """
    """
    {'cube_pos',
 'cube_quat',
 'gripper_to_cube_pos',
 'robot0_eef_pos',
 'robot0_eef_quat',
 'robot0_gripper_qpos',
 'robot0_gripper_qvel',
 'robot0_joint_pos_cos',
 'robot0_joint_pos_sin',
 'robot0_joint_vel'}
 """
    dataset = SequenceDataset(
        hdf5_path=dataset_path,
        obs_keys=("object",                # observations we want to appear in batches
                  "robot0_eef_pos",
                  "robot0_eef_quat",
                  "robot0_gripper_qpos",
                  "robot0_gripper_qvel",       # Additional gripper velocity
                  # "robot0_joint_pos",          # 7DOF joint positions
                  "robot0_joint_pos_cos",
                  "robot0_joint_pos_sin",
                  "robot0_joint_vel",          # 7DOF joint velocities
                  ),
        dataset_keys=(                  # can optionally specify more keys here if they should appear in batches
            "actions",
            "rewards",
            "dones",
            "horizon",                  # Additional metadata
            "episode_id",               # Episode information
        ),
        load_next_obs=True,             # load the next observation in each sequence
        frame_stack=1,
        seq_length=trajectory_length,   # length-10 temporal sequences
        # pad last obs per trajectory to ensure all sequences are sampled
        pad_frame_stack=True,
        pad_seq_length=True,            # pad last observation sequence
        get_pad_mask=False,             # do not return padding masks
        goal_mode=None,
        hdf5_cache_mode="all",          # cache dataset in memory to avoid repeated file i/o
        hdf5_use_swmr=True,
        hdf5_normalize_obs=False,       # normalize observations
        filter_by_attribute=None,       # can optionally provide a filter key here
    )

    print("\n============= Created Dataset =============")
    print(dataset)
    # Print additional dataset information
    print(f"Dataset length: {len(dataset)}")
    print(f"Observation keys: {dataset.obs_keys}")
    print(f"Dataset keys: {dataset.dataset_keys}")
    print("")

    # Set batch size to a manageable number
    data_loader = DataLoader(
        dataset=dataset,
        # no custom sampling logic (uniform sampling)
        sampler=None,
        batch_size=batch_size,          # batches of size 32
        shuffle=False,                   # shuffle the dataset
        # use 0 workers (could be set to more if using multiprocessing)
        num_workers=0,
        drop_last=True                  # drop the last incomplete batch
    )

    return data_loader

    No private macro file found!
    It is recommended to use a private macro file
    To setup, run: python /home/keerthi/IRIS/robomimic/scripts/setup_macros.py
)[0m


In [9]:
data_loader = get_data_loader(
    dataset_path="/home/keerthi/IRIS/lift/ph/low_dim_v141.hdf5")

SequenceDataset: loading dataset into memory...
  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:00<00:00, 364.56it/s]
SequenceDataset: caching get_item calls...
100%|██████████| 9666/9666 [00:01<00:00, 8056.65it/s]

SequenceDataset (
	path=/home/keerthi/IRIS/lift/ph/low_dim_v141.hdf5
	obs_keys=('object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos', 'robot0_gripper_qvel', 'robot0_joint_pos_cos', 'robot0_joint_pos_sin', 'robot0_joint_vel')
	seq_length=8
	filter_key=none
	frame_stack=1
	pad_seq_length=True
	pad_frame_stack=True
	goal_mode=none
	cache_mode=all
	num_demos=200
	num_sequences=9666
)
Dataset length: 9666
Observation keys: ('object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos', 'robot0_gripper_qvel', 'robot0_joint_pos_cos', 'robot0_joint_pos_sin', 'robot0_joint_vel')
Dataset keys: ('actions', 'rewards', 'dones', 'horizon', 'episode_id')



In [10]:
device = "cuda"

In [11]:
goal_state_config = None

# Iterate over the DataLoader to get batches of trajectories
c = 0
for batch in data_loader:
    batch_size = batch['obs']['robot0_eef_pos'].shape[0]

    # Iterate over each trajectory in the batch
    for i in range(batch_size):
        rewards = batch['rewards'][i]  # Rewards for the i-th trajectory

        # Iterate over timesteps within the trajectory
        for j in range(rewards.shape[0]):
            if rewards[j] == 1:  # Check if the reward at timestep `j` is 1
                # Extract the state configuration at the timestep `j`
                c += 1
                if c == 1:
                    continue
                state_config = torch.cat([
                    # End-effector position (3,)
                    batch['obs']['robot0_eef_pos'][i, j],
                    # Object-related observation (10,)
                    batch['obs']['object'][i, j],
                    # End-effector orientation (4,)
                    batch['obs']['robot0_eef_quat'][i, j],
                    # Gripper position (2,)
                    # batch['obs']['robot0_gripper_qpos'][i, j],
                    # Gripper velocity (2,)
                    # batch['obs']['robot0_gripper_qvel'][i, j],
                    # 7DOF joint positions (7,)
                    # batch['obs']['robot0_joint_pos_cos'][i, j],
                    # batch['obs']['robot0_joint_pos_sin'][i, j],
                    # 7DOF joint velocities (7,)
                    # batch['obs']['robot0_joint_vel'][i, j]
                ], axis=0).to(device)  # Concatenate along the feature axis

                # Print or save the state configuration when goal is attained
                print(
                    f"State configuration at batch {i}, timestep {j} where reward is 1:")
                print(state_config)
                goal_state_config = state_config  # Save the found state configuration

                # Break out of the inner loop once the goal state is found
                break
        if goal_state_config is not None:
            # Break out of the outer loop if a goal state is found
            break
    if goal_state_config is not None:
        # Exit the loop once a goal is found in the batch
        break

# Print the final goal state configuration if found
if goal_state_config is not None:
    print("Goal state configuration found:", goal_state_config)
else:
    print("No goal state configuration with a reward of 1 found in the dataset.")

State configuration at batch 0, timestep 6 where reward is 1:
tensor([ 0.0298,  0.0221,  0.8355,  0.0237,  0.0270,  0.8314, -0.0080, -0.0045,
         0.9694,  0.2452,  0.0061, -0.0049,  0.0041,  0.9862, -0.1468,  0.0732,
         0.0202], device='cuda:0', dtype=torch.float64)
Goal state configuration found: tensor([ 0.0298,  0.0221,  0.8355,  0.0237,  0.0270,  0.8314, -0.0080, -0.0045,
         0.9694,  0.2452,  0.0061, -0.0049,  0.0041,  0.9862, -0.1468,  0.0732,
         0.0202], device='cuda:0', dtype=torch.float64)


In [12]:
goal_state_config.shape

torch.Size([17])

In [13]:
cube_pos = goal_state_config[3:6]
cube_pos

tensor([0.0237, 0.0270, 0.8314], device='cuda:0', dtype=torch.float64)

In [14]:
cube_orient = goal_state_config[6:10]
cube_orient

tensor([-0.0080, -0.0045,  0.9694,  0.2452], device='cuda:0',
       dtype=torch.float64)

In [15]:
state_dim = 17
state_goal_dim = 17
action_dim = 7
latent_dim = 64

In [16]:
low_level_policy = LowLevelPolicy(
    state_dim, state_goal_dim, action_dim)

In [17]:
# Assuming dataset and models are already initialized


goal_proposal_vae = GoalProposalVAE(state_dim, state_goal_dim, latent_dim)
action_vae = ActionProposalVAE(state_dim, action_dim, latent_dim)

In [18]:
value_network = ValueNetwork(state_dim, action_dim)

In [34]:
from tqdm import tqdm


def train_IRIS_full_trajectory(low_level_policy, goal_proposal_vae: GoalProposalVAE, action_vae: ActionProposalVAE, value_network, data_loader, M=8, gamma=0.99):
    """
    Train the IRIS algorithm over full trajectories using a DataLoader that provides batches of trajectories.

    Args:
        low_level_policy (nn.Module): Low-level policy network.
        goal_proposal_vae (nn.Module): Goal proposal VAE network.
        action_vae (nn.Module): Action proposal VAE network.
        value_network (nn.Module): Value network.
        data_loader (DataLoader): DataLoader that provides batches of trajectories.
        M (int): Number of samples for action sampling.
        gamma (float): Discount factor for future rewards.
    """

    # Move models to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    low_level_policy = low_level_policy.to(device)
    goal_proposal_vae = goal_proposal_vae.to(device)
    action_vae = action_vae.to(device)
    value_network = value_network.to(device)

    # Optimizers
    policy_optimizer = optim.Adam(low_level_policy.parameters(), lr=1e-4)
    vae_optimizer = optim.Adam(goal_proposal_vae.parameters(), lr=1e-4)
    value_optimizer = optim.Adam(value_network.parameters(), lr=1e-4)
    action_optimizer = optim.Adam(action_vae.parameters(), lr=1e-4)

    # Statistics for monitoring
    cumulative_rewards = []
    cumulative_policy_loss = []
    cumulative_vae_loss = []
    cumulative_value_loss = []

    # Iterate through batches of trajectories
    for epoch in range(1):
        for batch_idx, trajectory_batch in enumerate(tqdm(data_loader)):
            # Each element in trajectory_batch is a batch of sequences
            batch_size = trajectory_batch['obs']['robot0_eef_pos'].shape[0]
            for i in range(batch_size):
                # Extract the i-th trajectory from the batch
                states = torch.cat([
                    # End-effector position (3,)
                    trajectory_batch['obs']['robot0_eef_pos'][i],
                    # Object-related observation (10,)
                    trajectory_batch['obs']['object'][i],
                    # End-effector orientation (4,)
                    trajectory_batch['obs']['robot0_eef_quat'][i],
                    # Gripper position (2,)
                    # trajectory_batch['obs']['robot0_gripper_qpos'][i],
                    # Gripper velocity (2,)
                    # trajectory_batch['obs']['robot0_gripper_qvel'][i],
                    # 7DOF joint positions (7,)
                    # trajectory_batch['obs']['robot0_joint_pos_cos'][i],
                    # trajectory_batch['obs']['robot0_joint_pos_sin'][i],
                    # 7DOF joint velocities (7,)
                    # trajectory_batch['obs']['robot0_joint_vel'][i]
                ], axis=1).to(device)  # Concatenate along the feature axis
                # print(states.shape)
                # states = torch.cat([
                #     trajectory_batch['obs']['object'][i],
                # ], axis=1).to(device)  # Concatenate along the feature axis
                states = torch.squeeze(states.type(torch.float32))

                actions = trajectory_batch['actions'][i].to(device)
                actions = torch.squeeze(actions.type(torch.float32))

                rewards = trajectory_batch['rewards'][i].to(device)
                rewards = torch.squeeze(rewards.type(torch.float32))

                # Remove last timestep for actions (to match states length)
                actions = actions[:-1]

                sg = states[-1]
                s_start = states[0]
                reward_sg = rewards[-2]
                action_last = actions[-2]
                state_second_last = states[-2]

                # Train Low-Level Policy
                policy_actions = [torch.squeeze(low_level_policy(
                    state, sg)) for state in states[:-1]]
                policy_actions = torch.squeeze(torch.stack(policy_actions))
                policy_loss = nn.MSELoss()(policy_actions, actions)

                policy_optimizer.zero_grad()
                policy_loss.backward()
                policy_optimizer.step()

                # VAE updates (Goal and Action VAEs)
                mu, logvar = goal_proposal_vae.encode(sg, s_start)
                z = goal_proposal_vae.reparameterize(mu, logvar)
                vae_loss = nn.MSELoss()(sg, goal_proposal_vae.decode(z, s_start)) + \
                    0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1)

                mu_a, logvar_a = action_vae.encode(
                    action_last, state_second_last)
                z_a = action_vae.reparameterize(mu_a, logvar_a)
                action_vae_loss = nn.MSELoss()(action_last, action_vae.decode(z_a, state_second_last)) + \
                    0.5 * torch.sum(mu_a.pow(2) +
                                    logvar_a.exp() - logvar_a - 1)

                vae_optimizer.zero_grad()
                action_optimizer.zero_grad()
                (vae_loss + action_vae_loss).backward()
                vae_optimizer.step()
                action_optimizer.step()

                # Sample actions and update value network
                sampled_actions = [action_vae.decode(
                    z_a, sg) for _ in range(M)]
                sampled_actions = torch.stack(sampled_actions)

                values = torch.stack([value_network(sg, a)
                                      for a in sampled_actions])
                max_value = torch.max(values)

                Vbar = torch.squeeze(reward_sg + gamma * max_value.detach())
                value_loss = nn.MSELoss()(Vbar, torch.squeeze(
                    value_network(state_second_last, action_last)))

                value_optimizer.zero_grad()
                value_loss.backward()
                value_optimizer.step()

                # Store losses and rewards for monitoring
                cumulative_rewards.append(reward_sg.item())
                cumulative_policy_loss.append(policy_loss.item())
                cumulative_vae_loss.append(
                    vae_loss.item() + action_vae_loss.item())
                cumulative_value_loss.append(value_loss.item())

            # Print statistics every N batches (you can adjust this interval)
            if (batch_idx+1) % 100 == 0:
                avg_reward = np.mean(cumulative_rewards)
                avg_policy_loss = np.mean(cumulative_policy_loss)
                avg_vae_loss = np.mean(cumulative_vae_loss)
                avg_value_loss = np.mean(cumulative_value_loss)

                print(f"Batch {batch_idx + 1}: Avg Reward: {avg_reward:.4f}, "
                      f"Avg Policy Loss: {avg_policy_loss:.4f}, "
                      f"Avg VAE Loss: {avg_vae_loss:.4f}, "
                      f"Avg Value Loss: {avg_value_loss:.4f}")

                # Reset cumulative stats after printing
                cumulative_rewards = []
                cumulative_policy_loss = []
                cumulative_vae_loss = []
                cumulative_value_loss = []


# Train the IRIS algorithm using the D4RL dataset
train_IRIS_full_trajectory(low_level_policy, goal_proposal_vae,
                           action_vae, value_network, data_loader)

  1%|          | 104/9666 [00:02<04:06, 38.73it/s]

Batch 100: Avg Reward: 0.1100, Avg Policy Loss: 0.0963, Avg VAE Loss: 0.1495, Avg Value Loss: 0.0158


  2%|▏         | 205/9666 [00:05<04:09, 37.90it/s]

Batch 200: Avg Reward: 0.2200, Avg Policy Loss: 0.0952, Avg VAE Loss: 0.1568, Avg Value Loss: 0.0295


  3%|▎         | 303/9666 [00:08<04:47, 32.60it/s]

Batch 300: Avg Reward: 0.2200, Avg Policy Loss: 0.0849, Avg VAE Loss: 0.1612, Avg Value Loss: 0.0698


  4%|▍         | 407/9666 [00:10<04:02, 38.13it/s]

Batch 400: Avg Reward: 0.2200, Avg Policy Loss: 0.0464, Avg VAE Loss: 0.1490, Avg Value Loss: 0.0381


  5%|▌         | 507/9666 [00:13<04:01, 37.88it/s]

Batch 500: Avg Reward: 0.2300, Avg Policy Loss: 0.0831, Avg VAE Loss: 0.1716, Avg Value Loss: 0.0335


  6%|▋         | 607/9666 [00:16<04:01, 37.53it/s]

Batch 600: Avg Reward: 0.2200, Avg Policy Loss: 0.0475, Avg VAE Loss: 0.1437, Avg Value Loss: 0.0445


  7%|▋         | 707/9666 [00:18<03:49, 39.06it/s]

Batch 700: Avg Reward: 0.2200, Avg Policy Loss: 0.0320, Avg VAE Loss: 0.1406, Avg Value Loss: 0.0311


  8%|▊         | 807/9666 [00:21<03:45, 39.31it/s]

Batch 800: Avg Reward: 0.2200, Avg Policy Loss: 0.0693, Avg VAE Loss: 0.1588, Avg Value Loss: 0.0269


  9%|▉         | 907/9666 [00:23<03:51, 37.77it/s]

Batch 900: Avg Reward: 0.2200, Avg Policy Loss: 0.0827, Avg VAE Loss: 0.1653, Avg Value Loss: 0.0332


 10%|█         | 1004/9666 [00:26<03:42, 38.92it/s]

Batch 1000: Avg Reward: 0.2200, Avg Policy Loss: 0.0715, Avg VAE Loss: 0.1435, Avg Value Loss: 0.0185


 11%|█▏        | 1105/9666 [00:28<03:39, 39.08it/s]

Batch 1100: Avg Reward: 0.2200, Avg Policy Loss: 0.0471, Avg VAE Loss: 0.1524, Avg Value Loss: 0.0180


 12%|█▏        | 1205/9666 [00:31<03:36, 39.05it/s]

Batch 1200: Avg Reward: 0.2200, Avg Policy Loss: 0.0575, Avg VAE Loss: 0.1640, Avg Value Loss: 0.0172


 14%|█▎        | 1307/9666 [00:34<03:33, 39.14it/s]

Batch 1300: Avg Reward: 0.2200, Avg Policy Loss: 0.0266, Avg VAE Loss: 0.1240, Avg Value Loss: 0.0214


 15%|█▍        | 1407/9666 [00:36<03:41, 37.26it/s]

Batch 1400: Avg Reward: 0.2200, Avg Policy Loss: 0.0724, Avg VAE Loss: 0.1640, Avg Value Loss: 0.0330


 16%|█▌        | 1505/9666 [00:39<03:31, 38.64it/s]

Batch 1500: Avg Reward: 0.2200, Avg Policy Loss: 0.0353, Avg VAE Loss: 0.1479, Avg Value Loss: 0.0220


 17%|█▋        | 1605/9666 [00:42<03:34, 37.60it/s]

Batch 1600: Avg Reward: 0.2500, Avg Policy Loss: 0.0482, Avg VAE Loss: 0.1323, Avg Value Loss: 0.0245


 18%|█▊        | 1705/9666 [00:44<03:31, 37.57it/s]

Batch 1700: Avg Reward: 0.1900, Avg Policy Loss: 0.0692, Avg VAE Loss: 0.1267, Avg Value Loss: 0.0246


 19%|█▊        | 1806/9666 [00:47<03:36, 36.37it/s]

Batch 1800: Avg Reward: 0.2200, Avg Policy Loss: 0.0556, Avg VAE Loss: 0.1263, Avg Value Loss: 0.0301


 20%|█▉        | 1906/9666 [00:50<03:20, 38.72it/s]

Batch 1900: Avg Reward: 0.2200, Avg Policy Loss: 0.0946, Avg VAE Loss: 0.1686, Avg Value Loss: 0.0301


 21%|██        | 2006/9666 [00:52<03:22, 37.88it/s]

Batch 2000: Avg Reward: 0.2200, Avg Policy Loss: 0.0697, Avg VAE Loss: 0.1352, Avg Value Loss: 0.0251


 22%|██▏       | 2106/9666 [00:55<03:17, 38.36it/s]

Batch 2100: Avg Reward: 0.3200, Avg Policy Loss: 0.0650, Avg VAE Loss: 0.1706, Avg Value Loss: 0.0266


 23%|██▎       | 2206/9666 [00:58<03:18, 37.50it/s]

Batch 2200: Avg Reward: 0.2300, Avg Policy Loss: 0.0981, Avg VAE Loss: 0.1662, Avg Value Loss: 0.0219


 24%|██▍       | 2306/9666 [01:00<03:18, 37.05it/s]

Batch 2300: Avg Reward: 0.2200, Avg Policy Loss: 0.0377, Avg VAE Loss: 0.1354, Avg Value Loss: 0.0280


 25%|██▍       | 2406/9666 [01:03<03:11, 37.99it/s]

Batch 2400: Avg Reward: 0.2200, Avg Policy Loss: 0.0514, Avg VAE Loss: 0.1487, Avg Value Loss: 0.0188


 26%|██▌       | 2506/9666 [01:06<03:12, 37.24it/s]

Batch 2500: Avg Reward: 0.2200, Avg Policy Loss: 0.0434, Avg VAE Loss: 0.1555, Avg Value Loss: 0.0211


 27%|██▋       | 2607/9666 [01:08<03:06, 37.75it/s]

Batch 2600: Avg Reward: 0.2300, Avg Policy Loss: 0.1280, Avg VAE Loss: 0.1429, Avg Value Loss: 0.0328


 28%|██▊       | 2707/9666 [01:11<03:13, 35.89it/s]

Batch 2700: Avg Reward: 0.2200, Avg Policy Loss: 0.0697, Avg VAE Loss: 0.1348, Avg Value Loss: 0.0475


 29%|██▉       | 2807/9666 [01:14<03:08, 36.35it/s]

Batch 2800: Avg Reward: 0.2500, Avg Policy Loss: 0.0522, Avg VAE Loss: 0.1418, Avg Value Loss: 0.0424


 30%|███       | 2907/9666 [01:17<03:08, 35.93it/s]

Batch 2900: Avg Reward: 0.3000, Avg Policy Loss: 0.0396, Avg VAE Loss: 0.1423, Avg Value Loss: 0.0258


 31%|███       | 3007/9666 [01:19<02:58, 37.30it/s]

Batch 3000: Avg Reward: 0.2200, Avg Policy Loss: 0.0432, Avg VAE Loss: 0.1189, Avg Value Loss: 0.0486


 32%|███▏      | 3107/9666 [01:22<02:49, 38.63it/s]

Batch 3100: Avg Reward: 0.2600, Avg Policy Loss: 0.0818, Avg VAE Loss: 0.1597, Avg Value Loss: 0.0694


 33%|███▎      | 3207/9666 [01:25<02:51, 37.59it/s]

Batch 3200: Avg Reward: 0.2900, Avg Policy Loss: 0.0831, Avg VAE Loss: 0.1583, Avg Value Loss: 0.0387


 34%|███▍      | 3307/9666 [01:27<02:43, 38.95it/s]

Batch 3300: Avg Reward: 0.2200, Avg Policy Loss: 0.0563, Avg VAE Loss: 0.1337, Avg Value Loss: 0.0220


 35%|███▌      | 3407/9666 [01:30<02:47, 37.29it/s]

Batch 3400: Avg Reward: 0.2300, Avg Policy Loss: 0.0692, Avg VAE Loss: 0.1240, Avg Value Loss: 0.0394


 36%|███▋      | 3507/9666 [01:33<02:44, 37.35it/s]

Batch 3500: Avg Reward: 0.3300, Avg Policy Loss: 0.0470, Avg VAE Loss: 0.1379, Avg Value Loss: 0.0372


 37%|███▋      | 3607/9666 [01:35<02:42, 37.29it/s]

Batch 3600: Avg Reward: 0.2200, Avg Policy Loss: 0.0540, Avg VAE Loss: 0.1654, Avg Value Loss: 0.0290


 38%|███▊      | 3706/9666 [01:38<02:44, 36.24it/s]

Batch 3700: Avg Reward: 0.3300, Avg Policy Loss: 0.0322, Avg VAE Loss: 0.1323, Avg Value Loss: 0.0461


 39%|███▉      | 3806/9666 [01:41<02:38, 36.88it/s]

Batch 3800: Avg Reward: 0.2200, Avg Policy Loss: 0.0611, Avg VAE Loss: 0.1221, Avg Value Loss: 0.0432


 40%|████      | 3906/9666 [01:44<02:37, 36.66it/s]

Batch 3900: Avg Reward: 0.2200, Avg Policy Loss: 0.0874, Avg VAE Loss: 0.1391, Avg Value Loss: 0.0506


 41%|████▏     | 4006/9666 [01:47<02:30, 37.67it/s]

Batch 4000: Avg Reward: 0.2200, Avg Policy Loss: 0.0298, Avg VAE Loss: 0.1149, Avg Value Loss: 0.1046


 42%|████▏     | 4106/9666 [01:49<02:28, 37.52it/s]

Batch 4100: Avg Reward: 0.3100, Avg Policy Loss: 0.0828, Avg VAE Loss: 0.1430, Avg Value Loss: 0.0972


 44%|████▎     | 4206/9666 [01:52<02:28, 36.68it/s]

Batch 4200: Avg Reward: 0.2400, Avg Policy Loss: 0.0683, Avg VAE Loss: 0.1323, Avg Value Loss: 0.1258


 45%|████▍     | 4306/9666 [01:55<02:28, 36.18it/s]

Batch 4300: Avg Reward: 0.3300, Avg Policy Loss: 0.0507, Avg VAE Loss: 0.1614, Avg Value Loss: 0.1309


 46%|████▌     | 4406/9666 [01:58<02:25, 36.17it/s]

Batch 4400: Avg Reward: 0.2100, Avg Policy Loss: 0.0505, Avg VAE Loss: 0.1213, Avg Value Loss: 0.1291


 47%|████▋     | 4506/9666 [02:01<02:26, 35.14it/s]

Batch 4500: Avg Reward: 0.2300, Avg Policy Loss: 0.1283, Avg VAE Loss: 0.1795, Avg Value Loss: 0.0811


 48%|████▊     | 4606/9666 [02:03<02:18, 36.56it/s]

Batch 4600: Avg Reward: 0.1500, Avg Policy Loss: 0.0929, Avg VAE Loss: 0.1160, Avg Value Loss: 0.0797


 49%|████▊     | 4706/9666 [02:06<02:14, 36.76it/s]

Batch 4700: Avg Reward: 0.2900, Avg Policy Loss: 0.0592, Avg VAE Loss: 0.1390, Avg Value Loss: 0.0726


 50%|████▉     | 4806/9666 [02:09<02:12, 36.61it/s]

Batch 4800: Avg Reward: 0.2200, Avg Policy Loss: 0.1084, Avg VAE Loss: 0.1468, Avg Value Loss: 0.1007


 51%|█████     | 4906/9666 [02:12<02:15, 35.15it/s]

Batch 4900: Avg Reward: 0.2200, Avg Policy Loss: 0.1032, Avg VAE Loss: 0.1469, Avg Value Loss: 0.1247


 52%|█████▏    | 5006/9666 [02:15<02:08, 36.12it/s]

Batch 5000: Avg Reward: 0.2000, Avg Policy Loss: 0.0417, Avg VAE Loss: 0.1263, Avg Value Loss: 0.0990


 53%|█████▎    | 5106/9666 [02:17<02:05, 36.24it/s]

Batch 5100: Avg Reward: 0.1300, Avg Policy Loss: 0.1388, Avg VAE Loss: 0.1558, Avg Value Loss: 0.0706


 54%|█████▍    | 5206/9666 [02:20<01:58, 37.75it/s]

Batch 5200: Avg Reward: 0.2700, Avg Policy Loss: 0.0742, Avg VAE Loss: 0.1463, Avg Value Loss: 0.1153


 55%|█████▍    | 5306/9666 [02:23<01:59, 36.44it/s]

Batch 5300: Avg Reward: 0.2800, Avg Policy Loss: 0.0653, Avg VAE Loss: 0.1416, Avg Value Loss: 0.0518


 56%|█████▌    | 5406/9666 [02:26<01:58, 35.84it/s]

Batch 5400: Avg Reward: 0.2200, Avg Policy Loss: 0.0533, Avg VAE Loss: 0.1080, Avg Value Loss: 0.0431


 57%|█████▋    | 5506/9666 [02:28<01:53, 36.58it/s]

Batch 5500: Avg Reward: 0.2200, Avg Policy Loss: 0.0581, Avg VAE Loss: 0.1228, Avg Value Loss: 0.0405


 58%|█████▊    | 5606/9666 [02:31<01:53, 35.74it/s]

Batch 5600: Avg Reward: 0.2300, Avg Policy Loss: 0.1810, Avg VAE Loss: 0.1362, Avg Value Loss: 0.0558


 59%|█████▉    | 5706/9666 [02:34<01:48, 36.50it/s]

Batch 5700: Avg Reward: 0.2200, Avg Policy Loss: 0.0961, Avg VAE Loss: 0.1695, Avg Value Loss: 0.0485


 60%|██████    | 5806/9666 [02:37<01:44, 36.99it/s]

Batch 5800: Avg Reward: 0.1100, Avg Policy Loss: 0.0682, Avg VAE Loss: 0.1329, Avg Value Loss: 0.0566


 61%|██████    | 5906/9666 [02:39<01:43, 36.43it/s]

Batch 5900: Avg Reward: 0.2200, Avg Policy Loss: 0.0401, Avg VAE Loss: 0.1122, Avg Value Loss: 0.0497


 62%|██████▏   | 6006/9666 [02:42<01:52, 32.55it/s]

Batch 6000: Avg Reward: 0.2200, Avg Policy Loss: 0.0517, Avg VAE Loss: 0.1296, Avg Value Loss: 0.0608


 63%|██████▎   | 6106/9666 [02:45<01:39, 35.71it/s]

Batch 6100: Avg Reward: 0.2200, Avg Policy Loss: 0.0354, Avg VAE Loss: 0.1284, Avg Value Loss: 0.0466


 64%|██████▍   | 6206/9666 [02:48<01:55, 30.03it/s]

Batch 6200: Avg Reward: 0.2300, Avg Policy Loss: 0.1030, Avg VAE Loss: 0.1497, Avg Value Loss: 0.0643


 65%|██████▌   | 6306/9666 [02:51<01:34, 35.51it/s]

Batch 6300: Avg Reward: 0.2200, Avg Policy Loss: 0.0936, Avg VAE Loss: 0.1472, Avg Value Loss: 0.0701


 66%|██████▋   | 6406/9666 [02:54<01:30, 35.92it/s]

Batch 6400: Avg Reward: 0.2200, Avg Policy Loss: 0.0595, Avg VAE Loss: 0.1182, Avg Value Loss: 0.0433


 67%|██████▋   | 6506/9666 [02:57<01:27, 36.19it/s]

Batch 6500: Avg Reward: 0.2200, Avg Policy Loss: 0.0616, Avg VAE Loss: 0.1248, Avg Value Loss: 0.0364


 68%|██████▊   | 6606/9666 [02:59<01:21, 37.54it/s]

Batch 6600: Avg Reward: 0.2200, Avg Policy Loss: 0.0894, Avg VAE Loss: 0.1474, Avg Value Loss: 0.0685


 69%|██████▉   | 6706/9666 [03:02<01:20, 36.57it/s]

Batch 6700: Avg Reward: 0.2200, Avg Policy Loss: 0.0436, Avg VAE Loss: 0.1240, Avg Value Loss: 0.0591


 70%|███████   | 6806/9666 [03:05<01:21, 35.28it/s]

Batch 6800: Avg Reward: 0.2200, Avg Policy Loss: 0.0496, Avg VAE Loss: 0.1209, Avg Value Loss: 0.0648


 71%|███████▏  | 6906/9666 [03:08<01:19, 34.77it/s]

Batch 6900: Avg Reward: 0.2200, Avg Policy Loss: 0.0495, Avg VAE Loss: 0.1526, Avg Value Loss: 0.1339


 72%|███████▏  | 7006/9666 [03:11<01:16, 34.74it/s]

Batch 7000: Avg Reward: 0.2200, Avg Policy Loss: 0.0305, Avg VAE Loss: 0.1191, Avg Value Loss: 0.1232


 74%|███████▎  | 7106/9666 [03:14<01:13, 34.83it/s]

Batch 7100: Avg Reward: 0.2200, Avg Policy Loss: 0.0866, Avg VAE Loss: 0.1302, Avg Value Loss: 0.1276


 75%|███████▍  | 7206/9666 [03:16<01:06, 36.77it/s]

Batch 7200: Avg Reward: 0.2200, Avg Policy Loss: 0.0453, Avg VAE Loss: 0.1281, Avg Value Loss: 0.0980


 76%|███████▌  | 7306/9666 [03:19<01:05, 35.97it/s]

Batch 7300: Avg Reward: 0.2200, Avg Policy Loss: 0.0297, Avg VAE Loss: 0.1177, Avg Value Loss: 0.0910


 77%|███████▋  | 7406/9666 [03:22<01:01, 36.90it/s]

Batch 7400: Avg Reward: 0.2200, Avg Policy Loss: 0.1075, Avg VAE Loss: 0.1149, Avg Value Loss: 0.0796


 78%|███████▊  | 7506/9666 [03:25<00:57, 37.88it/s]

Batch 7500: Avg Reward: 0.2200, Avg Policy Loss: 0.0418, Avg VAE Loss: 0.1282, Avg Value Loss: 0.0597


 79%|███████▊  | 7606/9666 [03:27<00:56, 36.73it/s]

Batch 7600: Avg Reward: 0.2200, Avg Policy Loss: 0.0444, Avg VAE Loss: 0.1540, Avg Value Loss: 0.1088


 80%|███████▉  | 7706/9666 [03:30<00:52, 37.57it/s]

Batch 7700: Avg Reward: 0.2200, Avg Policy Loss: 0.0696, Avg VAE Loss: 0.1579, Avg Value Loss: 0.0666


 81%|████████  | 7806/9666 [03:33<00:53, 34.47it/s]

Batch 7800: Avg Reward: 0.2600, Avg Policy Loss: 0.0697, Avg VAE Loss: 0.1285, Avg Value Loss: 0.0556


 82%|████████▏ | 7906/9666 [03:36<00:49, 35.86it/s]

Batch 7900: Avg Reward: 0.1800, Avg Policy Loss: 0.0609, Avg VAE Loss: 0.1334, Avg Value Loss: 0.0662


 83%|████████▎ | 8006/9666 [03:39<00:47, 35.02it/s]

Batch 8000: Avg Reward: 0.3300, Avg Policy Loss: 0.0472, Avg VAE Loss: 0.1227, Avg Value Loss: 0.1274


 84%|████████▍ | 8106/9666 [03:42<00:44, 35.36it/s]

Batch 8100: Avg Reward: 0.1100, Avg Policy Loss: 0.0282, Avg VAE Loss: 0.0913, Avg Value Loss: 0.0775


 85%|████████▍ | 8206/9666 [03:44<00:39, 36.73it/s]

Batch 8200: Avg Reward: 0.2200, Avg Policy Loss: 0.0473, Avg VAE Loss: 0.1336, Avg Value Loss: 0.0599


 86%|████████▌ | 8306/9666 [03:47<00:37, 36.73it/s]

Batch 8300: Avg Reward: 0.2200, Avg Policy Loss: 0.0570, Avg VAE Loss: 0.1273, Avg Value Loss: 0.1214


 87%|████████▋ | 8406/9666 [03:50<00:40, 31.17it/s]

Batch 8400: Avg Reward: 0.2400, Avg Policy Loss: 0.0337, Avg VAE Loss: 0.1164, Avg Value Loss: 0.0710


 88%|████████▊ | 8504/9666 [03:53<00:32, 35.60it/s]

Batch 8500: Avg Reward: 0.2300, Avg Policy Loss: 0.0292, Avg VAE Loss: 0.1253, Avg Value Loss: 0.0607


 89%|████████▉ | 8604/9666 [03:56<00:30, 35.02it/s]

Batch 8600: Avg Reward: 0.1900, Avg Policy Loss: 0.1426, Avg VAE Loss: 0.1375, Avg Value Loss: 0.0895


 90%|█████████ | 8704/9666 [03:59<00:27, 34.96it/s]

Batch 8700: Avg Reward: 0.2200, Avg Policy Loss: 0.1532, Avg VAE Loss: 0.1470, Avg Value Loss: 0.0976


 91%|█████████ | 8804/9666 [04:02<00:24, 35.06it/s]

Batch 8800: Avg Reward: 0.2200, Avg Policy Loss: 0.0940, Avg VAE Loss: 0.1037, Avg Value Loss: 0.0495


 92%|█████████▏| 8904/9666 [04:05<00:22, 33.52it/s]

Batch 8900: Avg Reward: 0.2200, Avg Policy Loss: 0.0986, Avg VAE Loss: 0.1310, Avg Value Loss: 0.1776


 93%|█████████▎| 9004/9666 [04:07<00:19, 34.34it/s]

Batch 9000: Avg Reward: 0.2200, Avg Policy Loss: 0.0439, Avg VAE Loss: 0.1232, Avg Value Loss: 0.1315


 94%|█████████▍| 9104/9666 [04:10<00:15, 37.26it/s]

Batch 9100: Avg Reward: 0.2200, Avg Policy Loss: 0.1086, Avg VAE Loss: 0.1337, Avg Value Loss: 0.1893


 95%|█████████▌| 9204/9666 [04:13<00:12, 36.95it/s]

Batch 9200: Avg Reward: 0.2200, Avg Policy Loss: 0.0593, Avg VAE Loss: 0.1216, Avg Value Loss: 0.2002


 96%|█████████▋| 9304/9666 [04:16<00:10, 35.75it/s]

Batch 9300: Avg Reward: 0.2400, Avg Policy Loss: 0.0615, Avg VAE Loss: 0.1486, Avg Value Loss: 0.2497


 97%|█████████▋| 9404/9666 [04:18<00:07, 36.14it/s]

Batch 9400: Avg Reward: 0.3100, Avg Policy Loss: 0.1524, Avg VAE Loss: 0.1244, Avg Value Loss: 0.1040


 98%|█████████▊| 9504/9666 [04:21<00:04, 35.52it/s]

Batch 9500: Avg Reward: 0.2200, Avg Policy Loss: 0.0479, Avg VAE Loss: 0.1249, Avg Value Loss: 0.2229


 99%|█████████▉| 9604/9666 [04:24<00:01, 33.79it/s]

Batch 9600: Avg Reward: 0.2200, Avg Policy Loss: 0.0518, Avg VAE Loss: 0.1413, Avg Value Loss: 0.1955


100%|██████████| 9666/9666 [04:26<00:00, 36.25it/s]


In [33]:
# Save each model
torch.save(low_level_policy.state_dict(), 'low_level_policy.pth')
torch.save(goal_proposal_vae.state_dict(), 'goal_proposal_vae.pth')
torch.save(action_vae.state_dict(), 'action_vae.pth')
torch.save(value_network.state_dict(), 'value_network.pth')

print("Models saved successfully as .pth files.")

Models saved successfully as .pth files.


In [36]:
low_level_policy.load_state_dict(torch.load('low_level_policy.pth'))
goal_proposal_vae.load_state_dict(torch.load('goal_proposal_vae.pth'))
action_vae.load_state_dict(torch.load('action_vae.pth'))
value_network.load_state_dict(torch.load('value_network.pth'))

  low_level_policy.load_state_dict(torch.load('low_level_policy.pth'))
  goal_proposal_vae.load_state_dict(torch.load('goal_proposal_vae.pth'))
  action_vae.load_state_dict(torch.load('action_vae.pth'))
  value_network.load_state_dict(torch.load('value_network.pth'))


<All keys matched successfully>

In [21]:
import robosuite
from robosuite.controllers import load_controller_config

# load default controller parameters for Operational Space Control (OSC)
controller_config = load_controller_config(default_controller="OSC_POSE")
# controller_config['uncouple_pos_ori'] = False
# print(controller_config)



In [None]:

def extract_features(state):
    # Extract and concatenate the desired features from the state
    # features = torch.cat([
    #     # End-effector position (3,)
    #     torch.tensor(state['robot0_eef_pos'], dtype=torch.float32),
    #     # Object-related observation (10,)
    #     torch.tensor(state['object-state'], dtype=torch.float32),
    #     # End-effector orientation (4,)
    #     torch.tensor(state['robot0_eef_quat'], dtype=torch.float32),
    #     # Gripper position (2,)
    #     torch.tensor(state['robot0_gripper_qpos'], dtype=torch.float32),
    #     # Gripper velocity (2,)
    #     torch.tensor(state['robot0_gripper_qvel'], dtype=torch.float32),
    #     # 7DOF joint positions (7,)
    #     torch.tensor(
    #         state['robot0_joint_pos_cos'], dtype=torch.float32),
    #     # 7DOF joint velocities (7,)
    #     torch.tensor(state['robot0_joint_vel'], dtype=torch.float32)
    # ], dim=0)  # Concatenate into a single 1D tensor
    features = torch.cat([
        # End-effector position (3,)
        torch.tensor(state['robot0_eef_pos'], dtype=torch.float32),
        # Object-related observation (10,)
        torch.tensor(state['object-state'], dtype=torch.float32),
        # End-effector orientation (4,)
        torch.tensor(state['robot0_eef_quat'], dtype=torch.float32),
        # Gripper position (2,)
        # torch.tensor(state['robot0_gripper_qpos'], dtype=torch.float32),
        # Gripper velocity (2,)
        # torch.tensor(state['robot0_gripper_qvel'], dtype=torch.float32),
        # 7DOF joint positions (7,)
        # torch.tensor(
        # state['robot0_joint_pos_cos'], dtype=torch.float32),
        # torch.tensor(state['robot0_joint_pos_sin'], dtype=torch.float32),
        # 7DOF joint velocities (7,)
        # torch.tensor(state['robot0_joint_vel'], dtype=torch.float32)
    ], dim=0)  # Concatenate into a single 1D tensor

    return features

In [23]:
import numpy as np
import robosuite as suite
import cv2
# create environment instance
env = suite.make(
    env_name="Lift",  # try with other tasks like "Stack" and "Door"
    robots="Sawyer",  # try with other robots like "Sawyer" and "Jaco"
    has_renderer=True,
    has_offscreen_renderer=False,
    use_camera_obs=False,
)
env.action_dim

8

In [24]:
# # Example cosine and sine values for each joint

# joint_pos_cos = goal_state_config[20:27].detach().cpu().numpy()
# joint_pos_sin = goal_state_config[27:34].detach().cpu().numpy()
# custom_cube_position = np.array([0.5, 0.0, 0.02])  # Position in (x, y, z)
# custom_cube_orientation = np.array([1, 0, 0, 0])   # Quaternion (w, x, y, z)
# reset_with_cos_sin(env, joint_pos_cos, joint_pos_sin,
#                    custom_cube_position, custom_cube_orientation)

# # Access observables after custom reset
# observations = env._get_observations()
# # print(observations)
# # for obs_name in ['cube_pos', 'cube_quat', 'robot0_joint_pos']:
# #     print(f"{obs_name}: {observations[obs_name]}")

# # Render to visualize the environment
# while True:
#     # _, _, _, _ = env.step([0, 0, 0, 0, 0, 0, 0, 0.5])
#     env.render()

In [49]:
goal_state_config

tensor([ 0.0298,  0.0221,  0.8355,  0.0237,  0.0270,  0.8314, -0.0080, -0.0045,
         0.9694,  0.2452,  0.0061, -0.0049,  0.0041,  0.9862, -0.1468,  0.0732,
         0.0202], device='cuda:0', dtype=torch.float64)

In [None]:
# tensor([-0.0984,  0.1381,  1.0242,  0.0256, -0.0211,  0.8311,  0.0000,  0.0000,
#          0.9863, -0.1648, -0.1240,  0.1592,  0.1931,  0.9967,  0.0800,  0.0082,
#          0.0115])

tensor([-5.0452e-01,  4.3512e-01,  8.1461e-01,  4.2443e-03,  3.7202e-03,
         8.2142e-01,  3.7909e-17,  5.9400e-18,  1.3081e-01,  9.9141e-01,
        -5.0877e-01,  4.3140e-01, -6.8032e-03,  7.0655e-01, -6.4116e-01,
        -2.9877e-01, -2.0888e-02])

In [171]:

# height, width, _ = env.render().shape
# fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for MP4 video
# video_writer = cv2.VideoWriter(
# "env_policy_video_lift.mp4", fourcc, 30, (width, height))
goal = goal_state_config.to("cpu")
# goal[3:6]=torch.zeros(3)
# goal = torch.zeros_like(goal_state_config)
# goal[10:13] = torch.zeros(3)
# # goal[10]=4.1
# # goal[11]=4
# goal[12]=4
# goal[7:] = goal_req

state = extract_features(env.reset())
desired_position = np.array([-0.02, 0.222, 2])

# Find the cube's joint name in the environment
cube_name = "cube_main"
cube_joint = env.sim.model.body_name2id(cube_name)

# Set the position of the cube
env.sim.data.set_joint_qpos(
    f"cube_joint0", np.concatenate([desired_position, np.zeros(4)]))

# Ensure the simulation state is updated
env.sim.forward()
# goal = state
# print(state)
value_function = value_network.to("cpu")
goal_proposal_vae.to("cpu")
action_vae.to("cpu")
low_level_policy.to("cpu")

for step in range(1000):
    # Render the environment and capture the frame
    frame = state
    # Write frame to video
    # video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    goal_tensor = torch.tensor(goal, dtype=torch.float32)
    # VAE update
    state_tensor = torch.tensor(state, dtype=torch.float32)
    mu, logvar = goal_proposal_vae.encode(goal_tensor, state_tensor)
    z = goal_proposal_vae.reparameterize(mu, logvar)
    goal = goal_proposal_vae.decode(z, state_tensor)
    goal_final = goal
    # value = value_function(goal, torch.squeeze(
    #     low_level_policy(state_tensor, goal)))
    # for k in range(10):
    #     z = goal_proposal_vae.reparameterize(mu, logvar)
    #     goal = goal_proposal_vae.decode(z, state_tensor)
    #     if value_function(goal, torch.squeeze(low_level_policy(state_tensor, goal))) > value:
    #         goal_final = goal
    #         value = value_function(
    #             goal, torch.squeeze(low_level_policy(state_tensor, goal)))
    goal = goal_final
    next_state = None
    # Get the action from the low-level policy
    reward = 0
    for _ in range(1):

        state_tensor = torch.tensor(state, dtype=torch.float32)
        action = low_level_policy(
            state_tensor, goal).detach().numpy()
        # video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        action = np.squeeze(action)
        # First joint is not activated
        next_state, reward, done, _ = env.step(
            np.concatenate((np.array([0]), action)))
        # frame = env.render(mode="rgb_array")
        env.render()
        # video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        state = extract_features(next_state)
    #     if reward > 0:
    #         print(
    #             f"Episode {step+1}, Step {step+1}: Goal reached!")
    #         break
    #     if done:
    #         print(state)
    #         break  # Terminate the episode if done is True
    # if done:
    #     break
    # if reward > 0:
    #     break
# for i in range(10):     frame = env.render()
#     state = env.reset()
#     goal = state
#     state=extract_features(state)
#     action = np.random.randn(env.robots[0].dof)  # sample random action
#     # take action in the environment
#     obs, reward, done, info = env.step(action)
#     env.render()  # render on display

  goal_tensor = torch.tensor(goal, dtype=torch.float32)
  state_tensor = torch.tensor(state, dtype=torch.float32)
  state_tensor = torch.tensor(state, dtype=torch.float32)


KeyboardInterrupt: 

In [219]:
import cv2
import torch
import numpy as np

# Define video dimensions and initialize video writer
width, height = 640, 480
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for MP4 video
video_writer = cv2.VideoWriter(
    "env_policy_video_lift_work3.mp4", fourcc, 30, (width, height))

goal = goal_state_config.to("cpu")
# goal[12] = 4
# goal[7:] = goal_req
goal = goal_state_config.to("cpu")
# goal[3:6]=torch.zeros(3)
# goal = torch.zeros_like(goal_state_config)
# goal[10:13] = torch.zeros(3)
# # goal[10]=4.1
# # goal[11]=4
# goal[12]=4
# goal[7:] = goal_req

state = extract_features(env.reset())
# desired_position = np.array([-0.120, 0.228, 0.8])
desired_position = np.array([-0.04, 0.24, 0.8])

# Find the cube's joint name in the environment
cube_name = "cube_main"
cube_joint = env.sim.model.body_name2id(cube_name)

# Set the position of the cube
env.sim.data.set_joint_qpos(
    f"cube_joint0", np.concatenate([desired_position, np.zeros(4)]))

# Ensure the simulation state is updated
env.sim.forward()

# Move models to CPU
value_function = value_network.to("cpu")
goal_proposal_vae.to("cpu")
action_vae.to("cpu")
low_level_policy.to("cpu")

# state = extract_features(env.reset())  # Initial state

for step in range(1000):
    # Render from camera directly in robosuite
    frame = env.sim.render(camera_name="frontview", width=width, height=height)
    frame = np.flipud(frame)  # Flip frame if needed for orientation
    if frame is not None:
        video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    goal_tensor = torch.tensor(goal, dtype=torch.float32)
    state_tensor = torch.tensor(state, dtype=torch.float32)
    env.render()
    # VAE operations
    mu, logvar = goal_proposal_vae.encode(goal_tensor, state_tensor)
    z = goal_proposal_vae.reparameterize(mu, logvar)
    goal = goal_proposal_vae.decode(z, state_tensor)
    goal_final = goal
    value = value_function(goal, torch.squeeze(
        low_level_policy(state_tensor, goal)))

    # for k in range(10):
    #     z = goal_proposal_vae.reparameterize(mu, logvar)
    #     goal = goal_proposal_vae.decode(z, state_tensor)
    #     current_value = value_function(goal, torch.squeeze(
    #         low_level_policy(state_tensor, goal)))
    #     if current_value > value:
    #         goal_final = goal
    #         value = current_value

    goal = goal_final  # Set optimized goal
    next_state = None

    # Execute actions
    for _ in range(3):
        state_tensor = torch.tensor(state, dtype=torch.float32)
        action = low_level_policy(state_tensor, goal).detach().numpy()
        action = np.squeeze(action)

        # Apply action and capture new frame
        next_state, reward, done, _ = env.step(
            np.concatenate((np.array([0]), action)))
        frame = env.sim.render(camera_name="frontview",
                               width=width, height=height)
        frame = np.flipud(frame)
        if frame is not None:
            video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        state = extract_features(next_state)

    #     if reward > 0:
    #         print(f"Episode {step + 1}, Step {step + 1}: Goal reached!")
    #         break
        if done:
            break

    if done:
        break

# Release video writer
video_writer.release()
print("Video saved as env_policy_video_lift.mp4")

  goal_tensor = torch.tensor(goal, dtype=torch.float32)
  state_tensor = torch.tensor(state, dtype=torch.float32)
  state_tensor = torch.tensor(state, dtype=torch.float32)


Video saved as env_policy_video_lift.mp4
