Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use policy.forward to compute an action from given observation for TicTacToe example? #208

Closed
4 of 8 tasks
wfng92 opened this issue Sep 8, 2020 · 5 comments · Fixed by #210
Closed
4 of 8 tasks
Labels
question Further information is requested

Comments

@wfng92
Copy link

wfng92 commented Sep 8, 2020

  • I have marked all applicable categories:
    • exception-raising bug
    • RL algorithm bug
    • documentation request (i.e. "X is missing from the documentation.")
    • new feature request
  • I have visited the source website
  • I have searched through the issue tracker for duplicates
  • I have mentioned version numbers, operating system and environment, where applicable:
    import tianshou, torch, sys
    print(tianshou.__version__, torch.__version__, sys.version, sys.platform)

0.2.6 1.6.0+cpu Python version 3.7.7

Hi, I am having issue with computing a single action from given observation. I have trained a policy based on the MultiAgent TicTacToe example (#121). According to the current documentation, the code for computing the action in a MultiAgentEnv is as follows:

env = MultiAgentEnv(...)
# obs is a dict containing obs, agent_id, and mask
obs = env.reset()
action = policy(obs)
obs, rew, done, info = env.step(action)
env.close()

I have attempted to run it inside tic_tac_toe.py by reusing the get_agents function as my policy and got the following error

AttributeError: 'dict' object has no attribute 'obs'

In addition, any class that inherits from BasePolicy comes with forward() function which accepts a Batch input parameter. The documentation is quite lacking on how to call it but I assume it should be called as follows

action = policy.forward(Batch(obs=obs, info=None))

I got the following error when running the code above

IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed

Subsequently, I tested it on CartPole DQN using the following code

# net and optim initialization
policy = DQNPolicy(net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq)
policy.load_state_dict(torch.load('log/CartPole-v0/dqn/policy.pth'))

obs = env.reset()
action = policy.forward(Batch(obs=obs, info=None))

I got the following error which indicates size mismatch

RuntimeError: size mismatch, m1: [4 x 1], m2: [4 x 128] at ..\aten\src\TH/generic/THTensorMath.cpp:41

Hence, I reshape the array as follows and successfully get the action from the DQN policy.

obs = np.array(obs)
obs = obs.reshape((1,-1))

# before [-0.0378007  -0.02966665 -0.01692553 -0.0120008 ]
# after [[-0.0378007  -0.02966665 -0.01692553 -0.0120008 ]]

May I know how to compute the action in a MultiAgentEnv. I am trying to run it as a AI bot which plays against human player. It might look something like this

while True:
    # if player, get input action
    # if bot, get an action from policy

    obs, reward, done, info = env.step(action)
    env.render()
    if done:
        break

An example or an update to the documentation would be greatly appreciated. Thanks a lot~

@Trinkle23897 Trinkle23897 added the question Further information is requested label Sep 8, 2020
@Trinkle23897
Copy link
Collaborator

Trinkle23897 commented Sep 8, 2020

I have attempted to run it inside tic_tac_toe.py by reusing the get_agents function as my policy and got the following error

AttributeError: 'dict' object has no attribute 'obs'

Could you please provide the detailed traceback (including the lines of code and traceback stack)?
Possibly it is because somewhere we assume it is a Batch but actually it is a dict.

In addition, any class that inherits from BasePolicy comes with forward() function which accepts a Batch input parameter. The documentation is quite lacking on how to call it

will be updated:
forward accepts a batch + a state, also with some other arguments (algorithm-specific). The first dimension of all variables in batch should be equal to the batch-size.

but I assume it should be called as follows

action = policy.forward(Batch(obs=obs, info=None))

The first dimension of obs or other variables should be the batch-size.

May I know how to compute the action in a MultiAgentEnv. I am trying to run it as a AI bot which plays against human player. It might look something like this

while True:
    # if player, get input action
    # if bot, get an action from policy

    obs, reward, done, info = env.step(action)
    env.render()
    if done:
        break

to be further discussed

@wfng92
Copy link
Author

wfng92 commented Sep 9, 2020

Could you please provide the detailed traceback (including the lines of code and traceback stack)?
Possibly it is because somewhere we assume it is a Batch but actually it is a dict.

I duplicated the watch function inside tictactoe.py and modified it as follows:

env = TicTacToeEnv(args.board_size, args.win_size)
obs = env.reset() # added for reference to obs object

# policy.eval()
# policy.policies[args.agent_id - 1].set_eps(args.eps_test)
# collector = Collector(policy, env)
# result = collector.collect(n_episode=1, render=args.render)
# print(f'Final reward: {result["rew"]}, length: {result["len"]}')

policy, optim = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)

print(obs)
action = policy(obs)

The output of obs is as follows:

{'agent_id': 1, 'obs': array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]]), 'mask': array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True])}

The stacktrace is as follows:

Traceback (most recent call last):
  File "test_tic_tac_toe.py", line 23, in <module>
    test_tic_tac_toe(get_args())
  File "test_tic_tac_toe.py", line 9, in test_tic_tac_toe
    watch2(args)
  File "C:\Users\...\tic_tac_toe.py", line 188, in watch2
    action = policy(obs)
  File "C:\Users\...\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\...\lib\site-packages\tianshou\policy\multiagent\mapolicy.py", line 90, in forward
    agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
AttributeError: 'dict' object has no attribute 'obs'

Further test with using the forward function instead of just policy(obs)

# action = policy(obs)
action = policy.forward(Batch(obs=obs,info=None))

yield the following error

Traceback (most recent call last):
  File "test_tic_tac_toe.py", line 23, in <module>
    test_tic_tac_toe(get_args())
  File "test_tic_tac_toe.py", line 9, in test_tic_tac_toe
    watch2(args)
  File "C:\Users\...\tic_tac_toe.py", line 190, in watch2
    action = policy.forward(Batch(obs=obs,info=None))
  File "C:\Users\...\lib\site-packages\tianshou\policy\multiagent\mapolicy.py", line 95, in forward
    tmp_batch = batch[agent_index]
  File "C:\Users\...\lib\site-packages\tianshou\data\batch.py", line 216, in __getitem__
    b.__dict__[k] = v[index]
  File "C:\Users\...\lib\site-packages\tianshou\data\batch.py", line 216, in __getitem__
    b.__dict__[k] = v[index]
IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed

forward accepts a batch + a state, also with some other arguments (algorithm-specific). The first dimension of all variables in batch should be equal to the batch-size.

Thanks for the information. It would be great if there are examples on how to create a Batch + state and pass it as input parameters to forward function.

to be further discussed

Thanks a lot. Looking forward to it.

@Trinkle23897
Copy link
Collaborator

Trinkle23897 commented Sep 9, 2020

The output of obs is as follows:

{'agent_id': 1, 'obs': array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]]), 'mask': array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True])}

So the working version is:

In [4]: Batch(obs=[obs])
Out[4]: 
Batch(
    obs: Batch(
             obs: array([[[0, 0, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0]]]),
             agent_id: array([1]),
             mask: array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
                            True,  True,  True,  True,  True,  True,  True,  True,  True,
                            True,  True,  True,  True,  True,  True,  True,  True,  True,
                            True,  True,  True,  True,  True,  True,  True,  True,  True]]),
         ),
)

The first dimension of all variables in batch should be equal to the batch-size.

here, [obs] create the 0-dim to be the batch-size. Otherwise, there's no batch-size. You can have a try on policy(Batch(obs=[obs])).

@wfng92
Copy link
Author

wfng92 commented Sep 9, 2020

Thanks a lot. It is working properly now. Looking forward to more features from this cool package~

@wfng92 wfng92 closed this as completed Sep 9, 2020
@Trinkle23897 Trinkle23897 linked a pull request Sep 10, 2020 that will close this issue
@Trinkle23897
Copy link
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants