Fix K-FAC covariance shapes when include_bias=True#294
Conversation
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
| from bergson.utils.utils import get_device | ||
|
|
||
|
|
||
| def shard_bounds(dim: int, rank: int, world_size: int) -> tuple[int, int]: |
There was a problem hiding this comment.
any reason we have this basically twice?
There was a problem hiding this comment.
If you mean the inner shard function in the class Claude suggested this pattern, I guess to make it more clear which variables are constant within the class and which are defined in the function call flow.
I'm currently not sure how much to "desloppify" claude code to bring it in line with my coding style, like I imagine at least some of its decisions are because that's how the Anthropic engineers prefer it and maybe they know something I don't? e.g. I'm coming around to using underscore prefixes for private methods in line with standard python recommendations, even though it's ugly, because it does seem like it makes the code a bit more clear. I do think its code is incredibly bloated and has too many layers of indirection so also happy to get rid of this one in a follow up if you have a strong take.
|
I am really happy to see the uneven shard/dim issue fixed. This looks great and we can merge it. |
Closes #277.
bergson buildwithinclude_bias=Truestores per-layer gradients of shape[O, I+1](the bias gradient is appended as an extra "activation" column), but the K-FAC covariance path computedA = aᵀaon the raw activation, givingA: [I, I].apply_hessian's per-layerview(-1, O, I)then failed on the flatN·O·(I+1)gradients:reproduce with
bergson ekfac --include_bias true --method kfac --model EleutherAI/pythia-14m.Fix
collect_hessiansnow passesGradientProcessor(include_bias=index_cfg.include_bias)to the Hessian collectors — previously thecollect_biasflags intarget_infowere alwaysFalseduring Hessian fitting regardless of the build config._init_covariance_dictsizes the activation covariance[I+1, I+1]whencollect_bias.CovarianceCollector(kfac),TraceCovarianceCollector(tkfac), andShampooCollectoraugment the activation with a ones column, matching the build-time gradient layout.LambdaCollectoraugments the activation before rotating byeigen_a(now[I+1, I+1]), so EKFAC eigenvalue corrections come out[O, I+1].shard_bounds(rank 0 absorbs the remainder) (in this PR because with the bias column,I+1is almost never divisible by the world size.)step2_fit_hessiantiming bug (approximate_hessiansran outside its_timedblock, always reporting 0.0s).