In [1]:
import argparse
import os
import time
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
     

import torchvision
import torchvision.transforms as transforms

from models import *


global best_prec
use_gpu = torch.cuda.is_available()
print('=> Building model...')
     
batch_size = 128
model_name = "Resnet20_quantAware"
model = resnet20_quant() # For quantization aware training

# print(model)


normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])


train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


print_freq = 100 # every 100 batches, accuracy printed. Here, each batch includes "batch_size" data points
# CIFAR10 has 50,000 training data, and 10,000 validation data.

def train(trainloader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()

    end = time.time()
    for i, (input, target) in enumerate(trainloader):
        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.cuda(), target.cuda()

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec = accuracy(output, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()


        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   epoch, i, len(trainloader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))

            

def validate(val_loader, model, criterion ):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
         
            input, target = input.cuda(), target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec = accuracy(output, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:  # This line shows how frequently print out the status. e.g., i%5 => every 5 batch, prints out
                print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1))

    print(' * Prec {top1.avg:.3f}% '.format(top1=top1))
    return top1.avg


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

        
def save_checkpoint(state, is_best, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_best.pth.tar'))     

#model = nn.DataParallel(model).cuda()
#all_params = checkpoint['state_dict']
#model.load_state_dict(all_params, strict=False)
#criterion = nn.CrossEntropyLoss().cuda()
#validate(testloader, model, criterion)

=> Building model...
Files already downloaded and verified
Files already downloaded and verified


In [2]:
def adjust_learning_rate(optimizer, epoch):
    """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs"""
    adjust_list = [20, 30, 150, 225]
    p_adjust_list = [25, 35]
    if epoch in adjust_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1
    if epoch in p_adjust_list:
        for param_group in optimizer.param_groups:
            param_group['momentum'] = param_group['momentum'] / 2

In [3]:
# This cell won't be given, but students will complete the training

lr = 0.1
weight_decay = 1e-4
epochs = 100
best_prec = 0

#model = nn.DataParallel(model).cuda()
model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
#cudnn.benchmark = True

if not os.path.exists('result'):
    os.makedirs('result')
fdir = 'result/'+str(model_name)
if not os.path.exists(fdir):
    os.makedirs(fdir)
        

for epoch in range(0, epochs):
    adjust_learning_rate(optimizer, epoch)

    train(trainloader, model, criterion, optimizer, epoch)
    
    # evaluate on test set
    print("Validation starts")
    prec = validate(testloader, model, criterion)

    # remember best precision and save checkpoint
    is_best = prec > best_prec
    best_prec = max(prec,best_prec)
    print('best acc: {:1f}'.format(best_prec))
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec': best_prec,
        'optimizer': optimizer.state_dict(),
    }, is_best, fdir)

Epoch: [0][0/391]	Time 2.177 (2.177)	Data 0.378 (0.378)	Loss 2.4776 (2.4776)	Prec 10.938% (10.938%)
Epoch: [0][100/391]	Time 0.041 (0.067)	Data 0.002 (0.006)	Loss 1.6168 (1.9811)	Prec 39.062% (25.186%)
Epoch: [0][200/391]	Time 0.037 (0.056)	Data 0.002 (0.004)	Loss 1.7641 (1.8235)	Prec 31.250% (31.110%)
Epoch: [0][300/391]	Time 0.051 (0.052)	Data 0.002 (0.004)	Loss 1.3537 (1.7224)	Prec 50.000% (34.972%)
Validation starts
Test: [0/79]	Time 0.243 (0.243)	Loss 1.3392 (1.3392)	Prec 56.250% (56.250%)
 * Prec 51.000% 
best acc: 51.000000
Epoch: [1][0/391]	Time 0.488 (0.488)	Data 0.431 (0.431)	Loss 1.2776 (1.2776)	Prec 51.562% (51.562%)
Epoch: [1][100/391]	Time 0.038 (0.051)	Data 0.002 (0.007)	Loss 1.2367 (1.2494)	Prec 57.031% (54.680%)
Epoch: [1][200/391]	Time 0.044 (0.048)	Data 0.002 (0.005)	Loss 1.2049 (1.1962)	Prec 57.031% (56.736%)
Epoch: [1][300/391]	Time 0.039 (0.046)	Data 0.002 (0.004)	Loss 1.0800 (1.1602)	Prec 61.719% (58.153%)
Validation starts
Test: [0/79]	Time 0.452 (0.452)	Loss 1.

Epoch: [15][200/391]	Time 0.039 (0.050)	Data 0.002 (0.005)	Loss 0.3969 (0.4075)	Prec 85.938% (85.766%)
Epoch: [15][300/391]	Time 0.042 (0.048)	Data 0.002 (0.004)	Loss 0.3101 (0.4069)	Prec 85.156% (85.922%)
Validation starts
Test: [0/79]	Time 0.251 (0.251)	Loss 0.4592 (0.4592)	Prec 82.031% (82.031%)
 * Prec 81.250% 
best acc: 81.740000
Epoch: [16][0/391]	Time 0.621 (0.621)	Data 0.559 (0.559)	Loss 0.3243 (0.3243)	Prec 88.281% (88.281%)
Epoch: [16][100/391]	Time 0.034 (0.050)	Data 0.002 (0.008)	Loss 0.4085 (0.3775)	Prec 85.938% (86.680%)
Epoch: [16][200/391]	Time 0.047 (0.045)	Data 0.002 (0.005)	Loss 0.3601 (0.3861)	Prec 87.500% (86.419%)
Epoch: [16][300/391]	Time 0.042 (0.044)	Data 0.002 (0.004)	Loss 0.3129 (0.3925)	Prec 88.281% (86.176%)
Validation starts
Test: [0/79]	Time 0.256 (0.256)	Loss 0.4277 (0.4277)	Prec 85.156% (85.156%)
 * Prec 83.290% 
best acc: 83.290000
Epoch: [17][0/391]	Time 0.625 (0.625)	Data 0.567 (0.567)	Loss 0.3780 (0.3780)	Prec 87.500% (87.500%)
Epoch: [17][100/391]	

Epoch: [30][300/391]	Time 0.037 (0.045)	Data 0.002 (0.004)	Loss 0.1360 (0.2025)	Prec 95.312% (92.951%)
Validation starts
Test: [0/79]	Time 0.246 (0.246)	Loss 0.2292 (0.2292)	Prec 93.750% (93.750%)
 * Prec 89.050% 
best acc: 89.150000
Epoch: [31][0/391]	Time 0.525 (0.525)	Data 0.463 (0.463)	Loss 0.2170 (0.2170)	Prec 93.750% (93.750%)
Epoch: [31][100/391]	Time 0.043 (0.047)	Data 0.002 (0.007)	Loss 0.1706 (0.2033)	Prec 94.531% (92.930%)
Epoch: [31][200/391]	Time 0.043 (0.046)	Data 0.002 (0.005)	Loss 0.1137 (0.2013)	Prec 99.219% (92.996%)
Epoch: [31][300/391]	Time 0.041 (0.045)	Data 0.002 (0.004)	Loss 0.2414 (0.2012)	Prec 92.969% (92.987%)
Validation starts
Test: [0/79]	Time 0.373 (0.373)	Loss 0.2429 (0.2429)	Prec 92.188% (92.188%)
 * Prec 89.080% 
best acc: 89.150000
Epoch: [32][0/391]	Time 0.470 (0.470)	Data 0.407 (0.407)	Loss 0.1636 (0.1636)	Prec 94.531% (94.531%)
Epoch: [32][100/391]	Time 0.037 (0.048)	Data 0.002 (0.006)	Loss 0.2625 (0.2042)	Prec 86.719% (93.062%)
Epoch: [32][200/391]	

Validation starts
Test: [0/79]	Time 0.240 (0.240)	Loss 0.2320 (0.2320)	Prec 94.531% (94.531%)
 * Prec 89.080% 
best acc: 89.150000
Epoch: [46][0/391]	Time 0.510 (0.510)	Data 0.450 (0.450)	Loss 0.1221 (0.1221)	Prec 96.875% (96.875%)
Epoch: [46][100/391]	Time 0.044 (0.047)	Data 0.002 (0.007)	Loss 0.1875 (0.1979)	Prec 92.188% (93.085%)
Epoch: [46][200/391]	Time 0.042 (0.046)	Data 0.002 (0.004)	Loss 0.2127 (0.1962)	Prec 92.969% (93.085%)
Epoch: [46][300/391]	Time 0.037 (0.046)	Data 0.001 (0.004)	Loss 0.1179 (0.1976)	Prec 96.875% (93.023%)
Validation starts
Test: [0/79]	Time 0.400 (0.400)	Loss 0.2683 (0.2683)	Prec 90.625% (90.625%)
 * Prec 88.960% 
best acc: 89.150000
Epoch: [47][0/391]	Time 0.473 (0.473)	Data 0.412 (0.412)	Loss 0.2052 (0.2052)	Prec 91.406% (91.406%)
Epoch: [47][100/391]	Time 0.050 (0.048)	Data 0.002 (0.006)	Loss 0.2364 (0.1989)	Prec 92.969% (92.992%)
Epoch: [47][200/391]	Time 0.055 (0.046)	Data 0.003 (0.004)	Loss 0.1592 (0.1966)	Prec 94.531% (93.136%)
Epoch: [47][300/391]	

 * Prec 89.020% 
best acc: 89.330000
Epoch: [61][0/391]	Time 0.702 (0.702)	Data 0.639 (0.639)	Loss 0.1719 (0.1719)	Prec 92.969% (92.969%)
Epoch: [61][100/391]	Time 0.042 (0.051)	Data 0.002 (0.009)	Loss 0.1411 (0.2000)	Prec 95.312% (93.046%)
Epoch: [61][200/391]	Time 0.040 (0.047)	Data 0.002 (0.006)	Loss 0.1935 (0.1979)	Prec 89.844% (93.140%)
Epoch: [61][300/391]	Time 0.039 (0.046)	Data 0.002 (0.004)	Loss 0.1662 (0.1955)	Prec 92.969% (93.132%)
Validation starts
Test: [0/79]	Time 0.266 (0.266)	Loss 0.2362 (0.2362)	Prec 92.188% (92.188%)
 * Prec 89.160% 
best acc: 89.330000
Epoch: [62][0/391]	Time 0.578 (0.578)	Data 0.518 (0.518)	Loss 0.1408 (0.1408)	Prec 96.094% (96.094%)
Epoch: [62][100/391]	Time 0.049 (0.053)	Data 0.002 (0.008)	Loss 0.0957 (0.1974)	Prec 96.875% (93.100%)
Epoch: [62][200/391]	Time 0.054 (0.051)	Data 0.003 (0.005)	Loss 0.1277 (0.1961)	Prec 95.312% (93.241%)
Epoch: [62][300/391]	Time 0.043 (0.047)	Data 0.002 (0.004)	Loss 0.1642 (0.1949)	Prec 92.969% (93.293%)
Validation s

Epoch: [76][100/391]	Time 0.044 (0.050)	Data 0.003 (0.007)	Loss 0.2217 (0.1961)	Prec 90.625% (92.953%)
Epoch: [76][200/391]	Time 0.036 (0.047)	Data 0.002 (0.005)	Loss 0.2258 (0.1941)	Prec 89.062% (93.163%)
Epoch: [76][300/391]	Time 0.050 (0.046)	Data 0.002 (0.004)	Loss 0.1393 (0.1951)	Prec 96.094% (93.122%)
Validation starts
Test: [0/79]	Time 0.456 (0.456)	Loss 0.2431 (0.2431)	Prec 92.969% (92.969%)
 * Prec 89.010% 
best acc: 89.330000
Epoch: [77][0/391]	Time 0.649 (0.649)	Data 0.590 (0.590)	Loss 0.2186 (0.2186)	Prec 92.969% (92.969%)
Epoch: [77][100/391]	Time 0.031 (0.049)	Data 0.002 (0.008)	Loss 0.1698 (0.2020)	Prec 94.531% (93.000%)
Epoch: [77][200/391]	Time 0.045 (0.046)	Data 0.002 (0.005)	Loss 0.1978 (0.2000)	Prec 94.531% (93.085%)
Epoch: [77][300/391]	Time 0.036 (0.046)	Data 0.001 (0.004)	Loss 0.1943 (0.1957)	Prec 94.531% (93.278%)
Validation starts
Test: [0/79]	Time 0.500 (0.500)	Loss 0.2395 (0.2395)	Prec 91.406% (91.406%)
 * Prec 88.900% 
best acc: 89.330000
Epoch: [78][0/391]	

Epoch: [91][200/391]	Time 0.036 (0.048)	Data 0.002 (0.005)	Loss 0.3505 (0.1895)	Prec 88.281% (93.455%)
Epoch: [91][300/391]	Time 0.049 (0.047)	Data 0.002 (0.004)	Loss 0.1606 (0.1912)	Prec 95.312% (93.452%)
Validation starts
Test: [0/79]	Time 0.466 (0.466)	Loss 0.2534 (0.2534)	Prec 90.625% (90.625%)
 * Prec 89.030% 
best acc: 89.330000
Epoch: [92][0/391]	Time 0.575 (0.575)	Data 0.516 (0.516)	Loss 0.1611 (0.1611)	Prec 92.969% (92.969%)
Epoch: [92][100/391]	Time 0.049 (0.049)	Data 0.002 (0.008)	Loss 0.2464 (0.1930)	Prec 92.188% (93.379%)
Epoch: [92][200/391]	Time 0.028 (0.047)	Data 0.002 (0.005)	Loss 0.1950 (0.1903)	Prec 92.969% (93.256%)
Epoch: [92][300/391]	Time 0.057 (0.046)	Data 0.002 (0.004)	Loss 0.1460 (0.1898)	Prec 93.750% (93.330%)
Validation starts
Test: [0/79]	Time 0.405 (0.405)	Loss 0.2522 (0.2522)	Prec 92.969% (92.969%)
 * Prec 88.980% 
best acc: 89.330000
Epoch: [93][0/391]	Time 0.544 (0.544)	Data 0.500 (0.500)	Loss 0.1896 (0.1896)	Prec 92.188% (92.188%)
Epoch: [93][100/391]	

In [4]:
PATH = "result/" + model_name + "/model_best.pth.tar"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
device = torch.device("cuda") 

model.cuda()
model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in testloader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(testloader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(testloader.dataset),
        100. * correct / len(testloader.dataset)))


Test set: Accuracy: 8933/10000 (89%)



In [5]:
#send an input and grap the value by using prehook like HW3
class SaveOutput:
    def __init__(self):
        self.outputs = []
    def __call__(self, module, module_in):
        self.outputs.append(module_in)
    def clear(self):
        self.outputs = []  

In [6]:
######### Save inputs from selected layer ##########
save_output = SaveOutput()
model.layer1[0].conv2.register_forward_pre_hook(save_output) ## Grab input for BasicBlock0 Conv1 layer
# model.layer1[1].conv1.register_forward_pre_hook(save_output) ## Grab input for BasicBlock1 Conv1 layer
####################################################

dataiter = iter(trainloader)
images, labels = next(dataiter)
images = images.to('cuda')
# Run a forward pass to prehook
out = model(images)

In [7]:
w_bit = 4
weight_q = model.layer1[0].conv2.weight_q # quantized value is stored during the training
# print(weight_q)
w_alpha = model.layer1[0].conv2.weight_quant.wgt_alpha   # alpha is defined in your model already. bring it out here
w_delta = w_alpha/(2**(w_bit-1)-1)   # delta can be calculated by using alpha and w_bit
weight_int = weight_q / w_delta # w_int can be calculated by weight_q and w_delta
print(weight_int) # you should see clean integer numbers

tensor([[[[-2.0000, -4.0000,  7.0000],
          [ 2.0000, -7.0000, -1.0000],
          [-5.0000, -2.0000,  5.0000]],

         [[ 2.0000,  1.0000,  1.0000],
          [ 2.0000, -3.0000, -4.0000],
          [ 0.0000, -2.0000, -4.0000]],

         [[ 6.0000,  3.0000, -6.0000],
          [-6.0000,  7.0000,  4.0000],
          [-4.0000,  4.0000, -4.0000]],

         ...,

         [[ 3.0000, -2.0000, -3.0000],
          [ 3.0000, -1.0000, -1.0000],
          [ 2.0000, -2.0000, -2.0000]],

         [[ 1.0000, -5.0000,  3.0000],
          [ 1.0000, -7.0000,  7.0000],
          [ 2.0000, -7.0000,  5.0000]],

         [[-4.0000, -4.0000, -7.0000],
          [-4.0000,  3.0000, -6.0000],
          [-1.0000,  1.0000, -2.0000]]],


        [[[-1.0000, -0.0000,  2.0000],
          [ 3.0000, -1.0000,  4.0000],
          [ 2.0000,  2.0000,  3.0000]],

         [[-3.0000, -0.0000,  1.0000],
          [ 1.0000,  3.0000,  4.0000],
          [ 2.0000,  3.0000,  6.0000]],

         [[-2.0000, -4.0000, -3

In [8]:
x_bit = 4    
x = save_output.outputs[0][0]  # input of the 2nd conv layer
x_alpha = model.layer1[0].conv2.act_alpha
x_delta = x_alpha/(2**x_bit - 1)

act_quant_fn = act_quantization(x_bit) # define the quantization function
x_q = act_quant_fn(x, x_alpha)         # create the quantized value for x

x_int = x_q / x_delta
print(x_int) # you should see clean integer numbers 

tensor([[[[ 0.0000,  0.0000,  0.0000,  ...,  4.0000,  4.0000,  0.0000],
          [ 3.0000,  2.0000,  2.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 1.0000,  0.0000,  0.0000,  ...,  3.0000,  1.0000,  0.0000],
          ...,
          [ 0.0000,  1.0000,  2.0000,  ...,  3.0000,  1.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 2.0000,  3.0000,  3.0000,  ...,  3.0000,  1.0000,  0.0000]],

         [[ 2.0000,  1.0000,  1.0000,  ...,  0.0000,  1.0000,  0.0000],
          [ 4.0000,  2.0000,  2.0000,  ...,  0.0000,  1.0000,  0.0000],
          [ 4.0000,  2.0000,  1.0000,  ...,  0.0000,  1.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  1.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  1.0000,  0.0000]],

         [[ 3.0000,  2.0000,  4.0000,  ...,  3.0000,  2.0000,  3.0000],
          [ 1.0000,  1.0000,  

In [9]:
conv_int = torch.nn.Conv2d(in_channels = 64, out_channels=64, kernel_size = 3, bias = False)
conv_int.weight = torch.nn.parameter.Parameter(weight_int)

output_int =  conv_int(x_int)    # output_int can be calculated with conv_int and x_int
output_recovered = output_int*x_delta*w_delta  # recover with x_delta and w_delta
print(output_recovered)

tensor([[[[ 1.1223e+01,  2.3106e+00,  9.0775e-01,  ...,  5.1329e+01,
           -1.0976e+01,  2.3932e+00],
          [ 1.1388e+01,  5.2815e+00,  1.1553e+00,  ...,  4.4975e+01,
           -1.6340e+01, -6.3543e+00],
          [ 2.9708e+00,  3.9350e-08,  6.3543e+00,  ...,  4.8689e+01,
           -1.8485e+01, -4.9514e+00],
          ...,
          [ 5.9417e+00, -4.5388e+00, -5.1164e+00,  ...,  8.0047e+00,
           -5.3640e+00, -3.1359e+00],
          [ 6.1892e+00,  7.2620e+00,  8.9950e+00,  ...,  6.1067e+00,
            3.8786e+00,  9.0775e-01],
          [ 2.1456e+00,  2.6407e+00,  5.2815e+00,  ...,  3.7961e+00,
           -1.8980e+00, -5.1990e+00]],

         [[ 1.6587e+01,  1.1058e+01,  7.5096e+00,  ..., -9.3251e+00,
           -6.1892e+00,  5.4465e+00],
          [ 1.7908e+01,  1.3534e+01,  1.2791e+01,  ..., -5.3640e+00,
           -2.3932e+00,  8.0873e+00],
          [ 7.9222e+00,  9.1601e+00,  1.2874e+01,  ..., -4.4562e+00,
           -1.7330e+00,  9.0775e+00],
          ...,
     

In [10]:
#### input floating number / weight quantized version

conv_ref = torch.nn.Conv2d(in_channels = 64, out_channels=64, kernel_size = 3, bias = False)
conv_ref.weight = model.layer1[0].conv2.weight_q
print(conv_ref.weight)
output_ref = conv_ref(x)
print(output_ref)

Parameter containing:
tensor([[[[-0.6599, -1.3198,  2.3097],
          [ 0.6599, -2.3097, -0.3300],
          [-1.6498, -0.6599,  1.6498]],

         [[ 0.6599,  0.3300,  0.3300],
          [ 0.6599, -0.9899, -1.3198],
          [ 0.0000, -0.6599, -1.3198]],

         [[ 1.9797,  0.9899, -1.9797],
          [-1.9797,  2.3097,  1.3198],
          [-1.3198,  1.3198, -1.3198]],

         ...,

         [[ 0.9899, -0.6599, -0.9899],
          [ 0.9899, -0.3300, -0.3300],
          [ 0.6599, -0.6599, -0.6599]],

         [[ 0.3300, -1.6498,  0.9899],
          [ 0.3300, -2.3097,  2.3097],
          [ 0.6599, -2.3097,  1.6498]],

         [[-1.3198, -1.3198, -2.3097],
          [-1.3198,  0.9899, -1.9797],
          [-0.3300,  0.3300, -0.6599]]],


        [[[-0.3300, -0.0000,  0.6599],
          [ 0.9899, -0.3300,  1.3198],
          [ 0.6599,  0.6599,  0.9899]],

         [[-0.9899, -0.0000,  0.3300],
          [ 0.3300,  0.9899,  1.3198],
          [ 0.6599,  0.9899,  1.9797]],

         

In [11]:
difference = abs( output_ref - output_recovered )
# diff_percent = 100*abs(output_ref - output_recovered)/output_ref
print(difference.mean())  ## It should be small, e.g.,2.3 in my trainned model

tensor(0.6073, device='cuda:0', grad_fn=<MeanBackward0>)


In [12]:
#### input floating number / weight floating number version

conv_ref = torch.nn.Conv2d(in_channels = 64, out_channels=64, kernel_size = 3, bias = False)
weight = model.layer1[0].conv2.weight_q
mean = weight.data.mean()
std = weight.data.std()
conv_ref.weight = torch.nn.parameter.Parameter(weight.add(-mean).div(std))

output_ref = conv_ref(x)
print(output_ref)


tensor([[[[ 1.0943e+01,  2.2941e+00,  6.8174e-01,  ...,  5.2347e+01,
           -1.0455e+01,  3.5751e+00],
          [ 1.1056e+01,  5.2068e+00,  1.7795e+00,  ...,  4.6503e+01,
           -1.7367e+01, -4.7495e+00],
          [ 3.2162e+00,  8.9087e-01,  6.5367e+00,  ...,  5.0049e+01,
           -1.7903e+01, -3.8221e+00],
          ...,
          [ 5.9307e+00, -4.3173e+00, -4.2835e+00,  ...,  7.0285e+00,
           -6.0234e+00, -1.7933e+00],
          [ 5.2925e+00,  9.1368e+00,  9.3217e+00,  ...,  6.3085e+00,
            3.4013e+00,  2.8006e+00],
          [ 1.4867e+00,  2.6702e+00,  4.5704e+00,  ...,  2.9954e+00,
           -1.9712e+00, -4.2182e+00]],

         [[ 1.6718e+01,  1.1372e+01,  8.4392e+00,  ..., -8.7674e+00,
           -5.5232e+00,  5.1990e+00],
          [ 1.8254e+01,  1.3352e+01,  1.2206e+01,  ..., -5.5794e+00,
           -1.8672e+00,  8.4840e+00],
          [ 9.3280e+00,  9.2767e+00,  1.2451e+01,  ..., -4.3460e+00,
           -1.2451e+00,  9.6297e+00],
          ...,
     

In [13]:
difference = abs( output_ref - output_recovered )
print(difference.mean())  ## It should be small, e.g.,2.3 in my trainned model

tensor(0.6450, device='cuda:0', grad_fn=<MeanBackward0>)


In [14]:
# Start process for 2bit

model_name = "Resnet20_quantAware_2bit"
model = resnet20_quant_2bit() # For quantization aware training

# Train the model
lr = 0.1
weight_decay = 1e-4
epochs = 45
best_prec = 0

#model = nn.DataParallel(model).cuda()
model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
#cudnn.benchmark = True

if not os.path.exists('result'):
    os.makedirs('result')
fdir = 'result/'+str(model_name)
if not os.path.exists(fdir):
    os.makedirs(fdir)
        

for epoch in range(0, epochs):
    adjust_learning_rate(optimizer, epoch)

    train(trainloader, model, criterion, optimizer, epoch)
    
    # evaluate on test set
    print("Validation starts")
    prec = validate(testloader, model, criterion)

    # remember best precision and save checkpoint
    is_best = prec > best_prec
    best_prec = max(prec,best_prec)
    print('best acc: {:1f}'.format(best_prec))
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec': best_prec,
        'optimizer': optimizer.state_dict(),
    }, is_best, fdir)

Epoch: [0][0/391]	Time 0.474 (0.474)	Data 0.405 (0.405)	Loss 2.4246 (2.4246)	Prec 8.594% (8.594%)
Epoch: [0][100/391]	Time 0.043 (0.046)	Data 0.002 (0.006)	Loss 1.8403 (2.1305)	Prec 28.906% (20.088%)
Epoch: [0][200/391]	Time 0.057 (0.048)	Data 0.003 (0.005)	Loss 1.8087 (1.9918)	Prec 29.688% (24.460%)
Epoch: [0][300/391]	Time 0.037 (0.046)	Data 0.002 (0.004)	Loss 1.5880 (1.9121)	Prec 42.969% (27.375%)
Validation starts
Test: [0/79]	Time 0.312 (0.312)	Loss 1.7356 (1.7356)	Prec 32.812% (32.812%)
 * Prec 33.470% 
best acc: 33.470000
Epoch: [1][0/391]	Time 0.467 (0.467)	Data 0.402 (0.402)	Loss 1.6372 (1.6372)	Prec 39.844% (39.844%)
Epoch: [1][100/391]	Time 0.046 (0.048)	Data 0.001 (0.006)	Loss 1.5927 (1.6258)	Prec 47.656% (39.070%)
Epoch: [1][200/391]	Time 0.042 (0.045)	Data 0.002 (0.004)	Loss 1.5665 (1.5913)	Prec 45.312% (40.730%)
Epoch: [1][300/391]	Time 0.045 (0.045)	Data 0.002 (0.004)	Loss 1.4796 (1.5636)	Prec 47.656% (41.515%)
Validation starts
Test: [0/79]	Time 0.340 (0.340)	Loss 1.37

Epoch: [15][200/391]	Time 0.056 (0.047)	Data 0.003 (0.005)	Loss 0.7374 (0.6785)	Prec 71.875% (76.461%)
Epoch: [15][300/391]	Time 0.044 (0.046)	Data 0.002 (0.004)	Loss 0.6613 (0.6745)	Prec 76.562% (76.570%)
Validation starts
Test: [0/79]	Time 0.444 (0.444)	Loss 0.8845 (0.8845)	Prec 70.312% (70.312%)
 * Prec 68.840% 
best acc: 73.830000
Epoch: [16][0/391]	Time 0.558 (0.558)	Data 0.515 (0.515)	Loss 0.6108 (0.6108)	Prec 75.781% (75.781%)
Epoch: [16][100/391]	Time 0.048 (0.049)	Data 0.002 (0.007)	Loss 0.6719 (0.6598)	Prec 79.688% (76.841%)
Epoch: [16][200/391]	Time 0.047 (0.047)	Data 0.003 (0.005)	Loss 0.4610 (0.6539)	Prec 85.938% (77.095%)
Epoch: [16][300/391]	Time 0.049 (0.046)	Data 0.003 (0.004)	Loss 0.5131 (0.6544)	Prec 85.156% (77.027%)
Validation starts
Test: [0/79]	Time 0.328 (0.328)	Loss 0.8275 (0.8275)	Prec 77.344% (77.344%)
 * Prec 71.680% 
best acc: 73.830000
Epoch: [17][0/391]	Time 0.476 (0.476)	Data 0.408 (0.408)	Loss 0.7082 (0.7082)	Prec 78.906% (78.906%)
Epoch: [17][100/391]	

Epoch: [30][300/391]	Time 0.045 (0.046)	Data 0.006 (0.004)	Loss 0.4193 (0.4430)	Prec 86.719% (84.461%)
Validation starts
Test: [0/79]	Time 0.363 (0.363)	Loss 0.4622 (0.4622)	Prec 84.375% (84.375%)
 * Prec 83.710% 
best acc: 83.710000
Epoch: [31][0/391]	Time 0.767 (0.767)	Data 0.720 (0.720)	Loss 0.3387 (0.3387)	Prec 86.719% (86.719%)
Epoch: [31][100/391]	Time 0.043 (0.049)	Data 0.002 (0.009)	Loss 0.3500 (0.4356)	Prec 88.281% (84.754%)
Epoch: [31][200/391]	Time 0.054 (0.047)	Data 0.003 (0.006)	Loss 0.3176 (0.4352)	Prec 91.406% (84.756%)
Epoch: [31][300/391]	Time 0.042 (0.046)	Data 0.002 (0.005)	Loss 0.4722 (0.4329)	Prec 85.156% (84.855%)
Validation starts
Test: [0/79]	Time 0.498 (0.498)	Loss 0.4404 (0.4404)	Prec 82.812% (82.812%)
 * Prec 83.070% 
best acc: 83.710000
Epoch: [32][0/391]	Time 0.638 (0.638)	Data 0.570 (0.570)	Loss 0.4245 (0.4245)	Prec 84.375% (84.375%)
Epoch: [32][100/391]	Time 0.042 (0.051)	Data 0.003 (0.008)	Loss 0.4256 (0.4411)	Prec 84.375% (84.406%)
Epoch: [32][200/391]	

Validation starts
Test: [0/79]	Time 0.537 (0.537)	Loss 0.4635 (0.4635)	Prec 83.594% (83.594%)
 * Prec 82.950% 
best acc: 83.790000
Epoch: [46][0/391]	Time 0.516 (0.516)	Data 0.468 (0.468)	Loss 0.4994 (0.4994)	Prec 79.688% (79.688%)
Epoch: [46][100/391]	Time 0.045 (0.048)	Data 0.002 (0.007)	Loss 0.5952 (0.4324)	Prec 78.906% (84.824%)
Epoch: [46][200/391]	Time 0.052 (0.046)	Data 0.002 (0.005)	Loss 0.4604 (0.4262)	Prec 87.500% (85.005%)
Epoch: [46][300/391]	Time 0.054 (0.047)	Data 0.003 (0.004)	Loss 0.5417 (0.4304)	Prec 83.594% (85.019%)
Validation starts
Test: [0/79]	Time 0.512 (0.512)	Loss 0.4471 (0.4471)	Prec 84.375% (84.375%)
 * Prec 83.030% 
best acc: 83.790000
Epoch: [47][0/391]	Time 0.619 (0.619)	Data 0.554 (0.554)	Loss 0.4200 (0.4200)	Prec 85.156% (85.156%)
Epoch: [47][100/391]	Time 0.047 (0.051)	Data 0.003 (0.008)	Loss 0.4925 (0.4363)	Prec 84.375% (84.916%)
Epoch: [47][200/391]	Time 0.044 (0.047)	Data 0.004 (0.005)	Loss 0.4718 (0.4359)	Prec 83.594% (84.935%)
Epoch: [47][300/391]	

 * Prec 83.450% 
best acc: 83.790000
Epoch: [61][0/391]	Time 0.631 (0.631)	Data 0.574 (0.574)	Loss 0.3940 (0.3940)	Prec 85.156% (85.156%)
Epoch: [61][100/391]	Time 0.050 (0.051)	Data 0.003 (0.008)	Loss 0.3989 (0.4239)	Prec 86.719% (85.435%)
Epoch: [61][200/391]	Time 0.058 (0.048)	Data 0.004 (0.005)	Loss 0.4349 (0.4300)	Prec 83.594% (85.145%)
Epoch: [61][300/391]	Time 0.040 (0.046)	Data 0.002 (0.004)	Loss 0.4246 (0.4340)	Prec 84.375% (84.956%)
Validation starts
Test: [0/79]	Time 0.401 (0.401)	Loss 0.4286 (0.4286)	Prec 82.031% (82.031%)
 * Prec 82.940% 
best acc: 83.790000
Epoch: [62][0/391]	Time 0.774 (0.774)	Data 0.713 (0.713)	Loss 0.3915 (0.3915)	Prec 83.594% (83.594%)
Epoch: [62][100/391]	Time 0.050 (0.052)	Data 0.002 (0.010)	Loss 0.3889 (0.4270)	Prec 82.031% (84.553%)
Epoch: [62][200/391]	Time 0.035 (0.049)	Data 0.003 (0.006)	Loss 0.4319 (0.4269)	Prec 81.250% (84.838%)
Epoch: [62][300/391]	Time 0.049 (0.048)	Data 0.003 (0.005)	Loss 0.3738 (0.4296)	Prec 91.406% (84.780%)
Validation s

Epoch: [76][100/391]	Time 0.056 (0.053)	Data 0.003 (0.008)	Loss 0.4100 (0.4291)	Prec 87.500% (84.916%)
Epoch: [76][200/391]	Time 0.035 (0.048)	Data 0.002 (0.005)	Loss 0.3177 (0.4302)	Prec 91.406% (85.012%)
Epoch: [76][300/391]	Time 0.038 (0.047)	Data 0.002 (0.004)	Loss 0.5653 (0.4317)	Prec 80.469% (84.884%)
Validation starts
Test: [0/79]	Time 0.572 (0.572)	Loss 0.5581 (0.5581)	Prec 81.250% (81.250%)
 * Prec 80.700% 
best acc: 83.820000
Epoch: [77][0/391]	Time 0.544 (0.544)	Data 0.489 (0.489)	Loss 0.4368 (0.4368)	Prec 80.469% (80.469%)
Epoch: [77][100/391]	Time 0.038 (0.047)	Data 0.002 (0.007)	Loss 0.5381 (0.4348)	Prec 77.344% (84.530%)
Epoch: [77][200/391]	Time 0.037 (0.045)	Data 0.002 (0.005)	Loss 0.5109 (0.4370)	Prec 78.906% (84.620%)
Epoch: [77][300/391]	Time 0.055 (0.045)	Data 0.003 (0.004)	Loss 0.4222 (0.4349)	Prec 85.938% (84.847%)
Validation starts
Test: [0/79]	Time 0.439 (0.439)	Loss 0.5317 (0.5317)	Prec 80.469% (80.469%)
 * Prec 82.990% 
best acc: 83.820000
Epoch: [78][0/391]	

Epoch: [91][200/391]	Time 0.049 (0.047)	Data 0.002 (0.005)	Loss 0.5307 (0.4249)	Prec 80.469% (84.950%)
Epoch: [91][300/391]	Time 0.051 (0.046)	Data 0.003 (0.004)	Loss 0.4635 (0.4303)	Prec 83.594% (84.840%)
Validation starts
Test: [0/79]	Time 0.407 (0.407)	Loss 0.4987 (0.4987)	Prec 81.250% (81.250%)
 * Prec 81.550% 
best acc: 83.860000
Epoch: [92][0/391]	Time 0.635 (0.635)	Data 0.576 (0.576)	Loss 0.4130 (0.4130)	Prec 85.156% (85.156%)
Epoch: [92][100/391]	Time 0.047 (0.052)	Data 0.002 (0.008)	Loss 0.4307 (0.4228)	Prec 82.812% (85.149%)
Epoch: [92][200/391]	Time 0.045 (0.048)	Data 0.002 (0.005)	Loss 0.4296 (0.4313)	Prec 84.375% (84.663%)
Epoch: [92][300/391]	Time 0.037 (0.047)	Data 0.002 (0.004)	Loss 0.4888 (0.4293)	Prec 79.688% (84.798%)
Validation starts
Test: [0/79]	Time 0.285 (0.285)	Loss 0.5228 (0.5228)	Prec 84.375% (84.375%)
 * Prec 82.740% 
best acc: 83.860000
Epoch: [93][0/391]	Time 0.519 (0.519)	Data 0.461 (0.461)	Loss 0.4262 (0.4262)	Prec 86.719% (86.719%)
Epoch: [93][100/391]	

In [15]:
PATH = "result/" + model_name + "/model_best.pth.tar"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
device = torch.device("cuda") 

model.cuda()
model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in testloader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(testloader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(testloader.dataset),
        100. * correct / len(testloader.dataset)))


Test set: Accuracy: 8406/10000 (84%)



In [16]:
######### Save inputs from selected layer ##########
save_output = SaveOutput()
model.layer1[0].conv2.register_forward_pre_hook(save_output) ## Grab input for BasicBlock0 Conv1 layer
# model.layer1[1].conv1.register_forward_pre_hook(save_output) ## Grab input for BasicBlock1 Conv1 layer
####################################################

dataiter = iter(trainloader)
images, labels = next(dataiter)
images = images.to('cuda')
# Run a forward pass to prehook
out = model(images)

In [17]:
w_bit = 4
weight_q = model.layer1[0].conv2.weight_q # quantized value is stored during the training
# print(weight_q)
w_alpha = model.layer1[0].conv2.weight_quant.wgt_alpha   # alpha is defined in your model already. bring it out here
w_delta = w_alpha/(2**(w_bit-1)-1)   # delta can be calculated by using alpha and w_bit
weight_int = weight_q / w_delta # w_int can be calculated by weight_q and w_delta
print(weight_int) # you should see clean integer numbers

tensor([[[[ 0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]],

         [[ 0.0000,  0.0000, -7.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 7.0000,  0.0000,  0.0000]],

         [[-0.0000, -0.0000, -0.0000],
          [-7.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  7.0000]],

         ...,

         [[ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]],

         [[ 0.0000,  7.0000,  0.0000],
          [ 0.0000,  0.0000,  7.0000],
          [ 0.0000,  0.0000,  7.0000]],

         [[-7.0000,  7.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[-0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000,  7.0000]],

         [[-7.0000,  7.0000,  0.0000],
          [-7.0000, -0.0000,  7.0000],
          [-0.0000,  0.0000, -0.0000]],

         [[ 0.0000, -0.0000, -7

In [18]:
x_bit = 4    
x = save_output.outputs[0][0]  # input of the 2nd conv layer
x_alpha = model.layer1[0].conv2.act_alpha
x_delta = x_alpha/(2**x_bit - 1)

act_quant_fn = act_quantization(x_bit) # define the quantization function
x_q = act_quant_fn(x, x_alpha)         # create the quantized value for x

x_int = x_q / x_delta
print(x_int) # you should see clean integer numbers 

tensor([[[[ 0.0000,  0.0000,  5.0000,  ...,  3.0000,  2.0000,  3.0000],
          [ 0.0000,  0.0000,  7.0000,  ...,  3.0000,  3.0000,  5.0000],
          [ 0.0000,  1.0000,  6.0000,  ...,  0.0000,  1.0000,  3.0000],
          ...,
          [ 0.0000,  0.0000,  3.0000,  ...,  2.0000,  3.0000,  0.0000],
          [ 0.0000,  0.0000,  3.0000,  ...,  3.0000,  3.0000,  2.0000],
          [ 0.0000,  0.0000,  2.0000,  ...,  2.0000,  2.0000,  1.0000]],

         [[ 0.0000, 13.0000,  9.0000,  ...,  3.0000,  6.0000,  5.0000],
          [ 0.0000, 14.0000,  8.0000,  ...,  5.0000,  5.0000,  7.0000],
          [ 0.0000, 12.0000,  8.0000,  ...,  6.0000,  5.0000,  7.0000],
          ...,
          [ 0.0000,  7.0000,  5.0000,  ...,  4.0000,  2.0000,  0.0000],
          [ 0.0000,  8.0000,  3.0000,  ...,  3.0000,  1.0000,  0.0000],
          [ 3.0000,  5.0000,  3.0000,  ...,  3.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 1.0000,  0.0000,  

In [19]:
conv_int = torch.nn.Conv2d(in_channels = 64, out_channels=64, kernel_size = 3, bias = False)
conv_int.weight = torch.nn.parameter.Parameter(weight_int)

output_int =  conv_int(x_int)    # output_int can be calculated with conv_int and x_int
output_recovered = output_int*x_delta*w_delta  # recover with x_delta and w_delta
print(output_recovered)

tensor([[[[-8.2721e+00,  2.5506e+01,  1.8612e+01,  ...,  1.5855e+01,
            7.5827e+00,  1.1029e+01],
          [-4.1360e+00,  2.0680e+01,  8.9614e+00,  ..., -4.1360e+00,
            1.3787e+00,  3.4467e+00],
          [-8.9614e+00,  1.9302e+01,  1.5165e+01,  ..., -1.5165e+01,
            1.1719e+01,  6.2041e+00],
          ...,
          [ 4.8254e+00, -6.8934e-01,  1.3787e+00,  ...,  4.1360e+00,
            1.5165e+01,  5.5147e+00],
          [-2.0680e+00,  9.6508e+00,  5.5147e+00,  ...,  6.8934e-01,
            2.0680e+00,  1.1719e+01],
          [ 1.3787e+00,  1.1719e+01,  6.2041e+00,  ...,  6.8934e+00,
            5.5147e+00,  1.7923e+01]],

         [[ 1.5165e+01, -4.9633e+01, -1.0340e+01,  ...,  2.4816e+01,
           -2.7574e+00,  1.0340e+01],
          [ 1.6544e+01, -4.0671e+01, -1.0340e+01,  ...,  3.8603e+01,
           -1.0340e+01,  1.5855e+01],
          [ 1.7923e+01, -4.2739e+01, -8.2721e+00,  ...,  2.4816e+01,
            2.2059e+01,  7.5827e+00],
          ...,
     

In [20]:
#### input floating number / weight quantized version

conv_ref = torch.nn.Conv2d(in_channels = 64, out_channels=64, kernel_size = 3, bias = False)
conv_ref.weight = model.layer1[0].conv2.weight_q
print(conv_ref.weight)
output_ref = conv_ref(x)
print(output_ref)

Parameter containing:
tensor([[[[ 0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]],

         [[ 0.0000,  0.0000, -2.6603],
          [ 0.0000,  0.0000,  0.0000],
          [ 2.6603,  0.0000,  0.0000]],

         [[-0.0000, -0.0000, -0.0000],
          [-2.6603, -0.0000, -0.0000],
          [-0.0000,  0.0000,  2.6603]],

         ...,

         [[ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]],

         [[ 0.0000,  2.6603,  0.0000],
          [ 0.0000,  0.0000,  2.6603],
          [ 0.0000,  0.0000,  2.6603]],

         [[-2.6603,  2.6603,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[-0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000,  2.6603]],

         [[-2.6603,  2.6603,  0.0000],
          [-2.6603, -0.0000,  2.6603],
          [-0.0000,  0.0000, -0.0000]],

         

In [21]:
difference = abs( output_ref - output_recovered )
print(difference.mean())  ## It should be small, e.g.,2.3 in my trainned model

tensor(1.2863, device='cuda:0', grad_fn=<MeanBackward0>)


In [22]:
#### input floating number / weight floating number version

conv_ref = torch.nn.Conv2d(in_channels = 64, out_channels=64, kernel_size = 3, bias = False)
weight = model.layer1[0].conv2.weight_q
mean = weight.data.mean()
std = weight.data.std()
conv_ref.weight = torch.nn.parameter.Parameter(weight.add(-mean).div(std))

output_ref = conv_ref(x)
print(output_ref)

tensor([[[[-2.6112e+00,  2.6014e+01,  2.0262e+01,  ...,  1.8545e+01,
            9.8519e+00,  1.3181e+01],
          [ 1.9345e+00,  2.1952e+01,  1.1142e+01,  ..., -6.1548e-01,
            3.1952e+00,  6.4562e+00],
          [-2.2232e+00,  2.1330e+01,  1.6116e+01,  ..., -9.6994e+00,
            1.3765e+01,  8.8784e+00],
          ...,
          [ 7.8637e+00,  1.7834e+00,  4.4641e+00,  ...,  6.5682e+00,
            1.6743e+01,  9.0967e+00],
          [ 1.4314e+00,  1.1283e+01,  7.8511e+00,  ...,  4.0250e+00,
            4.6428e+00,  1.3647e+01],
          [ 4.9720e+00,  1.3811e+01,  9.3012e+00,  ...,  9.9053e+00,
            8.6860e+00,  2.0076e+01]],

         [[ 1.4917e+01, -3.9197e+01, -4.5461e+00,  ...,  2.5463e+01,
            2.9596e+00,  1.4346e+01],
          [ 1.5718e+01, -3.1180e+01, -4.1376e+00,  ...,  3.6610e+01,
           -5.2657e+00,  1.7196e+01],
          [ 1.6602e+01, -3.3325e+01, -1.9959e+00,  ...,  2.3970e+01,
            2.2203e+01,  1.1640e+01],
          ...,
     

In [23]:
difference = abs( output_ref - output_recovered )
print(difference.mean())  ## It should be small, e.g.,2.3 in my trainned model

tensor(5.1600, device='cuda:0', grad_fn=<MeanBackward0>)
