Skip to content

add feature cosine similarity loss#3218

Merged
copybara-service[bot] merged 1 commit intomainfrom
cos_loss
Feb 26, 2026
Merged

add feature cosine similarity loss#3218
copybara-service[bot] merged 1 commit intomainfrom
cos_loss

Conversation

@entrpn
Copy link
Copy Markdown
Collaborator

@entrpn entrpn commented Feb 23, 2026

Description

Adds optional cosine similarity loss between attention outputs of various layers of teacher/student.

Tests

  • Updated train_distill_test unit test.
  • Ran train_distill.py on llama3.1-8b both with cosine loss enabled/disabled.
  • Updated maxtext_utils tests.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Feb 23, 2026

Comment thread src/maxtext/models/models.py
Comment thread src/maxtext/models/models.py
Comment thread src/maxtext/layers/attentions.py Outdated
Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py Outdated
Comment thread src/maxtext/trainers/post_train/distillation/train_distill.py Outdated
@entrpn entrpn force-pushed the cos_loss branch 3 times, most recently from fcb8fe8 to 07ff5de Compare February 25, 2026 01:08
Comment thread src/maxtext/configs/types.py Outdated
Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py
Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py Outdated
@entrpn entrpn force-pushed the cos_loss branch 2 times, most recently from 5e27282 to 1f3cd42 Compare February 25, 2026 22:09
Copy link
Copy Markdown
Collaborator

@vlad-karp vlad-karp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of comments

Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py
Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py
Copy link
Copy Markdown
Collaborator

@vlad-karp vlad-karp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall

match nested_key:
case "out_projection_activations":
if nested_key in model.decoder.layers["self_attention"]:
intermediate_value = model.decoder.layers["self_attention"][nested_key].get_value()[-1]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it only returning the last element e.g. [-1]?
what is the shape of model.decoder.layers["self_attention"][nested_key] ? wondering what is getting dropped?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because sow appends values in a tuple, so its just a way to retrieve it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in a follow up PR could you comment what's inside the tuple and what is being retrieved here?

"""Computes Eval Loss and returns empty aux dict (required for consistency)."""
# Parent logic for task loss
# We re-implement simple CE here to ensure float32 casting
s_logits = student_output.astype(jnp.float32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you do s_logits = student_output[0].astype(jnp.float32) similar to compute_loss now that model_forward_fn returns a tuple? also we can have a TODO to add other metrics in eval.

looks like our tests are not testing this function?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this will cause eval to fail. I'll fix this and add unit tests.

Copy link
Copy Markdown
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@copybara-service copybara-service Bot merged commit 5a4a9c3 into main Feb 26, 2026
163 checks passed
@copybara-service copybara-service Bot deleted the cos_loss branch February 26, 2026 07:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants