In [19]:
import torch
import torch.nn as nn
import requests
from PIL import Image

import warnings
# Ignore specific UserWarnings related to max_length in transformers
warnings.filterwarnings("ignore", 
    message=".*Using the model-agnostic default `max_length`.*")

In [20]:

class DummyModel(nn.Module):
  """
  A dummy model that consists of an embedding layer
  with two blocks of a linear layer followed by a layer
  norm layer.
  """
  def __init__(self):
    super().__init__()

    torch.manual_seed(123)

    self.token_embedding = nn.Embedding(2, 2)

    # Block 1
    self.linear_1 = nn.Linear(2, 2)
    self.layernorm_1 = nn.LayerNorm(2)

    # Block 2
    self.linear_2 = nn.Linear(2, 2)
    self.layernorm_2 = nn.LayerNorm(2)

    self.head = nn.Linear(2, 2)

  def forward(self, x):
    hidden_states = self.token_embedding(x)

    # Block 1
    hidden_states = self.linear_1(hidden_states)
    hidden_states = self.layernorm_1(hidden_states)

    # Block 2
    hidden_states = self.linear_2(hidden_states)
    hidden_states = self.layernorm_2(hidden_states)

    logits = self.head(hidden_states)
    return logits


def get_generation(model, processor, image, dtype):
  inputs = processor(image, return_tensors="pt").to(dtype)
  out = model.generate(**inputs)
  return processor.decode(out[0], skip_special_tokens=True)


def load_image(img_url):
    image = Image.open(requests.get(
        img_url, stream=True).raw).convert('RGB')

    return image

In [21]:
model = DummyModel()
model

DummyModel(
  (token_embedding): Embedding(2, 2)
  (linear_1): Linear(in_features=2, out_features=2, bias=True)
  (layernorm_1): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
  (linear_2): Linear(in_features=2, out_features=2, bias=True)
  (layernorm_2): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=2, out_features=2, bias=True)
)

In [22]:
def print_param_dtype(model):
    for name, param in model.named_parameters():
        print(f"{name} is loaded in {param.dtype}")
      
print_param_dtype(model)

token_embedding.weight is loaded in torch.float32
linear_1.weight is loaded in torch.float32
linear_1.bias is loaded in torch.float32
layernorm_1.weight is loaded in torch.float32
layernorm_1.bias is loaded in torch.float32
linear_2.weight is loaded in torch.float32
linear_2.bias is loaded in torch.float32
layernorm_2.weight is loaded in torch.float32
layernorm_2.bias is loaded in torch.float32
head.weight is loaded in torch.float32
head.bias is loaded in torch.float32


In the above code a dummy neural network model is created. It consists of an embedding layer, followed by two blocks of a linear layer + normalization layer. It concludes with a linear output.

All of the layers are clearly in the float32 datatype.

In [23]:
model_fp16 = model.half() #fp16
# model_bf16 = model.bfloat16() #bf16

In [24]:
print_param_dtype(model_fp16)

token_embedding.weight is loaded in torch.float16
linear_1.weight is loaded in torch.float16
linear_1.bias is loaded in torch.float16
layernorm_1.weight is loaded in torch.float16
layernorm_1.bias is loaded in torch.float16
linear_2.weight is loaded in torch.float16
linear_2.bias is loaded in torch.float16
layernorm_2.weight is loaded in torch.float16
layernorm_2.bias is loaded in torch.float16
head.weight is loaded in torch.float16
head.bias is loaded in torch.float16


The model can be easily converted to a lower datatype using the functions available in pytorch. Here we halved the precision of the float type.

In [25]:
dummy_input = torch.LongTensor([[1, 0], [0, 1]])  # LongTensor is a 64-bit integer

In [26]:
logits_fp32 = model(dummy_input)
logits_fp32

tensor([[[-0.6870,  0.7134],
         [-0.6870,  0.7134]],

        [[-0.6870,  0.7134],
         [-0.6870,  0.7134]]], dtype=torch.float16, grad_fn=<ViewBackward0>)

In [27]:
logits_fp16 = model_fp16(dummy_input)
logits_fp16

tensor([[[-0.6870,  0.7134],
         [-0.6870,  0.7134]],

        [[-0.6870,  0.7134],
         [-0.6870,  0.7134]]], dtype=torch.float16, grad_fn=<ViewBackward0>)

In [28]:
from transformers import BlipForConditionalGeneration

model_name = "Salesforce/blip-image-captioning-base"
model = BlipForConditionalGeneration.from_pretrained(model_name)
print_param_dtype(model)

vision_model.embeddings.class_embedding is loaded in torch.float32
vision_model.embeddings.position_embedding is loaded in torch.float32
vision_model.embeddings.patch_embedding.weight is loaded in torch.float32
vision_model.embeddings.patch_embedding.bias is loaded in torch.float32
vision_model.encoder.layers.0.self_attn.qkv.weight is loaded in torch.float32
vision_model.encoder.layers.0.self_attn.qkv.bias is loaded in torch.float32
vision_model.encoder.layers.0.self_attn.projection.weight is loaded in torch.float32
vision_model.encoder.layers.0.self_attn.projection.bias is loaded in torch.float32
vision_model.encoder.layers.0.layer_norm1.weight is loaded in torch.float32
vision_model.encoder.layers.0.layer_norm1.bias is loaded in torch.float32
vision_model.encoder.layers.0.mlp.fc1.weight is loaded in torch.float32
vision_model.encoder.layers.0.mlp.fc1.bias is loaded in torch.float32
vision_model.encoder.layers.0.mlp.fc2.weight is loaded in torch.float32
vision_model.encoder.layers.0.m

In [29]:
footprint_fp32 = model.get_memory_footprint()
footprint_fp32

989660400

In [30]:
model_fp16 = BlipForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float16)
footprint_fp16 = model_fp16.get_memory_footprint()
footprint_fp16

494832248

In [33]:
print(f"Footprint of fp16 dtype is {footprint_fp32 / footprint_fp16:.10f}x smaller")

Footprint of fp16 dtype is 1.9999917224x smaller


Models can be downcast as well by simply loading them using a smaller datatype. Here we downcast the model to fp16 from fp32.

In [34]:
desired_dtype = torch.bfloat16
torch.set_default_dtype(desired_dtype)

In [35]:
dummy_model_bf16 = DummyModel()
print_param_dtype(dummy_model_bf16)

token_embedding.weight is loaded in torch.bfloat16
linear_1.weight is loaded in torch.bfloat16
linear_1.bias is loaded in torch.bfloat16
layernorm_1.weight is loaded in torch.bfloat16
layernorm_1.bias is loaded in torch.bfloat16
linear_2.weight is loaded in torch.bfloat16
linear_2.bias is loaded in torch.bfloat16
layernorm_2.weight is loaded in torch.bfloat16
layernorm_2.bias is loaded in torch.bfloat16
head.weight is loaded in torch.bfloat16
head.bias is loaded in torch.bfloat16


We can also set the default preferred datatype using the `torch.set_default_dtype` function.