Skip to content

Better documentation of Global Batch Size #2628

@stefan-it

Description

@stefan-it

Documentation

Hi,

I am currently trying to figure out, how the global batch size will be calculated.

I was able to only find one hint here.

But it only documents the per_device_batch_size which says: "Sets the local batch size per accelerator chip."

So when someone has a v6-8 TPU for example, the following code returns the number of chips:

import jax
jax.local_device_count("tpu")

It will return 8. But this gets complicated when e.g. a v6e-32 is used. So the best way of getting the number of chips is to retrieve them from the documentation.

That means as a rule of thumb: "global_batch_size = per_device_batch_size * number of chips". It would be amazing if this information could be documented (at least for TPU usage).

Metadata

Metadata

Assignees

Labels

documentationImprovements or additions to documentation

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions