In [1]:
from air_hockey_challenge.framework.air_hockey_challenge_wrapper import AirHockeyChallengeWrapper
from air_hockey_challenge.framework.challenge_core import ChallengeCore
from air_hockey_challenge.framework.agent_base import AgentBase
from examples.control.hitting_agent import build_agent, HittingAgent
from examples.control.hitting_agent_wait import HittingAgentWait

from mushroom_rl.utils.dataset import parse_dataset, select_random_samples
from mushroom_rl.policy import GaussianTorchPolicy

import torch
import torch.nn as nn

from tqdm import tqdm

import pickle

import numpy as np

use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'
print(f"Cuda: {use_cuda}")


Cuda: False


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

class MushroomRLTrajectoryDataset(Dataset):
    def __init__(self, mdp, agent, n_episodes):
        """
        Initialize the MushroomRLTrajectoryDataset.

        Args:
            mdp (mushroom_rl.environments.Environment): the MDP (Markov Decision Process).
            agent: the agent to evaluate.
        """
        self.core = ChallengeCore(agent, mdp)
                # Initialize empty lists for each data type
        state_list, action_list, reward_list, next_state_list, absorbing_list, last_list = [], [], [], [], [], []

        for i in tqdm(range(n_episodes)):
            # Evaluate one episode at a time
            trajectory = self.core.evaluate(n_episodes=1, render=False)
            state, action, reward, next_state, absorbing, last = parse_dataset(trajectory)

            # Append the data from the current episode
            state_list.append(torch.from_numpy(state).to(device))
            action_list.append(torch.from_numpy(action).to(device))
            reward_list.append(torch.from_numpy(reward).to(device))
            next_state_list.append(torch.from_numpy(next_state).to(device))
            absorbing_list.append(torch.from_numpy(absorbing).to(device))
            last_list.append(torch.from_numpy(last).to(device))

        # Concatenate the data from all episodes
        self.state = torch.cat(state_list, dim=0)
        self.action = torch.cat(action_list, dim=0)
        self.reward = torch.cat(reward_list, dim=0)
        self.next_state = torch.cat(next_state_list, dim=0)
        self.absorbing = torch.cat(absorbing_list, dim=0)
        self.last = torch.cat(last_list, dim=0)

        self.length = self.state.shape[0]

    def __len__(self):
        """
        Return the total number of state-action-reward-next_state tuples in the dataset.

        Returns:
            int: the length of the dataset.
        """
        return self.length

    def __getitem__(self, index):
        """
        Get the state-action-reward-next_state tuple at the specified index.

        Args:
            index (int): the index of the desired tuple.

        Returns:
            dict: a dictionary containing state, action, reward, and next_state.
        """

        return {
            'state': self.state[index],
            'action': self.action[index],
            'reward': self.reward[index],
            'next_state': self.next_state[index],
            'absorbing': self.absorbing[index],
            'last': self.last[index],
        }



In [3]:
class Network(nn.Module):
    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super(Network, self).__init__()

        self._h1 = nn.Linear(input_shape[0], n_features)
        self._h2 = nn.Linear(n_features, n_features)
        self._h3 = nn.Linear(n_features, output_shape[0])

        nn.init.xavier_uniform_(self._h1.weight,
                                gain=nn.init.calculate_gain('tanh'))
        nn.init.xavier_uniform_(self._h2.weight,
                                gain=nn.init.calculate_gain('tanh'))
        nn.init.xavier_uniform_(self._h3.weight,
                                gain=nn.init.calculate_gain('linear'))

    def forward(self, obs, **kwargs):
        features1 = torch.tanh(self._h1(torch.squeeze(obs, 1).float()))
        features2 = torch.tanh(self._h2(features1))
        return self._h3(features2)


class BCAgent(AgentBase):
    def __init__(self, env_info, policy, **kwargs):
        super().__init__(env_info, **kwargs)
        self.policy = policy

    def reset(self):
        pass

    def draw_action(self, observation):
        return self.policy.draw_action(observation).reshape(2,3)


In [6]:
def train_dagger_agent(learner_policy, expert_policy, mdp, dataset, num_iterations=10, batch_size=64, lr=0.001):
    # Create data loader
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Set up the optimizer and loss function
    optimizer = optim.Adam(learner_policy.parameters(), lr=lr)

    # Train the initial policy on collected data
    for epoch in range(num_iterations):
        for i, data in enumerate(data_loader):
            optimizer.zero_grad()

            states = data['state'].float()

            # Get expert actions for each state in the batch
            expert_actions = []

            for state in states:
                expert_policy.reset()
                state_np = state.cpu().numpy()

                expert_action = expert_policy.draw_action(state_np).astype(np.float32)
                print(f"optimization_failed: {expert_policy.optimization_failed}")
                print(f"expert_action: {expert_action}")
                expert_actions.append(torch.from_numpy(expert_action))
            expert_actions = torch.stack(expert_actions).reshape((-1, 6))


            # Compute loss
            print(f"states: {states.shape}")
            print(f"expert_actions: {expert_actions.shape}")
            loss = -policy.log_prob_t(states, expert_actions).mean()

            # Update the learner policy
            loss.backward()
            optimizer.step()

        # Evaluate the current learner policy
        learner_trajectory_dataset = MushroomRLTrajectoryDataset(mdp, learner_policy, n_episodes=2)
        data_loader = DataLoader(learner_trajectory_dataset, batch_size=batch_size, shuffle=True)

        # Collect new expert actions for the new states
        for state in learner_trajectory_dataset:
            state_np = state['state'].float().unsqueeze(0).cpu().numpy()
            expert_action = expert_policy.draw_action(state_np).astype(np.float32)
            state['action'] = torch.from_numpy(expert_action).squeeze().tolist()

        # Add the new data to the dataset
        dataset.trajectories.extend(learner_trajectory_dataset.trajectories)
        dataset.length += learner_trajectory_dataset.length

    return learner_policy


# 1. defining BCAgent and expert agent to be trained in dagger

In [5]:
env = "3dof-hit"

mdp = AirHockeyChallengeWrapper(env)
mdp.reset()

# policy can only output 1D actions (6,) ... they need to be recast in (2,3) shape later on
policy = GaussianTorchPolicy(Network, (12,), (6,), std_0=1., n_features=64, use_cuda=use_cuda)

policy = policy.load('dataset/hit_500_policy')
bc_agent = BCAgent(mdp.env_info, policy)
dataset = MushroomRLTrajectoryDataset(mdp, bc_agent, n_episodes=10)


expert_agent = HittingAgentWait(mdp.env_info)


  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:04<00:00,  4.22s/it][A
 10%|█         | 1/10 [00:04<00:38,  4.24s/it][A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  3.52it/s][A
 20%|██        | 2/10 [00:04<00:15,  1.92s/it][A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  2.55it/s][A
 30%|███       | 3/10 [00:04<00:08,  1.23s/it][A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:03<00:00,  3.77s/it][A
 40%|████      | 4/10 [00:08<00:13,  2.23s/it][A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:03<00:00,  3.78s/it][A
 50%|█████     | 5/10 [00:12<00:13,  2.80s/it][A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:03<00:00,  3.81s/it][A
 60%|██████    | 6/10 [00:16<00:12,  3.15s/it][A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:03<00:00,  3.77s/it][A
 70%|███████   | 7/10 [00:20<00:10,  

In [7]:
dagger_agent = train_dagger_agent(policy, expert_agent, mdp, dataset, num_iterations=2, batch_size=64, lr=0.001)

optimization_failed: True
expert_action: [[ 0.01066965 -0.2526577   0.17868517]
 [ 0.          0.          0.        ]]
optimization_failed: True
expert_action: [[0.01447077 0.12069628 0.26875716]
 [0.         0.         0.        ]]
optimization_failed: True
expert_action: [[-1.1182134  1.259235   1.4226865]
 [ 0.         0.         0.       ]]
optimization_failed: True
expert_action: [[ 0.79358053 -0.53868973 -0.39164957]
 [ 0.          0.          0.        ]]
optimization_failed: True
expert_action: [[-0.11016908  0.4653831   0.6047175 ]
 [ 0.          0.          0.        ]]
optimization_failed: True
expert_action: [[-0.20213531 -0.25361478  0.4561776 ]
 [ 0.          0.          0.        ]]
optimization_failed: True
expert_action: [[ 0.4317515  -0.35400933 -0.16908808]
 [ 0.          0.          0.        ]]
optimization_failed: True
expert_action: [[-0.2613153   0.0558968   0.57011336]
 [ 0.          0.          0.        ]]
optimization_failed: True
expert_action: [[-0.261786

Exception in thread Thread-22:
Traceback (most recent call last):
  File "/opt/anaconda3/envs/air_hockey_challenge/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/opt/anaconda3/envs/air_hockey_challenge/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/annetteader/PycharmProjects/air_hockey_challenge/examples/control/hitting_agent_wait.py", line 87, in _plan_trajectory_thread
    ee_traj, hit_idx, q_anchor = self.plan_ee_trajectory(puck_pos, ee_pos)
  File "/Users/annetteader/PycharmProjects/air_hockey_challenge/examples/control/hitting_agent_wait.py", line 124, in plan_ee_trajectory
    p = np.hstack([p, np.ones((p.shape[0], 1)) * self.ee_height])
  File "<__array_function__ internals>", line 200, in hstack
  File "/opt/anaconda3/envs/air_hockey_challenge/lib/python3.8/site-packages/numpy/core/shape_base.py", line 368, in hstack
    return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting)


optimization_failed: True
expert_action: [[ 0.2441919   0.01215265 -0.29426318]
 [ 0.          0.          0.        ]]
optimization_failed: True
expert_action: [[-0.49010512  0.82798964  0.04939831]
 [ 0.          0.          0.        ]]
optimization_failed: True
expert_action: [[ 0.02420683 -0.5291351   0.12030693]
 [ 0.          0.          0.        ]]
optimization_failed: True
expert_action: [[-0.05879253 -0.6914746  -0.73099333]
 [ 0.          0.          0.        ]]
optimization_failed: True
expert_action: [[-0.3364749   0.40097693  0.7003572 ]
 [ 0.          0.          0.        ]]


KeyboardInterrupt: 

In [33]:
data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
optimizer = optim.Adam(policy.parameters(), lr=.001)

for i, data in enumerate(data_loader):
    states = data['state']
    for state in states:
        print(state.shape)
    # optimizer.zero_grad()
    # states = data['state'].float()
    # expert_actions = expert_agent.draw_action(states).float()

torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12])
torch.Size([12