New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use policy.forward to compute an action from given observation for TicTacToe example? #208
Comments
Could you please provide the detailed traceback (including the lines of code and traceback stack)?
will be updated:
The first dimension of obs or other variables should be the batch-size.
to be further discussed |
I duplicated the watch function inside env = TicTacToeEnv(args.board_size, args.win_size)
obs = env.reset() # added for reference to obs object
# policy.eval()
# policy.policies[args.agent_id - 1].set_eps(args.eps_test)
# collector = Collector(policy, env)
# result = collector.collect(n_episode=1, render=args.render)
# print(f'Final reward: {result["rew"]}, length: {result["len"]}')
policy, optim = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
print(obs)
action = policy(obs) The output of obs is as follows: {'agent_id': 1, 'obs': array([[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]), 'mask': array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True])} The stacktrace is as follows: Traceback (most recent call last):
File "test_tic_tac_toe.py", line 23, in <module>
test_tic_tac_toe(get_args())
File "test_tic_tac_toe.py", line 9, in test_tic_tac_toe
watch2(args)
File "C:\Users\...\tic_tac_toe.py", line 188, in watch2
action = policy(obs)
File "C:\Users\...\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\...\lib\site-packages\tianshou\policy\multiagent\mapolicy.py", line 90, in forward
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
AttributeError: 'dict' object has no attribute 'obs' Further test with using the forward function instead of just policy(obs) # action = policy(obs)
action = policy.forward(Batch(obs=obs,info=None)) yield the following error Traceback (most recent call last):
File "test_tic_tac_toe.py", line 23, in <module>
test_tic_tac_toe(get_args())
File "test_tic_tac_toe.py", line 9, in test_tic_tac_toe
watch2(args)
File "C:\Users\...\tic_tac_toe.py", line 190, in watch2
action = policy.forward(Batch(obs=obs,info=None))
File "C:\Users\...\lib\site-packages\tianshou\policy\multiagent\mapolicy.py", line 95, in forward
tmp_batch = batch[agent_index]
File "C:\Users\...\lib\site-packages\tianshou\data\batch.py", line 216, in __getitem__
b.__dict__[k] = v[index]
File "C:\Users\...\lib\site-packages\tianshou\data\batch.py", line 216, in __getitem__
b.__dict__[k] = v[index]
IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed
Thanks for the information. It would be great if there are examples on how to create a Batch + state and pass it as input parameters to forward function.
Thanks a lot. Looking forward to it. |
So the working version is: In [4]: Batch(obs=[obs])
Out[4]:
Batch(
obs: Batch(
obs: array([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]]),
agent_id: array([1]),
mask: array([[ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True]]),
),
)
here, |
Thanks a lot. It is working properly now. Looking forward to more features from this cool package~ |
0.2.6 1.6.0+cpu Python version 3.7.7
Hi, I am having issue with computing a single action from given observation. I have trained a policy based on the MultiAgent TicTacToe example (#121). According to the current documentation, the code for computing the action in a MultiAgentEnv is as follows:
I have attempted to run it inside
tic_tac_toe.py
by reusing theget_agents
function as my policy and got the following errorIn addition, any class that inherits from
BasePolicy
comes withforward()
function which accepts aBatch
input parameter. The documentation is quite lacking on how to call it but I assume it should be called as followsI got the following error when running the code above
Subsequently, I tested it on CartPole DQN using the following code
I got the following error which indicates size mismatch
Hence, I reshape the array as follows and successfully get the action from the DQN policy.
May I know how to compute the action in a MultiAgentEnv. I am trying to run it as a AI bot which plays against human player. It might look something like this
An example or an update to the documentation would be greatly appreciated. Thanks a lot~
The text was updated successfully, but these errors were encountered: