In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import numpy as np

In [33]:
class LowLevelPolicy(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, action_dim=2, hidden_dim=128):
        super(LowLevelPolicy, self).__init__()
        self.rnn = nn.LSTM(state_dim + goal_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, action_dim)

    def forward(self, state, goal):
        input_seq = torch.cat((state, goal), dim=-1)
        out, _ = self.rnn(torch.unsqueeze(input_seq, 0))
        actions = self.fc(out)
        return actions

In [34]:
class GoalProposalVAE(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, latent_dim=20):
        super(GoalProposalVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + goal_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, goal_dim)
        )

    def encode(self, x, c):
        h = self.encoder(torch.cat((x, c), dim=0))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z,c)
        inputs = torch.cat([z, c], 0)
        h3 = self.decoder(inputs)
        return h3

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z, c)
        return decoded, mu, logvar

    # def sample(self, num_samples, y):
    #     with torch.no_grad():
    #         z = torch.randn(num_samples, self.num_hidden)
    #         samples = self.decoder(self.condition_on_label(z, y))
    #     return samples

In [35]:
class ValueNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state, action):
        return self.fc(torch.cat((state, action), dim=-1))

In [36]:
def cvaeLoss(sg, D, mu, logvar, beta=0.0001):
    recon_loss = torch.nn.functional.mse_loss(D, sg)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss

In [37]:
class ActionProposalVAE(nn.Module):
    def __init__(self, state_dim=4, action_dim=2, latent_dim=20):
        super(ActionProposalVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )
        # self.label_projector = nn.Sequential(
        #     nn.Linear(state_dim, latent_dim), nn.ReLU())

    def encode(self, x, c):
        h = self.encoder(torch.cat((x, c), dim=0))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z,c)
        inputs = torch.cat([z, c], 0)
        h3 = self.decoder(inputs)
        return h3

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z, c)
        return decoded, mu, logvar

In [38]:
import gym
import d4rl
env = gym.make("maze2d-medium-v1")
dataset = env.get_dataset()
print(dataset['observations'].shape)

Downloading dataset: http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5 to /home/keerthi/.d4rl/datasets/maze2d-medium-sparse-v1.hdf5


load datafile: 100%|██████████| 8/8 [00:00<00:00, 15.78it/s]

(2000000, 4)





In [39]:
import numpy as np


def split_into_trajectories(dataset):
    observations = dataset['observations']
    actions = dataset['actions']
    rewards = dataset['rewards']
    # 'dones' in some datasets are called 'terminals'
    dones = dataset['terminals']

    trajectories = []
    current_trajectory = {
        'observations': [],
        'actions': [],
        'rewards': []
    }

    for i in range(len(observations)):
        # Append the current timestep's data to the current trajectory
        current_trajectory['observations'].append(observations[i])
        current_trajectory['actions'].append(actions[i])
        current_trajectory['rewards'].append(rewards[i])

        # If the 'done' flag is True, the current trajectory ends
        if rewards[i] == 1:
            # Convert lists to numpy arrays
            current_trajectory['observations'] = np.array(
                current_trajectory['observations'])
            current_trajectory['actions'] = np.array(
                current_trajectory['actions'])
            current_trajectory['rewards'] = np.array(
                current_trajectory['rewards'])

            # Add the current trajectory to the list of trajectories
            trajectories.append(current_trajectory)

            # Reset the current trajectory
            current_trajectory = {
                'observations': [],
                'actions': [],
                'rewards': []
            }

    return trajectories


# Split the dataset into a list of trajectories
trajectories = split_into_trajectories(dataset)

In [40]:
len(trajectories)

47333

In [None]:
import torch
import numpy as np
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
'''
Run this training loop for full trajecory from dataset'''


def train_IRIS_full_trajectory(low_level_policy, goal_proposal_vae: GoalProposalVAE, action_vae: ActionProposalVAE, value_network, trajectories, trajectory_length=5):
    # Move models to GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    low_level_policy = low_level_policy.to(device)
    goal_proposal_vae = goal_proposal_vae.to(device)
    action_vae = action_vae.to(device)
    value_network = value_network.to(device)

    policy_optimizer = optim.Adam(low_level_policy.parameters(), lr=0.0001)
    vae_optimizer = optim.Adam(goal_proposal_vae.parameters(), lr=0.001)
    value_optimizer = optim.Adam(value_network.parameters(), lr=0.01)
    action_optimizer = optim.Adam(action_vae.parameters(), lr=0.001)
    M = 30
    gamma = 0.99
    num_trajectories = len(dataset['observations']) // trajectory_length

    # Variables to store cumulative statistics
    cumulative_rewards = []
    cumulative_policy_loss = []
    cumulative_vae_loss = []
    cumulative_value_loss = []
    iteration = 0
    for trajectory in tqdm(trajectories):
        # Move data to GPU
        iteration += 1
        states = torch.tensor(
            trajectory['observations'], dtype=torch.float32).to(device)
        actions = torch.tensor(
            trajectory['actions'], dtype=torch.float32).to(device)
        rewards = torch.tensor(
            trajectory['rewards'], dtype=torch.float32).to(device)
        if len(rewards) == 1:
            continue
        if len(actions) <= 1:
            continue
        actions = actions[:-1]
        sg = states[-1]
        s_start = states[0]
        reward_sg = rewards[-2]
        actionlast = actions[-2]
        statesecondlast = states[-2]

        # Train Low-Level Policy
        policy_actions = []
        for state in states:
            policy_actions.append(low_level_policy(state, sg))
        policy_actions = policy_actions[:-1]
        policy_actions = torch.stack(policy_actions)
        policy_actions = torch.squeeze(policy_actions)
        policy_loss = nn.MSELoss()(policy_actions, actions)
        policy_optimizer.zero_grad()
        policy_loss.backward()
        policy_optimizer.step()

        # VAE update
        mu, logvar = goal_proposal_vae.encode(sg, s_start)
        z = goal_proposal_vae.reparameterize(mu, logvar)
        vae_loss = cvaeLoss(
            sg, goal_proposal_vae.decode(z, s_start), mu, logvar)

        mua, logvara = action_vae.encode(actionlast, statesecondlast)
        za = action_vae.reparameterize(mua, logvara)
        actionvae_loss = cvaeLoss(actionlast, action_vae.decode(
            za, statesecondlast), mua, logvara)
        action_optimizer.zero_grad()
        actionvae_loss.backward()
        action_optimizer.step()

        # Perform sampling and value update
        sampled_actions = []
        for _ in range(M):
            sampled_action = action_vae.decode(za, sg)
            sampled_actions.append(sampled_action)
        sampled_actions = torch.stack(sampled_actions)

        values = []
        for action in sampled_actions:
            value = value_network(sg, action)
            values.append(value)
        values = torch.stack(values)
        max_value = torch.max(values)
        Vbar = reward_sg + gamma * max_value.detach()
        Vbar = Vbar.unsqueeze(0)
        value_loss = nn.MSELoss()(Vbar, value_network(statesecondlast, actionlast))

        # Update optimizers
        vae_optimizer.zero_grad()
        vae_loss.backward()
        vae_optimizer.step()

        value_optimizer.zero_grad()
        value_loss.backward()
        value_optimizer.step()

        # Store losses and reward
        cumulative_rewards.append(reward_sg.item())
        cumulative_policy_loss.append(policy_loss.item())
        cumulative_vae_loss.append(vae_loss.item())
        cumulative_value_loss.append(value_loss.item())

        # Print averages every 1000 iterations
        if iteration % 1000 == 0 and iteration > 0:
            avg_reward = np.mean(cumulative_rewards)
            avg_policy_loss = np.mean(cumulative_policy_loss)
            avg_vae_loss = np.mean(cumulative_vae_loss)
            avg_value_loss = np.mean(cumulative_value_loss)

            print(f"Iteration {iteration}: Avg Reward: {avg_reward:.4f}, "
                  f"Avg Policy Loss: {avg_policy_loss:.4f}, "
                  f"Avg VAE Loss: {avg_vae_loss:.4f}, "
                  f"Avg Value Loss: {avg_value_loss:.4f}")

            # Reset cumulative statistics
            cumulative_rewards = []
            cumulative_policy_loss = []
            cumulative_vae_loss = []
            cumulative_value_loss = []


# Assuming dataset and models are already initialized
state_dim = dataset['observations'].shape[1]
state_goal_dim = dataset['observations'].shape[1]
action_dim = dataset['actions'].shape[1]
latent_dim = 8

low_level_policy = LowLevelPolicy(state_dim, state_goal_dim, action_dim)
goal_proposal_vae = GoalProposalVAE(state_dim, state_goal_dim, latent_dim)
action_vae = ActionProposalVAE(state_dim, action_dim, latent_dim)
value_network = ValueNetwork(state_dim, action_dim)

# Train the IRIS algorithm using the D4RL dataset
train_IRIS_full_trajectory(low_level_policy, goal_proposal_vae,
                           action_vae, value_network, trajectories)

  0%|          | 0/47333 [00:00<?, ?it/s]

100%|██████████| 47333/47333 [24:54<00:00, 31.67it/s]


In [42]:
# import torch
# import numpy as np
# import torch.optim as optim
# import torch.nn as nn
# import matplotlib.pyplot as plt  # Importing for plotting

# '''
# Run this training loop for full trajectory from dataset
# '''


# def train_IRIS_full_trajectory(low_level_policy, goal_proposal_vae: GoalProposalVAE, action_vae: ActionProposalVAE, value_network, trajectories, trajectory_length=5):
#     # Move models to GPU
#     max_steps = 200
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     low_level_policy = low_level_policy.to(device)
#     goal_proposal_vae = goal_proposal_vae.to(device)
#     action_vae = action_vae.to(device)
#     value_network = value_network.to(device)

#     policy_optimizer = optim.Adam(low_level_policy.parameters(), lr=0.0001)
#     vae_optimizer = optim.Adam(goal_proposal_vae.parameters(), lr=0.001)
#     value_optimizer = optim.Adam(value_network.parameters(), lr=0.01)
#     action_optimizer = optim.Adam(action_vae.parameters(), lr=0.001)
#     M = 30
#     gamma = 0.99
#     num_trajectories = len(dataset['observations']) // trajectory_length

#     # Variables to store cumulative statistics
#     cumulative_rewards = []
#     cumulative_policy_loss = []
#     cumulative_vae_loss = []
#     cumulative_value_loss = []
#     iteration = 0

#     # For plotting
#     rewards_per_iteration = []  # Store average rewards over iterations
#     validation_reward = []
#     for trajectory in trajectories:
#         # Move data to GPU
#         iteration += 1
#         states = torch.tensor(
#             trajectory['observations'], dtype=torch.float32).to(device)
#         actions = torch.tensor(
#             trajectory['actions'], dtype=torch.float32).to(device)
#         rewards = torch.tensor(
#             trajectory['rewards'], dtype=torch.float32).to(device)
#         if len(rewards) == 1:
#             continue
#         actions = actions[:-1]
#         sg = states[-1]
#         s_start = states[0]
#         reward_sg = rewards[-2]
#         actionlast = actions[-2]
#         statesecondlast = states[-2]

#         # Train Low-Level Policy
#         policy_actions = []
#         for state in states:
#             policy_actions.append(low_level_policy(state, sg))
#         policy_actions = policy_actions[:-1]
#         policy_actions = torch.stack(policy_actions)
#         policy_actions = torch.squeeze(policy_actions)
#         policy_loss = nn.MSELoss()(policy_actions, actions)
#         policy_optimizer.zero_grad()
#         policy_loss.backward()
#         policy_optimizer.step()

#         # VAE update
#         mu, logvar = goal_proposal_vae.encode(sg, s_start)
#         z = goal_proposal_vae.reparameterize(mu, logvar)
#         vae_loss = cvaeLoss(
#             sg, goal_proposal_vae.decode(z, s_start), mu, logvar)

#         mua, logvara = action_vae.encode(actionlast, statesecondlast)
#         za = action_vae.reparameterize(mua, logvara)
#         actionvae_loss = cvaeLoss(actionlast, action_vae.decode(
#             za, statesecondlast), mua, logvara)
#         action_optimizer.zero_grad()
#         actionvae_loss.backward()
#         action_optimizer.step()

#         # Perform sampling and value update
#         sampled_actions = []
#         for _ in range(M):
#             sampled_action = action_vae.decode(za, sg)
#             sampled_actions.append(sampled_action)
#         sampled_actions = torch.stack(sampled_actions)

#         values = []
#         for action in sampled_actions:
#             value = value_network(sg, action)
#             values.append(value)
#         values = torch.stack(values)
#         max_value = torch.max(values)
#         Vbar = reward_sg + gamma * max_value.detach()
#         Vbar = Vbar.unsqueeze(0)
#         value_loss = nn.MSELoss()(Vbar, value_network(statesecondlast, actionlast))

#         # Update optimizers
#         vae_optimizer.zero_grad()
#         vae_loss.backward()
#         vae_optimizer.step()

#         value_optimizer.zero_grad()
#         value_loss.backward()
#         value_optimizer.step()

#         # Store losses and reward
#         cumulative_rewards.append(reward_sg.item())
#         cumulative_policy_loss.append(policy_loss.item())
#         cumulative_vae_loss.append(vae_loss.item())
#         cumulative_value_loss.append(value_loss.item())
#         state = env.reset()
#         goal = state.copy()  # Assuming the goal is part of the observation for simplicity

#         for step in range(max_steps):

#             goal_tensor = torch.tensor(goal, dtype=torch.float32)
#             state_tensor = torch.tensor(state, dtype=torch.float32)
#             mu, logvar = goal_proposal_vae.encode(goal_tensor, state_tensor)
#             z = goal_proposal_vae.reparameterize(mu, logvar)
#             goal = goal_proposal_vae.decode(z, state_tensor)

#             next_state = None
#             # Get the action from the low-level policy
#             reward = 0
#             for _ in range(3):
#                 state_tensor = torch.tensor(state, dtype=torch.float32)
#                 action = low_level_policy(
#                     state_tensor, goal).detach().numpy()
#                 frame = env.render(mode="rgb_array")
#                 action = np.squeeze(action)
#                 next_state, reward, done, _ = env.step(action)
#                 frame = env.render(mode="rgb_array")
#                 state = next_state
#                 if reward > 0:
#                     break
#                 if done:
#                     break  # Terminate the episode if done is True
#             if done:
#                 validation_reward.append([iter,0])
#                 break
#             if reward > 0:
#                 accuracy += 1
#                 validation_reward.append([iter,reward])
#                 break
#         # Print averages every 1000 iterations
#         if iteration % 1000 == 0 and iteration > 0:
#             avg_reward = np.mean(cumulative_rewards)
#             avg_policy_loss = np.mean(cumulative_policy_loss)
#             avg_vae_loss = np.mean(cumulative_vae_loss)
#             avg_value_loss = np.mean(cumulative_value_loss)

#             print(f"Iteration {iteration}: Avg Reward: {avg_reward:.4f}, "
#                   f"Avg Policy Loss: {avg_policy_loss:.4f}, "
#                   f"Avg VAE Loss: {avg_vae_loss:.4f}, "
#                   f"Avg Value Loss: {avg_value_loss:.4f}")

#             # Reset cumulative statistics
#             cumulative_rewards = []
#             cumulative_policy_loss = []
#             cumulative_vae_loss = []
#             cumulative_value_loss = []

#         # Store average reward after each iteration for plotting
#         rewards_per_iteration.append(np.mean(cumulative_rewards))

#     return rewards_per_iteration


# def plot_rewards(rewards_per_iteration):
#     plt.figure(figsize=(10, 5))
#     plt.plot(rewards_per_iteration, label="Mean Reward per Iteration")
#     plt.xlabel("Iteration")
#     plt.ylabel("Mean Reward")
#     plt.title("Mean Reward During Training")
#     plt.legend()
#     plt.show()


# # Assuming dataset and models are already initialized
# state_dim = dataset['observations'].shape[1]
# state_goal_dim = dataset['observations'].shape[1]
# action_dim = dataset['actions'].shape[1]
# latent_dim = 8

# low_level_policy = LowLevelPolicy(state_dim, state_goal_dim, action_dim)
# goal_proposal_vae = GoalProposalVAE(state_dim, state_goal_dim, latent_dim)
# action_vae = ActionProposalVAE(state_dim, action_dim, latent_dim)
# value_network = ValueNetwork(state_dim, action_dim)

# # Train the IRIS algorithm using the D4RL dataset
# rewards_per_iteration = train_IRIS_full_trajectory(
#     low_level_policy, goal_proposal_vae, action_vae, value_network, trajectories)

# # Plot the mean reward during training
# plot_rewards(rewards_per_iteration)

In [43]:
# import torch
# import numpy as np
# import torch.optim as optim
# import torch.nn as nn
# import random


# def train_IRIS_traj_sample(low_level_policy, goal_proposal_vae: GoalProposalVAE, action_vae: ActionProposalVAE, value_network, dataset, num_iterations=20000, trajectory_length=20):
#     # Move models to GPU
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     low_level_policy = low_level_policy.to(device)
#     goal_proposal_vae = goal_proposal_vae.to(device)
#     action_vae = action_vae.to(device)
#     value_network = value_network.to(device)

#     policy_optimizer = optim.Adam(low_level_policy.parameters(), lr=0.0001)
#     vae_optimizer = optim.Adam(goal_proposal_vae.parameters(), lr=0.001)
#     value_optimizer = optim.Adam(value_network.parameters(), lr=0.01)
#     action_optimizer = optim.Adam(action_vae.parameters(), lr=0.001)
#     M = 30
#     gamma = 0.99
#     num_trajectories = len(dataset['observations']) // trajectory_length

#     # Variables to store cumulative statistics
#     cumulative_rewards = []
#     cumulative_policy_loss = []
#     cumulative_vae_loss = []
#     cumulative_value_loss = []

#     for iteration in range(num_iterations):
#         trajectory_idx = random.choice([0, 100000-30])
#         start_idx = trajectory_idx * trajectory_length
#         end_idx = start_idx + trajectory_length

#         # Move data to GPU
#         states = torch.tensor(
#             dataset['observations'][start_idx:end_idx], dtype=torch.float32).to(device)
#         actions = torch.tensor(
#             dataset['actions'][start_idx:end_idx], dtype=torch.float32).to(device)
#         rewards = torch.tensor(
#             dataset['rewards'][start_idx:end_idx], dtype=torch.float32).to(device)
#         if len(states) <= 1:
#             continue
#         actions = actions[:-1]
#         sg = states[-1]
#         s_start = states[0]
#         reward_sg = rewards[-2]
#         actionlast = actions[-2]
#         statesecondlast = states[-2]

#         # Train Low-Level Policy
#         policy_actions = []
#         for state in states:
#             policy_actions.append(low_level_policy(state, sg))
#         policy_actions = policy_actions[:-1]
#         policy_actions = torch.stack(policy_actions)
#         policy_actions = torch.squeeze(policy_actions)
#         policy_loss = nn.MSELoss()(policy_actions, actions)
#         policy_optimizer.zero_grad()
#         policy_loss.backward()
#         policy_optimizer.step()

#         # VAE update
#         mu, logvar = goal_proposal_vae.encode(sg, s_start)
#         z = goal_proposal_vae.reparameterize(mu, logvar)
#         vae_loss = cvaeLoss(
#             sg, goal_proposal_vae.decode(z, s_start), mu, logvar)

#         mua, logvara = action_vae.encode(actionlast, statesecondlast)
#         za = action_vae.reparameterize(mua, logvara)
#         actionvae_loss = cvaeLoss(actionlast, action_vae.decode(
#             za, statesecondlast), mua, logvara)
#         action_optimizer.zero_grad()
#         actionvae_loss.backward()
#         action_optimizer.step()

#         # Perform sampling and value update
#         sampled_actions = []
#         for _ in range(M):
#             sampled_action = action_vae.decode(za, sg)
#             sampled_actions.append(sampled_action)
#         sampled_actions = torch.stack(sampled_actions)

#         values = []
#         for action in sampled_actions:
#             value = value_network(sg, action)
#             values.append(value)
#         values = torch.stack(values)
#         max_value = torch.max(values)
#         Vbar = reward_sg + gamma * max_value.detach()
#         Vbar = Vbar.unsqueeze(0)
#         value_loss = nn.MSELoss()(Vbar, value_network(statesecondlast, actionlast))

#         # Update optimizers
#         vae_optimizer.zero_grad()
#         vae_loss.backward()
#         vae_optimizer.step()

#         value_optimizer.zero_grad()
#         value_loss.backward()
#         value_optimizer.step()

#         # Store losses and reward
#         cumulative_rewards.append(reward_sg.item())
#         cumulative_policy_loss.append(policy_loss.item())
#         cumulative_vae_loss.append(vae_loss.item())
#         cumulative_value_loss.append(value_loss.item())

#         # Print averages every 1000 iterations
#         if iteration % 1000 == 0 and iteration > 0:
#             avg_reward = np.mean(cumulative_rewards)
#             avg_policy_loss = np.mean(cumulative_policy_loss)
#             avg_vae_loss = np.mean(cumulative_vae_loss)
#             avg_value_loss = np.mean(cumulative_value_loss)

#             print(f"Iteration {iteration}: Avg Reward: {avg_reward:.4f}, "
#                   f"Avg Policy Loss: {avg_policy_loss:.4f}, "
#                   f"Avg VAE Loss: {avg_vae_loss:.4f}, "
#                   f"Avg Value Loss: {avg_value_loss:.4f}")

#             # Reset cumulative statistics
#             cumulative_rewards = []
#             cumulative_policy_loss = []
#             cumulative_vae_loss = []
#             cumulative_value_loss = []


# # Assuming dataset and models are already initialized
# state_dim = dataset['observations'].shape[1]
# state_goal_dim = dataset['observations'].shape[1]
# action_dim = dataset['actions'].shape[1]
# latent_dim = 20

