In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import datetime
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm

In [2]:
DEBUG = False
USE_DT_FILE = False
Test = True

In [3]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
train_data = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_data, batch_size=4, shuffle=False, num_workers=2)

classes = ('Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
alexnet = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True)

alexnet.classifier[4] = nn.Linear(4096,1024)
alexnet.classifier[6] = nn.Linear(1024,10)
if USE_DT_FILE:
    alexnet.load_state_dict(torch.load('model_20240531_040442_DT_Quantized_NEW',map_location=device))
else:
    alexnet.load_state_dict(torch.load('model_20240525_150139_final',map_location=device))
alexnet.eval()

Using cache found in C:\Users\Elijah/.cache\torch\hub\pytorch_vision_v0.6.0


AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [6]:
# # input : tensor output: binary string
# def binary(num):
#     return ''.join('{:0>8b}'.format(c) for c in struct.pack('!f', num))

# #input: mantissa bitstring
# #output: float value
# def calc_mantissa(mantissa):
#     res = 0
#     for k in range(len(mantissa)):
#         if mantissa[k] == '1':
#           res += 2**(-k-1)
#     return res

# #input exp: bitstring
# # new_exp_len: new length of exp
# def calc_exp(exp, new_exp_len):
#     limit = 2**(new_exp_len) - 1
#     # if exp is more than new_exp_len limit, truncate to new_exp_len limit.
#     bias = 2**(len(exp)-1) - 1
#     val = int(exp,2) - bias
#     if val > limit:
#         return limit
#     if val < -limit + 1:
#         return -limit + 1

#     return val

# def round_fp8(x, exp = 4):
#   '''
#   Quantizes input tensor to FP8 data format
#   inputs  x:      original tensor
#           exp:    number of bits used for exponent field
#                   e.g. E5M2 has 5 exp bits, E4M3 has 4 exp bits
#   output  x_32:   quantized tensor
#   '''

#   x_fp8 = x.clone().to(torch.float32)

#   for i in range(len(x_fp8)):
#     for j in range(len(x_fp8[i])):
#       result = 1.0
#       bin_str = binary(x[i][j])

#       bin_mantissa = bin_str[9:32]
#       res_mantissa = bin_mantissa[:7-exp]    
#       result += calc_mantissa(res_mantissa)

#       bin_exp = bin_str[1:9]
#       exp_int = calc_exp(bin_exp, exp)
#       result *= 2**exp_int

#       if bin_str[0] == '1':
#         result *= -1

#       x_fp8[i][j] = result    

#   return x_fp8.to(torch.float32)

# def dt_dequantize(quantized):
#     """
#     Custom dequantization function.
#     Assumes quantized values are in the custom format.
#     """
#     dequantized = torch.zeros_like(quantized, dtype=torch.int32)

#     for i, qval in enumerate(quantized.view(-1)):
#         # Extract the bits
#         sign_bit = (qval >> 7) & 1
#         exp_bits = (qval >> 3) & 0xF
#         bis_flag = (qval >> 2) & 1
#         bis_tree = qval & 0x3

#         # Compute the base value from the bisection tree bits
#         base_value = bis_tree / 4.0

#         # If bisection flag is set, adjust the base value
#         if bis_flag:
#             base_value += 0.5 / 4.0

#         # Compute the dequantized value using the exponent
#         value = base_value * (10 ** -exp_bits)


# #the input is normalized tensor x,
# def round_dt8(x, exp = 4):
#     x_dt8 = x.clone().to(torch.float32)
#     # print(f'DT quantizaiton; cloned shape: {x_dt8.shape}')
#     for i, row in enumerate(x_dt8):
#         # print(f'Iterating over rows: {x_dt8.shape}')
#         for j, val in enumerate(row):
#             # Determine sign bit
#             sign_bit = 0 if val >= 0 else 1

#             # Absolute value for further processing
#             abs_val = abs(val)
#                     # Determine the exponent bits (4 bits)
#             exp_bits = 0
#             for k in range(15):
#                 if abs_val < 10**-k:
#                     exp_bits = k - 1
#                     break

#             # Determine the bisection tree flag and binary bisection tree bits (3 bits)
#             bis_flag = 1 if abs_val % (10**-exp_bits) != 0 else 0
#             bis_tree = int((abs_val / (10**-exp_bits)) * 4) % 4

#             # Combine the bits
#             row[j] = (sign_bit << 7) | (exp_bits << 3) | (bis_flag << 2) | bis_tree
#             # row[j] = dt_dequantize(row[j])

#     return x_dt8

# def quantize_rowwise(x: torch.Tensor, dt = False):
#     abso = torch.abs(x)
#     output_maxs  = torch.max(abso,1)[0].unsqueeze(-1)
#     output = x[0]  / output_maxs[0,None] # What is this doing? Why x[0]?
#     if not dt:
#         output = round_fp8(output)
#     else:
#         output = round_dt8(output)
#     return output, output_maxs

# def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
#     output = x * state_x
#     return output

# def init_weights(m):
#     if isinstance(m, nn.Linear):
#         init.normal_(m.weight, mean=0.0, std=1.0)
#         if m.bias is not None:
#             init.zeros_(m.bias)

In [7]:

import torch
import torch.nn as nn
import math
import copy
import random
import numpy as np
import struct
import torch.nn.init as init

# input : tensor output: binary string
def binary(num):
  return ''.join('{:0>8b}'.format(c) for c in struct.pack('!f', num))

#input: mantissa bitstring
#output: float value
def calc_mantissa(mantissa):
  res = 0
  for k in range(len(mantissa)):
      if mantissa[k] == '1':
        res += 2**(-k-1)
  return res

#input exp: bitstring
# new_exp_len: new length of exp
def calc_exp(exp, new_exp_len):
  limit = 2**(new_exp_len) - 1
  # if exp is more than new_exp_len limit, truncate to new_exp_len limit.
  bias = 2**(len(exp)-1) - 1
  val = int(exp,2) - bias
  if val > limit:
      return limit
  if val < -limit + 1:
      return -limit + 1
      pass
  return val


def round_fp8(x, exp = 4):
  '''
  Quantizes input tensor to FP8 data format
  inputs  x:      original tensor
          exp:    number of bits used for exponent field
                  e.g. E5M2 has 5 exp bits, E4M3 has 4 exp bits
  output  x_32:   quantized tensor
  '''

  x_fp8 = copy.deepcopy(x)


  result = 1.0
  bin_str = binary(x)

  bin_mantissa = bin_str[9:32]
  res_mantissa = bin_mantissa[:7-exp]
  result += calc_mantissa(res_mantissa)

  bin_exp = bin_str[1:9]
  exp_int = calc_exp(bin_exp, exp)
  result *= 2**exp_int

  if bin_str[0] == '1':
    result *= -1

  return result


def bisection_quantization(num, bits = 7):
  val = abs(num)
  inversed_bits = []
  # Bisection tree quantization
  range_min, range_max = 0, 1
  for _ in range(bits):
      mid = (range_min + range_max) / 2
      if val >= mid:
          inversed_bits.append(1)
          range_min = mid
      else:
          inversed_bits.append(0)
          range_max = mid

  quantized_val = 0
  for k, bit in enumerate(inversed_bits):
      if bit:
          quantized_val += 2**-(k + 1)

  return quantized_val

#the input is normalized tensor x,
def round_dt8(x, exp = 4):
  val = copy.deepcopy(x)
  num_levels = 2 ** (7)


  sign_bit = 0 if val >= 0 else 1
  val = abs(val)
  exp_bits = 0
  while val < 0.1:
      val *= 10
      exp_bits += 1

  bs_bits = max(0, 6 - exp_bits)
  exp_bits = min(7, exp_bits)

  if exp_bits == 0:
      quantized_val = bisection_quantization(val, 7)
  elif exp_bits >= 6:
      quantized_val = val
  else:
      quantized_val = bisection_quantization(val, bs_bits)

  quantized_val = quantized_val if sign_bit == 0 else -quantized_val
  quantized_val *= 10**(-exp_bits)
  return quantized_val


def quantize_rowwise(x: torch.Tensor, dt = False):
  abso = torch.abs(x)
  output_maxs  = torch.max(abso,1)[0].unsqueeze(-1)
  output = x  / output_maxs[None,:]
  if not dt:
      output.apply_(round_fp8)
  else:
      output.apply_(round_dt8)
  return torch.squeeze(output), output_maxs

def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
  output = x * state_x
  return output

def init_weights(m, in_std = 1.0):
  if isinstance(m, nn.Linear):
      init.normal_(m.weight, mean=0.0, std=in_std)
      if m.bias is not None:
          init.zeros_(m.bias)

def measure_quantization_error(original_tensor, dequantized_tensor):
  abs_error = torch.abs(original_tensor - dequantized_tensor)
  return torch.mean(abs_error), abs_error

def quantize_stable_embedding(x, batch_size, dt = False):
  if (x.numel() % batch_size != 0):
    print("Invalid batch size. Batch size should be a divisor of " + str(x.numel()))
    return 

  flatarg = torch.argsort(x.flatten())
  indexing = flatarg.reshape((x.numel()//batch_size,batch_size))

  reshapedx = x.flatten()[indexing]
  output, maxes = quantize_rowwise(reshapedx,dt)

  return output.reshape(x.shape), maxes, indexing

def dequantize_stable_embedding(input, maxes, indexing):
  outreshape = input.reshape(indexing.shape)

  dequant = dequantize_rowwise(outreshape, maxes)
  return dequant.reshape(input.shape)

In [8]:
def quantize_dequantize_dt(mat):
    # return quantize_rowwise(mat, dt = True)[0]
    testing_dt, dt_max = quantize_rowwise(mat, dt = True)
    return dequantize_rowwise(testing_dt,dt_max)

In [37]:
if not DEBUG:
    count = 0
    for layer in [*alexnet.features,*alexnet.classifier]:
        count += 1
        try:
            if len(layer.weight.shape) == 4:
                weights = layer.weight.detach()
                print(f'Layer {count}')# weights shape pre-quantization: {weights.shape}\nWeights: {weights}')
                for filter in range(0, weights.shape[0]):
                    # print(f'Filter num {filter}')
                    for channel in range(0, weights.shape[1]):
                        # print(f'Channel num {channel}')
                        # print(layer.weight[filter,channel])
                        weights[filter,channel] = quantize_dequantize_dt(weights[filter,channel])
                        # for row in range(0,weights.shape[2]):
                        #     weights[filter,channel, row] = quantize_dequantize_dt(weights[filter,channel,row])
                        # print(f'Finish window')
                # print(f'Layer {count} weights shape post-quantization: {weights.shape}\nWeights: {weights}')
                # layer.weight = nn.parameter.Parameter(weights)
                print(f'Layer {count} weights shape post-quantization: {weights.shape}\nWeights: {weights}')
                layer.weight = nn.parameter.Parameter(weights)
            else:
                weights = layer.weight.detach()
                print(f'Layer {count}')# weights shape pre-quantization: {layer.weight.shape}\nWeights: {weights}')
                weights = quantize_dequantize_dt(weights)
                # for row in tqdm(range(0,weights.shape[0])):
                #     weights[row] = quantize_dequantize_dt(weights[row])
                layer.weight = nn.parameter.Parameter(weights)
                # print(f'Layer {count} weights shape post-quantization: {layer.weight.shape}\nWeights: {weights}')
                # print(layer.weight)
        except (TypeError, AttributeError):
            pass
else:
    count = 0
    for layer in alexnet.classifier:
        count += 1
        try:
            if len(layer.weight.shape) == 4:
                weights = layer.weight.detach()
                print(f'Layer {count} weights shape pre-quantization: {weights.shape}')
                
                for filter in range(0, layer.weight.shape[0]):
                    for channel in range(0, layer.weight.shape[1]):
                        # print(layer.weight[filter,channel])
                        weights[filter,channel] = quantize_dequantize_dt(weights[filter,channel])
                print(f'Layer {count} weights shape post-quantization: {weights.shape}')
                layer.weight = nn.parameter.Parameter(weights)
            else:
                print(f'In else loop')
                print(f'Layer {count} weights shape pre-quantization: {layer.weight.shape}')
                intermediate = quantize_dequantize_dt(layer.weight.detach())
                print(intermediate.shape)
                layer.weight = nn.parameter.Parameter(intermediate)
                print(f'Layer {count} weights shape post-quantization: {layer.weight.shape}')
                # print(layer.weight)
        except (TypeError, AttributeError):
            pass

# for layer in [*alexnet.features,*alexnet.classifier]:
#     try:
#         if len(layer.weight.shape) == 4:
#         for filter in range(0, layer.weight.shape[0]):
#             for channel in 
#         layer.weight = nn.parameter.Parameter(quantize_dequantize_dt(layer.weight.detach()))
#     except (TypeError, AttributeError):
#         print(layer)

# for layer in [*alexnet.features,*alexnet.classifier]:
#     try:
#         if len(layer.weight.shape) == 4:
#         for filter in range(0, layer.weight.shape[0]):
#             for channel in 
#         layer.weight = nn.parameter.Parameter(quantize_dequantize_dt(layer.weight.detach()))
#     except (TypeError, AttributeError):
#         print(layer)


Layer 1
Layer 1 weights shape post-quantization: torch.Size([64, 3, 11, 11])
Weights: tensor([[[[ 1.9580e-01,  1.5263e-01,  1.2334e-01,  ...,  5.0877e-02,
            1.8501e-02,  4.7793e-02],
          [ 1.3389e-01,  8.2233e-02,  6.8528e-02,  ...,  2.2140e-02,
           -1.1386e-02,  6.7473e-03],
          [ 1.0450e-01,  5.7597e-02,  4.7723e-02,  ...,  3.1267e-02,
            3.9495e-03,  9.5447e-03],
          ...,
          [ 8.5936e-02,  1.0503e-01,  5.5699e-02,  ..., -2.0211e-01,
           -1.3050e-01, -1.2095e-01],
          [ 3.3525e-02,  6.7051e-02,  2.7938e-02,  ..., -2.0115e-01,
           -1.1548e-01, -1.1920e-01],
          [ 3.8063e-02,  6.7244e-02,  1.9031e-02,  ..., -1.2434e-01,
           -1.0911e-01, -1.0911e-01]],

         [[-8.0944e-02, -7.6126e-02, -1.2238e-01,  ...,  8.0944e-03,
           -1.5418e-03,  5.9744e-02],
          [-9.2535e-02, -9.5488e-02, -1.2502e-01,  ...,  1.4766e-02,
            5.1189e-03,  5.8080e-02],
          [-1.4660e-01, -1.3327e-01, -1.6

In [53]:
if Test:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    alexnet.to(device)

cuda:0


In [47]:
# DT QUANTIZATION
if Test:
    #Testing Accuracy
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data[0].to(device), data[1].to(device)
            outputs = alexnet(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

Accuracy of the network on the 10000 test images: 82.22 %


In [54]:
# NO QUANTIZATION
if Test:
    #Testing Accuracy
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data[0].to(device), data[1].to(device)
            outputs = alexnet(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

100%|██████████| 2500/2500 [00:16<00:00, 150.41it/s]

Accuracy of the network on the 10000 test images: 82.23 %





In [38]:
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

In [39]:
model_path = 'model_{}_{}'.format(timestamp, "DT_Quantized_NEW")
torch.save(alexnet.state_dict(), model_path)