-
Notifications
You must be signed in to change notification settings - Fork 16
Part 2: Creating New Loss Function Classes #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…sociated tests. These will be used in future PRs to move PerFCL, FENDA, and other clients to use these losses.
…oss_function_classes
…oss_function_classes
…oss_function_classes
sanaAyrml
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, just proposed some minor changes.
… machine, Fixing a PICAI mypy error as well.
…nting improvements, adding a small test and fixing a test.
| negative_similarity = self.cos_sim(features, negative_pair) | ||
| logits = torch.cat((logits, negative_similarity.reshape(-1, 1)), dim=1) | ||
|
|
||
| # Compute the similarity of the batch of features with the collection of batches of negative pair features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
jewelltaylor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, looks great!
PR Type
Feature
Short Description
Clickup Ticket(s): Link
This PR is Part 1 of several PRs aiming to simplify and decouple our implementations of FENDA and PerFCL. In this first part, I'm simply creating loss function classes that will be used by the different clients. These functions are stand-alone in this PR to be easier to review.
There are also some small documentation changes and an variable name change for the Ditto and MR-MTL clients.
Tests Added
Stand-along tests for each of the loss functions implemented, along with one for the pre-existing drift loss function.