Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C51 algorithm #266

Merged
merged 26 commits into from Jan 6, 2021
Merged

Add C51 algorithm #266

merged 26 commits into from Jan 6, 2021

Conversation

shengxiang19
Copy link
Contributor

@shengxiang19 shengxiang19 commented Dec 25, 2020

Distributional RL algorithms are very powerful in atari environments. I am going to implement a series of typical algorithms, i.e. C51, QR-DQN, IQN, FQF, based on the reinforcement learning platform Tianshou.

This is my frist PR for C51algorithm: https://arxiv.org/abs/1707.06887

  1. add C51 policy in tianshou/policy/modelfree/c51.py.
  2. add C51 net in tianshou/utils/net/discrete.py.
  3. add C51 atari example in examples/atari/atari_c51.py.
  4. add C51 statement in tianshou/policy/init.py.
  5. add C51 test in test/discrete/test_c51.py.
  6. add C51 atari results in examples/atari/results/c51/.

By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '20.50 ± 0.50', in epoch 9.

By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.

@codecov-io
Copy link

codecov-io commented Dec 25, 2020

Codecov Report

Merging #266 (d315052) into master (5d13d8a) will decrease coverage by 0.56%.
The diff coverage is 75.53%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #266      +/-   ##
==========================================
- Coverage   94.54%   93.98%   -0.57%     
==========================================
  Files          41       42       +1     
  Lines        2677     2760      +83     
==========================================
+ Hits         2531     2594      +63     
- Misses        146      166      +20     
Flag Coverage Δ
unittests 93.98% <75.53%> (-0.57%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
tianshou/policy/base.py 73.26% <22.22%> (-3.03%) ⬇️
tianshou/utils/net/discrete.py 87.50% <30.00%> (-12.50%) ⬇️
tianshou/utils/net/common.py 97.33% <80.00%> (-2.67%) ⬇️
tianshou/policy/modelfree/c51.py 89.06% <89.06%> (ø)
tianshou/policy/__init__.py 100.00% <100.00%> (ø)
tianshou/env/worker/subproc.py 91.15% <0.00%> (-0.06%) ⬇️
tianshou/data/collector.py 95.97% <0.00%> (-0.03%) ⬇️
tianshou/data/buffer.py 99.03% <0.00%> (-0.01%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5d13d8a...d315052. Read the comment docs.

@Trinkle23897
Copy link
Collaborator

Nice job! Could you please also modify the README.md and docs/index.rst (add C51 description)?

@Trinkle23897
Copy link
Collaborator

That's ok I can help you fix the PEP8.

@shengxiang19
Copy link
Contributor Author

Thank you very much. I will modify the README.md and docs/index.rst in my next PR.

@Trinkle23897
Copy link
Collaborator

Thank you very much. I will modify the README.md and docs/index.rst in my next PR.

Just in this PR is okay.

@shengxiang19
Copy link
Contributor Author

Thank you very much. I will modify the README.md and docs/index.rst in my next PR.

Just in this PR is okay.

I'm not very good at GitHub, and I could not find how to do it in this PR.

@Trinkle23897
Copy link
Collaborator

Okay, that's fine :) I'll take a look later on.

@shengxiang19
Copy link
Contributor Author

All checks have passed :) It was a tough journey.

Copy link
Collaborator

@Trinkle23897 Trinkle23897 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could you please add a test script for CartPole-v0 under test/discrete/?

tianshou/policy/modelfree/c51.py Outdated Show resolved Hide resolved
@Trinkle23897
Copy link
Collaborator

Will implement a test/discrete/test_c51.py in the next PR.

I think you can directly add this file. No need to make a separate PR.

@shengxiang19
Copy link
Contributor Author

Will implement a test/discrete/test_c51.py in the next PR.

I think you can directly add this file. No need to make a separate PR.

I hope to combine the results of C51 in a variety of atrai games with it in the future PR.

@Trinkle23897
Copy link
Collaborator

I hope to combine the results of C51 in a variety of Atari games with it in the future PR.

Cool, and I think you can add these results here so that this PR can be a complete version of C51 implementation :)

@shengxiang19
Copy link
Contributor Author

I hope to combine the results of C51 in a variety of Atari games with it in the future PR.

Cool, and I think you can add these results here so that this PR can be a complete version of C51 implementation :)

I'm not sure how to add new files under the current PR. So, add a new PR is more convenient for me. In addition, I'm not quite sure when I can finish this work.

@Trinkle23897
Copy link
Collaborator

I'm not sure how to add new files under the current PR. So, add a new PR is more convenient for me.

Just add the file in shengxiang19/C51 branch instead of here. See https://stackoverflow.com/questions/10147445/github-adding-commits-to-existing-pull-request

In addition, I'm not quite sure when I can finish this work.

I can wait for you.

@shengxiang19
Copy link
Contributor Author

I'm not sure how to add new files under the current PR. So, add a new PR is more convenient for me.

Just add the file in shengxiang19/C51 branch instead of here. See https://stackoverflow.com/questions/10147445/github-adding-commits-to-existing-pull-request

In addition, I'm not quite sure when I can finish this work.

I can wait for you.

Thank you. I can try it later.

@shengxiang19
Copy link
Contributor Author

Will implement a test/discrete/test_c51.py in the next PR.

I think you can directly add this file. No need to make a separate PR.

I have add a test_c51 for CartPole-v0 under test/discrete/. Hope you can help me check it.

@shengxiang19
Copy link
Contributor Author

I have add the results of C51 in three typical atari environments. My current plan of C51 is done.

Copy link
Collaborator

@Trinkle23897 Trinkle23897 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll optimize the n_step code this week (in this PR). Thanks for your great work!

tianshou/utils/net/common.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@Trinkle23897 Trinkle23897 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm running QbertNoFrameskip-v4 for evaluation. Now in epoch 31 it reaches 14047.
Please double-check my implementation.

tianshou/policy/modelfree/c51.py Show resolved Hide resolved
@Trinkle23897 Trinkle23897 merged commit c6f2648 into thu-ml:master Jan 6, 2021
"""
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why hasattr(obs, "obs") could be false ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three are the same as existing DQNPolicy. I guess we can make a separate PR to enhance these things :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I noticed that :)

dist, h = model(obs_, state=state, info=batch.info)
q = (dist * self.support).sum(2)
act: np.ndarray = to_numpy(q.max(dim=1)[1])
if hasattr(obs, "mask"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like much this approach, but right now I have no idea about to avoid it. Maybe adding masked_array method to Batch class to offer something similar to numpy's masked arrays. Internally it would use the same mechanism, but it would be hidden in Batch, which is way better in by opinion.

batch.weight = cross_entropy.detach() # prio-buffer
loss.backward()
self.optim.step()
self._cnt += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend explicit variable names _cnt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants