In [1]:
# Imports
import torch
from helper import DummyModel

In [2]:
# Loading a Dummy Model
model = DummyModel()
model

DummyModel(
  (token_embedding): Embedding(2, 2)
  (linear_1): Linear(in_features=2, out_features=2, bias=True)
  (layernorm_1): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
  (linear_2): Linear(in_features=2, out_features=2, bias=True)
  (layernorm_2): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=2, out_features=2, bias=True)
)

In [3]:
# Visualize Param Dtype
def print_param_dtype(model):
    for name, param in model.named_parameters():
        print(f"{name} is loaded in {param.dtype}")

In [4]:
# Visualize Param dtype for float32 
print_param_dtype(model)

token_embedding.weight is loaded in torch.float32
linear_1.weight is loaded in torch.float32
linear_1.bias is loaded in torch.float32
layernorm_1.weight is loaded in torch.float32
layernorm_1.bias is loaded in torch.float32
linear_2.weight is loaded in torch.float32
linear_2.bias is loaded in torch.float32
layernorm_2.weight is loaded in torch.float32
layernorm_2.bias is loaded in torch.float32
head.weight is loaded in torch.float32
head.bias is loaded in torch.float32


In [5]:
# Casting model weights to float 16
model_fp16 = DummyModel().half()

In [6]:
# Visualize Param dtype for float16 
print_param_dtype(model_fp16)

token_embedding.weight is loaded in torch.float16
linear_1.weight is loaded in torch.float16
linear_1.bias is loaded in torch.float16
layernorm_1.weight is loaded in torch.float16
layernorm_1.bias is loaded in torch.float16
linear_2.weight is loaded in torch.float16
linear_2.bias is loaded in torch.float16
layernorm_2.weight is loaded in torch.float16
layernorm_2.bias is loaded in torch.float16
head.weight is loaded in torch.float16
head.bias is loaded in torch.float16


In [7]:
# Inference using float32 model
dummy_input = torch.LongTensor([[1, 0], [0, 1]])
logits_fp32 = model(dummy_input)
logits_fp32

tensor([[[-0.6872,  0.7132],
         [-0.6872,  0.7132]],

        [[-0.6872,  0.7132],
         [-0.6872,  0.7132]]], grad_fn=<ViewBackward0>)

In [8]:
# Inference using float16 model
logits_fp16 = model_fp16(dummy_input)
logits_fp16

tensor([[[-0.6870,  0.7134],
         [-0.6870,  0.7134]],

        [[-0.6870,  0.7134],
         [-0.6870,  0.7134]]], dtype=torch.float16, grad_fn=<ViewBackward0>)

In [9]:
# Casting model weights to bfloat 16
model_bf16 = model.to(torch.bfloat16)

In [10]:
# Visualize Param dtype for bfloat16 
print_param_dtype(model_bf16)

token_embedding.weight is loaded in torch.bfloat16
linear_1.weight is loaded in torch.bfloat16
linear_1.bias is loaded in torch.bfloat16
layernorm_1.weight is loaded in torch.bfloat16
layernorm_1.bias is loaded in torch.bfloat16
linear_2.weight is loaded in torch.bfloat16
linear_2.bias is loaded in torch.bfloat16
layernorm_2.weight is loaded in torch.bfloat16
layernorm_2.bias is loaded in torch.bfloat16
head.weight is loaded in torch.bfloat16
head.bias is loaded in torch.bfloat16


In [11]:
logits_bf16 = model_bf16(dummy_input)
logits_bf16

tensor([[[-0.6875,  0.7148],
         [-0.6875,  0.7148]],

        [[-0.6875,  0.7148],
         [-0.6875,  0.7148]]], dtype=torch.bfloat16, grad_fn=<ViewBackward0>)

In [12]:
# Quantization error for float16 compared to float32
mean_diff = torch.abs(logits_fp16 - logits_fp32).mean().item()
max_diff = torch.abs(logits_fp16 - logits_fp32).max().item()

print(f"Mean diff: {mean_diff} | Max diff: {max_diff}")

Mean diff: 0.00020457804203033447 | Max diff: 0.00022590160369873047


In [13]:
# Quantization error for bfloat16 compared to float32
mean_diff = torch.abs(logits_bf16 - logits_fp32).mean().item()
max_diff = torch.abs(logits_bf16 - logits_fp32).max().item()

print(f"Mean diff: {mean_diff} | Max diff: {max_diff}")

Mean diff: 0.000997886061668396 | Max diff: 0.0016907453536987305
