# Setting up training with codebase

In [1]:
import torch
import torch.nn as nn

In [2]:
!pwd

/home/andang/workspace/Classes/spring_2023/eecs298/air_hockey_challenge


### Example dummy agent.
- this agent holds its current position/vel
- doesn't do anything.
- used to show how to setup an agent.

In [3]:
import numpy as np
from air_hockey_challenge.framework import AgentBase


def build_agent(env_info, **kwargs):
    """
    Function where an Agent that controls the environments should be returned.
    The Agent should inherit from the mushroom_rl Agent base env.

    Args:
        env_info (dict): The environment information
        kwargs (any): Additionally setting from agent_config.yml
    Returns:
         (AgentBase) An instance of the Agent
    """

    return DummyAgent(env_info, **kwargs)

'''
    We will compute the action which is desired pos & vel.
'''

class DummyAgent(AgentBase):
    def __init__(self, env_info, **kwargs):
        super().__init__(env_info, **kwargs)
        self.new_start = True
        self.hold_position = None

    def reset(self):
        self.new_start = True
        self.hold_position = None

    def draw_action(self, obs):
        # breaking down observation into something we can use
        # self.get_joint_pos(obs), self.get_joint_vel(obs), self.get_puck_pos(obs)
        
        hold_position = self.get_joint_pos(observation)
        velocity = np.zeros_like(hold_position)
        action = np.vstack([hold_position, velocity])
        return action

### Setup PyTorch Neural Network to convert observation space to action space

In [33]:
#super stupid forward pass neural network
class ActionGenerator(nn.Module):
    def __init__(self, input_dim, num_layers, layer_width, output_dim, activation = nn.LeakyReLU(0.1) ):
        super().__init__()
        layers = [nn.Linear(input_dim, layer_width), activation]
        for i in range(num_layers-1):
            layers.append(nn.Linear(layer_width, layer_width))
            layers.append(activation)
        layers.append(nn.Linear(layer_width, output_dim))
        layers.append(activation)
        
        self.model = nn.Sequential(*layers)
    def forward(self, x):
        out = self.model(x)
        return out

In [34]:
network = ActionGenerator(8,20,10,6)
print(network(torch.zeros(1,8)).shape)

torch.Size([1, 6])


In [6]:
#we are gonna try using a DQN for this
from mushroom_rl.algorithms.value import DQN

### Setup DeepDummy Agent

In [41]:
class DeepDummyAgent(AgentBase):
    def __init__(self, env_info, act_shape, load_path=None, **kwargs):
        super().__init__(env_info, **kwargs)
        self.new_start = True
        self.hold_position = None
        
        #observation is shape 12 
        self.model = ActionGenerator(6, 20,10, act_shape)
        
        if not (load_path is None):
            self.model.load_state_dict(load_path)
        
        
    def reset(self):
        self.new_start = True
        self.hold_position = None

    def draw_action(self, obs):
        # breaking down observation into something we can use
        # self.get_joint_pos(obs), self.get_joint_vel(obs), self.get_puck_pos(obs)
        
        #obs is a numpy array
        
        puck_pos, puck_vel = self.get_puck_state(obs)
        
        puck_state = np.hstack((puck_pos,puck_vel))[:,np.newaxis].T
        
        action = self.model(torch.tensor(puck_state,dtype=torch.float))
        action = torch.reshape(action,(2,3))
        return action.detach().numpy()

In [43]:
#setting this up

import numpy as np
from air_hockey_challenge.framework.air_hockey_challenge_wrapper import AirHockeyChallengeWrapper
from air_hockey_challenge.framework.challenge_core import ChallengeCore

mdp = AirHockeyChallengeWrapper(env="3dof-hit", action_type="position-velocity", interpolation_order=3, debug=True)
agent = DeepDummyAgent(mdp.base_env.env_info,6)

core = ChallengeCore(agent, mdp)

core.learn(n_episodes=10, n_episodes_per_fit=10, render=False) #render allows us to visualize what's going on














  0%|                                                                                                                                                      | 0/10 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A












 10%|██████████████▏                                                                                                                               | 1/10 [00:01<00:16,  1.87s/it][A[A[A[A[A[A[A[A[A[A[A[A[A












 20%|████████████████████████████▍                                                                                                                 | 2/10 [00:03<00:14,  1.84s/it][A[A[A[A[A[A[A[A[A[A[A[A[A












 30%|██████████████████████████████████████████▌                                                                                                   | 3/10 [00:05<00:12,  1.83s/it][A[A[A[A[A[A[A[A[A[A[A[A[A












 40%|████████████████████████████████████████████████████████▊     

NotImplementedError: Agent is an abstract class