# FP8 Emulation Toolkit - INTEL
https://github.com/IntelLabs/FP8-Emulation-Toolkit

- Create simple tutorial to verify that emulator is working.
- Post-training quantization to FP8.

# Libraries

In [54]:
import torch
from torchvision import models
from torchvision.models import AlexNet_Weights

# import the emulator
from mpemu import mpt_emu

import copy

In [11]:
# Set CPU or GPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Running on {device}')

Running on cuda


## 1. Load a Pre-trained model
- Use AlexNet.  
https://learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/

In [55]:
# Get the most up-to-date weigths
alexnet_test = models.alexnet(weights=AlexNet_Weights.DEFAULT)

# Set the evaluation mode for inference
# set dropout and batch normalization layers to evaluation mode before running inference. 
# Failing to do this will yield inconsistent inference results.
alexnet_test.eval()

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [56]:
# Print the load model's state_dict
print("Pre-trained Model's state_dict:\n")
for param_tensor in alexnet_test.state_dict():
    print(param_tensor, "\t", alexnet_test.state_dict()[param_tensor].size())

Pre-trained Model's state_dict:

features.0.weight 	 torch.Size([64, 3, 11, 11])
features.0.bias 	 torch.Size([64])
features.3.weight 	 torch.Size([192, 64, 5, 5])
features.3.bias 	 torch.Size([192])
features.6.weight 	 torch.Size([384, 192, 3, 3])
features.6.bias 	 torch.Size([384])
features.8.weight 	 torch.Size([256, 384, 3, 3])
features.8.bias 	 torch.Size([256])
features.10.weight 	 torch.Size([256, 256, 3, 3])
features.10.bias 	 torch.Size([256])
classifier.1.weight 	 torch.Size([4096, 9216])
classifier.1.bias 	 torch.Size([4096])
classifier.4.weight 	 torch.Size([4096, 4096])
classifier.4.bias 	 torch.Size([4096])
classifier.6.weight 	 torch.Size([1000, 4096])
classifier.6.bias 	 torch.Size([1000])


In [57]:
# Print one weight sample
sample = "features.0.weight"
print(f'Sample weight: {sample}')
print(f'Dimension: {alexnet_test.state_dict()[sample].shape}')
print(f'Type: {alexnet_test.state_dict()[sample].dtype}')
print(alexnet_test.state_dict()[sample])

Sample weight: features.0.weight
Dimension: torch.Size([64, 3, 11, 11])
Type: torch.float32
tensor([[[[ 1.1864e-01,  9.4069e-02,  9.5435e-02,  ...,  5.5822e-02,
            2.1575e-02,  4.9963e-02],
          [ 7.4882e-02,  3.8940e-02,  5.2979e-02,  ...,  2.5709e-02,
           -1.1299e-02,  4.1590e-03],
          [ 7.5425e-02,  3.8779e-02,  5.4930e-02,  ...,  4.3596e-02,
            1.0225e-02,  1.3251e-02],
          ...,
          [ 9.3155e-02,  1.0374e-01,  6.7547e-02,  ..., -2.0277e-01,
           -1.2839e-01, -1.1220e-01],
          [ 4.3544e-02,  6.4916e-02,  3.6164e-02,  ..., -2.0248e-01,
           -1.1376e-01, -1.0719e-01],
          [ 4.7369e-02,  6.2543e-02,  2.4758e-02,  ..., -1.1844e-01,
           -9.5567e-02, -8.3890e-02]],

         [[-7.2634e-02, -5.7996e-02, -8.0661e-02,  ..., -6.0304e-04,
           -2.5309e-02,  2.5471e-02],
          [-6.9042e-02, -6.7562e-02, -7.6367e-02,  ..., -3.9616e-03,
           -3.0402e-02,  1.0477e-02],
          [-9.9517e-02, -8.5592e-02

## 2. Quantize the model

In [58]:
# We need a deep copy of the model since the function overwrite it
alexnet_to_q = copy.deepcopy(alexnet_test)

In [59]:
# layers exempt from e4m3 conversion
list_exempt_layers = ["classifier.6"]

In [60]:
model_fp8, emulator = mpt_emu.quantize_model (model=alexnet_to_q, dtype="E4M3",
                               list_exempt_layers=list_exempt_layers)

e4m3 : quantizing model weights..


## 3. Verify Results

In [61]:
# Print the quantized model's state_dict
print("Quantized Model's state_dict:\n")
for param_tensor in model_fp8.state_dict():
    print(param_tensor, "\t", model_fp8.state_dict()[param_tensor].size())

Quantized Model's state_dict:

features.0.weight 	 torch.Size([64, 3, 11, 11])
features.0.bias 	 torch.Size([64])
features.3.weight 	 torch.Size([192, 64, 5, 5])
features.3.bias 	 torch.Size([192])
features.6.weight 	 torch.Size([384, 192, 3, 3])
features.6.bias 	 torch.Size([384])
features.8.weight 	 torch.Size([256, 384, 3, 3])
features.8.bias 	 torch.Size([256])
features.10.weight 	 torch.Size([256, 256, 3, 3])
features.10.bias 	 torch.Size([256])
classifier.1.weight 	 torch.Size([4096, 9216])
classifier.1.bias 	 torch.Size([4096])
classifier.4.weight 	 torch.Size([4096, 4096])
classifier.4.bias 	 torch.Size([4096])
classifier.6.weight 	 torch.Size([1000, 4096])
classifier.6.bias 	 torch.Size([1000])


- Quantized Layers:

In [66]:
# Print one weight sample
sample = "features.8.weight"
print(f'Sample weight (Original): {sample}')
print(f'Dimension: {alexnet_test.state_dict()[sample].shape}')
print(f'Type: {alexnet_test.state_dict()[sample].dtype}')
print(alexnet_test.state_dict()[sample])

Sample weight (Original): features.8.weight
Dimension: torch.Size([256, 384, 3, 3])
Type: torch.float32
tensor([[[[-0.0020, -0.0081, -0.0114],
          [-0.0193,  0.0007,  0.0114],
          [-0.0541, -0.0012, -0.0244]],

         [[ 0.0350,  0.0133,  0.0260],
          [-0.0282, -0.0062, -0.0269],
          [ 0.0035,  0.0181,  0.0147]],

         [[-0.0572, -0.0474,  0.0019],
          [-0.0402, -0.0462, -0.0257],
          [-0.0515, -0.0490,  0.0254]],

         ...,

         [[-0.0184, -0.0234,  0.0097],
          [-0.0443, -0.0076, -0.0178],
          [-0.0518, -0.0351, -0.0455]],

         [[-0.0037, -0.0011, -0.0447],
          [-0.0524, -0.0318, -0.0524],
          [-0.0031, -0.0111, -0.0443]],

         [[-0.0199, -0.0015,  0.0159],
          [ 0.0051, -0.0149, -0.0237],
          [ 0.0259,  0.0332,  0.0081]]],


        [[[ 0.0210,  0.0214,  0.0528],
          [-0.0056,  0.0240,  0.0338],
          [-0.0091,  0.0343,  0.0236]],

         [[-0.0239, -0.0183, -0.0083],
       

In [67]:
# Print weight samples
sample = "features.8.weight"
print(f'Sample weight (Quantized): {sample}')
print(f'Dimension: {model_fp8.state_dict()[sample].shape}')
print(f'Type: {model_fp8.state_dict()[sample].dtype}')
print(model_fp8.state_dict()[sample])

Sample weight (Quantized): features.8.weight
Dimension: torch.Size([256, 384, 3, 3])
Type: torch.float32
tensor([[[[-1.9568e-03, -8.4293e-03, -1.0838e-02],
          [-1.9267e-02,  6.7735e-04,  1.0838e-02],
          [-5.2984e-02, -1.1289e-03, -2.4084e-02]],

         [[ 3.6125e-02,  1.3246e-02,  2.6492e-02],
          [-2.8900e-02, -6.0209e-03, -2.6492e-02],
          [ 3.3115e-03,  1.8063e-02,  1.4450e-02]],

         [[-5.7801e-02, -4.8167e-02,  1.8063e-03],
          [-3.8534e-02, -4.8167e-02, -2.6492e-02],
          [-5.2984e-02, -4.8167e-02,  2.6492e-02]],

         ...,

         [[-1.8063e-02, -2.4084e-02,  9.6335e-03],
          [-4.3351e-02, -7.8272e-03, -1.8063e-02],
          [-5.2984e-02, -3.6125e-02, -4.3351e-02]],

         [[-3.6125e-03, -1.1289e-03, -4.3351e-02],
          [-5.2984e-02, -3.1309e-02, -5.2984e-02],
          [-3.0105e-03, -1.0838e-02, -4.3351e-02]],

         [[-1.9267e-02, -1.5052e-03,  1.5654e-02],
          [ 4.8167e-03, -1.4450e-02, -2.4084e-02],
   

- Ommited Layer:

In [64]:
# Print one weight sample that was ommited
sample = "classifier.6.weight"
print(f'Sample weight (Original): {sample}')
print(f'Dimension: {alexnet_test.state_dict()[sample].shape}')
print(f'Type: {alexnet_test.state_dict()[sample].dtype}')
print(alexnet_test.state_dict()[sample])

Sample weight (Original): classifier.6.weight
Dimension: torch.Size([1000, 4096])
Type: torch.float32
tensor([[ 0.0327, -0.0062, -0.0040,  ...,  0.0160,  0.0456, -0.0158],
        [-0.0281,  0.0393, -0.0035,  ..., -0.0250,  0.0265, -0.0159],
        [-0.0019, -0.0004, -0.0081,  ..., -0.0093,  0.0203, -0.0136],
        ...,
        [-0.0249, -0.0350,  0.0131,  ..., -0.0082,  0.0454, -0.0043],
        [ 0.0252, -0.0026, -0.0109,  ..., -0.0091, -0.0615, -0.0009],
        [-0.0039,  0.0090, -0.0018,  ...,  0.0229,  0.0042,  0.0185]])


In [65]:
# Print weight samples that was ommited
sample = "classifier.6.weight"
print(f'Sample weight (Quantized Model): {sample}')
print(f'Dimension: {model_fp8.state_dict()[sample].shape}')
print(f'Type: {model_fp8.state_dict()[sample].dtype}')
print(model_fp8.state_dict()[sample])

Sample weight (Quantized Model): classifier.6.weight
Dimension: torch.Size([1000, 4096])
Type: torch.float32
tensor([[ 0.0327, -0.0062, -0.0040,  ...,  0.0160,  0.0456, -0.0158],
        [-0.0281,  0.0393, -0.0035,  ..., -0.0250,  0.0265, -0.0159],
        [-0.0019, -0.0004, -0.0081,  ..., -0.0093,  0.0203, -0.0136],
        ...,
        [-0.0249, -0.0350,  0.0131,  ..., -0.0082,  0.0454, -0.0043],
        [ 0.0252, -0.0026, -0.0109,  ..., -0.0091, -0.0615, -0.0009],
        [-0.0039,  0.0090, -0.0018,  ...,  0.0229,  0.0042,  0.0185]])
