Skip to content

feat: add cmmd#38

Merged
begumcig merged 5 commits intomainfrom
feat/add-cmmd-metric
Apr 1, 2025
Merged

feat: add cmmd#38
begumcig merged 5 commits intomainfrom
feat/add-cmmd-metric

Conversation

@begumcig
Copy link
Member

@begumcig begumcig commented Mar 26, 2025

Description

This PR introduces CMMD (CLIP-based Maximum Mean Discrepancy) as a new evaluation metric in pruna.

CMMD measures the distributional discrepancy between two sets of sample images in the CLIP embedding space. It leverages Maximum Mean Discrepancy (MMD) to quantify semantic and visual alignment, without relying on explicit pairwise similarity

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Quick start:
Default operation mode of the metric is single:

from diffusers import AutoPipelineForText2Image

from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.task import Task
from pruna.evaluation.evaluation_agent import EvaluationAgent

task = Task(request=["cmmd"], datamodule=PrunaDataModule.from_string("COCO"))
eval_agent = EvaluationAgent(task)

pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

eval_agent.evaluate(pipe)

You can also evaluate in pairwise mode:

from pruna.evaluation.metrics import CMMD

request = [CMMD(call_type="pairwise")]

@begumcig begumcig force-pushed the feat/add-cmmd-metric branch 3 times, most recently from 83e6bd9 to 5e268ea Compare March 26, 2025 15:51
@begumcig begumcig marked this pull request as ready for review March 26, 2025 15:57
@begumcig begumcig force-pushed the feat/add-cmmd-metric branch from 5e268ea to 6a6a26a Compare March 26, 2025 16:03
@begumcig begumcig force-pushed the feat/add-cmmd-metric branch from 6a6a26a to d858cf0 Compare March 26, 2025 16:13
@johannaSommer
Copy link
Member

Heyo @begumcig Codacy was talking about your PR behind your back, linking it here in case you want to have a look at the complaints lol
https://app.codacy.com/gh/PrunaAI/pruna/pull-requests/38/issues

Copy link

@guennemann guennemann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the nice work. Some comments inline.

Comment on lines +75 to +86
"# --- Option 1: Using a simple string (default = single mode) ---\n",
"#request = \"image_generation_quality\"\n",
"\n",
"\n",
"# --- Option 2: Using a simple string (default = single mode) ---\n",
"request = [\"cmmd\"]\n",
"\n",
"# --- Option 3: Full control using the class ---\n",
"#from pruna.evaluation.metrics import CMMD\n",
"#request = [CMMD()] # For single mode\n",
"# request = [CMMD(call_type=\"pairwise\")] # For pairwise mode"
]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code snippet might be a bit confusing since you did not explain single and pairwise mode before. Might be helpful to add some comment/link to the docs?

"\n",
"The EvaluationAgent needs a PrunaModel to evaluate. We can evaluate the baseline model even before smashing.\n",
"\n",
"This is done by calling the `evaluate_base` method of the EvaluationAgent."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

evaluate_base is deprecated. Please update.

"source": [
"### 5. Evaluate the smashed model\n",
"\n",
"Now, we can evaluate the smashed model. This is done by calling the `evaluate_smashed` method of the EvaluationAgent."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. evaluate_smashed is deprecated.

Comment on lines +54 to +56
if not isinstance(default, (Tensor, List)) or (isinstance(default, List) and default):
pruna_logger.error("State variable must be a tensor or any empty list (where you can append tensors)")
raise ValueError("State variable must be a tensor or any empty list (where you can append tensors)")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether the if statement is really clear. Also, it would be helpful if you give a comment to explain what the List actually should represent? In particular when the list is meant to be empty? You also don't check whether the list is empty?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first condition isinstance(default (Tensor, List)) checks whether the default is a Tensor or a List. The second condition isinstance(default, List) and default checks when the default is a List whether if it is empty. So I do check if my default is a Tensor or a List, and make sure that if it's a List, it's an empty list. The list is always supposed to be empty.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind adding some more brackets to make clear whether the NOT or the OR has preference.

@@ -63,7 +77,12 @@ def forward(self, *args, **kwargs) -> None:

def reset(self) -> None:
"""Reset the metric state."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some comment explaining what the reset is supposed to do.

],
indirect=["model_fixture"],
)
def test_cmmd(model_fixture: tuple[Any, SmashConfig], device: str, clip_model: str) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you also want to add a test for the pairwise mode?

*args,
device: str | torch.device = "cuda",
clip_model_name: str = "openai/clip-vit-large-patch14-336",
call_type: str = "",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using here the default call type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I want to allow people to initiate the pairwise mode by passing pairwise (rather than the actual call type required for the metric_data_processor which is pairwise_gt_y). So when a user gives pairwise, I still need to keep the default call_type in the metric, to update the call_type.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was unclear about my comment. Why don't you write call_type: str = "HERE THE DEFAULT CALL TYPE"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also add the default call_type as a string here, sure :) But I need to keep the default_call_type for the pairwise case

) -> None:
super().__init__(*args, **kwargs)
self.device = device
self.clip_model = CLIPVisionModelWithProjection.from_pretrained(clip_model_name).to(self.device)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we catch exceptions/errors here, e.g. if the clip_model_name does not exist?

@begumcig begumcig requested a review from guennemann March 31, 2025 09:45
@begumcig begumcig force-pushed the feat/add-cmmd-metric branch from a4e3bbc to cc91753 Compare March 31, 2025 09:55
@begumcig begumcig force-pushed the feat/add-cmmd-metric branch from cc91753 to 70be8ac Compare March 31, 2025 10:14
Copy link

@guennemann guennemann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Only some minor points.

Comment on lines +54 to +56
if not isinstance(default, (Tensor, List)) or (isinstance(default, List) and default):
pruna_logger.error("State variable must be a tensor or any empty list (where you can append tensors)")
raise ValueError("State variable must be a tensor or any empty list (where you can append tensors)")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind adding some more brackets to make clear whether the NOT or the OR has preference.

*args,
device: str | torch.device = "cuda",
clip_model_name: str = "openai/clip-vit-large-patch14-336",
call_type: str = "",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was unclear about my comment. Why don't you write call_type: str = "HERE THE DEFAULT CALL TYPE"

@begumcig begumcig marked this pull request as draft March 31, 2025 11:58
@begumcig begumcig marked this pull request as ready for review March 31, 2025 11:58
@begumcig begumcig marked this pull request as draft March 31, 2025 12:00
@begumcig begumcig marked this pull request as ready for review March 31, 2025 12:01
Copy link
Member

@johnrachwan123 johnrachwan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean PR! I left a few comments :)

@begumcig begumcig force-pushed the feat/add-cmmd-metric branch from 5619728 to 432ea87 Compare March 31, 2025 15:13
@begumcig begumcig requested a review from johnrachwan123 March 31, 2025 15:22
Copy link
Member

@johnrachwan123 johnrachwan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🚀

@begumcig begumcig merged commit 22fe4d4 into main Apr 1, 2025
8 checks passed
@johannaSommer johannaSommer deleted the feat/add-cmmd-metric branch May 19, 2025 13:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants