add feature cosine similarity loss#3218
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
fcb8fe8 to
07ff5de
Compare
5e27282 to
1f3cd42
Compare
| match nested_key: | ||
| case "out_projection_activations": | ||
| if nested_key in model.decoder.layers["self_attention"]: | ||
| intermediate_value = model.decoder.layers["self_attention"][nested_key].get_value()[-1] |
There was a problem hiding this comment.
why is it only returning the last element e.g. [-1]?
what is the shape of model.decoder.layers["self_attention"][nested_key] ? wondering what is getting dropped?
There was a problem hiding this comment.
this is because sow appends values in a tuple, so its just a way to retrieve it.
There was a problem hiding this comment.
in a follow up PR could you comment what's inside the tuple and what is being retrieved here?
| """Computes Eval Loss and returns empty aux dict (required for consistency).""" | ||
| # Parent logic for task loss | ||
| # We re-implement simple CE here to ensure float32 casting | ||
| s_logits = student_output.astype(jnp.float32) |
There was a problem hiding this comment.
should you do s_logits = student_output[0].astype(jnp.float32) similar to compute_loss now that model_forward_fn returns a tuple? also we can have a TODO to add other metrics in eval.
looks like our tests are not testing this function?
There was a problem hiding this comment.
yes this will cause eval to fail. I'll fix this and add unit tests.
Description
Adds optional cosine similarity loss between attention outputs of various layers of teacher/student.
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.