In [33]:
from env import BallReach

In [34]:
env = BallReach((640, 480))

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import numpy as np

In [5]:
class LowLevelPolicy(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, action_dim=2, hidden_dim=128):
        super(LowLevelPolicy, self).__init__()
        self.rnn = nn.LSTM(state_dim + goal_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, action_dim)

    def forward(self, state, goal):
        input_seq = torch.cat((state, goal), dim=-1)
        out, _ = self.rnn(torch.unsqueeze(input_seq, 0))
        actions = self.fc(out)
        return actions

In [6]:
class GoalProposalVAE(nn.Module):
    def __init__(self, state_dim=4, goal_dim=4, latent_dim=20):
        super(GoalProposalVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + goal_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, goal_dim)
        )

    def encode(self, x, c):
        h = self.encoder(torch.cat((x, c), dim=0))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z,c)
        inputs = torch.cat([z, c], 0)
        h3 = self.decoder(inputs)
        return h3

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z, c)
        return decoded, mu, logvar

    # def sample(self, num_samples, y):
    #     with torch.no_grad():
    #         z = torch.randn(num_samples, self.num_hidden)
    #         samples = self.decoder(self.condition_on_label(z, y))
    #     return samples

In [7]:
class ValueNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state, action):
        return self.fc(torch.cat((state, action), dim=-1))

In [8]:
def cvaeLoss(sg, D, mu, logvar, beta=0.0001):
    recon_loss = torch.nn.functional.mse_loss(D, sg)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss

In [9]:
class ActionProposalVAE(nn.Module):
    def __init__(self, state_dim=4, action_dim=2, latent_dim=20):
        super(ActionProposalVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )
        # self.label_projector = nn.Sequential(
        #     nn.Linear(state_dim, latent_dim), nn.ReLU())

    def encode(self, x, c):
        h = self.encoder(torch.cat((x, c), dim=0))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):  # P(x|z,c)
        inputs = torch.cat([z, c], 0)
        h3 = self.decoder(inputs)
        return h3

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z, c)
        return decoded, mu, logvar

In [10]:
dataset = np.load(
    '/home/keerthi/IRIS/BallReach/ballreach_dataset_continous.npz')

In [11]:
dataset['observations'].shape

(4082, 2)

In [12]:
max(dataset['rewards'])

1.0

In [13]:
dataset['actions'].shape

(4082, 2)

In [14]:
dataset.keys

<bound method Mapping.keys of NpzFile '/home/keerthi/IRIS/BallReach/ballreach_dataset_continous.npz' with keys: observations, actions, next_observations, rewards, terminals>

In [15]:
def split_into_trajectories(dataset):
    # Initialize list to store trajectories
    trajectories = []
    current_trajectory = {key: [] for key in dataset.keys()}

    for i in range(len(dataset['rewards'])):
        # Append each key's data for the current step
        for key in dataset.keys():
            current_trajectory[key].append(dataset[key][i])

        # If reward is 10, finalize the current trajectory
        if dataset['rewards'][i] == 1:
            # Append completed trajectory to the list and reset current trajectory
            trajectories.append(
                {key: current_trajectory[key] for key in current_trajectory})
            current_trajectory = {key: [] for key in dataset.keys()}

    # Add any remaining steps as a final trajectory if not empty
    if current_trajectory['rewards']:
        trajectories.append(current_trajectory)

    return trajectories

In [16]:
dataset['actions']

array([[ 0.        ,  0.56006829],
       [-0.39010135,  0.42860776],
       [-0.77362204,  0.53963346],
       ...,
       [ 0.        ,  0.33661993],
       [ 0.73662977,  0.64769791],
       [-0.74849565,  0.75005955]])

In [17]:
traj = split_into_trajectories(dataset)

In [18]:
len(traj)

163

In [19]:
traj[90]['actions']

[array([0.       , 0.5560624]),
 array([0.        , 0.34249263]),
 array([0.        , 0.40938305]),
 array([0.09720208, 0.74936002]),
 array([0.7747063 , 0.71463588]),
 array([0.        , 0.39739082]),
 array([0.87617789, 0.74414573]),
 array([0.        , 0.30110621]),
 array([0.        , 0.69523896]),
 array([0.        , 0.66704971]),
 array([0.        , 0.74193264]),
 array([0.        , 0.67372962]),
 array([0.       , 0.7737488]),
 array([0.        , 0.63158808]),
 array([0.       , 0.4953053])]

In [20]:
import torch
import torch.nn as nn
import torch.nn.init as init


def xavier_init(m):
    if isinstance(m, nn.Linear):
        init.xavier_uniform_(m.weight)
        if m.bias is not None:
            # Set a small constant bias initialization
            m.bias.data.fill_(0.0001)

In [36]:
import torch
import numpy as np
import torch.optim as optim
import torch.nn as nn

'''
Run this training loop for full trajecory from dataset'''


def train_IRIS_full_trajectory(low_level_policy, goal_proposal_vae: GoalProposalVAE, action_vae: ActionProposalVAE, value_network, trajectories, trajectory_length=5):
    # Move models to GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    low_level_policy = low_level_policy.to(device)
    goal_proposal_vae = goal_proposal_vae.to(device)
    action_vae = action_vae.to(device)
    value_network = value_network.to(device)

    policy_optimizer = optim.Adam(low_level_policy.parameters(), lr=0.0001)
    vae_optimizer = optim.Adam(goal_proposal_vae.parameters(), lr=0.001)
    value_optimizer = optim.Adam(value_network.parameters(), lr=0.01)
    action_optimizer = optim.Adam(action_vae.parameters(), lr=0.001)
    M = 30
    gamma = 0.99
    num_trajectories = len(dataset['observations']) // trajectory_length

    # Variables to store cumulative statistics
    cumulative_rewards = []
    cumulative_policy_loss = []
    cumulative_vae_loss = []
    cumulative_value_loss = []
    iteration = 0
    for trajectory in trajectories:
        # Move data to GPU
        iteration += 1
        states = torch.tensor(
            trajectory['observations'], dtype=torch.float32).to(device)
        actions = torch.tensor(
            trajectory['actions'], dtype=torch.float32).to(device)
        rewards = torch.tensor(
            trajectory['rewards'], dtype=torch.float32).to(device)
        if len(rewards) == 1:
            continue
        actions = actions[:-1]
        sg = states[-1]
        s_start = states[0]
        reward_sg = rewards[-2]
        actionlast = actions[-2]
        statesecondlast = states[-2]

        # Train Low-Level Policy
        policy_actions = []
        for state in states:
            policy_actions.append(low_level_policy(state, sg))
        policy_actions = policy_actions[:-1]
        policy_actions = torch.stack(policy_actions)
        policy_actions = torch.squeeze(policy_actions)
        policy_loss = nn.MSELoss()(policy_actions, actions)
        policy_optimizer.zero_grad()
        policy_loss.backward()
        policy_optimizer.step()

        # VAE update
        mu, logvar = goal_proposal_vae.encode(sg, s_start)
        z = goal_proposal_vae.reparameterize(mu, logvar)
        vae_loss = cvaeLoss(
            sg, goal_proposal_vae.decode(z, s_start), mu, logvar)

        mua, logvara = action_vae.encode(actionlast, statesecondlast)
        za = action_vae.reparameterize(mua, logvara)
        actionvae_loss = cvaeLoss(actionlast, action_vae.decode(
            za, statesecondlast), mua, logvara)
        action_optimizer.zero_grad()
        actionvae_loss.backward()
        action_optimizer.step()

        # Perform sampling and value update
        sampled_actions = []
        for _ in range(M):
            sampled_action = action_vae.decode(za, sg)
            sampled_actions.append(sampled_action)
        sampled_actions = torch.stack(sampled_actions)

        values = []
        for action in sampled_actions:
            value = value_network(sg, action)
            values.append(value)
        values = torch.stack(values)
        max_value = torch.max(values)
        Vbar = reward_sg + gamma * max_value.detach()
        Vbar = Vbar.unsqueeze(0)
        value_loss = nn.MSELoss()(Vbar, value_network(statesecondlast, actionlast))

        # Update optimizers
        vae_optimizer.zero_grad()
        vae_loss.backward()
        vae_optimizer.step()

        value_optimizer.zero_grad()
        value_loss.backward()
        value_optimizer.step()

        # Store losses and reward
        cumulative_rewards.append(reward_sg.item())
        cumulative_policy_loss.append(policy_loss.item())
        cumulative_vae_loss.append(vae_loss.item())
        cumulative_value_loss.append(value_loss.item())

        # Print averages every 1000 iterations
        if iteration % 10 == 0 and iteration > 0:
            avg_reward = np.mean(cumulative_rewards)
            avg_policy_loss = np.mean(cumulative_policy_loss)
            avg_vae_loss = np.mean(cumulative_vae_loss)
            avg_value_loss = np.mean(cumulative_value_loss)

            print(f"Iteration {iteration}: Avg Reward: {avg_reward:.4f}, "
                  f"Avg Policy Loss: {avg_policy_loss:.4f}, "
                  f"Avg VAE Loss: {avg_vae_loss:.4f}, "
                  f"Avg Value Loss: {avg_value_loss:.4f}")

            # Reset cumulative statistics
            cumulative_rewards = []
            cumulative_policy_loss = []
            cumulative_vae_loss = []
            cumulative_value_loss = []


