In [None]:
!python -m pip install bitsandbytes

Resources:
 - https://arxiv.org/abs/2208.07339

Given what we know about a transformer-based language model, how can we estimate how much memory a given model needs?

Let's investigate the size of the model a little more.

Since Phi3-mini-4k is advertised as a 3.8B parameter model, we're expecting 3.8 billion parameters.

Let's verify that quickly:

In [3]:
import torch
from transformers import AutoModel, BitsAndBytesConfig

model_id = "microsoft/Phi-3-mini-4k-instruct"

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model = AutoModel.from_pretrained(model_id, device_map="cuda")
sum(p.numel() for p in model.parameters()) / 1e9

Downloading shards:   0%|                                                                                                                                                                   | 0/2 [00:00<?, ?it/s]

Downloading shards:  50%|█████████████████████████████████████████████████████████████████████████████▌                                                                             | 1/2 [00:09<00:09,  9.76s/it]

Downloading shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:15<00:00,  7.28s/it]

Downloading shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:15<00:00,  7.65s/it]




Loading checkpoint shards:   0%|                                                                                                                                                            | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:  50%|██████████████████████████████████████████████████████████████████████████                                                                          | 1/2 [00:10<00:10, 10.05s/it]

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:15<00:00,  7.11s/it]

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:15<00:00,  7.55s/it]




3722578944

In [5]:
3722578944 / 1e9

3.722578944

So confirmed, we have about 3.7 billion parameters to load into the GPU.
Remember, each parameter is just a number!

By default, the data type used to represent each parameter / number is a 32-bit float, which gives us

 (32 / 8) * 3.72e9 billion parameters / 1e9 bytes per GB = 14.88 GB


In [6]:
(32 / 8) * 3.72

14.88

Thus in theory, we'll need at least 14.88 GB of GPU RAM just to load the `microsoft/Phi-3-mini-4k-instruct` model into the GPU.

Let's take a look at how memory usage is affected in practice.

In [1]:
import torch
from transformers import AutoModel, BitsAndBytesConfig

model_id = "microsoft/Phi-3-mini-4k-instruct"

In [2]:
def check_model_size_on_gpu_mb(quantization_config: BitsAndBytesConfig) -> float:
    size_before = torch.cuda.memory_allocated()
    model = AutoModel.from_pretrained(
        model_id,
        quantization_config=quantization_config,
        device_map="cuda",
    )
    model_size = torch.cuda.memory_allocated() - size_before
    del model
    return model_size / 1e6

Note that I'm restarting the jupyter kernel between each invocation:

In [3]:
print("Full size:", check_model_size_on_gpu_mb(quantization_config=None))

Loading checkpoint shards:   0%|                                                                          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:  50%|█████████████████████████████████                                 | 1/2 [00:05<00:05,  5.05s/it]

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00,  3.56s/it]

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00,  3.79s/it]

Full size: 14890.594304





In [3]:
print("8-bit size:", check_model_size_on_gpu_mb(quantization_config=BitsAndBytesConfig(
    load_in_8bit=True
)))

Loading checkpoint shards:   0%|                                                                          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:  50%|█████████████████████████████████                                 | 1/2 [00:03<00:03,  3.39s/it]

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.37s/it]

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.52s/it]

8-bit size: 3859.044352





In [3]:
print("4-bit size:", check_model_size_on_gpu_mb(quantization_config=BitsAndBytesConfig(
    load_in_4bit=True
)))

Loading checkpoint shards:   0%|                                                                          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:  50%|█████████████████████████████████                                 | 1/2 [00:03<00:03,  3.47s/it]

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.38s/it]

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.54s/it]

4-bit size: 2240.436224





Note that the memory usage does not perfectly linearly decrease as we quantize, mainly because we aren't quantizing every single parameter.


All the BitsAndBytesConfig options:

        load_in_8bit (`bool`, *optional*, defaults to `False`):
            This flag is used to enable 8-bit quantization with LLM.int8().
        load_in_4bit (`bool`, *optional*, defaults to `False`):
            This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from
            `bitsandbytes`.
        llm_int8_threshold (`float`, *optional*, defaults to 6.0):
            This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
            Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
            that is above this threshold will be considered an outlier and the operation on those values will be done
            in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
            there are some exceptional systematic outliers that are very differently distributed for large models.
            These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
            magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
            but a lower threshold might be needed for more unstable models (small models, fine-tuning).
        llm_int8_skip_modules (`List[str]`, *optional*):
            An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
            Jukebox that has several heads in different places and not necessarily at the last position. For example
            for `CausalLM` models, the last `lm_head` is kept in its original `dtype`.
        llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`):
            This flag is used for advanced use cases and users that are aware of this feature. If you want to split
            your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
            this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
            operations will not be run on CPU.
        llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`):
            This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not
            have to be converted back and forth for the backward pass.
        bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`):
            This sets the computational type which might be different than the input type. For example, inputs might be
            fp32, but computation can be set to bf16 for speedups.
        bnb_4bit_quant_type (`str`,  *optional*, defaults to `"fp4"`):
            This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types
            which are specified by `fp4` or `nf4`.
        bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
            This flag is used for nested quantization where the quantization constants from the first quantization are
            quantized again.
        bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
            This sets the storage type to pack the quanitzed 4-bit prarams.