Conversation
83e6bd9 to
5e268ea
Compare
5e268ea to
6a6a26a
Compare
6a6a26a to
d858cf0
Compare
|
Heyo @begumcig Codacy was talking about your PR behind your back, linking it here in case you want to have a look at the complaints lol |
guennemann
left a comment
There was a problem hiding this comment.
Thanks for the nice work. Some comments inline.
| "# --- Option 1: Using a simple string (default = single mode) ---\n", | ||
| "#request = \"image_generation_quality\"\n", | ||
| "\n", | ||
| "\n", | ||
| "# --- Option 2: Using a simple string (default = single mode) ---\n", | ||
| "request = [\"cmmd\"]\n", | ||
| "\n", | ||
| "# --- Option 3: Full control using the class ---\n", | ||
| "#from pruna.evaluation.metrics import CMMD\n", | ||
| "#request = [CMMD()] # For single mode\n", | ||
| "# request = [CMMD(call_type=\"pairwise\")] # For pairwise mode" | ||
| ] |
There was a problem hiding this comment.
This code snippet might be a bit confusing since you did not explain single and pairwise mode before. Might be helpful to add some comment/link to the docs?
| "\n", | ||
| "The EvaluationAgent needs a PrunaModel to evaluate. We can evaluate the baseline model even before smashing.\n", | ||
| "\n", | ||
| "This is done by calling the `evaluate_base` method of the EvaluationAgent." |
There was a problem hiding this comment.
evaluate_base is deprecated. Please update.
| "source": [ | ||
| "### 5. Evaluate the smashed model\n", | ||
| "\n", | ||
| "Now, we can evaluate the smashed model. This is done by calling the `evaluate_smashed` method of the EvaluationAgent." |
There was a problem hiding this comment.
Same as above. evaluate_smashed is deprecated.
| if not isinstance(default, (Tensor, List)) or (isinstance(default, List) and default): | ||
| pruna_logger.error("State variable must be a tensor or any empty list (where you can append tensors)") | ||
| raise ValueError("State variable must be a tensor or any empty list (where you can append tensors)") |
There was a problem hiding this comment.
I am not sure whether the if statement is really clear. Also, it would be helpful if you give a comment to explain what the List actually should represent? In particular when the list is meant to be empty? You also don't check whether the list is empty?
There was a problem hiding this comment.
The first condition isinstance(default (Tensor, List)) checks whether the default is a Tensor or a List. The second condition isinstance(default, List) and default checks when the default is a List whether if it is empty. So I do check if my default is a Tensor or a List, and make sure that if it's a List, it's an empty list. The list is always supposed to be empty.
There was a problem hiding this comment.
Do you mind adding some more brackets to make clear whether the NOT or the OR has preference.
| @@ -63,7 +77,12 @@ def forward(self, *args, **kwargs) -> None: | |||
|
|
|||
| def reset(self) -> None: | |||
| """Reset the metric state.""" | |||
There was a problem hiding this comment.
Please add some comment explaining what the reset is supposed to do.
| ], | ||
| indirect=["model_fixture"], | ||
| ) | ||
| def test_cmmd(model_fixture: tuple[Any, SmashConfig], device: str, clip_model: str) -> None: |
There was a problem hiding this comment.
Do you also want to add a test for the pairwise mode?
| *args, | ||
| device: str | torch.device = "cuda", | ||
| clip_model_name: str = "openai/clip-vit-large-patch14-336", | ||
| call_type: str = "", |
There was a problem hiding this comment.
Why not using here the default call type?
There was a problem hiding this comment.
Because I want to allow people to initiate the pairwise mode by passing pairwise (rather than the actual call type required for the metric_data_processor which is pairwise_gt_y). So when a user gives pairwise, I still need to keep the default call_type in the metric, to update the call_type.
There was a problem hiding this comment.
Sorry, I was unclear about my comment. Why don't you write call_type: str = "HERE THE DEFAULT CALL TYPE"
There was a problem hiding this comment.
I can also add the default call_type as a string here, sure :) But I need to keep the default_call_type for the pairwise case
| ) -> None: | ||
| super().__init__(*args, **kwargs) | ||
| self.device = device | ||
| self.clip_model = CLIPVisionModelWithProjection.from_pretrained(clip_model_name).to(self.device) |
There was a problem hiding this comment.
Should we catch exceptions/errors here, e.g. if the clip_model_name does not exist?
a4e3bbc to
cc91753
Compare
cc91753 to
70be8ac
Compare
| if not isinstance(default, (Tensor, List)) or (isinstance(default, List) and default): | ||
| pruna_logger.error("State variable must be a tensor or any empty list (where you can append tensors)") | ||
| raise ValueError("State variable must be a tensor or any empty list (where you can append tensors)") |
There was a problem hiding this comment.
Do you mind adding some more brackets to make clear whether the NOT or the OR has preference.
| *args, | ||
| device: str | torch.device = "cuda", | ||
| clip_model_name: str = "openai/clip-vit-large-patch14-336", | ||
| call_type: str = "", |
There was a problem hiding this comment.
Sorry, I was unclear about my comment. Why don't you write call_type: str = "HERE THE DEFAULT CALL TYPE"
johnrachwan123
left a comment
There was a problem hiding this comment.
Very clean PR! I left a few comments :)
5619728 to
432ea87
Compare
Description
This PR introduces CMMD (CLIP-based Maximum Mean Discrepancy) as a new evaluation metric in pruna.
CMMD measures the distributional discrepancy between two sets of sample images in the CLIP embedding space. It leverages Maximum Mean Discrepancy (MMD) to quantify semantic and visual alignment, without relying on explicit pairwise similarity
Type of Change
How Has This Been Tested?
Checklist
Additional Notes
Quick start:
Default operation mode of the metric is
single:You can also evaluate in
pairwisemode: