In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import datetime
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm

In [4]:
DEBUG = False
USE_DT_FILE = False
Test = True

In [5]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [6]:
train_data = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_data, batch_size=4, shuffle=False, num_workers=2)

classes = ('Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')

Files already downloaded and verified
Files already downloaded and verified


In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = 'cpu'

In [8]:
import torchvision.models as models
def create_model():
    alexnet = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)
    
    alexnet.classifier[4] = nn.Linear(4096,1024)
    alexnet.classifier[6] = nn.Linear(1024,10)
    
    alexnet.load_state_dict(torch.load('model_20240611_201336_SE_Quantized_NEW',map_location=device)) # model_20240603_151633_final_frozen_alexnet
    alexnet.eval()
    return alexnet


In [14]:

import torch
import torch.nn as nn
import math
import copy
import random
import numpy as np
import struct
import torch.nn.init as init

EXP_COUNT = {0:0,
             1:0,
             2:0,
             3:0,
             4:0,
             5:0,
             6:0,
             7:0,
            }
            
# input : tensor output: binary string
def binary(num):
  return ''.join('{:0>8b}'.format(c) for c in struct.pack('!f', num))

#input: mantissa bitstring
#output: float value
def calc_mantissa(mantissa):
  res = 0
  for k in range(len(mantissa)):
      if mantissa[k] == '1':
        res += 2**(-k-1)
  return res

#input exp: bitstring
# new_exp_len: new length of exp
def calc_exp(exp, new_exp_len):
  limit = 2**(new_exp_len) - 1
  # if exp is more than new_exp_len limit, truncate to new_exp_len limit.
  bias = 2**(len(exp)-1) - 1
  val = int(exp,2) - bias
  if val > limit:
      return limit
  if val < -limit + 1:
      return -limit + 1
      pass
  return val


def round_fp8(x, exp = 4):
  '''
  Quantizes input tensor to FP8 data format
  inputs  x:      original tensor
          exp:    number of bits used for exponent field
                  e.g. E5M2 has 5 exp bits, E4M3 has 4 exp bits
  output  x_32:   quantized tensor
  '''

  x_fp8 = copy.deepcopy(x)


  result = 1.0
  bin_str = binary(x)

  bin_mantissa = bin_str[9:32]
  res_mantissa = bin_mantissa[:7-exp]
  result += calc_mantissa(res_mantissa)

  bin_exp = bin_str[1:9]
  exp_int = calc_exp(bin_exp, exp)
  result *= 2**exp_int

  if bin_str[0] == '1':
    result *= -1

  return result


def bisection_quantization(num, bits = 7):
    if bits == 0:
        return 0.1
    val = abs(num)
    inversed_bits = []
    # Bisection tree quantization
    range_min, range_max = 0, 1
    for p in range(bits):
        p += 1
        bit_val = 2**(-p)
        if val >= bit_val:
            inversed_bits.append(1)
            val -= bit_val
        else:
            inversed_bits.append(0)

    quantized_val = 0
    for k, bit in enumerate(inversed_bits):
        if bit:
            quantized_val += 2**-(k + 1)
    return quantized_val

#the input is normalized tensor x,
def round_dt8(x, exp = 4, num_bits = 8):
  val = copy.deepcopy(x)

  sign_bit = 0 if val >= 0 else 1
  val = abs(val)
  exp_bits = 0
  while val < 0.1:
      val *= 10
      exp_bits += 1

  bs_bits = max(0, num_bits - 2 - exp_bits)
  exp_bits = min(num_bits -1, exp_bits)
  EXP_COUNT[exp_bits] += 1

  if exp_bits == 0:
      quantized_val = bisection_quantization(val, 7)
  elif exp_bits >= num_bits - 2:
      quantized_val = 0.0
  else:
      quantized_val = bisection_quantization(val, bs_bits)

  quantized_val = quantized_val if sign_bit == 0 else -quantized_val
  quantized_val *= 10**(-exp_bits)
  return quantized_val


def quantize_rowwise(x: torch.Tensor, dt = False):
  '''Takes in a (2d) tensor and returns quantized array
  DT = if you use the dynamic tree quantization. False is using fp8
  tensor.view( shape[0],-1) can reshape a 3d tensor to 2d. that should work'''
  abso = torch.abs(x)
  output_maxs  = torch.max(abso,1)[0].unsqueeze(-1)
  output = x  / output_maxs[None,:]
  if not dt:
      output.apply_(round_fp8)
  else:
      output.apply_(round_dt8)
  return torch.squeeze(output), output_maxs

def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
  '''Dequantizes the tensor given the maxes'''
  output = x * state_x
  return output

def init_weights(m, in_std = 1.0):
  if isinstance(m, nn.Linear):
      init.normal_(m.weight, mean=0.0, std=in_std)
      if m.bias is not None:
          init.zeros_(m.bias)

def measure_quantization_error(original_tensor, dequantized_tensor):
  abs_error = torch.abs(original_tensor - dequantized_tensor)
  return torch.mean(abs_error), abs_error


def quantize_stable_embedding(x, batch_size, dt = False):
  '''Qunatizes the given array
  Batch size must be a divisor of the array size
  Returns the quantized array, the maximums, and the indexes
  DT  = false, means using fp8
  Dt = true, means using the dynamic tree
  '''
  if (x.numel() % batch_size != 0):
    print("Invalid batch size. Batch size should be a divisor of " + str(x.numel()))
    return

  flatarg = torch.argsort(torch.abs(x.flatten()))
  indexing = flatarg.reshape((x.numel()//batch_size,batch_size))

  reshapedx = x.flatten()[indexing]

  output, maxes = quantize_rowwise(reshapedx,dt)

  return output.reshape(x.shape), maxes, indexing

def dequantize_stable_embedding(input, maxes, indexing):
  '''Takes the quantized matrices and dequantizes them by multiplying the normalized and maxes
  Then uses the indexes to place the dequantized values back into their original spots
  returns dequantized values in the right positions
  Takes in the quantized array, the array maximums, and the indexes'''
  outreshape = input.reshape(indexing.shape)

  dequant = dequantize_rowwise(outreshape, maxes).flatten()
  dequant[indexing.flatten()] = dequant.clone()
  return dequant.reshape(input.shape)

In [15]:
def quantize_stable_embedding(x, batch_size, dt = False):
  '''Qunatizes the given array
  Batch size must be a divisor of the array size
  Returns the quantized array, the maximums, and the indexes
  DT  = false, means using fp8
  Dt = true, means using the dynamic tree
  '''
  if (x.numel() % batch_size != 0):
    print("Invalid batch size. Batch size should be a divisor of " + str(x.numel()))
    return

  flatarg = torch.argsort(torch.abs(x.flatten()))
  indexing = flatarg.reshape((x.numel()//batch_size,batch_size))

  reshapedx = x.flatten()[indexing]

  output, maxes = quantize_rowwise(reshapedx,dt)

  return output.reshape(x.shape), maxes, indexing

def dequantize_stable_embedding(input, maxes, indexing):
  '''Takes the quantized matrices and dequantizes them by multiplying the normalized and maxes
  Then uses the indexes to place the dequantized values back into their original spots
  returns dequantized values in the right positions
  Takes in the quantized array, the array maximums, and the indexes'''
  outreshape = input.reshape(indexing.shape)

  dequant = dequantize_rowwise(outreshape, maxes).flatten()
  dequant[indexing.flatten()] = dequant.clone()
  return dequant.reshape(input.shape)

In [16]:
def quantize_dequantize_se(mat):
    testing, max, indexing = quantize_stable_embedding(mat, mat.shape[0],dt = True)
    return dequantize_stable_embedding(testing, max, indexing)

In [17]:
DEBUG = False

In [None]:
output_path = Path('./Output')

In [19]:
alexnet = create_model()
quantized_lst = []
if not DEBUG:
    count = 0
    for layer in [*alexnet.features,*alexnet.classifier]:
        count += 1
        try:
            if len(layer.weight.shape) == 4:
                weights = layer.weight.detach()
                print(f'1.Layer {count}')# weights shape pre-quantization: {weights.shape}\nWeights: {weights}')
                for filter in range(0, weights.shape[0]):
                    # print(f'Filter num {filter}')
                    for channel in range(0, weights.shape[1]):
                        # print(f'Channel num {channel}')
                        # print(layer.weight[filter,channel])
                        weights[filter,channel] = quantize_dequantize_se(weights[filter,channel])
                        # for row in range(0,weights.shape[2]):
                        #     weights[filter,channel, row] = quantize_dequantize_dt(weights[filter,channel,row])
                        # print(f'Finish window')
                # print(f'Layer {count} weights shape post-quantization: {weights.shape}\nWeights: {weights}')
                # layer.weight = nn.parameter.Parameter(weights)
                # print(f'2.Layer {count} weights shape post-quantization: {weights.shape}\nWeights: {weights}')
                layer.weight = nn.parameter.Parameter(weights)
                quantized_lst.append(count)
                print(f'1.Layer {count} finished quantization')
            else:
                # print("##1",layer.weight)
                weights = layer.weight.detach()
                print(f'2. Layer {count}')# weights shape pre-quantization: {layer.weight.shape}\nWeights: {weights}')
                weights = quantize_dequantize_se(weights)
                # for row in tqdm(range(0,weights.shape[0])):
                #     weights[row] = quantize_dequantize_dt(weights[row])
                # print("##2",weights)
                layer.weight = nn.parameter.Parameter(weights)
                # print(f'Layer {count} weights shape post-quantization: {layer.weight.shape}\nWeights: {weights}')
                # print(layer.weight)
                # print("##3",layer.weight)
                quantized_lst.append(count)
                print(f'2.Layer {count} finished quantization')

        except (TypeError, AttributeError):
            pass
else:
    count = 0
    for layer in alexnet.classifier:
        count += 1
        try:
            if len(layer.weight.shape) == 4:
                weights = layer.weight.detach()
                print(f'Layer {count} weights shape pre-quantization: {weights.shape}')
                
                for filter in range(0, layer.weight.shape[0]):
                    for channel in range(0, layer.weight.shape[1]):
                        # print(layer.weight[filter,channel])
                        weights[filter,channel] = quantize_dequantize_se(weights[filter,channel])
                print(f'Layer {count} weights shape post-quantization: {weights.shape}')
                layer.weight = nn.parameter.Parameter(weights)
            else:
                print(f'In else loop')
                print(f'Layer {count} weights shape pre-quantization: {layer.weight.shape}')
                intermediate = quantize_dequantize_se(layer.weight.detach())
                print(intermediate.shape)
                layer.weight = nn.parameter.Parameter(intermediate)
                print(f'Layer {count} weights shape post-quantization: {layer.weight.shape}')
                # print(layer.weight)
        except (TypeError, AttributeError):
            pass

# for layer in [*alexnet.features,*alexnet.classifier]:
#     try:
#         if len(layer.weight.shape) == 4:
#         for filter in range(0, layer.weight.shape[0]):
#             for channel in 
#         layer.weight = nn.parameter.Parameter(quantize_dequantize_dt(layer.weight.detach()))
#     except (TypeError, AttributeError):
#         print(layer)

# for layer in [*alexnet.features,*alexnet.classifier]:
#     try:
#         if len(layer.weight.shape) == 4:
#         for filter in range(0, layer.weight.shape[0]):
#             for channel in 
#         layer.weight = nn.parameter.Parameter(quantize_dequantize_dt(layer.weight.detach()))
#     except (TypeError, AttributeError):
#         print(layer)


Using cache found in C:\Users\Elijah/.cache\torch\hub\pytorch_vision_v0.10.0


1.Layer 1
1.Layer 1 finished quantization
1.Layer 4
1.Layer 4 finished quantization
1.Layer 7
1.Layer 7 finished quantization
1.Layer 9
1.Layer 9 finished quantization
1.Layer 11
1.Layer 11 finished quantization
2. Layer 15
2.Layer 15 finished quantization
2. Layer 18
2.Layer 18 finished quantization
2. Layer 20
2.Layer 20 finished quantization


In [9]:
alexnet = create_model()
alexnet.to(device)

Using cache found in C:\Users\Elijah/.cache\torch\hub\pytorch_vision_v0.10.0


AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [11]:
# DT QUANTIZATION
if Test:
    #Testing Accuracy
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data[0].to(device), data[1].to(device)
            outputs = alexnet(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

100%|██████████| 2500/2500 [00:17<00:00, 144.62it/s]

Accuracy of the network on the 10000 test images: 85.19 %





In [24]:
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

In [25]:
model_path = 'model_{}_{}'.format(timestamp, "SE_Quantized_NEW")
torch.save(alexnet.state_dict(), model_path)