Skip to content

Conversation

@johannaSommer
Copy link
Member

Description

This PR refactors the usage of a batch size arguments for various methods. The user now sets the expected inference batch size once in the SmashConfig and this is used throughout the algorithms. The batch_size hyperparameters in the algorithms are deprecated accordingly. Additionally, I deprecated the naming of max_batch_size and renamed it to batch_size, as "max" might be misleading.

Related Issue

None.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Ran tests for all affected methods and tested the deprecation locally - works as intended.

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

None.

Copy link
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty much good to me, I'm looking forward to using this unified batch_size and having access to data from outside the smash_config :) I only suggested a couple of minor changes which should take just a minute.

@johannaSommer johannaSommer force-pushed the feat/batch-size-refactor branch from 3f598b9 to 27a87e3 Compare April 20, 2025 13:00
Copy link
Member

@begumcig begumcig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything already looks super solid! Just left a small comment regarding batch_size mismatches btw the pipeline & dataloader and how it could affect evaluation.
Great job overall 🥹

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the batch_size argument here override the calib_data's batch_size?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good question, in this case the calib data is the string of all text snippets from the dataset as a whole and doesnt have an inherent batch size... and then GPTQ slices and embedds as necessary

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, just wanted to flag two things here:

1 Since this batch_size can directly impact inference performance (latency, memory, etc.), I'm a bit concerned that changes here (or in any future algorithm that plays with this setting) could unintentionally affect our evaluation metrics, especially when comparing against the base model. Would it make sense to pass a config to the evaluation agent as well, so we're running everything under the same conditions?

2 This is not a problem but more of a question but what happens if the data we pass later (e.g., from a DataLoader) is already batched differently? Does this lead to re-batching under the hood? Might be worth double-checking that the batch sizes align, or that we’re not unintentionally introducing extra batching logic from the pipeline itself.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after our discussion, I added the following warning:

from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.task import Task
from pruna.data.pruna_datamodule import PrunaDataModule
import torch
from diffusers import StableDiffusionPipeline
from pruna import SmashConfig, smash

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)

config = SmashConfig() # batch size is 1 by default
config["cacher"] = "deepcache"
smashed_pipe = smash(pipe, config)

task = Task(["gpu_memory"], datamodule=PrunaDataModule.from_string("LAION256", dataloader_args={"batch_size": 3}))
agent = EvaluationAgent(task)
agent.evaluate(smashed_pipe)

would now output:

INFO - Starting cacher deepcache...
INFO - cacher deepcache was applied successfully.
INFO - Loaded only training, splitting train 80/10/10 into train, validation and test...
INFO - Testing compatibility with functools.partial(<function image_generation_collate at 0x7f6950b86f80>, img_size=512)...
INFO - Creating metrics from names: ['gpu_memory']
INFO - Evaluating a smashed model.
INFO - Detected diffusers model. Using DiffuserHandler with fixed seed.
- The first element of the batch is passed as input.
- The generated outputs are expected to have .images attribute.
WARNING - Batch size mismatch between evaluation datamodule and smashed model's smash config. This may lead to incorrect metric computation due to compression algorithms being batch size specific. Adjust the datamodule creation to match the smashed model's batch size, e.g., datamodule = PrunaDataModule.from_string(dataset_name, dataloader_args={'batch_size': 1})
INFO - Evaluating stateful metrics.
INFO - Evaluating isolated inference metrics.

@johannaSommer johannaSommer force-pushed the feat/batch-size-refactor branch from e849161 to 7710ef3 Compare May 2, 2025 15:38
model.inference_handler.log_model_info()
if (
"batch_size" in self.task.datamodule.dataloader_args
and self.task.datamodule.dataloader_args["batch_size"] != model.smash_config.batch_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a step in the right direction! My only concern is that since the smash_config always includes a batch_size attribute by default, we might end up showing this warning every time—even when the model itself isn't changing in a way that would affect inference. Ideally, the warning should only appear if the model will actually run inference with a different batch size internally.

I'm not entirely sure how to reliably detect that. But if this issue only occurs with specific batching algorithms (is this the case???), maybe we could check whether the smashing algorithm is "batcher" instead of just relying on the batch_size in the config, for instance. What do you think?

@johannaSommer johannaSommer requested a review from begumcig May 6, 2025 09:15
"batch_size" in self.task.datamodule.dataloader_args
and self.task.datamodule.dataloader_args["batch_size"] != model.smash_config.batch_size
and not is_base
and model.smash_config.is_batch_size_locked()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💅💅💅

Copy link
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only the matter of the deprecation warning message: a one line change if you agree. Other than this it looks super good to me 💅 thanks for taking care of this!

self.max_batch_size = max_batch_size
if max_batch_size is not None:
warn(
"max_batch_size is soon to be deprecated. Please use batch_size instead.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "soon to be" and not "is"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah youre completely right 🙈 thanks!

@johannaSommer johannaSommer merged commit 68cbf37 into main May 7, 2025
6 checks passed
@johannaSommer johannaSommer deleted the feat/batch-size-refactor branch May 7, 2025 12:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants