# Batch size calculator

We want to calculate the micro batch size  in terms of number of sequences of length `seq_len` that fit on the GPU.

There are multiple ways we can do this:

1. Convert from a batch size and for sequence length A to batch size for sequence length B.


## 1. Convert from a batch size and for sequence length A to batch size for sequence length B.

Useful for when we have a known batch size in term of #tokens for one sequence length, and we want to know what the batch size is for another sequence length.

Input parameters:

- `batch_size_A`: batch size for sequence length A
- `seq_len_A`: sequence length A
- `seq_len_B`: sequence length B

Output:

- `batch_size_B`: batch size for sequence length B

In [4]:
def convert_batch_size_to_new_seq_len(
    batch_size_A: int, seq_len_A: int, seq_len_B: int
) -> int:
    total_batch_size_A = batch_size_A * seq_len_A
    assert total_batch_size_A % seq_len_B == 0
    batch_size_B = total_batch_size_A // seq_len_B
    print(f"batch_size_A: {batch_size_A}")
    print(f"seq_len_A: {seq_len_A}")
    print(f"seq_len_B: {seq_len_B}")
    print(f"batch_size_B: {batch_size_B}")
    return batch_size_B


input_seq_len = 512
input_batch_size = 32768 // input_seq_len
output_seq_len = 1024

convert_batch_size_to_new_seq_len(input_batch_size, input_seq_len, output_seq_len)

batch_size_A: 64
seq_len_A: 512
seq_len_B: 1024
batch_size_B: 32


32

# 2. Calculate iterations given total batch size and dataset size

Input parameters:

- `total_batch_size`: total batch size in #tokens
- `dataset_size`: dataset size in #tokens

Output:

- `num_iterations`: number of iterations


In [11]:
def calculate_iterations(total_batch_size: int, dataset_size: int) -> int:
    if dataset_size % total_batch_size == 0:
        num_iterations = dataset_size // total_batch_size
    else:
        num_iterations = dataset_size // total_batch_size + 1
    return num_iterations


total_batch_size = 524288
print(f"total_batch_size: {total_batch_size}")
print("-" * 100)
for dataset_size in [
    ("tinystories", 925_653_391),
    ("1B", 1_000_000_000),
    ("10B", 10_000_000_000),
    ("100B", 100_000_000_000),
    ("1T", 1_000_000_000_000),
]:
    print(f"dataset_size: {dataset_size}")
    num_iterations = calculate_iterations(total_batch_size, dataset_size[1])
    print(f"num_iterations: {num_iterations}")
    print("-" * 100)


total_batch_size: 524288
----------------------------------------------------------------------------------------------------
dataset_size: ('tinystories', 925653391)
num_iterations: 1766
----------------------------------------------------------------------------------------------------
dataset_size: ('1B', 1000000000)
num_iterations: 1908
----------------------------------------------------------------------------------------------------
dataset_size: ('10B', 10000000000)
num_iterations: 19074
----------------------------------------------------------------------------------------------------
dataset_size: ('100B', 100000000000)
num_iterations: 190735
----------------------------------------------------------------------------------------------------
dataset_size: ('1T', 1000000000000)
num_iterations: 1907349
----------------------------------------------------------------------------------------------------
