Skip to content

Fix hardcoded action dim in pi0 pytorch model#855

Merged
kvablack merged 1 commit intoPhysical-Intelligence:mainfrom
abhaybd:fix-pytorch-pi0
Mar 19, 2026
Merged

Fix hardcoded action dim in pi0 pytorch model#855
kvablack merged 1 commit intoPhysical-Intelligence:mainfrom
abhaybd:fix-pytorch-pi0

Conversation

@abhaybd
Copy link
Contributor

@abhaybd abhaybd commented Jan 21, 2026

As seen here, the jax version of the pi0 modeling code correctly uses config.action_dim for the linear layer shapes. However, the corresponding pytorch code hardcodes this as 32.

This works fine for pi0/pi05 since they use an action dim of 32, but if users want to e.g. train a paligemma model with a different action dimension the pytorch code will error. This PR fixes this bug and uses the action dimension from the config.

@jimmyt857 jimmyt857 removed their request for review January 21, 2026 23:28
@abhaybd
Copy link
Contributor Author

abhaybd commented Jan 22, 2026

Fixes #714

@kvablack kvablack requested review from kvablack and removed request for uzhilinsky March 19, 2026 20:51
@kvablack kvablack merged commit 37ea668 into Physical-Intelligence:main Mar 19, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants