-
Notifications
You must be signed in to change notification settings - Fork 0
Fix M matrix floating point issues #203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Is there a not-terribly-expensive test or two we can add from this experience? I'm thinking things like:
|
|
I'll implement these tests together with Nix ~tomorrow |
|
Test fails on CPU -- giving up, only running the non-ablation tests on GPU for now. |
…curves broken on CPU
|
Okay |
|
While all tests pass on GPU, the float32 ablations break on GPU. This is probably due to slightly different C matrices, rather than ablation_config being with float32, but will test. |
|
These ablation curves errors are (unlike the |
|
The CPU tests do pass if I calculate and accumulate the Lambda matrixes in float64. I had changed
|
|
Now I tested:
Okay, so looks like einsum on CPU was the issue. |
|
I manually confirmed that the tests pass with different seeds (tried 3 different seeds on GPU) |
|
Also, the test which takes ~2 min on CPU devbox takes ~24min on CI runner. |
Seems pretty sad to have ci go from ~7 to ~30 mins! Is it possible to adjust the config to have it run faster? Or do we just need to --skip-ci this test until we have gpu runners (#170 ) |
|
Tests continue to fail on the CI while they work on CPU and GPU devboxes. The float32 ablation curve is (a) not flat and (b) different from the float64 one. |
|
The output was just wrong, the actual slowest tests were the new ones as expected Skipping those on CI from now on. Edit: For posterity's sake, durations of both new tests: |
|
Looks like we're not the only ones with dtype related issues on GitHub CI: https://opensourcemechanistic.slack.com/archives/C04SRRE96UV/p1700313374803079 |
danbraunai-apollo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved conditional on minor comments. Feel free to merge after addressing.
rib/interaction_algos.py
Outdated
| Lambda_abs_sqrt_trunc = Lambda_abs_sqrt_trunc | ||
| Lambda_abs_sqrt_trunc_pinv = Lambda_abs_sqrt_trunc_pinv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is going on here? Remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, remainder from debugging, deleting these lines
| import torch | ||
| import yaml | ||
|
|
||
| from experiments.lm_ablations.run_lm_ablations import Config as AblationConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised that running pytest in different directories works with this, without doing:
# Append the root directory to sys.path
ROOT_DIR = Path(__file__).parent.parent.resolve()
sys.path.append(str(ROOT_DIR))
like I did for other test files. I guess I don't need to do that because pytest recognises the root directory as the one which has the tests dir?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also no such line in test_hooks.py or test_data_accumulator.py -- leaving it out. As discussed, Nix will remove all such lines from our test files
| rib_config["dtype"] = dtype | ||
| rib_config["exp_name"] = exp_name | ||
| if not torch.cuda.is_available(): | ||
| # Try to reduce memory usage for CI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you had an OOM in the CI before? Seems like this might be overkill unless you have noticed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/test_float_precision.py
Outdated
| :, :n_max | ||
| ] | ||
| assert torch.allclose( | ||
| float32_C.to(torch.float64), float64_C, rtol=0.5, atol=0.5 * float64_C.max() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rtol=0.5, atol=0.5 * float64_C.max() is an interesting pattern. Can't you just use rtol or rtol + atol to achieve the same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair! I agree that atol=0.5 * float64_C.max() is a terrible test (because it's a very weak assertion). But our float32 and float64 matrices are so different that even rtol=0.5 doesn't pass.
I removed the rtol=0.5 part, but the atol=0.5 * float64_C.max() has to stay for now
This reverts commit e07827e.
Fixes two of our floating point issues, most of the time. 1. Turns out that M and M dash need to be kept at float64 at all times (until the eigendecompose). Rounding them by even momentarily converting either to float32 breaks ablation curves. 2. Turns out that the einsum for Lambda_dash needs to be run in float64. Implemented tests that a) Compare rib build outputs between float32 and 64. a1) gram_matrices match well a2) eigenvectors the first columns match alright a3) Onteraction_rotations approximately match on GPU and with batch size > 1 for the first columns b) Compare ablation results b1) Matching ablation results between f32 and f64 for mlp layers b2) Flat ablation curves for mlp layers


Description
Based on feature/force_overwrite_output, merge that before this PR (Merged)
Hopefully fixes remaining floating point issues.Fixes two of our floating point issues most of the time.
Tested
Implemented tests that
Future work:
Also manually tested by observing that ablation curves stay flat until 128 with these configs:
Build
Ablate
Result:

The results were also tested on a larger dataset of
return_set_frac: 0.1.PS: My debugging run script