-
Notifications
You must be signed in to change notification settings - Fork 75
feat: max_batch_size refactoring
#67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
gsprochette
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty much good to me, I'm looking forward to using this unified batch_size and having access to data from outside the smash_config :) I only suggested a couple of minor changes which should take just a minute.
3f598b9 to
27a87e3
Compare
begumcig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything already looks super solid! Just left a small comment regarding batch_size mismatches btw the pipeline & dataloader and how it could affect evaluation.
Great job overall 🥹
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the batch_size argument here override the calib_data's batch_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah good question, in this case the calib data is the string of all text snippets from the dataset as a whole and doesnt have an inherent batch size... and then GPTQ slices and embedds as necessary
src/pruna/algorithms/batching/ifw.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, just wanted to flag two things here:
1 Since this batch_size can directly impact inference performance (latency, memory, etc.), I'm a bit concerned that changes here (or in any future algorithm that plays with this setting) could unintentionally affect our evaluation metrics, especially when comparing against the base model. Would it make sense to pass a config to the evaluation agent as well, so we're running everything under the same conditions?
2 This is not a problem but more of a question but what happens if the data we pass later (e.g., from a DataLoader) is already batched differently? Does this lead to re-batching under the hood? Might be worth double-checking that the batch sizes align, or that we’re not unintentionally introducing extra batching logic from the pipeline itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after our discussion, I added the following warning:
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.task import Task
from pruna.data.pruna_datamodule import PrunaDataModule
import torch
from diffusers import StableDiffusionPipeline
from pruna import SmashConfig, smash
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
config = SmashConfig() # batch size is 1 by default
config["cacher"] = "deepcache"
smashed_pipe = smash(pipe, config)
task = Task(["gpu_memory"], datamodule=PrunaDataModule.from_string("LAION256", dataloader_args={"batch_size": 3}))
agent = EvaluationAgent(task)
agent.evaluate(smashed_pipe)would now output:
INFO - Starting cacher deepcache...
INFO - cacher deepcache was applied successfully.
INFO - Loaded only training, splitting train 80/10/10 into train, validation and test...
INFO - Testing compatibility with functools.partial(<function image_generation_collate at 0x7f6950b86f80>, img_size=512)...
INFO - Creating metrics from names: ['gpu_memory']
INFO - Evaluating a smashed model.
INFO - Detected diffusers model. Using DiffuserHandler with fixed seed.
- The first element of the batch is passed as input.
- The generated outputs are expected to have .images attribute.
WARNING - Batch size mismatch between evaluation datamodule and smashed model's smash config. This may lead to incorrect metric computation due to compression algorithms being batch size specific. Adjust the datamodule creation to match the smashed model's batch size, e.g., datamodule = PrunaDataModule.from_string(dataset_name, dataloader_args={'batch_size': 1})
INFO - Evaluating stateful metrics.
INFO - Evaluating isolated inference metrics.
e849161 to
7710ef3
Compare
| model.inference_handler.log_model_info() | ||
| if ( | ||
| "batch_size" in self.task.datamodule.dataloader_args | ||
| and self.task.datamodule.dataloader_args["batch_size"] != model.smash_config.batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a step in the right direction! My only concern is that since the smash_config always includes a batch_size attribute by default, we might end up showing this warning every time—even when the model itself isn't changing in a way that would affect inference. Ideally, the warning should only appear if the model will actually run inference with a different batch size internally.
I'm not entirely sure how to reliably detect that. But if this issue only occurs with specific batching algorithms (is this the case???), maybe we could check whether the smashing algorithm is "batcher" instead of just relying on the batch_size in the config, for instance. What do you think?
| "batch_size" in self.task.datamodule.dataloader_args | ||
| and self.task.datamodule.dataloader_args["batch_size"] != model.smash_config.batch_size | ||
| and not is_base | ||
| and model.smash_config.is_batch_size_locked() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💅💅💅
gsprochette
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's only the matter of the deprecation warning message: a one line change if you agree. Other than this it looks super good to me 💅 thanks for taking care of this!
src/pruna/config/smash_config.py
Outdated
| self.max_batch_size = max_batch_size | ||
| if max_batch_size is not None: | ||
| warn( | ||
| "max_batch_size is soon to be deprecated. Please use batch_size instead.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why "soon to be" and not "is"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah youre completely right 🙈 thanks!
Description
This PR refactors the usage of a batch size arguments for various methods. The user now sets the expected inference batch size once in the SmashConfig and this is used throughout the algorithms. The batch_size hyperparameters in the algorithms are deprecated accordingly. Additionally, I deprecated the naming of
max_batch_sizeand renamed it tobatch_size, as "max" might be misleading.Related Issue
None.
Type of Change
How Has This Been Tested?
Ran tests for all affected methods and tested the deprecation locally - works as intended.
Checklist
Additional Notes
None.