In [132]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import datetime
from matplotlib import pyplot as plt

In [133]:
DEBUG = False

In [134]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [135]:
train_data = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_data, batch_size=4, shuffle=False, num_workers=2)

classes = ('Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')

Files already downloaded and verified
Files already downloaded and verified


In [136]:
alexnet = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)

alexnet.classifier[4] = nn.Linear(4096,1024)
alexnet.classifier[6] = nn.Linear(1024,10)

alexnet.load_state_dict(torch.load('model_20240525_150139_final'))
alexnet.eval()

Using cache found in C:\Users\Elijah/.cache\torch\hub\pytorch_vision_v0.10.0


AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [137]:
# input : tensor output: binary string
def binary(num):
    return ''.join('{:0>8b}'.format(c) for c in struct.pack('!f', num))

#input: mantissa bitstring
#output: float value
def calc_mantissa(mantissa):
    res = 0
    for k in range(len(mantissa)):
        if mantissa[k] == '1':
          res += 2**(-k-1)
    return res

#input exp: bitstring
# new_exp_len: new length of exp
def calc_exp(exp, new_exp_len):
    limit = 2**(new_exp_len) - 1
    # if exp is more than new_exp_len limit, truncate to new_exp_len limit.
    bias = 2**(len(exp)-1) - 1
    val = int(exp,2) - bias
    if val > limit:
        return limit
    if val < -limit + 1:
        return -limit + 1

    return val

def round_fp8(x, exp = 4):
  '''
  Quantizes input tensor to FP8 data format
  inputs  x:      original tensor
          exp:    number of bits used for exponent field
                  e.g. E5M2 has 5 exp bits, E4M3 has 4 exp bits
  output  x_32:   quantized tensor
  '''

  x_fp8 = x.clone().to(torch.float32)

  for i in range(len(x_fp8)):
    for j in range(len(x_fp8[i])):
      result = 1.0
      bin_str = binary(x[i][j])

      bin_mantissa = bin_str[9:32]
      res_mantissa = bin_mantissa[:7-exp]    
      result += calc_mantissa(res_mantissa)

      bin_exp = bin_str[1:9]
      exp_int = calc_exp(bin_exp, exp)
      result *= 2**exp_int

      if bin_str[0] == '1':
        result *= -1

      x_fp8[i][j] = result    

  return x_fp8.to(torch.float32)

def dt_dequantize(quantized):
    """
    Custom dequantization function.
    Assumes quantized values are in the custom format.
    """
    dequantized = torch.zeros_like(quantized, dtype=torch.int32)

    for i, qval in enumerate(quantized.view(-1)):
        # Extract the bits
        sign_bit = (qval >> 7) & 1
        exp_bits = (qval >> 3) & 0xF
        bis_flag = (qval >> 2) & 1
        bis_tree = qval & 0x3

        # Compute the base value from the bisection tree bits
        base_value = bis_tree / 4.0

        # If bisection flag is set, adjust the base value
        if bis_flag:
            base_value += 0.5 / 4.0

        # Compute the dequantized value using the exponent
        value = base_value * (10 ** -exp_bits)


#the input is normalized tensor x,
def round_dt8(x, exp = 4):
    x_dt8 = x.clone().to(torch.float32)
    # print(f'DT quantizaiton; cloned shape: {x_dt8.shape}')
    for i, row in enumerate(x_dt8):
        # print(f'Iterating over rows: {x_dt8.shape}')
        for j, val in enumerate(row):
            # Determine sign bit
            sign_bit = 0 if val >= 0 else 1

            # Absolute value for further processing
            abs_val = abs(val)
                    # Determine the exponent bits (4 bits)
            exp_bits = 0
            for k in range(15):
                if abs_val < 10**-k:
                    exp_bits = k - 1
                    break

            # Determine the bisection tree flag and binary bisection tree bits (3 bits)
            bis_flag = 1 if abs_val % (10**-exp_bits) != 0 else 0
            bis_tree = int((abs_val / (10**-exp_bits)) * 4) % 4

            # Combine the bits
            row[j] = (sign_bit << 7) | (exp_bits << 3) | (bis_flag << 2) | bis_tree
            # row[j] = dt_dequantize(row[j])

    return x_dt8

def quantize_rowwise(x: torch.Tensor, dt = False):
    abso = torch.abs(x)
    output_maxs  = torch.max(abso,1)[0].unsqueeze(-1)
    output = x[0]  / output_maxs[0,None] # What is this doing? Why x[0]?
    if not dt:
        output = round_fp8(output)
    else:
        output = round_dt8(output)
    return output, output_maxs

def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
    output = x * state_x
    return output

def init_weights(m):
    if isinstance(m, nn.Linear):
        init.normal_(m.weight, mean=0.0, std=1.0)
        if m.bias is not None:
            init.zeros_(m.bias)

In [138]:
def quantize_dequantize_dt(mat):
    # return quantize_rowwise(mat, dt = True)[0]
    testing_dt, dt_max = quantize_rowwise(mat, dt = True)
    return dequantize_rowwise(testing_dt,dt_max)

In [139]:
if not DEBUG:
    count = 0
    for layer in [*alexnet.features,*alexnet.classifier]:
        count += 1
        try:
            if len(layer.weight.shape) == 4:
                weights = layer.weight.detach()
                print(f'Layer {count} weights shape pre-quantization: {weights.shape}\nWeights: {weights}')
                for filter in range(0, layer.weight.shape[0]):
                    for channel in range(0, layer.weight.shape[1]):
                        # print(layer.weight[filter,channel])
                        weights[filter,channel] = quantize_dequantize_dt(weights[filter,channel])
                print(f'Layer {count} weights shape post-quantization: {weights.shape}\nWeights: {weights}')
                layer.weight = nn.parameter.Parameter(weights)
            else:
                print(f'Layer {count} weights shape pre-quantization: {layer.weight.shape}\nWeights: {weights}')
                layer.weight = nn.parameter.Parameter(quantize_dequantize_dt(layer.weight.detach()))
                print(f'Layer {count} weights shape post-quantization: {layer.weight.shape}\nWeights: {weights}')
                # print(layer.weight)
        except (TypeError, AttributeError):
            pass
else:
    count = 0
    for layer in alexnet.classifier:
        count += 1
        try:
            if len(layer.weight.shape) == 4:
                weights = layer.weight.detach()
                print(f'Layer {count} weights shape pre-quantization: {weights.shape}')
                
                for filter in range(0, layer.weight.shape[0]):
                    for channel in range(0, layer.weight.shape[1]):
                        # print(layer.weight[filter,channel])
                        weights[filter,channel] = quantize_dequantize_dt(weights[filter,channel])
                print(f'Layer {count} weights shape post-quantization: {weights.shape}')
                layer.weight = nn.parameter.Parameter(weights)
            else:
                print(f'In else loop')
                print(f'Layer {count} weights shape pre-quantization: {layer.weight.shape}')
                intermediate = quantize_dequantize_dt(layer.weight.detach())
                print(intermediate.shape)
                layer.weight = nn.parameter.Parameter(intermediate)
                print(f'Layer {count} weights shape post-quantization: {layer.weight.shape}')
                # print(layer.weight)
        except (TypeError, AttributeError):
            pass

# for layer in [*alexnet.features,*alexnet.classifier]:
#     try:
#         if len(layer.weight.shape) == 4:
#         for filter in range(0, layer.weight.shape[0]):
#             for channel in 
#         layer.weight = nn.parameter.Parameter(quantize_dequantize_dt(layer.weight.detach()))
#     except (TypeError, AttributeError):
#         print(layer)

# for layer in [*alexnet.features,*alexnet.classifier]:
#     try:
#         if len(layer.weight.shape) == 4:
#         for filter in range(0, layer.weight.shape[0]):
#             for channel in 
#         layer.weight = nn.parameter.Parameter(quantize_dequantize_dt(layer.weight.detach()))
#     except (TypeError, AttributeError):
#         print(layer)


Layer 1 weights shape pre-quantization: torch.Size([64, 3, 11, 11])
Weights: tensor([[[[ 1.9734e-01,  1.5311e-01,  1.2346e-01,  ...,  5.1259e-02,
            1.8628e-02,  4.9283e-02],
          [ 1.3495e-01,  8.2730e-02,  6.8725e-02,  ...,  2.2493e-02,
           -1.1554e-02,  7.1114e-03],
          [ 1.0532e-01,  5.7741e-02,  4.7859e-02,  ...,  3.1955e-02,
            4.0302e-03,  9.6770e-03],
          ...,
          [ 8.5962e-02,  1.0595e-01,  5.6914e-02,  ..., -2.0370e-01,
           -1.3139e-01, -1.2237e-01],
          [ 3.4958e-02,  6.8188e-02,  2.8633e-02,  ..., -2.0155e-01,
           -1.1676e-01, -1.2000e-01],
          [ 3.9003e-02,  6.7265e-02,  2.0207e-02,  ..., -1.2560e-01,
           -1.0937e-01, -1.0958e-01]],

         [[-8.0962e-02, -7.6534e-02, -1.2334e-01,  ...,  8.2895e-03,
           -1.7828e-03,  6.0680e-02],
          [-9.2570e-02, -9.6053e-02, -1.2600e-01,  ...,  1.5380e-02,
            5.2623e-03,  5.8764e-02],
          [-1.4768e-01, -1.3341e-01, -1.7059e-01, 

In [140]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
alexnet.to(device)

cuda:0


AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [143]:
#Testing Accuracy
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = alexnet(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

Accuracy of the network on the 10000 test images: 10.0 %


In [None]:
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

In [None]:
model_path = 'model_{}_{}'.format(timestamp, "DT_Quantized")
torch.save(alexnet.state_dict(), model_path)

---

In [52]:
print(f'Alexnet Layer: {alexnet.features[0]}\nShape: {alexnet.features[0].weight.shape}')

Alexnet Layer: Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
Shape: torch.Size([64, 3, 11, 11])


In [36]:
alexnet.features[0].weight[0][0][2]

tensor([0.1053, 0.0577, 0.0479, 0.0451, 0.0370, 0.0314, 0.0281, 0.0106, 0.0320,
        0.0040, 0.0097], grad_fn=<SelectBackward0>)

# DT Quantization Testing

In [120]:
# input : tensor output: binary string
def binary(num):
    return ''.join('{:0>8b}'.format(c) for c in struct.pack('!f', num))

#input: mantissa bitstring
#output: float value
def calc_mantissa(mantissa):
    res = 0
    for k in range(len(mantissa)):
        if mantissa[k] == '1':
          res += 2**(-k-1)
    return res

#input exp: bitstring
# new_exp_len: new length of exp
def calc_exp(exp, new_exp_len):
    limit = 2**(new_exp_len) - 1
    # if exp is more than new_exp_len limit, truncate to new_exp_len limit.
    bias = 2**(len(exp)-1) - 1
    val = int(exp,2) - bias
    if val > limit:
        return limit
    if val < -limit + 1:
        return -limit + 1

    return val

def round_fp8(x, exp = 4):
  '''
  Quantizes input tensor to FP8 data format
  inputs  x:      original tensor
          exp:    number of bits used for exponent field
                  e.g. E5M2 has 5 exp bits, E4M3 has 4 exp bits
  output  x_32:   quantized tensor
  '''

  x_fp8 = x.clone().to(torch.float32)

  for i in range(len(x_fp8)):
    for j in range(len(x_fp8[i])):
      result = 1.0
      bin_str = binary(x[i][j])

      bin_mantissa = bin_str[9:32]
      res_mantissa = bin_mantissa[:7-exp]    
      result += calc_mantissa(res_mantissa)

      bin_exp = bin_str[1:9]
      exp_int = calc_exp(bin_exp, exp)
      result *= 2**exp_int

      if bin_str[0] == '1':
        result *= -1

      x_fp8[i][j] = result    

  return x_fp8.to(torch.float32)

def dt_dequantize(quantized):
    """
    Custom dequantization function.
    Assumes quantized values are in the custom format.
    """
    dequantized = torch.zeros_like(quantized, dtype=torch.int32)

    for i, qval in enumerate(quantized.view(-1)):
        # Extract the bits
        sign_bit = (qval >> 7) & 1
        exp_bits = (qval >> 3) & 0xF
        bis_flag = (qval >> 2) & 1
        bis_tree = qval & 0x3

        # Compute the base value from the bisection tree bits
        base_value = bis_tree / 4.0

        # If bisection flag is set, adjust the base value
        if bis_flag:
            base_value += 0.5 / 4.0

        # Compute the dequantized value using the exponent
        value = base_value * (10 ** -exp_bits)


#the input is normalized tensor x,
def round_dt8(x, exp = 4):
    x_dt8 = x.clone().to(torch.float32)
    for i, row in enumerate(x_dt8):
        for j, val in enumerate(row):
            # Determine sign bit
            sign_bit = 0 if val >= 0 else 1

            # Absolute value for further processing
            abs_val = abs(val)
                    # Determine the exponent bits (4 bits)
            exp_bits = 0
            for k in range(15):
                if abs_val < 10**-k:
                    exp_bits = k - 1
                    break

            # Determine the bisection tree flag and binary bisection tree bits (3 bits)
            bis_flag = 1 if abs_val % (10**-exp_bits) != 0 else 0
            bis_tree = int((abs_val / (10**-exp_bits)) * 4) % 4

            # Combine the bits
            row[j] = (sign_bit << 7) | (exp_bits << 3) | (bis_flag << 2) | bis_tree
            # row[j] = dt_dequantize(row[j])

    return x_dt8

def quantize_rowwise(x: torch.Tensor, dt = False):
    abso = torch.abs(x)
    output_maxs  = torch.max(abso,1)[0].unsqueeze(-1)
    output = x[0]  / output_maxs[0,None]
    if not dt:
        output = round_fp8(output)
    else:
        output = round_dt8(output)
    return output, output_maxs

def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
    output = x * state_x
    return output

In [121]:
def quantize_dequantize_dt(mat):
    # return quantize_rowwise(mat, dt = True)[0]
    testing_dt, dt_max = quantize_rowwise(mat, dt = True)
    return dequantize_rowwise(testing_dt,dt_max)

3 channels, 2x3 

In [122]:
conv1 = nn.Conv2d(3, 2, 3)
weight = conv1.weight.data.numpy()
print(weight[0,0])
print(f'Shape: {weight[0,0].shape}')

[[ 0.17571832 -0.15971728 -0.10938939]
 [ 0.08544098  0.02062272  0.17651452]
 [-0.09197273  0.1612549   0.02800082]]
Shape: (3, 3)


In [123]:
# Before Quantization
conv1.weight

Parameter containing:
tensor([[[[ 0.1757, -0.1597, -0.1094],
          [ 0.0854,  0.0206,  0.1765],
          [-0.0920,  0.1613,  0.0280]],

         [[ 0.1043,  0.1438,  0.1659],
          [-0.1101, -0.0866, -0.1302],
          [-0.1267, -0.1905,  0.1319]],

         [[-0.0784, -0.0221,  0.0483],
          [ 0.1775,  0.1904,  0.0567],
          [ 0.1777,  0.1157,  0.0988]]],


        [[[-0.0899,  0.1893,  0.0966],
          [-0.1393,  0.1850, -0.0379],
          [-0.0261, -0.1160, -0.1630]],

         [[ 0.0714,  0.0714,  0.0970],
          [-0.0463,  0.0269,  0.0344],
          [ 0.1601,  0.1483, -0.0044]],

         [[-0.0080,  0.1798,  0.1078],
          [ 0.0017,  0.1502,  0.0239],
          [ 0.1659,  0.1358,  0.1605]]]], requires_grad=True)

In [124]:
weights = conv1.weight.detach()
print(f'Layer {count} weights shape pre-quantization: {weights.shape}\nWeights: {weights}')
for filter in range(0, weights.shape[0]):
    print(f'Filter num {filter}')
    for channel in range(0, weights.shape[1]):
        print(f'Channel num {channel}')
        # print(layer.weight[filter,channel])
        weights[filter,channel] = quantize_dequantize_dt(weights[filter,channel])
        print(f'Finish window')
print(f'Layer {count} weights shape post-quantization: {weights.shape}\nWeights: {weights}')
# layer.weight = nn.parameter.Parameter(weights)

Layer 20 weights shape pre-quantization: torch.Size([2, 3, 3, 3])
Weights: tensor([[[[ 0.1757, -0.1597, -0.1094],
          [ 0.0854,  0.0206,  0.1765],
          [-0.0920,  0.1613,  0.0280]],

         [[ 0.1043,  0.1438,  0.1659],
          [-0.1101, -0.0866, -0.1302],
          [-0.1267, -0.1905,  0.1319]],

         [[-0.0784, -0.0221,  0.0483],
          [ 0.1775,  0.1904,  0.0567],
          [ 0.1777,  0.1157,  0.0988]]],


        [[[-0.0899,  0.1893,  0.0966],
          [-0.1393,  0.1850, -0.0379],
          [-0.0261, -0.1160, -0.1630]],

         [[ 0.0714,  0.0714,  0.0970],
          [-0.0463,  0.0269,  0.0344],
          [ 0.1601,  0.1483, -0.0044]],

         [[-0.0080,  0.1798,  0.1078],
          [ 0.0017,  0.1502,  0.0239],
          [ 0.1659,  0.1358,  0.1605]]]])
Filter num 0
Channel num 0
Finish window
Channel num 1
Finish window
Channel num 2
Finish window
Filter num 1
Channel num 0
Finish window
Channel num 1
Finish window
Channel num 2
Finish window
Layer 20 weigh

In [125]:
weights

tensor([[[[ 0.0000, -0.7029, -0.7029],
          [ 0.0000, -0.7061, -0.7061],
          [ 0.0000, -0.6450, -0.6450]],

         [[-0.6637, -0.6637,  0.0000],
          [-0.5210, -0.5210,  0.0000],
          [-0.7622, -0.7622,  0.0000]],

         [[10.0323, -0.3135, -0.3135],
          [24.3759, -0.7617, -0.7617],
          [22.7433, -0.7107, -0.7107]]],


        [[[-0.7570,  0.0000, -0.7570],
          [-0.7401,  0.0000, -0.7401],
          [-0.6519,  0.0000, -0.6519]],

         [[-0.3879, -0.3879,  0.0000],
          [-0.1850, -0.1850,  0.0000],
          [-0.6405, -0.6405,  0.0000]],

         [[-0.7193,  0.0000, -0.7193],
          [-0.6006,  0.0000, -0.6006],
          [-0.6637,  0.0000, -0.6637]]]])

In [82]:
# layer = conv1
# weights = 
# if len(layer.weight.shape) == 4:
#     for f in range(0, layer.weight.shape[0]):
#         for channel in range(0, layer.weight.shape[1]):
#             # print(layer.weight[filter,channel])
#             layer.weight[f,channel] = quantize_dequantize_dt(layer.weight.detach()[f,channel])

SyntaxError: invalid syntax (3047544188.py, line 2)

3 channels, 5x9

In [126]:
conv1 = nn.Conv2d(3, 5, 9)
weight = conv1.weight.data.numpy()
print(weight[0,0])
print(f'Shape: {weight[0,0].shape}')

[[ 0.02876318  0.0298016  -0.01580995  0.04515649 -0.04235711  0.0548744
  -0.05962999 -0.03796124 -0.03850383]
 [-0.06024191  0.00939912  0.05952165  0.05664774  0.04435438  0.00022057
   0.061092    0.03485528 -0.02440702]
 [-0.02034292  0.0024887  -0.04590605  0.002611   -0.02906006  0.058422
   0.01825075 -0.01882584  0.02503988]
 [-0.06075707  0.04262581 -0.06081435  0.03414683 -0.01653043 -0.02937788
  -0.0612909  -0.02942312  0.05791073]
 [-0.04128146 -0.01805271 -0.04751305 -0.01434946  0.00777362 -0.0545434
   0.0589461  -0.02584737 -0.01401746]
 [-0.01357341 -0.03338958  0.00251536 -0.02421685  0.00203022  0.04923867
  -0.04494392  0.03626681  0.04377199]
 [ 0.01569144 -0.01332225 -0.00418554  0.03602285 -0.05743553 -0.05675019
  -0.04123453  0.05785275  0.00239326]
 [-0.04667281  0.03680348  0.01815452  0.00047237 -0.00509175 -0.04635657
   0.02168845 -0.0196288  -0.01502234]
 [-0.01630999  0.00420251 -0.02627191 -0.01753236 -0.0574376  -0.05961076
  -0.01812495  0.05866265 

In [127]:
# Before Quantization
conv1.weight

Parameter containing:
tensor([[[[ 0.0288,  0.0298, -0.0158,  ..., -0.0596, -0.0380, -0.0385],
          [-0.0602,  0.0094,  0.0595,  ...,  0.0611,  0.0349, -0.0244],
          [-0.0203,  0.0025, -0.0459,  ...,  0.0183, -0.0188,  0.0250],
          ...,
          [ 0.0157, -0.0133, -0.0042,  ..., -0.0412,  0.0579,  0.0024],
          [-0.0467,  0.0368,  0.0182,  ...,  0.0217, -0.0196, -0.0150],
          [-0.0163,  0.0042, -0.0263,  ..., -0.0181,  0.0587,  0.0509]],

         [[-0.0420,  0.0270,  0.0108,  ...,  0.0097,  0.0247,  0.0512],
          [ 0.0131, -0.0629, -0.0496,  ...,  0.0190, -0.0560, -0.0481],
          [ 0.0287,  0.0480, -0.0156,  ..., -0.0171,  0.0079, -0.0079],
          ...,
          [-0.0318,  0.0459, -0.0144,  ...,  0.0382,  0.0615,  0.0426],
          [-0.0407,  0.0190, -0.0202,  ...,  0.0210, -0.0015,  0.0070],
          [-0.0285,  0.0106,  0.0355,  ..., -0.0480,  0.0235, -0.0171]],

         [[-0.0195,  0.0617, -0.0195,  ...,  0.0253,  0.0071,  0.0015],
        

In [128]:
weights = conv1.weight.detach()
print(f'Layer {count} weights shape pre-quantization: {weights.shape}\nWeights: {weights}')
for filter in range(0, weights.shape[0]):
    print(f'Filter num {filter}')
    for channel in range(0, weights.shape[1]):
        print(f'Channel num {channel}')
        # print(layer.weight[filter,channel])
        weights[filter,channel] = quantize_dequantize_dt(weights[filter,channel])
        print(f'Finish window')
print(f'Layer {count} weights shape post-quantization: {weights.shape}\nWeights: {weights}')
# layer.weight = nn.parameter.Parameter(weights)weights = conv1.weight.detach()
print(f'Layer {count} weights shape pre-quantization: {weights.shape}\nWeights: {weights}')
for filter in range(0, weights.shape[0]):
    print(f'Filter num {filter}')
    for channel in range(0, weights.shape[1]):
        print(f'Channel num {channel}')
        # print(layer.weight[filter,channel])
        weights[filter,channel] = quantize_dequantize_dt(weights[filter,channel])
        print(f'Finish window')
print(f'Layer {count} weights shape post-quantization: {weights.shape}\nWeights: {weights}')
# layer.weight = nn.parameter.Parameter(weights)

Layer 20 weights shape pre-quantization: torch.Size([5, 3, 9, 9])
Weights: tensor([[[[ 0.0288,  0.0298, -0.0158,  ..., -0.0596, -0.0380, -0.0385],
          [-0.0602,  0.0094,  0.0595,  ...,  0.0611,  0.0349, -0.0244],
          [-0.0203,  0.0025, -0.0459,  ...,  0.0183, -0.0188,  0.0250],
          ...,
          [ 0.0157, -0.0133, -0.0042,  ..., -0.0412,  0.0579,  0.0024],
          [-0.0467,  0.0368,  0.0182,  ...,  0.0217, -0.0196, -0.0150],
          [-0.0163,  0.0042, -0.0263,  ..., -0.0181,  0.0587,  0.0509]],

         [[-0.0420,  0.0270,  0.0108,  ...,  0.0097,  0.0247,  0.0512],
          [ 0.0131, -0.0629, -0.0496,  ...,  0.0190, -0.0560, -0.0481],
          [ 0.0287,  0.0480, -0.0156,  ..., -0.0171,  0.0079, -0.0079],
          ...,
          [-0.0318,  0.0459, -0.0144,  ...,  0.0382,  0.0615,  0.0426],
          [-0.0407,  0.0190, -0.0202,  ...,  0.0210, -0.0015,  0.0070],
          [-0.0285,  0.0106,  0.0355,  ..., -0.0480,  0.0235, -0.0171]],

         [[-0.0195,  0.0617

In [129]:
weights

tensor([[[[-30.5306, -30.5306, -30.5306,  ...,   0.0000, -30.5306, -30.5306],
          [-31.2791, -31.2791, -31.2791,  ...,   0.0000, -31.2791, -31.2791],
          [-29.9121, -29.9121, -29.9121,  ...,   0.0000, -29.9121, -29.9121],
          ...,
          [-29.6206, -29.6206, -29.6206,  ...,   0.0000, -29.6206, -29.6206],
          [-23.8965, -23.8965, -23.8965,  ...,   0.0000, -23.8965, -23.8965],
          [-30.5207, -30.5207, -30.5207,  ...,   0.0000, -30.5207, -30.5207]],

         [[ 26.2096,  26.2096,  26.2096,  ...,  26.2096,  26.2096,  -1.6381],
          [ 32.2059,  32.2059,  32.2059,  ...,  32.2059,  32.2059,  -2.0129],
          [ 28.6300,  28.6300,  28.6300,  ...,  28.6300,  28.6300,  -1.7894],
          ...,
          [ 31.4729,  31.4729,  31.4729,  ...,  31.4729,  31.4729,  -1.9671],
          [ 20.8408,  20.8408,  20.8408,  ...,  20.8408,  20.8408,  -1.3025],
          [ 24.5560,  24.5560,  24.5560,  ...,  24.5560,  24.5560,  -1.5348]],

         [[ 31.5881,  -1.9743,

In [131]:
weights[0]

tensor([[[-30.5306, -30.5306, -30.5306, -30.5306, -30.5306, -30.5306,   0.0000,
          -30.5306, -30.5306],
         [-31.2791, -31.2791, -31.2791, -31.2791, -31.2791, -31.2791,   0.0000,
          -31.2791, -31.2791],
         [-29.9121, -29.9121, -29.9121, -29.9121, -29.9121, -29.9121,   0.0000,
          -29.9121, -29.9121],
         [-31.3809, -31.3809, -31.3809, -31.3809, -31.3809, -31.3809,   0.0000,
          -31.3809, -31.3809],
         [-30.1804, -30.1804, -30.1804, -30.1804, -30.1804, -30.1804,   0.0000,
          -30.1804, -30.1804],
         [-25.2102, -25.2102, -25.2102, -25.2102, -25.2102, -25.2102,   0.0000,
          -25.2102, -25.2102],
         [-29.6206, -29.6206, -29.6206, -29.6206, -29.6206, -29.6206,   0.0000,
          -29.6206, -29.6206],
         [-23.8965, -23.8965, -23.8965, -23.8965, -23.8965, -23.8965,   0.0000,
          -23.8965, -23.8965],
         [-30.5207, -30.5207, -30.5207, -30.5207, -30.5207, -30.5207,   0.0000,
          -30.5207, -30.5207]],


In [142]:
alexnet.features[0].weight[0]

tensor([[[ 0.0000, -0.7894, -0.7894, -0.7894, -0.7894, -0.7894, -0.7894,
          -0.7894, -0.7894, -0.7894, -0.7894],
         [ 0.0000, -0.5398, -0.5398, -0.5398, -0.5398, -0.5398, -0.5398,
          -0.5398, -0.5398, -0.5398, -0.5398],
         [ 0.0000, -0.4213, -0.4213, -0.4213, -0.4213, -0.4213, -0.4213,
          -0.4213, -0.4213, -0.4213, -0.4213],
         [ 0.0000, -0.3654, -0.3654, -0.3654, -0.3654, -0.3654, -0.3654,
          -0.3654, -0.3654, -0.3654, -0.3654],
         [ 0.0000, -0.3950, -0.3950, -0.3950, -0.3950, -0.3950, -0.3950,
          -0.3950, -0.3950, -0.3950, -0.3950],
         [ 0.0000, -0.4068, -0.4068, -0.4068, -0.4068, -0.4068, -0.4068,
          -0.4068, -0.4068, -0.4068, -0.4068],
         [ 0.0000, -0.5804, -0.5804, -0.5804, -0.5804, -0.5804, -0.5804,
          -0.5804, -0.5804, -0.5804, -0.5804],
         [ 0.0000, -0.6692, -0.6692, -0.6692, -0.6692, -0.6692, -0.6692,
          -0.6692, -0.6692, -0.6692, -0.6692],
         [ 0.0000, -0.8148, -0.8148, -0.

In [64]:
layer.weight.detach()[filter,channel]

tensor([[-0.0799,  0.0143,  0.0621],
        [-0.1410,  0.1376,  0.1556],
        [ 0.1617,  0.1489,  0.0450]])