In [2]:
import torch
import torch.nn as nn
import requests
from PIL import Image


In [3]:
class DummyModel(nn.Module):
  """
  A dummy model that consists of an embedding layer
  with two blocks of a linear layer followed by a layer
  norm layer.
  """
  def __init__(self):
    super().__init__()

    torch.manual_seed(123)

    self.token_embedding = nn.Embedding(2, 2)

    # Block 1
    self.linear_1 = nn.Linear(2, 2)
    self.layernorm_1 = nn.LayerNorm(2)

    # Block 2
    self.linear_2 = nn.Linear(2, 2)
    self.layernorm_2 = nn.LayerNorm(2)

    self.head = nn.Linear(2, 2)

  def forward(self, x):
    hidden_states = self.token_embedding(x)

    # Block 1
    hidden_states = self.linear_1(hidden_states)
    hidden_states = self.layernorm_1(hidden_states)

    # Block 2
    hidden_states = self.linear_2(hidden_states)
    hidden_states = self.layernorm_2(hidden_states)

    logits = self.head(hidden_states)
    return logits


def get_generation(model, processor, image, dtype):
  inputs = processor(image, return_tensors="pt").to(dtype)
  out = model.generate(**inputs)
  return processor.decode(out[0], skip_special_tokens=True)


def load_image(img_url):
    image = Image.open(requests.get(
        img_url, stream=True).raw).convert('RGB')

    return image


In [4]:
model = DummyModel()

In [5]:
model

DummyModel(
  (token_embedding): Embedding(2, 2)
  (linear_1): Linear(in_features=2, out_features=2, bias=True)
  (layernorm_1): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
  (linear_2): Linear(in_features=2, out_features=2, bias=True)
  (layernorm_2): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=2, out_features=2, bias=True)
)

In [8]:
def print_param_dtype(model):
    for name, param in model.named_parameters():
        print(f"{name} -- {param.dtype}")

In [9]:
print_param_dtype(model)

token_embedding.weight -- torch.float32
linear_1.weight -- torch.float32
linear_1.bias -- torch.float32
layernorm_1.weight -- torch.float32
layernorm_1.bias -- torch.float32
linear_2.weight -- torch.float32
linear_2.bias -- torch.float32
layernorm_2.weight -- torch.float32
layernorm_2.bias -- torch.float32
head.weight -- torch.float32
head.bias -- torch.float32


### Model Casting: float16

In [10]:
model_fp16 = DummyModel().half()

In [11]:
print_param_dtype(model_fp16)

token_embedding.weight -- torch.float16
linear_1.weight -- torch.float16
linear_1.bias -- torch.float16
layernorm_1.weight -- torch.float16
layernorm_1.bias -- torch.float16
linear_2.weight -- torch.float16
linear_2.bias -- torch.float16
layernorm_2.weight -- torch.float16
layernorm_2.bias -- torch.float16
head.weight -- torch.float16
head.bias -- torch.float16


In [14]:
dummy_input = torch.LongTensor([[1, 0], [0, 1]])

In [16]:
logits_fp32 = model(dummy_input)

logits_fp32

tensor([[[-0.6872,  0.7132],
         [-0.6872,  0.7132]],

        [[-0.6872,  0.7132],
         [-0.6872,  0.7132]]], grad_fn=<ViewBackward0>)

In [17]:
try:
    logits_fp16 = model_fp16(dummy_input)
except Exception as error:
    print("\033[91m", type(error).__name__, ": ", error, "\033[0m")

### Model Casting: bfloat16

In [18]:
from copy import deepcopy

In [19]:
model_bf16 = deepcopy(model)

In [20]:
model_bf16 = model_bf16.to(torch.bfloat16)

In [21]:
print_param_dtype(model_bf16)

token_embedding.weight -- torch.bfloat16
linear_1.weight -- torch.bfloat16
linear_1.bias -- torch.bfloat16
layernorm_1.weight -- torch.bfloat16
layernorm_1.bias -- torch.bfloat16
linear_2.weight -- torch.bfloat16
linear_2.bias -- torch.bfloat16
layernorm_2.weight -- torch.bfloat16
layernorm_2.bias -- torch.bfloat16
head.weight -- torch.bfloat16
head.bias -- torch.bfloat16


In [22]:
logits_bf16 = model_bf16(dummy_input)

In [23]:
### DIFFERENCE

mean_diff = torch.abs(logits_bf16 - logits_fp32).mean().item()

max_diff = torch.max(logits_bf16 - logits_fp32).max().item()

In [24]:
print(f"Mean diff: {mean_diff} | Max diff: {max_diff}")

Mean diff: 0.000997886061668396 | Max diff: 0.0016907453536987305


## Generative Models

In [26]:
from transformers import AutoProcessor, BlipForConditionalGeneration

processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

* Memory footprint

In [27]:
fp32_mem_footprint = model.get_memory_footprint()

In [28]:
print("Footprint of the fp32 model in bytes: ",
      fp32_mem_footprint)
print("Footprint of the fp32 model in MBs: ", 
      fp32_mem_footprint/1e+6)

Footprint of the fp32 model in bytes:  989660400
Footprint of the fp32 model in MBs:  989.6604


In [29]:
model_bf16 = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base",
    torch_dtype=torch.bfloat16
)

In [30]:
bf16_mem_footprint = model_bf16.get_memory_footprint()

### Model Performance

In [33]:
from transformers import BlipProcessor

In [35]:
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

In [36]:
from IPython.display import display

img_url = 'https://storage.googleapis.com/\sfr-vision-language-research/BLIP/demo.jpg'

image = load_image(img_url)
display(image.resize((500, 350)))

UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x7835fe05a130>

In [None]:
results_fp32 = get_generation(model, 
                              processor, 
                              image, 
                              torch.float32)

In [None]:
print("fp32 Model Results:\n", results_fp32)

In [None]:
results_bf16 = get_generation(model_bf16, 
                              processor, 
                              image, 
                              torch.bfloat16)

In [None]:
print("bf16 Model Results:\n", results_bf16)

## Default Data Type

In [None]:
desired_dtype = torch.bfloat16

# setting for all torch tensors
torch.set_default_dtype(desired_dtype)