Find the maximum value for any dimension your PyTorch models can handle without running out of memory.
Batch Finder automatically detects your model's inputs (type and shape), fixes the dimensions you specify, and finds the maximum value for the remaining axis using a configurable search strategy.
- 🎯 Single unified API – One function
find_max_minibatchfor all cases - 🔍 Automatic input detection – Infers input names, types (int/float), and shapes from the model
- 📐 Flexible shape specification – Use
input_shapewith-1for the variable axis, oraxis_to_maximize+fixed_axis - 🚀 Inference or full backward – Test with or without gradients
- ⚙️ Configurable search – Customize
factor_down,factor_up,n_attempts,initial_value - 🛡️ Safe testing – Error handling, memory cleanup, returns
Noneif fails at value 1 - 📊 Progress tracking – tqdm progress bar with status in postfix
pip install batch-finderOr install from source:
git clone https://github.com/yourusername/batch-finder.git
cd batch-finder
pip install -e .Use a tuple with -1 for the axis to maximize and numbers for fixed dimensions:
from batch_finder import find_max_minibatch
model = MyModel()
# Maximize axis 0, fix (64, 256)
max_val = find_max_minibatch(
model=model,
input_shape=(-1, 64, 256),
initial_value=64,
)
# Maximize axis 2, fix (4, 8)
max_val = find_max_minibatch(model=model, input_shape=(4, 8, -1), initial_value=32)
# Multiple -1: same value for all variable axes
max_val = find_max_minibatch(model=model, input_shape=(-1, 4, -1, 16), initial_value=32)from transformers import AutoModelForCausalLM
from batch_finder import find_max_minibatch
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
max_batch = find_max_minibatch(
model=model,
axis_to_maximize="batch_size",
fixed_axis={"seq_len": 32},
initial_value=32,
)
print(f"Max batch size: {max_batch}")max_val = find_max_minibatch(
model=model,
input_shape=(-1, 128, 512),
initial_value=8,
n_attempts=30,
factor_down=3.0, # divide by 3 on failure
factor_up=2.0, # multiply by 2 on success
)Find the maximum value for the modifiable axis without OOM.
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
torch.nn.Module |
– | PyTorch or HuggingFace model |
input_shape |
Tuple[int, ...] |
None |
Shape with -1 for variable axis(s), e.g. (-1, 64, 256) |
axis_to_maximize |
str |
None |
Axis name when not using input_shape, e.g. "batch_size" |
fixed_axis |
Dict[str, int] |
{} |
Fixed values, e.g. {"seq_len": 128} |
device |
torch.device |
auto | Device to run on |
delay |
float |
3.0 |
Seconds between attempts |
initial_value |
int |
1024 |
First value to try |
n_attempts |
int |
50 |
Maximum attempts |
inference_only |
bool |
False |
If True, no gradients. If False, full forward+backward. |
factor_down |
float |
2.0 |
On failure: next = value / factor_down |
factor_up |
float |
2.0 |
On success: next = value * factor_up |
Returns: Tuple[int, ...] (when using input_shape), int (when using axis_to_maximize), or None if none passed.
Modes:
- Provide
input_shape: uses first input param with the given shape;-1= variable axis. - Provide
axis_to_maximize+fixed_axis: builds inputs from detected params and conventions.
Example output (axis_to_maximize + fixed_axis):
--- Detected inputs (type, estimated shape) ---
input_ids: integer, (32, 64)
attention_mask: integer, (32, 64)
---
batch_size fixed={'seq_len': 32}: 100%|████████████████████| 22/50 [01:26<00:00, 3.9s/it, gpus=1, i=22/50, max_ok=1919, min_fail=1920, status=✅, value=1919]
✅ Max value that passed: 1919
- Input detection – Uses
inspect.signatureonmodel.forwardto find input names. - Type inference – Integer for
*ids,*mask,labels; float for others. - Shape estimation – From model (Linear.in_features, config, etc.) and param-name conventions.
- Search – On success: try
value * factor_up. On failure: tryvalue / factor_down. Stops when value 1 fails orn_attemptsreached. - Loss – Uses
output.lossif present, else sum of all output tensors.
- Memory: Use conservative
initial_valueon limited GPU memory. - Time: Use
inference_only=Truefor faster runs. - Training: Use
inference_only=Falseto stress-test with backward pass. - Value 1: If the run fails at value 1, the function returns
None(no smaller value).
Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
Made with ❤️ for the PyTorch community