Skip to content

Conversation

@stefan-apollo
Copy link
Contributor

@stefan-apollo stefan-apollo commented Nov 17, 2023

Description

Based on feature/force_overwrite_output, merge that before this PR (Merged)

Hopefully fixes remaining floating point issues.

Fixes two of our floating point issues most of the time.

  1. Turns out that M and M dash need to be kept at float64 at all times (until the eigendecompose). Rounding them by even momentarily converting either to float32 breaks ablation curves.
  2. Turns out that the einsum for Lambda_dash needs to be run in float64.

Tested

Implemented tests that

  • Compare rib build outputs between float32 and 64.
    • test_gram_matrices match well
    • test_eigenvectors the first columns match alright
    • Kinda approximately match on GPU and with batch size > 1
  • Compare ablation results
    • Matching ablation results between f32 and f64 for mlp layers
    • Flat ablation curves for mlp layers

Future work:

  • Debug why ln ablation curves break for float32
  • Debug why ablation curves break on CI
  • Get more sensible tests for RIB build outputs (test of Cs is very weak, no test of edges)

Also manually tested by observing that ablation curves stay flat until 128 with these configs:
Build

exp_name: debug-pythia-14m
force_overwrite_output: true
seed: 0
tlens_pretrained: pythia-14m
tlens_model_path: null
dataset:
  source: huggingface
  name: NeelNanda/pile-10k
  tokenizer_name: EleutherAI/pythia-14m
  return_set: train  # pile-10k only has train, so we take the first 90% for building and last 10% for ablations
  return_set_frac: null
  return_set_n_samples: 10
  return_set_portion: first
node_layers:
  - ln1.0
  - mlp_out.0
  - ln2.3
  - mlp_out.3
  - ln1.5
  - mlp_out.5
  - output
batch_size: 4  #  A100 can handle 24
gram_batch_size: 20  #  A100 can handle 80
truncation_threshold: 1e-6
rotate_final_node_layer: false
n_intervals: 10
dtype: float32
calculate_edges: false
eval_type: null

Ablate

exp_name: debug-pythia-14m
force_overwrite_output: true
ablation_type: rib
interaction_graph_path: /mnt/ssd-apollo/stefan/rib/experiments/lm_rib_build/out/debug-pythia-14m_rib_Cs.pt
schedule:
  schedule_type: linear
  early_stopping_threshold: 1.5
  n_points: 20
  specific_points: [128, 129, 130]
dataset:
  source: huggingface
  name: NeelNanda/pile-10k
  tokenizer_name: EleutherAI/pythia-14m
  return_set: train  # pile-10k only has train, so we take the first 90% for building and last 10% for ablations
  return_set_frac: null
  return_set_n_samples: 10
  return_set_portion: first
ablation_node_layers:
  - ln1.0
  - mlp_out.0
  - ln2.3
  - mlp_out.3
  - ln1.5
  - mlp_out.5
batch_size: 30  # A100 can handle 60
dtype: float32
eval_type: ce_loss
seed: 0

Result:
debug-pythia-14m_ce_loss_vs_ablated_vecs
The results were also tested on a larger dataset of return_set_frac: 0.1.

PS: My debugging run script

#!/bin/bash
set -e
python /mnt/ssd-apollo/stefan/rib/experiments/lm_rib_build/run_lm_rib_build.py /mnt/ssd-apollo/stefan/rib/experiments/lm_rib_build/fptest_pythia.yaml
python /mnt/ssd-apollo/stefan/rib/experiments/lm_ablations/run_lm_ablations.py /mnt/ssd-apollo/stefan/rib/experiments/lm_ablations/fptest_ablate_pythia.yaml
python /mnt/ssd-apollo/stefan/rib/experiments/lm_ablations/plot_lm_ablations.py /mnt/ssd-apollo/stefan/rib/experiments/lm_ablations/out/debug-pythia-14m_ablation_results.json -f

@nix-apollo
Copy link
Contributor

Is there a not-terribly-expensive test or two we can add from this experience? I'm thinking things like:

  1. running the experiment in float 32 (except for mdash) gives the same result as running with float 64
  2. the ablation curves are flat in the way they should be

@stefan-apollo
Copy link
Contributor Author

I'll implement these tests together with Nix ~tomorrow

@stefan-apollo
Copy link
Contributor Author

Test fails on CPU -- giving up, only running the non-ablation tests on GPU for now.

INFO     root:run_lm_rib_build.py:304 Time to load model and dataset: 11.03
INFO     root:run_lm_rib_build.py:331 Collecting gram matrices for 1 batches.
INFO     root:run_lm_rib_build.py:342 Time to collect gram matrices: 0.93
INFO     root:run_lm_rib_build.py:348 Calculating interaction rotations (Cs).
INFO     root:run_lm_rib_build.py:365 Time to calculate Cs: 0.6 minutes
INFO     root:run_lm_rib_build.py:371 Skipping edge calculation.
INFO     root:run_lm_rib_build.py:441 Saved results to /tmp/tmpm3rc6_3k/float-precision-test-pythia-14m-float32_rib_Cs.pt
INFO     root:run_lm_rib_build.py:304 Time to load model and dataset: 2.57
INFO     root:run_lm_rib_build.py:331 Collecting gram matrices for 1 batches.
INFO     root:run_lm_rib_build.py:342 Time to collect gram matrices: 1.37
INFO     root:run_lm_rib_build.py:348 Calculating interaction rotations (Cs).
INFO     root:run_lm_rib_build.py:365 Time to calculate Cs: 1.0 minutes
INFO     root:run_lm_rib_build.py:371 Skipping edge calculation.
INFO     root:run_lm_rib_build.py:441 Saved results to /tmp/tmpm3rc6_3k/float-precision-test-pythia-14m-float64_rib_Cs.pt
============================================================================================================================================================================================================ short test summary info ============================================================================================================================================================================================================
FAILED tests/test_float_precision.py::test_pythia_floating_point_errors - AssertionError: 1
================================================================================================================================================================================================= 1 failed, 71 deselected in 122.60s (0:02:02) ==================================================================================================================================================================================================

@stefan-apollo
Copy link
Contributor Author

Okay pytest --runslow -k test_pythia_floating_point_errors runs on CPU now, taking ~2-3 minutes.

@stefan-apollo
Copy link
Contributor Author

stefan-apollo commented Nov 22, 2023

While all tests pass on GPU, the float32 ablations break on GPU.

    @pytest.mark.parametrize("dtype", ["float32", "float64"])
    def test_ablation_result_flatness(self, ablation_results: dict, dtype: str) -> None:
        for node_layer in ablation_results["float32"].keys():
            if "mlp_out" in node_layer:
                # Should be identical due to residual stream size
>               ablation_result_128 = ablation_results[dtype][node_layer]["128"]
E               AssertionError: MLP non-flat ablation curve float32 mlp_out.0: 3.7895944118499756 (128) != 3.77636981010437 (642)
E               assert False
E                +  where False = <built-in method allclose of type object at 0x7f55fe59c540>(tensor(3.7896), tensor(3.7764), atol=0.001)
E                +    where <built-in method allclose of type object at 0x7f55fe59c540> = torch.allclose
E                +    and   tensor(3.7896) = <built-in method tensor of type object at 0x7f55fe59c540>(3.7895944118499756)
E                +      where <built-in method tensor of type object at 0x7f55fe59c540> = torch.tensor
E                +    and   tensor(3.7764) = <built-in method tensor of type object at 0x7f55fe59c540>(3.77636981010437)
E                +      where <built-in method tensor of type object at 0x7f55fe59c540> = torch.tensor

