Skip to content

Conversation

@danbraunai-apollo
Copy link
Contributor

@danbraunai-apollo danbraunai-apollo commented Nov 29, 2023

Update tlens and misc config settings

Description

  • Bump transformer-lens to ~=1.10.0, and subsequent torch and torchvision bumps resulting from the tlens update. Notably, this update sets IGNORE to -inf in Attention instead of -1e5 (which was causing us issues in the old version).
  • Reduce truncation_threshold to 1e-15 from 1e-5 in all the configs. Tests in Clean up dtype loading #215 showed that this allowed us to have a much smaller atol.
  • Used float64 instead of float32 in pythia and modular_arithmetic configs. For modadd, this gives some slightly different edge weights (in particular, float32 truncates more vectors than float64), but the resulting graphs look almost identical (see images below). I think for modadd it probably doesn't matter much, and for pythia it might. Regardless, I'd prefer if we just continue to use float64 everywhere.
  • Train a new modular arithmetic model with the above updates. Store it, along with its resulting rib graph, in sample_checkpoints and sample_graphs, respectively. I did not bother running the (slow) pythia graphs with these updates.

Related Issue

Closes #217; Closes #216; Closes #205 (@nix-apollo please confirm, I think this issue may have been caused by the IGNORE that was too large?).

Motivation and Context

Having the old INGORE=1e-5 caused issues in both transformerlens and also in RIB tests (requiring much larger atols in tests). The other changes are also likely to improve any precision issues we have.

How Has This Been Tested?

Updated tests to point to the new modular arithmetic graphs. Same tests pass.

Does this PR introduce a breaking change?

Technically no. But everybody please reinstall the package that has an updated requirements.txt, otherwise things will break.

float32 vs float64 modular arithmetic graph build

modular_arithmetic_rib_graph
modular_arithmetic_fp32_rib_graph

Copy link
Contributor

@nix-apollo nix-apollo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

tokenizer_name: EleutherAI/pythia-14m
return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations
return_set_frac: 0.01
return_set_frac: 0.9
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could set to 0.07. This is ~1M toks, and what I used for the last set of edge calcs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Ideally we'd have some sweeps for this, like you mentioned in 3a here https://apolloresearchhq.slack.com/archives/C06484S5UF9/p1701264768373199.

self.register_buffer("mask", causal_mask)

self.register_buffer("IGNORE", torch.tensor(-1e5))
self.register_buffer("IGNORE", torch.tensor(-torch.inf))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this may break if all possible attention positions are masked, as softmax of a vector of all -torch.inf is nan.

transformerlens avoided this by checking for nans and replacing with 0s here. I'm not sure if we want that though -- it seems fine/good to me that things break if an attention head isn't allowed to pay attention to any positions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support passing an attention_mask to the forward method of Attention, so we shouldn't be able to get a token attending to no other positions, unless I'm missing something. This makes me especially keen for things to break if something happens causing nans.

@nix-apollo
Copy link
Contributor

Imported updated requirements and ran tests successfully.

@nix-apollo
Copy link
Contributor

I confirm this should close #205

@nix-apollo
Copy link
Contributor

I assume you are planning to put the change to the default value of eps in a different PR?

@danbraunai-apollo
Copy link
Contributor Author

I assume you are planning to put the change to the default value of eps in a different PR?

Yep will do. Just made an issue for it #224.

@danbraunai-apollo danbraunai-apollo merged commit 30afc49 into main Nov 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Reduce truncation threshold Update transformer lens Pythia causal mask bug?

3 participants