# low_level_policy = LowLevelPolicy(state_dim, state_goal_dim, action_dim)
# goal_proposal_vae = GoalProposalVAE(state_dim, state_goal_dim, latent_dim)
# action_vae = ActionProposalVAE(state_dim, action_dim, latent_dim)
# value_network = ValueNetwork(state_dim, action_dim)

# # Train the IRIS algorithm using the D4RL dataset
# train_IRIS_traj_sample(low_level_policy, goal_proposal_vae,
#                        action_vae, value_network, dataset)

In [48]:
import cv2
import numpy as np
import torch


def visualize_policy_as_video(low_level_policy, goal_proposal_vae, value_function, env, num_episodes=20, max_steps=10000, save_path="env_policy_video_5.mp4"):
    # Define video writer using OpenCV
    height, width, _ = env.render(mode="rgb_array").shape
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for MP4 video
    video_writer = cv2.VideoWriter(save_path, fourcc, 30, (width, height))
    accuracy = 0
    for episode in range(num_episodes):
        state = env.reset()
        goal = state.copy()  # Assuming the goal is part of the observation for simplicity

        for step in range(max_steps):
            # Render the environment and capture the frame
            frame = env.render(mode="rgb_array")
            # Write frame to video
            video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

            goal_tensor = torch.tensor(goal, dtype=torch.float32)
            # VAE update
            state_tensor = torch.tensor(state, dtype=torch.float32)
            mu, logvar = goal_proposal_vae.encode(goal_tensor, state_tensor)
            z = goal_proposal_vae.reparameterize(mu, logvar)
            goal = goal_proposal_vae.decode(z, state_tensor)
            goal_final = goal
            value = value_function(goal, torch.squeeze(
                low_level_policy(state_tensor, goal)))
            for k in range(5):
                z = goal_proposal_vae.reparameterize(mu, logvar)
                goal = goal_proposal_vae.decode(z, state_tensor)
                if value_function(goal, torch.squeeze(low_level_policy(state_tensor, goal))) > value:
                    goal_final = goal
                    value = value_function(
                        goal, torch.squeeze(low_level_policy(state_tensor, goal)))
            goal = goal_final
            next_state = None
            # Get the action from the low-level policy
            reward = 0
            for _ in range(2):
                state_tensor = torch.tensor(state, dtype=torch.float32)
                action = low_level_policy(
                    state_tensor, goal).detach().numpy()
                frame = env.render(mode="rgb_array")
                video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
                action = np.squeeze(action)
                next_state, reward, done, _ = env.step(action)
                frame = env.render(mode="rgb_array")
                video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
                state = next_state
                if reward > 0:
                    print(
                        f"Episode {episode+1}, Step {step+1}: Goal reached!")
                    break
                if done:
                    break  # Terminate the episode if done is True
            if done:
                break
            if reward > 0:
                accuracy += 1
                break
    # Release the video writer after finishing
    video_writer.release()
    print(f"Video saved to {save_path}")
    print(f"Accuracy is {accuracy/num_episodes*100.0:.4f}")


# Visualize the learned policy as a video
visualize_policy_as_video(low_level_policy.to(
    "cpu"), goal_proposal_vae.to("cpu"), value_network.to("cpu"), env)

  goal_tensor = torch.tensor(goal, dtype=torch.float32)


Episode 2, Step 196: Goal reached!
Episode 3, Step 110: Goal reached!
Episode 4, Step 106: Goal reached!
Episode 5, Step 47: Goal reached!
Episode 6, Step 18: Goal reached!
Episode 7, Step 167: Goal reached!
Episode 8, Step 135: Goal reached!
Episode 9, Step 183: Goal reached!
Episode 10, Step 235: Goal reached!
Episode 11, Step 53: Goal reached!
Episode 12, Step 19: Goal reached!
Episode 13, Step 55: Goal reached!
Episode 14, Step 69: Goal reached!
Episode 15, Step 107: Goal reached!
Episode 16, Step 159: Goal reached!
Episode 17, Step 167: Goal reached!
Episode 18, Step 1: Goal reached!
Episode 20, Step 204: Goal reached!
Video saved to env_policy_video_5.mp4
Accuracy is 90.0000


In [45]:
torch.save(low_level_policy.state_dict(),
           "/home/keerthi/IRIS/low_level_policy_full_traj.pth")
torch.save(goal_proposal_vae.state_dict(),
           "/home/keerthi/IRIS/goal_proposal_cvae_full_traj.pth")