# Assuming dataset and models are already initialized
state_dim = dataset['observations'].shape[1]
state_goal_dim = dataset['observations'].shape[1]
action_dim = dataset['actions'].shape[1]
latent_dim = 8
low_level_policy = LowLevelPolicy(state_dim, state_goal_dim, action_dim)
goal_proposal_vae = GoalProposalVAE(state_dim, state_goal_dim, latent_dim)
action_vae = ActionProposalVAE(state_dim, action_dim, latent_dim)
value_network = ValueNetwork(state_dim, action_dim)

# Train the IRIS algorithm using the D4RL dataset
train_IRIS_full_trajectory(low_level_policy, goal_proposal_vae,
                           action_vae, value_network, traj)

Iteration 10: Avg Reward: -0.2427, Avg Policy Loss: 0.4163, Avg VAE Loss: 3598636337562992.0000, Avg Value Loss: 1007779327087.5797
Iteration 20: Avg Reward: -0.2360, Avg Policy Loss: 0.3730, Avg VAE Loss: 1266233179.9719, Avg Value Loss: 1093.8010
Iteration 30: Avg Reward: -0.3117, Avg Policy Loss: 0.2891, Avg VAE Loss: 476777.5969, Avg Value Loss: 1659.5016
Iteration 40: Avg Reward: -0.1511, Avg Policy Loss: 0.2647, Avg VAE Loss: 463692.8438, Avg Value Loss: 1159.3336
Iteration 50: Avg Reward: -0.2926, Avg Policy Loss: 0.2168, Avg VAE Loss: 487382.5344, Avg Value Loss: 510.9566
Iteration 60: Avg Reward: -0.1759, Avg Policy Loss: 0.1792, Avg VAE Loss: 452874.9813, Avg Value Loss: 224.0255
Iteration 70: Avg Reward: -0.2470, Avg Policy Loss: 0.1623, Avg VAE Loss: 442762.4281, Avg Value Loss: 242.0969
Iteration 80: Avg Reward: -0.2393, Avg Policy Loss: 0.1452, Avg VAE Loss: 449769.5156, Avg Value Loss: 382.4279
Iteration 90: Avg Reward: -0.1310, Avg Policy Loss: 0.1553, Avg VAE Loss: 436

In [35]:
accuracy = 0
state = env.state
goal = state.copy()
goal_proposal_vae = goal_proposal_vae.to('cpu')
low_level_policy = low_level_policy.to('cpu')
for step in range(1000):
    # Render the environment and capture the frame
    # Write frame to video
    # video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    goal_tensor = torch.tensor(goal, dtype=torch.float32)
    # VAE update
    state_tensor = torch.tensor(state, dtype=torch.float32)
    mu, logvar = goal_proposal_vae.encode(goal_tensor, state_tensor)
    z = goal_proposal_vae.reparameterize(mu, logvar)
    goal = goal_proposal_vae.decode(z, state_tensor)

    next_state = None
    # Get the action from the low-level policy
    reward = 0
    for _ in range(3):
        state_tensor = torch.tensor(state, dtype=torch.float32)
        action = low_level_policy(
            state_tensor, goal).detach().numpy()
        # video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        action = np.squeeze(action)
        env.step(state)
        # video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        state = env.state
        if reward > 0:
            print(
                f"Step {step+1}: Goal reached!")
            break

State: [360, 50], Action: [0.         0.73710374], Next State: [360.         123.71037395], Reward: -0.0
State: [360.         123.71037395], Action: [0.85860199 0.60680916], Next State: [445.86019903 184.39128945], Reward: -0.8586019903069624
State: [445.86019903 184.39128945], Action: [0.         0.72244528], Next State: [445.86019903 256.63581696], Reward: -0.0
State: [445.86019903 256.63581696], Action: [0.         0.52177777], Next State: [445.86019903 308.8135939 ], Reward: -0.0


  goal_tensor = torch.tensor(goal, dtype=torch.float32)


State: [445.86019903 308.8135939 ], Action: [0.         0.36782419], Next State: [445.86019903 345.59601263], Reward: -0.0
State: [445.86019903 345.59601263], Action: [0.         0.55045389], Next State: [445.86019903 400.64140201], Reward: -0.0
State: [445.86019903 400.64140201], Action: [0.         0.47851148], Next State: [445.86019903 448.4925498 ], Reward: -0.0
State: [445.86019903 448.4925498 ], Action: [-0.6843877   0.37170819], Next State: [377.42142861 480.        ], Reward: -0.6843877042376454
State: [377.42142861 480.        ], Action: [0.         0.42143001], Next State: [377.42142861 480.        ], Reward: -0.0


ValueError: Can't normalize Vector of length Zero