Skip to content

Fix K-FAC covariance shapes when include_bias=True#294

Merged
luciaquirke merged 1 commit into
mainfrom
fix/kfac-include-bias
Jun 9, 2026
Merged

Fix K-FAC covariance shapes when include_bias=True#294
luciaquirke merged 1 commit into
mainfrom
fix/kfac-include-bias

Conversation

@luciaquirke

@luciaquirke luciaquirke commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Closes #277.

bergson build with include_bias=True stores per-layer gradients of shape [O, I+1] (the bias gradient is appended as an extra "activation" column), but the K-FAC covariance path computed A = aᵀa on the raw activation, giving A: [I, I]. apply_hessian's per-layer view(-1, O, I) then failed on the flat N·O·(I+1) gradients:

RuntimeError: shape '[-1, 128, 128]' is invalid for input of size 16512

reproduce with bergson ekfac --include_bias true --method kfac --model EleutherAI/pythia-14m.

Fix

  • collect_hessians now passes GradientProcessor(include_bias=index_cfg.include_bias) to the Hessian collectors — previously the collect_bias flags in target_info were always False during Hessian fitting regardless of the build config.
  • _init_covariance_dict sizes the activation covariance [I+1, I+1] when collect_bias.
  • CovarianceCollector (kfac), TraceCovarianceCollector (tkfac), and ShampooCollector augment the activation with a ones column, matching the build-time gradient layout.
  • LambdaCollector augments the activation before rotating by eigen_a (now [I+1, I+1]), so EKFAC eigenvalue corrections come out [O, I+1].
  • Sharded computation now supports unevenly divisible dimensions via shard_bounds (rank 0 absorbs the remainder) (in this PR because with the bias column, I+1 is almost never divisible by the world size.)
  • Fixed step2_fit_hessian timing bug (approximate_hessians ran outside its _timed block, always reporting 0.0s).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@luciaquirke luciaquirke requested a review from LouisYRYJ June 5, 2026 05:13
from bergson.utils.utils import get_device


def shard_bounds(dim: int, rank: int, world_size: int) -> tuple[int, int]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason we have this basically twice?

@luciaquirke luciaquirke Jun 9, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean the inner shard function in the class Claude suggested this pattern, I guess to make it more clear which variables are constant within the class and which are defined in the function call flow.
I'm currently not sure how much to "desloppify" claude code to bring it in line with my coding style, like I imagine at least some of its decisions are because that's how the Anthropic engineers prefer it and maybe they know something I don't? e.g. I'm coming around to using underscore prefixes for private methods in line with standard python recommendations, even though it's ugly, because it does seem like it makes the code a bit more clear. I do think its code is incredibly bloated and has too many layers of indirection so also happy to get rid of this one in a follow up if you have a strong take.

@LouisYRYJ

Copy link
Copy Markdown
Contributor

I am really happy to see the uneven shard/dim issue fixed. This looks great and we can merge it.

@luciaquirke luciaquirke merged commit 1269029 into main Jun 9, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

K-FAC Hessian + include_bias=True produces incompatible gradient/covariance shapes

2 participants