In [276]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import numpy as np

In [277]:
class LowLevelPolicy(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, action_dim=2, hidden_dim=128):
        super(LowLevelPolicy, self).__init__()
        self.rnn = nn.LSTM(state_dim + goal_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, action_dim)

    def forward(self, state, goal):
        input_seq = torch.cat((state, goal), dim=-1)
        out, _ = self.rnn(torch.unsqueeze(input_seq, 0))
        actions = self.fc(out)
        return actions

In [278]:
class GoalProposalVAE(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, latent_dim=20):
        super(GoalProposalVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + goal_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, goal_dim)
        )

    def encode(self, x, c):
        h = self.encoder(torch.cat((x, c), dim=0))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z,c)
        inputs = torch.cat([z, c], 0)
        h3 = self.decoder(inputs)
        return h3

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z, c)
        return decoded, mu, logvar

    # def sample(self, num_samples, y):
    #     with torch.no_grad():
    #         z = torch.randn(num_samples, self.num_hidden)
    #         samples = self.decoder(self.condition_on_label(z, y))
    #     return samples

In [279]:
class ValueNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state, action):
        return self.fc(torch.cat((state, action), dim=-1))

In [280]:
def cvaeLoss(sg, D, mu, logvar, beta=0.0001):
    recon_loss = torch.nn.functional.mse_loss(D, sg)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss

In [281]:
class ActionProposalVAE(nn.Module):
    def __init__(self, state_dim=4, action_dim=2, latent_dim=20):
        super(ActionProposalVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )
        # self.label_projector = nn.Sequential(
        #     nn.Linear(state_dim, latent_dim), nn.ReLU())

    def encode(self, x, c):
        h = self.encoder(torch.cat((x, c), dim=0))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z,c)
        inputs = torch.cat([z, c], 0)
        h3 = self.decoder(inputs)
        return h3

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z, c)
        return decoded, mu, logvar

In [282]:
import gym
import d4rl
env = gym.make("maze2d-umaze-v1")
dataset = env.get_dataset()
print(dataset['observations'].shape)

load datafile: 100%|██████████| 8/8 [00:00<00:00, 32.44it/s]

(1000000, 4)





In [283]:
# def train_IRIS(low_level_policy, goal_proposal_vae: GoalProposalVAE, action_vae: ActionProposalVAE, value_network, dataset, num_iterations=100000, trajectory_length=7):
#     policy_optimizer = optim.Adam(low_level_policy.parameters(), lr=0.001)
#     vae_optimizer = optim.Adam(goal_proposal_vae.parameters(), lr=0.001)
#     value_optimizer = optim.Adam(value_network.parameters(), lr=0.001)
#     action_optimizer = optim.Adam(action_vae.parameters(), lr=0.001)
#     M = 10
#     gamma = 0.99
#     # Assume observations are organized by trajectories
#     num_trajectories = len(dataset['observations']) // trajectory_length

#     for iteration in range(num_iterations):
#         # Sample a trajectory index (rather than random individual steps)
#         trajectory_idx = np.random.choice(num_trajectories)

#         # Define start and end indices for the trajectory
#         start_idx = trajectory_idx * trajectory_length
#         end_idx = start_idx + trajectory_length

#         # Sample the entire trajectory (states, actions, goals, rewards)
#         states = torch.tensor(
#             dataset['observations'][start_idx:end_idx], dtype=torch.float32)
#         actions = torch.tensor(
#             dataset['actions'][start_idx:end_idx], dtype=torch.float32)
#         goals = torch.tensor(dataset['infos/goal']
#                              [start_idx:end_idx], dtype=torch.float32)
#         rewards = torch.tensor(
#             dataset['rewards'][start_idx:end_idx], dtype=torch.float32)
#         actions = actions[:-1]
#         sg = states[-1]
#         s_start = states[0]
#         reward_sg = rewards[-2]
#         actionlast = actions[-2]
#         statesecondlast = states[-2]

#         # Train Low-Level Policy
#         policy_actions = []
#         for state in states:
#             policy_actions.append(low_level_policy(state, sg))
#         policy_actions = policy_actions[:-1]
#         policy_actions = torch.stack(policy_actions)
#         policy_actions = torch.squeeze(policy_actions)
#         policy_loss = nn.MSELoss()(policy_actions, actions)
#         policy_optimizer.zero_grad()
#         policy_loss.backward()
#         policy_optimizer.step()

#         # VAE update
#         mu, logvar = goal_proposal_vae.encode(sg, s_start)
#         z = goal_proposal_vae.reparameterize(mu, logvar)
#         vae_loss = cvaeLoss(
#             sg, goal_proposal_vae.decode(z, s_start), mu, logvar)

#         mua, logvara = action_vae.encode(actionlast, statesecondlast)
#         za = action_vae.reparameterize(mua, logvara)
#         actionvae_loss = cvaeLoss(actionlast, action_vae.decode(
#             za, statesecondlast), mua, logvara)
#         action_optimizer.zero_grad()
#         actionvae_loss.backward()
#         action_optimizer.step()
#         # Perform the sampling operation
#         sampled_actions = []
#         for _ in range(M):
#             sampled_action = action_vae.decode(za, sg)
#             sampled_actions.append(sampled_action)

#         sampled_actions = torch.stack(sampled_actions)
#         values = []
#         for action in sampled_actions:
#             value = value_network(sg, action)
#             values.append(value)
#         values = torch.stack(values)
#         max_value = torch.max(values)
#         Vbar = reward_sg+gamma*max_value.detach()
#         value_loss = nn.MSELoss()(Vbar, value_network(statesecondlast, actionlast))
#         vae_optimizer.zero_grad()
#         vae_loss.backward()
#         vae_optimizer.step()

#         value_optimizer.zero_grad()
#         value_loss.backward()
#         value_optimizer.step()

#         if iteration % 1000 == 0:
#             print(
#                 f"Iteration {iteration}: VAE Loss: {vae_loss.item():.4f}, Policy Loss: {policy_loss.item():.4f}, Value Loss: {value_loss.item():.4f}")


# state_dim = dataset['observations'].shape[1]
# state_goal_dim = dataset['observations'].shape[1]
# action_dim = dataset['actions'].shape[1]
# latent_dim = 20

# low_level_policy = LowLevelPolicy(state_dim, state_goal_dim, action_dim)
# goal_proposal_vae = GoalProposalVAE(state_dim, state_goal_dim, latent_dim)
# action_vae = ActionProposalVAE(state_dim, action_dim, latent_dim)
# value_network = ValueNetwork(state_dim, action_dim)

# # Train the IRIS algorithm using the D4RL dataset
# train_IRIS(low_level_policy, goal_proposal_vae,
#            action_vae, value_network, dataset)

In [284]:
# import torch
# import numpy as np
# import torch.optim as optim
# import torch.nn as nn


# def train_IRIS(low_level_policy, goal_proposal_vae: GoalProposalVAE, action_vae: ActionProposalVAE, value_network, dataset, num_iterations=20000, trajectory_length=50):
#     policy_optimizer = optim.Adam(low_level_policy.parameters(), lr=0.001)
#     vae_optimizer = optim.Adam(goal_proposal_vae.parameters(), lr=0.001)
#     value_optimizer = optim.Adam(value_network.parameters(), lr=0.001)
#     action_optimizer = optim.Adam(action_vae.parameters(), lr=0.001)
#     M = 10
#     gamma = 0.99
#     num_trajectories = len(dataset['observations']) // trajectory_length

#     # Variables to store cumulative statistics
#     cumulative_rewards = []
#     cumulative_policy_loss = []
#     cumulative_vae_loss = []
#     cumulative_value_loss = []

#     for iteration in range(num_iterations):
#         trajectory_idx = np.random.choice(num_trajectories)
#         start_idx = trajectory_idx * trajectory_length
#         end_idx = start_idx + trajectory_length

#         states = torch.tensor(
#             dataset['observations'][start_idx:end_idx], dtype=torch.float32)
#         actions = torch.tensor(
#             dataset['actions'][start_idx:end_idx], dtype=torch.float32)
#         rewards = torch.tensor(
#             dataset['rewards'][start_idx:end_idx], dtype=torch.float32)

#         actions = actions[:-1]
#         sg = states[-1]
#         s_start = states[0]
#         reward_sg = rewards[-2]
#         actionlast = actions[-2]
#         statesecondlast = states[-2]

#         # Train Low-Level Policy
#         policy_actions = []
#         for state in states:
#             policy_actions.append(low_level_policy(state, sg))
#         policy_actions = policy_actions[:-1]
#         policy_actions = torch.stack(policy_actions)
#         policy_actions = torch.squeeze(policy_actions)
#         policy_loss = nn.MSELoss()(policy_actions, actions)
#         policy_optimizer.zero_grad()
#         policy_loss.backward()
#         policy_optimizer.step()

#         # VAE update
#         mu, logvar = goal_proposal_vae.encode(sg, s_start)
#         z = goal_proposal_vae.reparameterize(mu, logvar)
#         vae_loss = cvaeLoss(
#             sg, goal_proposal_vae.decode(z, s_start), mu, logvar)

#         mua, logvara = action_vae.encode(actionlast, statesecondlast)
#         za = action_vae.reparameterize(mua, logvara)
#         actionvae_loss = cvaeLoss(actionlast, action_vae.decode(
#             za, statesecondlast), mua, logvara)
#         action_optimizer.zero_grad()
#         actionvae_loss.backward()
#         action_optimizer.step()

#         # Perform sampling and value update
#         sampled_actions = []
#         for _ in range(M):
#             sampled_action = action_vae.decode(za, sg)
#             sampled_actions.append(sampled_action)
#         sampled_actions = torch.stack(sampled_actions)

#         values = []
#         for action in sampled_actions:
#             value = value_network(sg, action)
#             values.append(value)
#         values = torch.stack(values)
#         max_value = torch.max(values)
#         Vbar = reward_sg + gamma * max_value.detach()
#         value_loss = nn.MSELoss()(Vbar, value_network(statesecondlast, actionlast))

#         # Update optimizers
#         vae_optimizer.zero_grad()
#         vae_loss.backward()
#         vae_optimizer.step()

#         value_optimizer.zero_grad()
#         value_loss.backward()
#         value_optimizer.step()

#         # Store losses and reward
#         cumulative_rewards.append(reward_sg.item())
#         cumulative_policy_loss.append(policy_loss.item())
#         cumulative_vae_loss.append(vae_loss.item())
#         cumulative_value_loss.append(value_loss.item())

#         # Print averages every 1000 iterations
#         if iteration % 1000 == 0 and iteration > 0:
#             avg_reward = np.mean(cumulative_rewards)
#             avg_policy_loss = np.mean(cumulative_policy_loss)
#             avg_vae_loss = np.mean(cumulative_vae_loss)
#             avg_value_loss = np.mean(cumulative_value_loss)

#             print(f"Iteration {iteration}: Avg Reward: {avg_reward:.4f}, "
#                   f"Avg Policy Loss: {avg_policy_loss:.4f}, "
#                   f"Avg VAE Loss: {avg_vae_loss:.4f}, "
#                   f"Avg Value Loss: {avg_value_loss:.4f}")

#             # Reset cumulative statistics
#             cumulative_rewards = []
#             cumulative_policy_loss = []
#             cumulative_vae_loss = []
#             cumulative_value_loss = []


# state_dim = dataset['observations'].shape[1]
# state_goal_dim = dataset['observations'].shape[1]
# action_dim = dataset['actions'].shape[1]
# latent_dim = 20

# low_level_policy = LowLevelPolicy(state_dim, state_goal_dim, action_dim)
# goal_proposal_vae = GoalProposalVAE(state_dim, state_goal_dim, latent_dim)
# action_vae = ActionProposalVAE(state_dim, action_dim, latent_dim)
# value_network = ValueNetwork(state_dim, action_dim)

# # Train the IRIS algorithm using the D4RL dataset
# train_IRIS(low_level_policy, goal_proposal_vae,
#            action_vae, value_network, dataset)

In [285]:
import torch
import numpy as np
import torch.optim as optim
import torch.nn as nn


def train_IRIS(low_level_policy, goal_proposal_vae: GoalProposalVAE, action_vae: ActionProposalVAE, value_network, dataset, num_iterations=40000, trajectory_length=5):
    # Move models to GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    low_level_policy = low_level_policy.to(device)
    goal_proposal_vae = goal_proposal_vae.to(device)
    action_vae = action_vae.to(device)
    value_network = value_network.to(device)

    policy_optimizer = optim.Adam(low_level_policy.parameters(), lr=0.0001)
    vae_optimizer = optim.Adam(goal_proposal_vae.parameters(), lr=0.001)
    value_optimizer = optim.Adam(value_network.parameters(), lr=0.01)
    action_optimizer = optim.Adam(action_vae.parameters(), lr=0.001)
    M = 30
    gamma = 0.99
    num_trajectories = len(dataset['observations']) // trajectory_length

    # Variables to store cumulative statistics
    cumulative_rewards = []
    cumulative_policy_loss = []
    cumulative_vae_loss = []
    cumulative_value_loss = []

    for iteration in range(num_iterations):
        trajectory_idx = iteration
        start_idx = trajectory_idx * trajectory_length
        end_idx = start_idx + trajectory_length

        # Move data to GPU
        states = torch.tensor(
            dataset['observations'][start_idx:end_idx], dtype=torch.float32).to(device)
        actions = torch.tensor(
            dataset['actions'][start_idx:end_idx], dtype=torch.float32).to(device)
        rewards = torch.tensor(
            dataset['rewards'][start_idx:end_idx], dtype=torch.float32).to(device)

        actions = actions[:-1]
        sg = states[-1]
        s_start = states[0]
        reward_sg = rewards[-2]
        actionlast = actions[-2]
        statesecondlast = states[-2]

        # Train Low-Level Policy
        policy_actions = []
        for state in states:
            policy_actions.append(low_level_policy(state, sg))
        policy_actions = policy_actions[:-1]
        policy_actions = torch.stack(policy_actions)
        policy_actions = torch.squeeze(policy_actions)
        policy_loss = nn.MSELoss()(policy_actions, actions)
        policy_optimizer.zero_grad()
        policy_loss.backward()
        policy_optimizer.step()

        # VAE update
        mu, logvar = goal_proposal_vae.encode(sg, s_start)
        z = goal_proposal_vae.reparameterize(mu, logvar)
        vae_loss = cvaeLoss(
            sg, goal_proposal_vae.decode(z, s_start), mu, logvar)

        mua, logvara = action_vae.encode(actionlast, statesecondlast)
        za = action_vae.reparameterize(mua, logvara)
        actionvae_loss = cvaeLoss(actionlast, action_vae.decode(
            za, statesecondlast), mua, logvara)
        action_optimizer.zero_grad()
        actionvae_loss.backward()
        action_optimizer.step()

        # Perform sampling and value update
        sampled_actions = []
        for _ in range(M):
            sampled_action = action_vae.decode(za, sg)
            sampled_actions.append(sampled_action)
        sampled_actions = torch.stack(sampled_actions)

        values = []
        for action in sampled_actions:
            value = value_network(sg, action)
            values.append(value)
        values = torch.stack(values)
        max_value = torch.max(values)
        Vbar = reward_sg + gamma * max_value.detach()
        Vbar = Vbar.unsqueeze(0)
        value_loss = nn.MSELoss()(Vbar, value_network(statesecondlast, actionlast))

        # Update optimizers
        vae_optimizer.zero_grad()
        vae_loss.backward()
        vae_optimizer.step()

        value_optimizer.zero_grad()
        value_loss.backward()
        value_optimizer.step()

        # Store losses and reward
        cumulative_rewards.append(reward_sg.item())
        cumulative_policy_loss.append(policy_loss.item())
        cumulative_vae_loss.append(vae_loss.item())
        cumulative_value_loss.append(value_loss.item())

        # Print averages every 1000 iterations
        if iteration % 1000 == 0 and iteration > 0:
            avg_reward = np.mean(cumulative_rewards)
            avg_policy_loss = np.mean(cumulative_policy_loss)
            avg_vae_loss = np.mean(cumulative_vae_loss)
            avg_value_loss = np.mean(cumulative_value_loss)

            print(f"Iteration {iteration}: Avg Reward: {avg_reward:.4f}, "
                  f"Avg Policy Loss: {avg_policy_loss:.4f}, "
                  f"Avg VAE Loss: {avg_vae_loss:.4f}, "
                  f"Avg Value Loss: {avg_value_loss:.4f}")

            # Reset cumulative statistics
            cumulative_rewards = []
            cumulative_policy_loss = []
            cumulative_vae_loss = []
            cumulative_value_loss = []


# Assuming dataset and models are already initialized
state_dim = dataset['observations'].shape[1]
state_goal_dim = dataset['observations'].shape[1]
action_dim = dataset['actions'].shape[1]
latent_dim = 8

low_level_policy = LowLevelPolicy(state_dim, state_goal_dim, action_dim)
goal_proposal_vae = GoalProposalVAE(state_dim, state_goal_dim, latent_dim)
action_vae = ActionProposalVAE(state_dim, action_dim, latent_dim)
value_network = ValueNetwork(state_dim, action_dim)

# Train the IRIS algorithm using the D4RL dataset
train_IRIS(low_level_policy, goal_proposal_vae,
           action_vae, value_network, dataset)

Iteration 1000: Avg Reward: 0.0929, Avg Policy Loss: 0.5161, Avg VAE Loss: 0.5435, Avg Value Loss: 0.4712
Iteration 2000: Avg Reward: 0.0440, Avg Policy Loss: 0.4861, Avg VAE Loss: 0.1007, Avg Value Loss: 0.1934
Iteration 3000: Avg Reward: 0.0730, Avg Policy Loss: 0.4603, Avg VAE Loss: 0.0830, Avg Value Loss: 3.1286
Iteration 4000: Avg Reward: 0.0720, Avg Policy Loss: 0.4362, Avg VAE Loss: 0.0933, Avg Value Loss: 1.4655
Iteration 5000: Avg Reward: 0.1060, Avg Policy Loss: 0.4029, Avg VAE Loss: 0.0559, Avg Value Loss: 0.8981
Iteration 6000: Avg Reward: 0.0950, Avg Policy Loss: 0.3719, Avg VAE Loss: 0.0719, Avg Value Loss: 0.3152
Iteration 7000: Avg Reward: 0.0640, Avg Policy Loss: 0.3455, Avg VAE Loss: 0.0569, Avg Value Loss: 0.2787
Iteration 8000: Avg Reward: 0.0860, Avg Policy Loss: 0.3179, Avg VAE Loss: 0.0506, Avg Value Loss: 0.4265
Iteration 9000: Avg Reward: 0.0610, Avg Policy Loss: 0.2955, Avg VAE Loss: 0.0410, Avg Value Loss: 0.3590
Iteration 10000: Avg Reward: 0.0720, Avg Polic

In [331]:
import cv2
import numpy as np
import torch


def visualize_policy_as_video(low_level_policy, env, num_episodes=1, max_steps=10000, save_path="env_policy_video.mp4"):
    # Define video writer using OpenCV
    height, width, _ = env.render(mode="rgb_array").shape
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for MP4 video
    video_writer = cv2.VideoWriter(save_path, fourcc, 30, (width, height))

    for episode in range(num_episodes):
        state = env.reset()
        goal = state.copy()  # Assuming the goal is part of the observation for simplicity

        for step in range(max_steps):
            # Render the environment and capture the frame
            frame = env.render(mode="rgb_array")
            # Write frame to video
            video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

            goal_tensor = torch.tensor(goal, dtype=torch.float32)
            # VAE update
            state_tensor = torch.tensor(state, dtype=torch.float32)
            mu, logvar = goal_proposal_vae.encode(goal_tensor, state_tensor)
            z = goal_proposal_vae.reparameterize(mu, logvar)
            goal = goal_proposal_vae.decode(z, state_tensor)

            next_state = None
            # Get the action from the low-level policy
            reward = 0
            for _ in range(3):
                state_tensor = torch.tensor(state, dtype=torch.float32)
                action = low_level_policy(
                    state_tensor, goal).detach().numpy()
                frame = env.render(mode="rgb_array")
                video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
                action = np.squeeze(action)
                next_state, reward, done, _ = env.step(action)
                frame = env.render(mode="rgb_array")
                video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
                state = next_state
                if reward > 0:
                    print(
                        f"Episode {episode+1}, Step {step+1}: Goal reached!")
                    break
                if done:
                    break  # Terminate the episode if done is True
            if done:
                break
            if reward > 0:
                break
    # Release the video writer after finishing
    video_writer.release()
    print(f"Video saved to {save_path}")


# Visualize the learned policy as a video
visualize_policy_as_video(low_level_policy, env)

  goal_tensor = torch.tensor(goal, dtype=torch.float32)


Episode 1, Step 26: Goal reached!
Video saved to env_policy_video.mp4
