# Inference of Simple Networks using FP8 Emulator
Use two networks:
- CNN for MNIST 
- AlexNet for CIFAR10  

References:
1. https://github.com/IntelLabs/FP8-Emulation-Toolkit/blob/main/examples/inference/classifier/imagenet_test.py 
2. https://github.com/pytorch/examples/blob/main/mnist/main.py

# Libraries

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

from torchvision import models
from torchvision.models import AlexNet_Weights

from tqdm import tqdm
import copy

# import the emulator
from mpemu import mpt_emu

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Set CPU or GPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Running on {device}')

Running on cuda


# 1. CNN for MNIST

## 1.1 Methods

In [3]:
"""
Train CNN Network
"""
def train(model, device, train_loader, optimizer, epoch):
    # Set network for training
    model.train()
    
    # Training Loop
    for batch_idx, (data, target) in enumerate(train_loader):
        # Send to device
        data, target = data.to(device), target.to(device)
        # Define optimizer
        optimizer.zero_grad()
        # Forward pass
        output = model(data)
    
        # Loss Calculation
        # The negative log likelihood loss --> It is useful to train a classification problem with C classes.
        loss = F.nll_loss(output, target)
        # Optimization
        loss.backward()
        optimizer.step()

        # Print every 100 batches
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]')
            print(f'tLoss: {loss.item():.6f}') 

In [4]:
"""
Test CNN Network
"""
def test(model, device, test_loader):
    # Evaluation mode
    model.eval()
    # Initialize values
    test_loss = 0
    correct = 0
    
    # Test loop
    with torch.no_grad():
        for data, target in tqdm(test_loader):
            # Data to device
            data, target = data.to(device), target.to(device)
            # Forward pass
            output = model(data)
            
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()  
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)  
            # Calculate correct predictions
            correct += pred.eq(target.view_as(pred)).sum().item()

    # Total test loss
    test_loss /= len(test_loader.dataset)

    print(f'\nTest set: Average loss: {test_loss:.4f}')
    print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)')

## 1.2 Network

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [6]:
# Create the model
cnn_s = Net().to(device)
print(cnn_s)

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


## 1.3 Datasets

In [7]:
# Define arguments for training and testing
train_kwargs = {'batch_size': 64}
test_kwargs = {'batch_size': 1000}

# Include arguments for CUDA
if torch.cuda.is_available():
    cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [8]:
# Tranforms for dataset
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

# Get dataset
dataset1 = datasets.MNIST('./MNIST_data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('./MNIST_data', train=False,
                    transform=transform)

# Dataloaders
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

## 1.4 Training

In [9]:
# Hyper-parameters
lr = 1.0
gamma = 0.7

# Define optimizer
optimizer = optim.Adadelta(cnn_s.parameters(), lr=lr)
# Scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [10]:
# Epochs
epochs = 14
for epoch in range(1, epochs + 1):
    train(model=cnn_s, device=device, train_loader=train_loader, optimizer=optimizer, epoch=epoch)
    test(model=cnn_s, device=device, test_loader=test_loader)
    scheduler.step()

# Save the model
torch.save(cnn_s.state_dict(), './models/mnist_cnn.pth.tar')

tLoss: 2.310377
tLoss: 0.313422
tLoss: 0.260981
tLoss: 0.180671
tLoss: 0.319536
tLoss: 0.073190
tLoss: 0.092991
tLoss: 0.067335
tLoss: 0.179203
tLoss: 0.303795


100%|██████████| 10/10 [00:00<00:00, 18.16it/s]


Test set: Average loss: 0.0437
Accuracy: 9849/10000 (98%)





tLoss: 0.349314
tLoss: 0.018424
tLoss: 0.106881
tLoss: 0.165347
tLoss: 0.025091
tLoss: 0.087487
tLoss: 0.035303
tLoss: 0.074044
tLoss: 0.068245
tLoss: 0.042863


100%|██████████| 10/10 [00:00<00:00, 19.21it/s]


Test set: Average loss: 0.0379
Accuracy: 9877/10000 (99%)





tLoss: 0.008922
tLoss: 0.101310
tLoss: 0.016313
tLoss: 0.056904
tLoss: 0.007783
tLoss: 0.144041
tLoss: 0.069690
tLoss: 0.029539
tLoss: 0.068990
tLoss: 0.040939


100%|██████████| 10/10 [00:00<00:00, 18.98it/s]


Test set: Average loss: 0.0348
Accuracy: 9882/10000 (99%)





tLoss: 0.140477
tLoss: 0.019292
tLoss: 0.127054
tLoss: 0.024709
tLoss: 0.019319
tLoss: 0.021368
tLoss: 0.090063
tLoss: 0.128324
tLoss: 0.056537
tLoss: 0.046776


100%|██████████| 10/10 [00:00<00:00, 19.03it/s]


Test set: Average loss: 0.0323
Accuracy: 9887/10000 (99%)





tLoss: 0.094660
tLoss: 0.122817
tLoss: 0.020085
tLoss: 0.017906
tLoss: 0.004348
tLoss: 0.032048
tLoss: 0.048920
tLoss: 0.038294
tLoss: 0.054126
tLoss: 0.016549


100%|██████████| 10/10 [00:00<00:00, 19.13it/s]


Test set: Average loss: 0.0288
Accuracy: 9910/10000 (99%)





tLoss: 0.006779
tLoss: 0.003196
tLoss: 0.004638
tLoss: 0.031004
tLoss: 0.113164
tLoss: 0.002338
tLoss: 0.029165
tLoss: 0.016110
tLoss: 0.079549
tLoss: 0.001859


100%|██████████| 10/10 [00:00<00:00, 18.79it/s]


Test set: Average loss: 0.0301
Accuracy: 9896/10000 (99%)





tLoss: 0.000274
tLoss: 0.002541
tLoss: 0.048294
tLoss: 0.035202
tLoss: 0.084909
tLoss: 0.005557
tLoss: 0.109337
tLoss: 0.041536
tLoss: 0.001656
tLoss: 0.048554


100%|██████████| 10/10 [00:00<00:00, 18.99it/s]


Test set: Average loss: 0.0298
Accuracy: 9912/10000 (99%)





tLoss: 0.030545
tLoss: 0.001909
tLoss: 0.005710
tLoss: 0.014598
tLoss: 0.094214
tLoss: 0.004013
tLoss: 0.019102
tLoss: 0.000692
tLoss: 0.019070
tLoss: 0.038610


100%|██████████| 10/10 [00:00<00:00, 19.11it/s]


Test set: Average loss: 0.0278
Accuracy: 9908/10000 (99%)





tLoss: 0.065393
tLoss: 0.077645
tLoss: 0.056082
tLoss: 0.018979
tLoss: 0.038143
tLoss: 0.092074
tLoss: 0.009750
tLoss: 0.018649
tLoss: 0.017778
tLoss: 0.008522


100%|██████████| 10/10 [00:00<00:00, 18.65it/s]


Test set: Average loss: 0.0267
Accuracy: 9914/10000 (99%)





tLoss: 0.027950
tLoss: 0.003279
tLoss: 0.010466
tLoss: 0.003281
tLoss: 0.016939
tLoss: 0.128306
tLoss: 0.048794
tLoss: 0.027070
tLoss: 0.008950
tLoss: 0.010081


100%|██████████| 10/10 [00:00<00:00, 19.09it/s]


Test set: Average loss: 0.0277
Accuracy: 9910/10000 (99%)





tLoss: 0.077925
tLoss: 0.012583
tLoss: 0.007758
tLoss: 0.011331
tLoss: 0.003577
tLoss: 0.019993
tLoss: 0.006149
tLoss: 0.004458
tLoss: 0.083719
tLoss: 0.097696


100%|██████████| 10/10 [00:00<00:00, 18.90it/s]


Test set: Average loss: 0.0266
Accuracy: 9913/10000 (99%)





tLoss: 0.007980
tLoss: 0.002048
tLoss: 0.009649
tLoss: 0.160508
tLoss: 0.032507
tLoss: 0.010444
tLoss: 0.001407
tLoss: 0.030772
tLoss: 0.011948
tLoss: 0.076268


100%|██████████| 10/10 [00:00<00:00, 17.23it/s]


Test set: Average loss: 0.0271
Accuracy: 9916/10000 (99%)





tLoss: 0.001584
tLoss: 0.001450
tLoss: 0.139662
tLoss: 0.028051
tLoss: 0.010259
tLoss: 0.008088
tLoss: 0.002229
tLoss: 0.013270
tLoss: 0.017565
tLoss: 0.018406


100%|██████████| 10/10 [00:00<00:00, 19.25it/s]


Test set: Average loss: 0.0269
Accuracy: 9917/10000 (99%)





tLoss: 0.213277
tLoss: 0.018455
tLoss: 0.161404
tLoss: 0.110782
tLoss: 0.077451
tLoss: 0.147759
tLoss: 0.047308
tLoss: 0.010143
tLoss: 0.001111
tLoss: 0.014321


100%|██████████| 10/10 [00:00<00:00, 19.24it/s]


Test set: Average loss: 0.0262
Accuracy: 9919/10000 (99%)





## FP8 Quantization
- To E4M3.

In [11]:
# Print the model's state_dict
print("Simple CNN Model's state_dict:\n")
for param_tensor in cnn_s.state_dict():
    print(param_tensor, "\t", cnn_s.state_dict()[param_tensor].size())

Simple CNN Model's state_dict:

conv1.weight 	 torch.Size([32, 1, 3, 3])
conv1.bias 	 torch.Size([32])
conv2.weight 	 torch.Size([64, 32, 3, 3])
conv2.bias 	 torch.Size([64])
fc1.weight 	 torch.Size([128, 9216])
fc1.bias 	 torch.Size([128])
fc2.weight 	 torch.Size([10, 128])
fc2.bias 	 torch.Size([10])


In [12]:
# Print one weight sample
sample = "conv2.weight"
print(f'Sample weight (Original): {sample}')
print(f'Dimension: {cnn_s.state_dict()[sample].shape}')
print(f'Type: {cnn_s.state_dict()[sample].dtype}')
print(cnn_s.state_dict()[sample])

Sample weight (Original): conv2.weight
Dimension: torch.Size([64, 32, 3, 3])
Type: torch.float32
tensor([[[[-0.0717,  0.0022, -0.0746],
          [-0.0517, -0.0383, -0.0650],
          [-0.1691, -0.1093, -0.0253]],

         [[ 0.0159, -0.0175,  0.0524],
          [-0.0760,  0.0156,  0.0738],
          [-0.1006, -0.0111, -0.0306]],

         [[-0.0450, -0.1207, -0.0332],
          [-0.0546, -0.0301, -0.0562],
          [-0.0882, -0.0450, -0.1424]],

         ...,

         [[-0.0561,  0.0307,  0.0729],
          [-0.0534, -0.0089,  0.0227],
          [ 0.0345,  0.0804,  0.0289]],

         [[-0.0624, -0.0255, -0.0708],
          [-0.0785, -0.1008,  0.0169],
          [-0.0506, -0.1112,  0.0070]],

         [[-0.0272, -0.0596, -0.1872],
          [-0.0752,  0.0216, -0.0031],
          [ 0.0261,  0.0010,  0.0249]]],


        [[[ 0.0203, -0.0349, -0.0786],
          [-0.0359, -0.0193, -0.0436],
          [-0.0825, -0.0650, -0.0501]],

         [[-0.0522, -0.0199, -0.0134],
          [-0.

In [14]:
# We need a deep copy of the model since the function overwrite it
cnns_to_e4m3 = copy.deepcopy(cnn_s)

In [15]:
# layers exempt from conversion
list_exempt_layers = []

In [17]:
# It needs the outputs even though it overwrites in model
cnns_e4m3, emulator = mpt_emu.quantize_model (model=cnns_to_e4m3, dtype="E4M3",
                               list_exempt_layers=list_exempt_layers, verbose=True)

e4m3 : quantizing model weights..
[weights: [e4m3_rne, scale: per-channel, method: max], inputs: [e4m3_rne, scale: per-tensor, method: max], output: None] conv1                                   
[weights: [e4m3_rne, scale: per-channel, method: max], inputs: [e4m3_rne, scale: per-tensor, method: max], output: None] conv2                                   
[weights: [e4m3_rne, scale: per-channel, method: max], inputs: [e4m3_rne, scale: per-tensor, method: max], output: None] fc1                                     
[weights: [e4m3_rne, scale: per-channel, method: max], inputs: [e4m3_rne, scale: per-tensor, method: max], output: None] fc2                                     


In [31]:
# Run evaluation for quantized model
test(model=cnns_e4m3, device=device, test_loader=test_loader)

100%|██████████| 10/10 [00:00<00:00, 16.54it/s]


Test set: Average loss: 0.0265
Accuracy: 9915/10000 (99%)





# 2. AlexNet for ImageNet

## 2.1 Load the Pre-trained model

In [5]:
# Get the most up-to-date weigths
alexnet_test = models.alexnet(weights=AlexNet_Weights.DEFAULT)

# Set the evaluation mode for inference
# set dropout and batch normalization layers to evaluation mode before running inference. 
# Failing to do this will yield inconsistent inference results.
alexnet_test.eval()

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

## 2.2 Quantize the model
- E4M3

In [6]:
# We need a deep copy of the model since the function overwrite it
alexnet_to_e4m3 = copy.deepcopy(alexnet_test)

In [7]:
# layers exempt from conversion
list_exempt_layers = []

In [8]:
# It needs the outputs even though it overwrites in model
model_e4m3, emulator = mpt_emu.quantize_model (model=alexnet_to_e4m3, dtype="E4M3",
                               list_exempt_layers=list_exempt_layers)

e4m3 : quantizing model weights..


## 2.3 Evaluation

### 2.3.1 Methods

In [18]:
'''
Evaluation
'''
def eval_model(model, testloader):
    # Initialize variables
    correct = 0
    total = 0
    
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():

        # Testing Loop
        for data in tqdm(testloader):
            images, labels = data
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy {correct}/{total}: {100 * correct/total:.2f} %')

### 2.3.2 Dataset

In [14]:
# Pre-processing for ImageNet

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Like INTEL example
batch_size = 256

# trainset = torchvision.datasets.CIFAR10(root='./CIFAR_data', train=True,
#                                         download=True, transform=transform)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
#                                           shuffle=True, num_workers=2)

testset = torchvision.datasets.ImageNet(root='./Imagenet_data', split='val',
                                        transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [15]:
print(testset)

Dataset ImageNet
    Number of datapoints: 50000
    Root location: ./Imagenet_data
    Split: val
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=bilinear, max_size=None, antialias=warn)
               CenterCrop(size=(224, 224))
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )


### 2.3.3 Result
- https://github.com/pytorch/examples/issues/987

- Original model:

In [19]:
eval_model(model=alexnet_test, testloader=testloader)

100%|██████████| 196/196 [03:15<00:00,  1.00it/s]

Accuracy 28261/50000: 56.52 %





- FP8 model

In [20]:
eval_model(model=model_e4m3, testloader=testloader)

100%|██████████| 196/196 [03:45<00:00,  1.15s/it]

Accuracy 28124/50000: 56.25 %