tests/test_float_precision.py:204: AssertionError
====================================================================================================== short test summary info =======================================================================================================
FAILED tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_ablation_result_float_precision - AssertionError: Float difference mlp_out.0 128: 3.7895944118499756 (float32) != 3.7763756462461164 (float32)
FAILED tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_ablation_result_flatness[float32] - AssertionError: MLP non-flat ablation curve float32 mlp_out.0: 3.7895944118499756 (128) != 3.77636981010437 (642)
========================================================================================= 2 failed, 3 passed, 1 skipped in 214.94s (0:03:34) =========================================================================================

This is probably due to slightly different C matrices, rather than ablation_config being with float32, but will test.
Edit: Confirmed, ablation_config["dtype"] does not matter.

@stefan-apollo
Copy link
Contributor Author

These ablation curves errors are (unlike the test_interaction_rotations ones we had, skipping now) not related to batch size and do not occur on GPU even when

            batch_size: 1  #  A100 can handle 24
            gram_batch_size: 1  #  A100 can handle 80

@stefan-apollo
Copy link
Contributor Author

stefan-apollo commented Nov 22, 2023

The CPU tests do pass if I calculate and accumulate the Lambda matrixes in float64. I had changed

  • Lambda_dash = torch.einsum line in M_dash_and_Lambda_dash_pre_forward_hook_fn
  • Every used of Lamba in calculate_interaction_rotations all the way to C and C_pinv

@stefan-apollo
Copy link
Contributor Author

Now I tested:

  • Briefly to f32 and back to f64 + convert to f32 after collect_M_dash_and_Lambda_dash and staying in f32 from then on. So this runs only einsum + accumulation in f64 -- passes!
  • If passes: test which of the two was the issue:
    • Run only einsum in fp64 -- passes!
    • Run only accumulation in fp64 (running this just for completeness) -- fails.

Okay, so looks like einsum on CPU was the issue.

@stefan-apollo
Copy link
Contributor Author

I manually confirmed that the tests pass with different seeds (tried 3 different seeds on GPU)

@stefan-apollo
Copy link
Contributor Author

Also, the test which takes ~2 min on CPU devbox takes ~24min on CI runner.

@nix-apollo
Copy link
Contributor

nix-apollo commented Nov 22, 2023

Also, the test which takes ~2 min on CPU devbox takes ~24min on CI runner.

Seems pretty sad to have ci go from ~7 to ~30 mins! Is it possible to adjust the config to have it run faster? Or do we just need to --skip-ci this test until we have gpu runners (#170 )

@stefan-apollo
Copy link
Contributor Author

Tests continue to fail on the CI while they work on CPU and GPU devboxes. The float32 ablation curve is (a) not flat and (b) different from the float64 one.

{'642': 3.776362737019857, '321': 3.7787582079569497, '128': 3.77876877784729, '0': 10.825837135314941} (float32)
!=
{'642': 3.7763756431303417, '321': 3.7763756431306406, '128': 3.776375642827917, '0': 10.825839875788878} (float64)

@stefan-apollo
Copy link
Contributor Author

Nix: Based on the output it seems it's a different test that takes long, but I agree this makes no sense to me and my test could totally take long.
Screenshot 2023-11-22 at 14 14 26

@stefan-apollo
Copy link
Contributor Author

stefan-apollo commented Nov 22, 2023

The output was just wrong, the actual slowest tests were the new ones as expected

============================= slowest 10 durations =============================
9.73s setup    tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_gram_matrices
88.86s call     tests/test_build_graph.py::test_pythia_14m_build_graph
38.82s call     tests/test_folded_bias.py::test_gpt2_folded_bias
30.09s call     tests/test_build_graph.py::test_mnist_build_graph
19.85s call     tests/test_folded_bias.py::test_pythia_folded_bias
14.28s call     tests/test_train_modular_arithmetic.py::test_main_accuracy
13.77s call     tests/test_build_graph.py::test_modular_arithmetic_build_graph
9.50s call     tests/test_ablations.py::test_run_mnist_orthog_ablations
7.51s call     tests/test_train_mnist.py::test_main_accuracy
7.15s call     tests/test_ablations.py::test_run_modular_arithmetic_rib_ablations
=========== 71 passed, 1 skipped, 5 deselected in 596.81s (0:09:56) ============

Skipping those on CI from now on.

Edit: For posterity's sake, durations of both new tests:


============================= slowest 10 durations =============================
634.99s setup    tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_gram_matrices
382.80s setup    tests/test_float_precision.py::TestPythiaFloatingPointErrors::test_ablation_result_float_precision
157.70s call     tests/test_build_graph.py::test_pythia_14m_build_graph
62.86s call     tests/test_build_graph.py::test_mnist_build_graph
54.36s call     tests/test_folded_bias.py::test_gpt2_folded_bias
24.69s call     tests/test_train_modular_arithmetic.py::test_main_accuracy
23.68s call     tests/test_build_graph.py::test_modular_arithmetic_build_graph
19.62s call     tests/test_folded_bias.py::test_pythia_folded_bias
19.04s call     tests/test_ablations.py::test_run_mnist_orthog_ablations
14.99s call     tests/test_train_mnist.py::test_main_accuracy
=========================== short test summary info ============================

@stefan-apollo
Copy link
Contributor Author

stefan-apollo commented Nov 22, 2023

Looks like we're not the only ones with dtype related issues on GitHub CI: https://opensourcemechanistic.slack.com/archives/C04SRRE96UV/p1700313374803079

Copy link
Contributor

@danbraunai-apollo danbraunai-apollo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved conditional on minor comments. Feel free to merge after addressing.

Comment on lines 277 to 278
Lambda_abs_sqrt_trunc = Lambda_abs_sqrt_trunc
Lambda_abs_sqrt_trunc_pinv = Lambda_abs_sqrt_trunc_pinv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is going on here? Remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, remainder from debugging, deleting these lines

import torch
import yaml

from experiments.lm_ablations.run_lm_ablations import Config as AblationConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that running pytest in different directories works with this, without doing:

# Append the root directory to sys.path
ROOT_DIR = Path(__file__).parent.parent.resolve()
sys.path.append(str(ROOT_DIR))

like I did for other test files. I guess I don't need to do that because pytest recognises the root directory as the one which has the tests dir?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also no such line in test_hooks.py or test_data_accumulator.py -- leaving it out. As discussed, Nix will remove all such lines from our test files

rib_config["dtype"] = dtype
rib_config["exp_name"] = exp_name
if not torch.cuda.is_available():
# Try to reduce memory usage for CI
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you had an OOM in the CI before? Seems like this might be overkill unless you have noticed it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, error 143 is probably memory; also it used ~5% of my 250GB devbox memory which is > 7GB CI memory.
Screenshot 2023-11-23 at 10 51 54

:, :n_max
]
assert torch.allclose(
float32_C.to(torch.float64), float64_C, rtol=0.5, atol=0.5 * float64_C.max()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rtol=0.5, atol=0.5 * float64_C.max() is an interesting pattern. Can't you just use rtol or rtol + atol to achieve the same thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair! I agree that atol=0.5 * float64_C.max() is a terrible test (because it's a very weak assertion). But our float32 and float64 matrices are so different that even rtol=0.5 doesn't pass.

I removed the rtol=0.5 part, but the atol=0.5 * float64_C.max() has to stay for now

@stefan-apollo stefan-apollo merged commit bb52383 into main Nov 23, 2023
CindyWuApollo pushed a commit that referenced this pull request Dec 5, 2023
Fixes two of our floating point issues, most of the time.

1. Turns out that M and M dash need to be kept at float64 at all times (until the eigendecompose). Rounding them by even momentarily converting either to float32 breaks ablation curves.
2. Turns out that the einsum for Lambda_dash needs to be run in float64.

Implemented tests that
a) Compare rib build outputs between float32 and 64.
a1) gram_matrices match well
a2) eigenvectors the first columns match alright
a3) Onteraction_rotations approximately match on GPU and with batch size > 1 for the first columns
b) Compare ablation results
b1) Matching ablation results between f32 and f64 for mlp layers
b2) Flat ablation curves for mlp layers
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants