Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FactoredMatrix bug #367

Merged

Conversation

callummcdougall
Copy link
Contributor

@callummcdougall callummcdougall commented Aug 18, 2023

Description

The function _convert_to_slice was added recently, which causes indexing with elements to work (previously this didn't work). But when this change was made, it caused indexing with sequences to work.

Type of change

The solution is to make _convert_to_slice only alter integer elements of the index.

To further explain, _convert_to_slice will take an integer like 1 and convert it into the slice slice(1, 2). This is good because now the indexed tensors A and B have the correct number of dimensions (i.e. they don't lose a dimension). But this isn't how we should treat things like t.tensor([0, 1, 2]) when they're used as an index, because they're already in the correct form for indexing (i.e. indexing with them won't delete a dimension).

I've now added tests to tests/unit/factored_matrix/test_get_item.py which would have failed on both the version of TransformerLens before the recent change to FactoredMatrix, and the version after this change, but which now all pass.

Screenshots

three_versions

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

The function `_convert_to_slice` was added recently, which causes indexing with elements to work (previously this didn't work). But when this change was made, it caused indexing with sequences to work. The solution is to make `_convert_to_slice` only alter integer elements of the index.

To further clarify, `_convert_to_slice` will take an integer like `1` and convert it into the slice `slice(1, 2)`, this way the indexed shapes have the correct number of dimensions. But this isn't how we should treat things like `t.tensor([0, 1, 2])` when they're used as an index, because they're already in the correct form for indexing (i.e. indexing with them won't delete a dimension).
I've added tests which would have broken on the previous edit to FactoredMatrices, but which now pass.
@jbloomAus jbloomAus merged commit e956cba into TransformerLensOrg:main Aug 18, 2023
4 checks passed
@jbloomAus
Copy link
Collaborator

This was causing issues in ARENA tutorials so expedited review. I'm planning on reviewing other PR's / issues this weekend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants