# **Module imports**

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# importing necessary modules
import numpy as np
import math

import os
import sys
import time
import numpy as np
import argparse
from scipy.linalg import hadamard
from torch.autograd import Function

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
 
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.init as init

import torch.optim as optim
import torch.backends.cudnn as cudnn



# from models import *
# from utils import progress_bar

# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(191009)

<torch._C.Generator at 0x7faaa41c7870>

In [4]:
# CUDA_VISIBLE_DEVICE = 3
if torch.cuda.is_available(): 
 dev = "cuda:0" 
else: 
 dev = "cpu" 
device = torch.device(dev)

In [5]:
torch.cuda.current_device()

0

# **Initialization**

In [6]:
def ZerO_Init_on_matrix(matrix_tensor):
    # Algorithm 1 in the paper.
    
    m = matrix_tensor.size(0)
    n = matrix_tensor.size(1)
    
    if m <= n:
        init_matrix = torch.nn.init.eye_(torch.empty(m, n))
    elif m > n:
        clog_m = math.ceil(math.log2(m))
        p = 2**(clog_m)
        init_matrix = torch.nn.init.eye_(torch.empty(m, p)) @ (torch.tensor(hadamard(p)).float()/(2**(clog_m/2))) @ torch.nn.init.eye_(torch.empty(p, n))
    
    return init_matrix

def Identity_Init_on_matrix(matrix_tensor):
    # Definition 1 in the paper
    # See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.eye_ for details. Preserves the identity of the inputs in Linear layers, where as many inputs are preserved as possible, the same as partial identity matrix.
    
    m = matrix_tensor.size(0)
    n = matrix_tensor.size(1)
    
    init_matrix = torch.nn.init.eye_(torch.empty(m, n))
    
    return init_matrix


def init_sub_identity_conv1x1(weight):
    tensor = weight.data
    out_dim = tensor.size()[0]
    in_dim = tensor.size()[1]
    ori_dim = tensor.size()
    assert tensor.size()[2] == 1 and tensor.size()[3] == 1
    if out_dim<in_dim:
        i = torch.eye(out_dim).type_as(tensor)
        j = torch.zeros(out_dim,(in_dim-out_dim)).type_as(tensor)
        k = torch.cat((i,j),1)
    elif out_dim>in_dim:
        i = torch.eye(in_dim).type_as(tensor)
        j = torch.zeros((out_dim-in_dim),in_dim).type_as(tensor)
        k = torch.cat((i,j),0)
    else:
        k = torch.eye(out_dim).type_as(tensor)
    k.unsqueeze_(2)
    k.unsqueeze_(3)
    assert k.size() == ori_dim
    
    weight.data = k

class Hadamard_Transform(nn.Module):

    def __init__(self, dim_in, dim_out):
        super(Hadamard_Transform, self).__init__()
        if dim_in != dim_out:
            raise RuntimeError('orthogonal transform not supports dim_in != dim_out currently')
        hadamard_matrix = hadamard(dim_in)
        hadamard_matrix = torch.Tensor(hadamard_matrix)

        n = int(np.log2(dim_in))
        normalized_hadamard_matrix = hadamard_matrix / (2**(n / 2))

        self.hadamard_matrix = nn.Parameter(normalized_hadamard_matrix, requires_grad=False)


    def forward(self, x):
        # input is a B x C x N x M
        
        return torch.matmul(x.permute(0,2,3,1), self.hadamard_matrix).permute(0,3,1,2)
class SkipConnection(nn.Module):

    def __init__(self, scale=1):
        super(SkipConnection, self).__init__()
        self.scale = scale
    def _shortcut(self, input):
        #needs to be implemented

        return input

    def forward(self, x):
        # with torch.no_grad():
        identity = self._shortcut(x)
        return identity * self.scale

class ChannelPaddingSkip(SkipConnection):

    def __init__(self, num_expand_channels_left, num_expand_channels_right, scale=1):
        super(ChannelPaddingSkip, self).__init__(scale)
        self.num_expand_channels_left = num_expand_channels_left
        self.num_expand_channels_right = num_expand_channels_right
    
    def _shortcut(self, input):
        # input is (N, C, H, M)
        # and return is (N, C + num_left + num_right, H, M)
        
        return F.pad(input, (0, 0, 0, 0, self.num_expand_channels_left, self.num_expand_channels_right) , "constant", 0) 
class Zero_Relu(Function):
        
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = input.clamp(min=0)
        return output    
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
    
zero_relu = Zero_Relu.apply

# **Suplementary Functions**

In [7]:
def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

# def init_params(net):
#     '''Init layer parameters.'''
#     for m in net.modules():
#         if isinstance(m, nn.Conv2d):
#             init.kaiming_normal(m.weight, mode='fan_out')
#             if m.bias:
#                 init.constant(m.bias, 0)
#         elif isinstance(m, nn.BatchNorm2d):
#             init.constant(m.weight, 1)
#             init.constant(m.bias, 0)
#         elif isinstance(m, nn.Linear):
#             init.normal(m.weight, std=1e-3)
#             if m.bias:
#                 init.constant(m.bias, 0)


# _, term_width = os.popen('stty size', 'r').read().split()
# term_width = int(term_width)
term_width=80

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f

# **Cifar10 dataloader**

In [8]:
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


# **Train/Test function**

In [9]:

# Training
def train(epoch,model): 
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    print('Train Loss:',(train_loss/(batch_idx+1)), 'Acc: ',100.*correct/total,'correct',correct,'total',total)


def test(epoch,model):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
        print('Test Loss:',(test_loss/(batch_idx+1)),'Acc:',100.*correct/total,'Correct',correct,'total',total)

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        # print('Saving..')
        # state = {
        #     'net': net.state_dict(),
        #     'acc': acc,
        #     'epoch': epoch,
        # }
        # if not os.path.isdir('checkpoint'):
        #     os.mkdir('checkpoint')
        # torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc
    print(best_acc)



# **Resnet18 model**

# PreAct Architecture

In [10]:

class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = zero_relu

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                Hadamard_Transform(self.expansion*planes,self.expansion*planes))
            self.shortcut[0].type_name = 'conv1x1'

    def forward(self, x):
        out = self.bn1(self.relu(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out += shortcut

        identity = out
        out = self.conv2(self.bn2(self.relu(out)))
        out += identity

        return out

class PreActBlockNonBN(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlockNonBN, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = zero_relu
        self.bias1a = nn.Parameter(torch.zeros(1))
        self.bias1b = nn.Parameter(torch.zeros(1))
        self.bias2a = nn.Parameter(torch.zeros(1))
        self.scale = nn.Parameter(torch.ones(1))
        self.bias2b = nn.Parameter(torch.zeros(1))

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                Hadamard_Transform(self.expansion*planes,self.expansion*planes))
            self.shortcut[0].type_name = 'conv1x1'

    def forward(self, x):
        out = self.relu(x)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out + self.bias1a)
        out += shortcut + self.bias1b

        identity = out
        out = self.conv2(self.relu(out) + self.bias2a)
        out = out * self.scale + self.bias2b
        out += identity

        return out

class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, init='ZerO', BN_enable=True):
        super(PreActResNet, self).__init__()
        self.in_planes = 64
        self.num_layers = num_blocks[0]

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN_enable = BN_enable
        if self.BN_enable:
            self.bn1 = nn.BatchNorm2d(64)
            self.bn2 = nn.BatchNorm2d(512*block.expansion)
        else:
            self.bias1 = nn.Parameter(torch.zeros(1))
            self.bias2 = nn.Parameter(torch.zeros(1))
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

        self.first_transform = nn.Sequential(ChannelPaddingSkip(0,61, scale=1), Hadamard_Transform(64,64))
        
        self.relu = zero_relu
        
        if init == 'ZerO':
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.constant_(m.weight, 0)
                    nn.init.constant_(m.bias, 0)
                if isinstance(m, PreActBlock) or isinstance(m, PreActBlockNonBN):
                    # initialize every conv layer as zero
                    nn.init.constant_(m.conv1.weight, 0)
                    nn.init.constant_(m.conv2.weight, 0)
                if isinstance(m, nn.Conv2d):
                    # initialize first conv layer as zero
                    if hasattr(m,'type_name'):
                        if 'conv1x1' in m.type_name:
                            # nn.init.constant_(m.weight, 0)
                            init_sub_identity_conv1x1(m.weight)
                            print('sub identity init a conv1x1')
                    else:
                        nn.init.constant_(m.weight, 0)
        elif init == 'Kaiming':
            #initialize in a standard way
            pass 
        elif init == 'Xavier':
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_normal_(m.weight)
                if isinstance(m, PreActBlock) or isinstance(m, PreActBlockNonBN):
                    # initialize every conv layer as zero
                    nn.init.xavier_normal_(m.conv1.weight)
                    nn.init.xavier_normal_(m.conv2.weight)
                if isinstance(m, nn.Conv2d):
                    # initialize first conv layer as zero
                    if hasattr(m,'type_name'):
                        if 'conv1x1' in m.type_name:
                            # nn.init.constant_(m.weight, 0)
                            nn.init.xavier_normal_(m.weight)
                    else:
                        nn.init.xavier_normal_(m.weight)
        elif init == 'Fixup':
            for m in self.modules():
                if isinstance(m, PreActBlock) or isinstance(m, PreActBlockNonBN):
                    nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
                    nn.init.constant_(m.conv2.weight, 0)
                elif isinstance(m, nn.Linear):
                    nn.init.constant_(m.weight, 0)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Conv2d):
                    # initialize first conv layer as zero
                    if hasattr(m,'type_name'):
                        nn.init.constant_(m.weight, 0)

        # check initialization status
        for name, param in self.named_parameters():
            unique_values = torch.unique(param.data)
            if len(unique_values) > 2 and 'downsample' not in name and 'ortho_transform' not in name:
                print('!!!!!!!!!!!!!!!!following is not initialized as zero or one!!!!!!!!!!!!!!!!')
                print(name)
                print(unique_values)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x) + self.first_transform(x)
        out = self.relu(out)
        if self.BN_enable:
            out = self.bn1(out)
        else:
            # simply replace the BN in the same position, may not be the optimal choice
            out = out + self.bias1

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.relu(out)
        if self.BN_enable:
            out = self.bn2(out)
        else:
            out = out + self.bias2

        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def identity_resnet18():
    return PreActResNet(PreActBlock, [2,2,2,2])

def identity_resnet18_non_bn():
    return PreActResNet(PreActBlockNonBN, [2,2,2,2], BN_enable=False)

def identity_resnet18_non_bn_kaiming():
    return PreActResNet(PreActBlockNonBN, [2,2,2,2], BN_enable=False, init='Kaiming')

def identity_resnet18_non_bn_xavier():
    return PreActResNet(PreActBlockNonBN, [2,2,2,2], BN_enable=False, init='Xavier')

def identity_resnet18_non_bn_fixup():
    return PreActResNet(PreActBlockNonBN, [2,2,2,2], BN_enable=False, init='Fixup')

def identity_resnet18_kaiming():
    return PreActResNet(PreActBlock, [2,2,2,2], init='Kaiming')

def identity_resnet18_xavier():
    return PreActResNet(PreActBlock, [2,2,2,2], init='Xavier')

def identity_resnet18_fixup():
    return PreActResNet(PreActBlock, [2,2,2,2], init='Fixup')

def identity_resnet34():
    return PreActResNet(PreActBlock, [4,4,4,4])

def identity_resnet66():
    return PreActResNet(PreActBlock, [8,8,8,8])

def identity_resnet66_non_bn():
    return PreActResNet(PreActBlockNonBN, [8,8,8,8], BN_enable=False)

def identity_resnet66_non_bn_kaiming():
    return PreActResNet(PreActBlockNonBN, [8,8,8,8], BN_enable=False, init='Kaiming')

def identity_resnet66_non_bn_xavier():
    return PreActResNet(PreActBlockNonBN, [8,8,8,8], BN_enable=False, init='Xavier')

def identity_resnet66_non_bn_fixup():
    return PreActResNet(PreActBlockNonBN, [8,8,8,8], BN_enable=False, init='Fixup')

def identity_resnet126():
    return PreActResNet(PreActBlock, [16,16,16,16])

def identity_resnet126_non_bn():
    return PreActResNet(PreActBlockNonBN, [16,16,16,16], BN_enable=False)

def identity_resnet126_non_bn_kaiming():
    return PreActResNet(PreActBlockNonBN, [16,16,16,16], BN_enable=False, init='Kaiming')

def identity_resnet126_non_bn_xavier():
    return PreActResNet(PreActBlockNonBN, [16,16,16,16], BN_enable=False, init='Xavier')

def identity_resnet126_non_bn_fixup():
    return PreActResNet(PreActBlockNonBN, [16,16,16,16], BN_enable=False, init='Fixup')

def identity_resnet250():
    return PreActResNet(PreActBlock, [32,32,32,32])

def identity_resnet250_non_bn():
    return PreActResNet(PreActBlockNonBN, [32,32,32,32], BN_enable=False)

def identity_resnet250_non_bn_kaiming():
    return PreActResNet(PreActBlockNonBN, [32,32,32,32], BN_enable=False, init='Kaiming')

def identity_resnet250_non_bn_xavier():
    return PreActResNet(PreActBlockNonBN, [32,32,32,32], BN_enable=False, init='Xavier')

def identity_resnet250_non_bn_fixup():
    return PreActResNet(PreActBlockNonBN, [32,32,32,32], BN_enable=False, init='Fixup')

def identity_resnet250_kaiming():
    return PreActResNet(PreActBlock, [32,32,32,32], init='Kaiming')

def identity_resnet498():
    return PreActResNet(PreActBlock, [64,64,64,64])

def identity_resnet498_non_bn():
    return PreActResNet(PreActBlockNonBN, [64,64,64,64], BN_enable=False)

def identity_resnet498_non_bn_kaiming():
    return PreActResNet(PreActBlockNonBN, [64,64,64,64], BN_enable=False, init='Kaiming')

def identity_resnet498_non_bn_xavier():
    return PreActResNet(PreActBlockNonBN, [64,64,64,64], BN_enable=False, init='Xavier')

def identity_resnet498_non_bn_fixup():
    return PreActResNet(PreActBlockNonBN, [64,64,64,64], BN_enable=False, init='Fixup')

def identity_resnet498_kaiming():
    return PreActResNet(PreActBlock, [64,64,64,64], init='Kaiming')

def identity_resnet994():
    return PreActResNet(PreActBlock, [128,128,128,128])

# ResNet Architecture

In [None]:
from numpy.core.shape_base import block
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        # self.relu = zero_relu

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes),
                Hadamard_Transform(self.expansion*planes,self.expansion*planes)
            )
            self.shortcut[0].type_name= 'conv1x1'

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks,init, num_classes=10):
        super(ResNet, self).__init__()
        self.init=init 
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

        self.first_transform = nn.Sequential(ChannelPaddingSkip(0,61, scale=1), Hadamard_Transform(64,64))
        # self.first_transform = nn.Sequential( Hadamard_Transform(64,64))

        # self.apply(self._init_weights)
        # self.relu = zero_relu

        if self.init == 'ZerO':
          for m in self.modules():
            if isinstance(m,nn.Linear):
              nn.init.constant_(m.weight,0)
              nn.init.constant_(m.bias,0)
            if isinstance(m,BasicBlock):
              nn.init.constant_(m.conv1.weight,0)
              nn.init.constant_(m.conv2.weight,0)
            if isinstance(m,nn.Conv2d):
               if hasattr(m,'type_name'):
                   if 'conv1x1' in m.type_name:
                      init_sub_identity_conv1x1(m.weight)
                      print('sub identity init a conv1x1')
               else:
                      nn.init.constant_(m.weight,0)
            
        
        elif self.init == 'Random':
          pass
        # check initialization status
        for name, param in self.named_parameters():
            unique_values = torch.unique(param.data)
            if len(unique_values) > 2 and 'downsample' not in name and 'ortho_transform' not in name:
                print('!!!!!!!!!!!!!!!!following is not initialized as zero or one!!!!!!!!!!!!!!!!')
                print(name)
                print(unique_values)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        # out=self.first_transform(x)
        # out = F.relu(self.bn1(self.conv1(out)))
        out=self.conv1(x)+self.first_transform(x)
        out=F.relu(self.bn1(out))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18_ZerO():
    return ResNet(BasicBlock, [2, 2, 2, 2], init='ZerO')
def ResNet18_Random():
  return ResNet(BasicBlock,[2,2,2,2],init='Random')

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])


def test():
    net = ResNet18()
    y = net(torch.randn(1, 3, 32, 32))
    print(y.size())

# **Building normal  resnet model**

In [None]:
# parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
# parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
# parser.add_argument('--resume', '-r', action='store_true',
#                     help='resume from checkpoint')
# args = parser.parse_args()

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
# Model
print('==> Building model..')
net = identity_resnet18()
net = net.to(device)
# if device == 'cuda':
#     net = torch.nn.DataParallel(net)
#     cudnn.benchmark = True

# if args.resume:
#     # Load checkpoint.
#     print('==> Resuming from checkpoint..')
#     assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
#     checkpoint = torch.load('./checkpoint/ckpt.pth')
#     net.load_state_dict(checkpoint['net'])
#     best_acc = checkpoint['acc']
#     start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


for epoch in range(start_epoch, start_epoch+100):
    train(epoch,net)
    test(epoch,net)
    scheduler.step()

==> Building model..
sub identity init a conv1x1
sub identity init a conv1x1
sub identity init a conv1x1

Epoch: 0
Train Loss: 1.7352758527106946 Acc:  36.446 correct 18223 total 50000
Test Loss: 1.5353835332393646 Acc: 43.7 Correct 4370 total 10000
43.7

Epoch: 1
Train Loss: 1.3634341619813535 Acc:  50.148 correct 25074 total 50000
Test Loss: 1.2196313798427583 Acc: 56.16 Correct 5616 total 10000
56.16

Epoch: 2
Train Loss: 1.099568054041899 Acc:  60.56 correct 30280 total 50000
Test Loss: 0.9942542421817779 Acc: 64.11 Correct 6411 total 10000
64.11

Epoch: 3
Train Loss: 0.8719585711693825 Acc:  69.278 correct 34639 total 50000
Test Loss: 0.8182187420129776 Acc: 71.39 Correct 7139 total 10000
71.39

Epoch: 4
Train Loss: 0.7319169805177947 Acc:  74.456 correct 37228 total 50000
Test Loss: 0.6827771493792534 Acc: 76.29 Correct 7629 total 10000
76.29

Epoch: 5
Train Loss: 0.641233514565641 Acc:  77.76 correct 38880 total 50000
Test Loss: 0.7311046376824379 Acc: 74.81 Correct 7481 total 1

# **Quantizer Function**

In [24]:
class Quantizer(nn.Module):
    def __init__(self, m, e,gamma):
        super().__init__()
        self.m = m # number of mantissa bit 
        self.e = e # number of exponent bit
        self.gamma=gamma
    def forward(self, input):
        sign=torch.sign(input) #sign bit
        b= 2**(self.e-1) #bias
        b=b-torch.log2(torch.tensor(self.gamma))     
        b=b.to(device)
        c=(2-2**(-self.m))*2**(2**self.e-b-1)  #maximum representable range i.e. dynamic range
        temp=torch.floor(torch.log2(abs(input))+b)
        temp_max=temp.max()
        temp_max=temp_max.to(device)
        temp_min=torch.tensor(1)
        temp_min=temp_min.to(device)
        if temp_max<temp_min:
          temp_max=temp_min
        temp=torch.clamp(temp,temp_min,temp_max)
        p=temp-b-torch.round(torch.tensor(self.m))
        s=2**p
        output = torch.clamp(sign*s*torch.round(abs(input)/s),-c,c)
        return output



class QuantizedResNet18(nn.Module):
    def __init__(self, model):
        super(QuantizedResNet18, self).__init__()
        self.quant = Quantizer(4,3,1)
        self.model = model
        # self.conv1_init_weight = self.model.conv1.weight.data
        W = self.model.layer1[0].conv1.weight.data
        print(W)
        print(W.size())
        W = self.model.layer2[0].shortcut[0].weight.data
        print(W)
        # W = self.model.layer2[0].shortcut[1].weight.data
        # print(W)
        W = self.model.linear.weight.data
        print(W)
        B=self.model.linear.bias.data
        print(B)

    def forward(self, x):
        # x = self.quant(x)
        W=self.model.conv1.weight.data
        self.model.conv1.weight.data=self.quant(W)
        x=self.model.conv1(x)+self.model.first_transform(x)
        # x=self.quant(x)
        W=self.model.bn1.weight.data
        self.model.bn1.weight.data=self.quant(W)
        B=self.model.bn1.bias.data
        self.model.bn1.bias.data=self.quant(B)
        x=self.model.bn1(x)
        x10=F.relu(x)
        # x=self.quant(x)
        # #layer 1 
        W=self.model.layer1[0].conv1.weight.data
        self.model.layer1[0].conv1.weight.data=self.quant(W)
        x=self.model.layer1[0].conv1(x10)
        W=self.model.layer1[0].bn1.weight.data
        self.model.layer1[0].bn1.weight.data=self.quant(W)
        B=self.model.layer1[0].bn1.bias.data
        self.model.layer1[0].bn1.bias.data=self.quant(B)
        x=self.model.layer1[0].bn1(x)
        x=F.relu(x)
        W=self.model.layer1[0].conv2.weight.data
        self.model.layer1[0].conv2.weight.data=self.quant(W)
        x=self.model.layer1[0].conv2(x)
        W=self.model.layer1[0].bn2.weight.data
        self.model.layer1[0].bn2.weight.data=self.quant(W)
        B=self.model.layer1[0].bn2.bias.data
        self.model.layer1[0].bn2.bias.data=self.quant(B)
        x=self.model.layer1[0].bn2(x)
        x11=F.relu(x)
        W=self.model.layer1[1].conv1.weight.data
        self.model.layer1[1].conv1.weight.data=self.quant(W)
        x=self.model.layer1[1].conv1(x11)
        W=self.model.layer1[1].bn1.weight.data
        self.model.layer1[1].bn1.weight.data=self.quant(W)
        B=self.model.layer1[1].bn1.bias.data
        self.model.layer1[1].bn1.bias.data=self.quant(B)
        x=self.model.layer1[1].bn1(x)
        x=F.relu(x)
        W=self.model.layer1[1].conv2.weight.data
        self.model.layer1[1].conv2.weight.data=self.quant(W)
        x=self.model.layer1[1].conv2(x)
        W=self.model.layer1[1].bn2.weight.data
        self.model.layer1[1].bn2.weight.data=self.quant(W)
        B=self.model.layer1[1].bn2.bias.data
        self.model.layer1[1].bn2.bias.data=self.quant(B)
        x=self.model.layer1[1].bn2(x)
        x20=F.relu(x)
        # # #layer 2
        W=self.model.layer2[0].conv1.weight.data
        self.model.layer2[0].conv1.weight.data=self.quant(W)
        x=self.model.layer2[0].conv1(x20)
        W=self.model.layer2[0].bn1.weight.data
        self.model.layer2[0].bn1.weight.data=self.quant(W)
        B=self.model.layer2[0].bn1.bias.data
        self.model.layer2[0].bn1.bias.data=self.quant(B)
        x=self.model.layer2[0].bn1(x)
        x=F.relu(x)
        W=self.model.layer2[0].conv2.weight.data
        self.model.layer2[0].conv2.weight.data=self.quant(W)
        x=self.model.layer2[0].conv2(x)
        W=self.model.layer2[0].bn2.weight.data
        self.model.layer2[0].bn2.weight.data=self.quant(W)
        B=self.model.layer2[0].bn2.bias.data
        self.model.layer2[0].bn2.bias.data=self.quant(B)
        x=self.model.layer2[0].bn2(x)
        x=F.relu(x)
        W=self.model.layer2[0].shortcut[0].weight.data
        self.model.layer2[0].shortcut[0].weight.data=self.quant(W)
        xs=self.model.layer2[0].shortcut[0](x20)
        W=self.model.layer2[0].shortcut[1].weight.data
        self.model.layer2[0].shortcut[1].weight.data=self.quant(W)
        B=self.model.layer2[0].shortcut[1].bias.data
        self.model.layer2[0].shortcut[1].bias.data=self.quant(B)
        xs=self.model.layer2[0].shortcut[1](xs)
        x=x.clone()+self.model.layer2[0].shortcut[2](xs)
        x21=F.relu(x)
        W=self.model.layer2[1].conv1.weight.data
        self.model.layer2[1].conv1.weight.data=self.quant(W)
        x=self.model.layer2[1].conv1(x21)
        W=self.model.layer2[1].bn1.weight.data
        self.model.layer2[1].bn1.weight.data=self.quant(W)
        B=self.model.layer2[1].bn1.bias.data
        self.model.layer2[1].bn1.bias.data=self.quant(B)
        x=self.model.layer2[1].bn1(x)
        x=F.relu(x)
        W=self.model.layer2[1].conv2.weight.data
        self.model.layer2[1].conv2.weight.data=self.quant(W)
        x=self.model.layer2[1].conv2(x)
        W=self.model.layer2[1].bn2.weight.data
        self.model.layer2[1].bn2.weight.data=self.quant(W)
        B=self.model.layer2[1].bn2.bias.data
        self.model.layer2[1].bn2.bias.data=self.quant(B)
        x=self.model.layer2[1].bn2(x)
        x30=F.relu(x)

        # # #layer 3
        W=self.model.layer3[0].conv1.weight.data
        self.model.layer3[0].conv1.weight.data=self.quant(W)
        x=self.model.layer3[0].conv1(x30)
        W=self.model.layer3[0].bn1.weight.data
        self.model.layer3[0].bn1.weight.data=self.quant(W)
        B=self.model.layer3[0].bn1.bias.data
        self.model.layer3[0].bn1.bias.data=self.quant(B)
        x=self.model.layer3[0].bn1(x)
        x=F.relu(x)
        W=self.model.layer3[0].conv2.weight.data
        self.model.layer3[0].conv2.weight.data=self.quant(W)
        x=self.model.layer3[0].conv2(x)
        W=self.model.layer3[0].bn2.weight.data
        self.model.layer3[0].bn2.weight.data=self.quant(W)
        B=self.model.layer3[0].bn2.bias.data
        self.model.layer3[0].bn2.bias.data=self.quant(B)
        x=self.model.layer3[0].bn2(x)
        x=F.relu(x)
        W=self.model.layer3[0].shortcut[0].weight.data
        self.model.layer3[0].shortcut[0].weight.data=self.quant(W)
        xs=self.model.layer3[0].shortcut[0](x30)
        W=self.model.layer3[0].shortcut[1].weight.data
        self.model.layer3[0].shortcut[1].weight.data=self.quant(W)
        B=self.model.layer3[0].shortcut[1].bias.data
        self.model.layer3[0].shortcut[1].bias.data=self.quant(B)
        xs=self.model.layer3[0].shortcut[1](xs)
        x=x.clone()+self.model.layer3[0].shortcut[2](xs)
        x31=F.relu(x)
        W=self.model.layer3[1].conv1.weight.data
        self.model.layer3[1].conv1.weight.data=self.quant(W)
        x=self.model.layer3[1].conv1(x31)
        W=self.model.layer3[1].bn1.weight.data
        self.model.layer3[1].bn1.weight.data=self.quant(W)
        B=self.model.layer3[1].bn1.bias.data
        self.model.layer3[1].bn1.bias.data=self.quant(B)
        x=self.model.layer3[1].bn1(x)
        x=F.relu(x)
        W=self.model.layer3[1].conv2.weight.data
        self.model.layer3[1].conv2.weight.data=self.quant(W)
        x=self.model.layer3[1].conv2(x)
        W=self.model.layer3[1].bn2.weight.data
        self.model.layer3[1].bn2.weight.data=self.quant(W)
        B=self.model.layer3[1].bn2.bias.data
        self.model.layer3[1].bn2.bias.data=self.quant(B)
        x=self.model.layer3[1].bn2(x)
        x40=F.relu(x)
        # # layer 4
        W=self.model.layer4[0].conv1.weight.data
        self.model.layer4[0].conv1.weight.data=self.quant(W)
        x=self.model.layer4[0].conv1(x40)
        W=self.model.layer4[0].bn1.weight.data
        self.model.layer4[0].bn1.weight.data=self.quant(W)
        B=self.model.layer4[0].bn1.bias.data
        self.model.layer4[0].bn1.bias.data=self.quant(B)
        x=self.model.layer4[0].bn1(x)
        x=F.relu(x)
        W=self.model.layer4[0].conv2.weight.data
        self.model.layer4[0].conv2.weight.data=self.quant(W)
        x=self.model.layer4[0].conv2(x)
        W=self.model.layer4[0].bn2.weight.data
        self.model.layer4[0].bn2.weight.data=self.quant(W)
        B=self.model.layer4[0].bn2.bias.data
        self.model.layer4[0].bn2.bias.data=self.quant(B)
        x=self.model.layer4[0].bn2(x)
        x=F.relu(x)
        W=self.model.layer4[0].shortcut[0].weight.data
        self.model.layer4[0].shortcut[0].weight.data=self.quant(W)
        xs=self.model.layer4[0].shortcut[0](x40)
        W=self.model.layer4[0].shortcut[1].weight.data
        self.model.layer4[0].shortcut[1].weight.data=self.quant(W)
        B=self.model.layer4[0].shortcut[1].bias.data
        self.model.layer4[0].shortcut[1].bias.data=self.quant(B)
        xs=self.model.layer4[0].shortcut[1](xs)
        x=x.clone()+self.model.layer4[0].shortcut[2](xs)
        x41=F.relu(x)
        W=self.model.layer4[1].conv1.weight.data
        self.model.layer4[1].conv1.weight.data=self.quant(W)
        x=self.model.layer4[1].conv1(x41)
        W=self.model.layer4[1].bn1.weight.data
        self.model.layer4[1].bn1.weight.data=self.quant(W)
        B=self.model.layer4[1].bn1.bias.data
        self.model.layer4[1].bn1.bias.data=self.quant(B)
        x=self.model.layer4[1].bn1(x)
        x=F.relu(x)
        W=self.model.layer4[1].conv2.weight.data
        self.model.layer4[1].conv2.weight.data=self.quant(W)
        x=self.model.layer4[1].conv2(x)
        W=self.model.layer4[1].bn2.weight.data
        self.model.layer4[1].bn2.weight.data=self.quant(W)
        B=self.model.layer4[1].bn2.bias.data
        self.model.layer4[1].bn2.bias.data=self.quant(B)
        x=self.model.layer4[1].bn2(x)
        x=F.relu(x)


        x=F.avg_pool2d(x,4)
        x=x.view(x.size(0),-1)
        W=self.model.linear.weight.data
        self.model.linear.weight.data=self.quant(W)
        B=self.model.linear.bias.data
        self.model.linear.bias.data=self.quant(B)
        x=self.model.linear(x)
        return x

class PreAct(nn.Module):
    def __init__(self, model):
        super(PreAct, self).__init__()
        self.quant = Quantizer(4,3,0.2)
        self.model = model
        self.conv1_init_weight = self.model.conv1.weight.data
        self.conv1_init_weight= self.conv1_init_weight.to(device)
        # self.bn1_init_weight=self.model.bn1.weight.data
        # self.bn1_init_weight=self.bn1_init_weight.to(device)
        # self.bn1_init_bias=self.model.bn1.weight.data
        # self.bn1_init_bias=self.bn1_init_bias.to(device)
        #layer1
        # self.layer1_0_bn1_init_weight=self.model.layer1[0].bn1.weight.data
        # self.layer1_0_bn1_init_weight=self.layer1_0_bn1_init_weight.to(device)
        # self.layer1_0_bn1_init_bias=self.model.layer1[0].bn1.bias.data
        # self.layer1_0_bn1_init_bias=self.layer1_0_bn1_init_bias.to(device)
        self.layer1_0_conv1_init_weight=self.model.layer1[0].conv1.weight.data
        self.layer1_0_conv1_init_weight=self.layer1_0_conv1_init_weight.to(device)
        # self.layer1_0_bn2_init_weight=self.model.layer1[0].bn2.weight.data
        # self.layer1_0_bn2_init_weight=self.layer1_0_bn2_init_weight.to(device)
        # self.layer1_0_bn2_init_bias=self.model.layer1[0].bn2.bias.data
        # self.layer1_0_bn2_init_bias=self.layer1_0_bn2_init_bias.to(device)
        self.layer1_0_conv2_init_weight=self.model.layer1[0].conv2.weight.data
        self.layer1_0_conv2_init_weight=self.layer1_0_conv2_init_weight.to(device)
        # self.layer1_1_bn1_init_weight=self.model.layer1[1].bn1.weight.data
        # self.layer1_1_bn1_init_weight=self.layer1_1_bn1_init_weight.to(device)
        # self.layer1_1_bn1_init_bias=self.model.layer1[1].bn1.bias.data
        # self.layer1_1_bn1_init_bias=self.layer1_1_bn1_init_bias.to(device)
        self.layer1_1_conv1_init_weight=self.model.layer1[1].conv1.weight.data
        self.layer1_1_conv1_init_weight=self.layer1_1_conv1_init_weight.to(device)
        # self.layer1_1_bn2_init_weight=self.model.layer1[1].bn2.weight.data
        # self.layer1_1_bn2_init_weight= self.layer1_1_bn2_init_weight.to(device)
        # self.layer1_1_bn2_init_bias=self.model.layer1[1].bn2.bias.data
        # self.layer1_1_bn2_init_bias=self.layer1_1_bn2_init_bias.to(device)
        self.layer1_1_conv2_init_weight=self.model.layer1[1].conv2.weight.data
        self.layer1_1_conv2_init_weight=self.layer1_1_conv2_init_weight.to(device)
        #layer2
        # self.layer2_0_bn1_init_weight=self.model.layer2[0].bn1.weight.data
        # self.layer2_0_bn1_init_weight=self.layer2_0_bn1_init_weight.to(device)
        # self.layer2_0_bn1_init_bias=self.model.layer2[0].bn1.bias.data
        # self.layer2_0_bn1_init_bias=self.layer2_0_bn1_init_bias.to(device)
        self.layer2_0_shortcut_0_init_weight=self.model.layer2[0].shortcut[0].weight.data
        self.layer2_0_shortcut_0_init_weight=self.layer2_0_shortcut_0_init_weight.to(device)
        self.layer2_0_conv1_init_weight=self.model.layer2[0].conv1.weight.data
        self.layer2_0_conv1_init_weight=self.layer2_0_conv1_init_weight.to(device)
        # self.layer2_0_bn2_init_weight=self.model.layer2[0].bn2.weight.data
        # self.layer2_0_bn2_init_weight=self.layer2_0_bn2_init_weight.to(device)
        # self.layer2_0_bn2_init_bias=self.model.layer2[0].bn2.bias.data
        # self.layer2_0_bn2_init_bias=self.layer2_0_bn2_init_bias.to(device)
        self.layer2_0_conv2_init_weight=self.model.layer2[0].conv2.weight.data
        self.layer2_0_conv2_init_weight=self.layer2_0_conv2_init_weight.to(device)
        # self.layer2_1_bn1_init_weight=self.model.layer2[1].bn1.weight.data
        # self.layer2_1_bn1_init_weight=self.layer2_1_bn1_init_weight.to(device)
        # self.layer2_1_bn1_init_bias=self.model.layer2[1].bn1.bias.data
        # self.layer2_1_bn1_init_bias=self.layer2_1_bn1_init_bias.to(device)
        self.layer2_1_conv1_init_weight=self.model.layer2[1].conv1.weight.data
        self.layer2_1_conv1_init_weight=self.layer2_1_conv1_init_weight.to(device)
        # self.layer2_1_bn2_init_weight=self.model.layer2[1].bn2.weight.data
        # self.layer2_1_bn2_init_weight=self.layer2_1_bn2_init_weight.to(device)
        # self.layer2_1_bn2_init_bias=self.model.layer2[1].bn2.bias.data
        # self.layer2_1_bn2_init_bias=self.layer2_1_bn2_init_bias.to(device)
        self.layer2_1_conv2_init_weight=self.model.layer2[1].conv2.weight.data
        self.layer2_1_conv2_init_weight=self.layer2_1_conv2_init_weight.to(device)
        #layer3
        # self.layer3_0_bn1_init_weight=self.model.layer3[0].bn1.weight.data
        # self.layer3_0_bn1_init_weight=self.layer3_0_bn1_init_weight.to(device)
        # self.layer3_0_bn1_init_bias=self.model.layer3[0].bn1.bias.data
        # self.layer3_0_bn1_init_bias=self.layer3_0_bn1_init_bias.to(device)
        self.layer3_0_shortcut_0_init_weight=self.model.layer3[0].shortcut[0].weight.data
        self.layer3_0_shortcut_0_init_weight=self.layer3_0_shortcut_0_init_weight.to(device)
        self.layer3_0_conv1_init_weight=self.model.layer3[0].conv1.weight.data
        self.layer3_0_conv1_init_weight=self.layer3_0_conv1_init_weight.to(device)
        # self.layer3_0_bn2_init_weight=self.model.layer3[0].bn2.weight.data
        # self.layer3_0_bn2_init_weight=self.layer3_0_bn2_init_weight.to(device)
        # self.layer3_0_bn2_init_bias=self.model.layer3[0].bn2.bias.data
        # self.layer3_0_bn2_init_bias=self.layer3_0_bn2_init_bias.to(device)
        self.layer3_0_conv2_init_weight=self.model.layer3[0].conv2.weight.data
        self.layer3_0_conv2_init_weight=self.layer3_0_conv2_init_weight.to(device)
        # self.layer3_1_bn1_init_weight=self.model.layer3[1].bn1.weight.data
        # self.layer3_1_bn1_init_weight=self.layer3_1_bn1_init_weight.to(device)
        # self.layer3_1_bn1_init_bias=self.model.layer3[1].bn1.bias.data
        # self.layer3_1_bn1_init_bias=self.layer3_1_bn1_init_bias.to(device)
        self.layer3_1_conv1_init_weight=self.model.layer3[1].conv1.weight.data
        self.layer3_1_conv1_init_weight=self.layer3_1_conv1_init_weight.to(device)
        # self.layer3_1_bn2_init_weight=self.model.layer3[1].bn2.weight.data
        # self.layer3_1_bn2_init_weight=self.layer3_1_bn2_init_weight.to(device)
        # self.layer3_1_bn2_init_bias=self.model.layer3[1].bn2.bias.data
        # self.layer3_1_bn2_init_bias=self.layer3_1_bn2_init_bias.to(device)
        self.layer3_1_conv2_init_weight=self.model.layer3[1].conv2.weight.data
        self.layer3_1_conv2_init_weight=self.layer3_1_conv2_init_weight.to(device)
        #layer4
        # self.layer4_0_bn1_init_weight=self.model.layer4[0].bn1.weight.data
        # self.layer4_0_bn1_init_weight=self.layer4_0_bn1_init_weight.to(device)
        # self.layer4_0_bn1_init_bias=self.model.layer4[0].bn1.bias.data
        # self.layer4_0_bn1_init_bias=self.layer4_0_bn1_init_bias.to(device)
        self.layer4_0_shortcut_0_init_weight=self.model.layer4[0].shortcut[0].weight.data
        self.layer4_0_shortcut_0_init_weight=self.layer4_0_shortcut_0_init_weight.to(device)
        self.layer4_0_conv1_init_weight=self.model.layer4[0].conv1.weight.data
        self.layer4_0_conv1_init_weight=self.layer4_0_conv1_init_weight.to(device)
        # self.layer4_0_bn2_init_weight=self.model.layer4[0].bn2.weight.data
        # self.layer4_0_bn2_init_weight=self.layer4_0_bn2_init_weight.to(device)
        # self.layer4_0_bn2_init_bias=self.model.layer4[0].bn2.bias.data
        # self.layer4_0_bn2_init_bias=self.layer4_0_bn2_init_bias.to(device)
        self.layer4_0_conv2_init_weight=self.model.layer4[0].conv2.weight.data
        self.layer4_0_conv2_init_weight=self.layer4_0_conv2_init_weight.to(device)
        # self.layer4_1_bn1_init_weight=self.model.layer4[1].bn1.weight.data
        # self.layer4_1_bn1_init_weight=self.layer4_1_bn1_init_weight.to(device)
        # self.layer4_1_bn1_init_bias=self.model.layer4[1].bn1.bias.data
        # self.layer4_1_bn1_init_bias=self.layer4_1_bn1_init_bias.to(device)
        self.layer4_1_conv1_init_weight=self.model.layer4[1].conv1.weight.data
        self.layer4_1_conv1_init_weight=self.layer4_1_conv1_init_weight.to(device)
        # self.layer4_1_bn2_init_weight=self.model.layer4[1].bn2.weight.data
        # self.layer4_1_bn2_init_weight=self.layer4_1_bn2_init_weight.to(device)
        # self.layer4_1_bn2_init_bias=self.model.layer4[1].bn2.bias.data
        # self.layer4_1_bn2_init_bias=self.layer4_1_bn2_init_bias.to(device)
        self.layer4_1_conv2_init_weight=self.model.layer4[1].conv2.weight.data
        self.layer4_1_conv2_init_weight=self.layer4_1_conv2_init_weight.to(device)

        # self.bn2_init_weight=self.model.bn2.weight.data
        # self.bn2_init_weight=self.bn2_init_weight.to(device)
        # self.bn2_init_bias=self.model.bn2.bias.data
        # self.bn2_init_bias=self.bn2_init_bias.to(device)
        self.linear_init_weight=self.model.linear.weight.data
        self.linear_init_weight=self.linear_init_weight.to(device)
        self.linear_init_bias=self.model.linear.bias.data
        self.linear_init_bias=self.linear_init_bias.to(device)
       
        W = self.model.layer1[0].conv1.weight.data
        print(W)
        W = self.model.layer2[0].shortcut[0].weight.data
        print(W)
        W = self.model.linear.weight.data
        print(W)
        B=self.model.linear.bias.data
        print(B)

    def forward(self, x):
        # x = self.quant(x)
        W=self.model.conv1.weight.data
        # self.model.conv1.weight.data=self.quant(W)
        self.model.conv1.weight.data=self.conv1_init_weight+self.quant(W-self.conv1_init_weight)
        x=self.model.conv1(x)+self.model.first_transform(x)
        # x=self.quant(x)
        x=zero_relu(x)
        # W=self.model.bn1.weight.data
        # self.model.bn1.weight.data=self.quant(W)
        # self.model.bn1.weight.data=self.bn1_init_weight+self.quant(W-self.bn1_init_weight)
        # B=self.model.bn1.bias.data
        # self.model.bn1.bias.data=self.quant(B)
        # self.model.bn1.bias.data=self.bn1_init_bias+self.quant(B-self.bn1_init_bias)
        x10=self.model.bn1(x)
        
        # #layer 1 
        # W=self.model.layer1[0].bn1.weight.data
        # self.model.layer1[0].bn1.weight.data=self.quant(W)
        # self.model.layer1[0].bn1.weight.data=self.layer1_0_bn1_init_weight + self.quant(W-self.layer1_0_bn1_init_weight)
        # B=self.model.layer1[0].bn1.bias.data
        # self.model.layer1[0].bn1.bias.data=self.quant(B)
        # self.model.layer1[0].bn1.bias.data=self.layer1_0_bn1_init_bias + self.quant(B-self.layer1_0_bn1_init_bias)
        x=self.model.layer1[0].bn1(zero_relu(x10))
        shortcut=x10
        W=self.model.layer1[0].conv1.weight.data
        # self.model.layer1[0].conv1.weight.data=self.quant(W)
        self.model.layer1[0].conv1.weight.data=self.layer1_0_conv1_init_weight + self.quant(W-self.layer1_0_conv1_init_weight)
        x=self.model.layer1[0].conv1(x)
        x=x+shortcut
        identity=x
        # W=self.model.layer1[0].bn2.weight.data
        # self.model.layer1[0].bn2.weight.data=self.quant(W)
        # self.model.layer1[0].bn2.weight.data=self.layer1_0_bn2_init_weight + self.quant(W-self.layer1_0_bn2_init_weight)
        # B=self.model.layer1[0].bn2.bias.data
        # self.model.layer1[0].bn2.bias.data=self.quant(B)
        # self.model.layer1[0].bn2.bias.data=self.layer1_0_bn2_init_bias + self.quant(W-self.layer1_0_bn2_init_bias)
        x=self.model.layer1[0].bn2(zero_relu(x))
        W=self.model.layer1[0].conv2.weight.data
        # self.model.layer1[0].conv2.weight.data=self.quant(W)
        self.model.layer1[0].conv2.weight.data=self.layer1_0_conv2_init_weight + self.quant(W-self.layer1_0_conv2_init_weight )
        x=self.model.layer1[0].conv2(x)
        x11=x+identity
        # W=self.model.layer1[1].bn1.weight.data
        # self.model.layer1[1].bn1.weight.data=self.quant(W)
        # self.model.layer1[1].bn1.weight.data=self.layer1_1_bn1_init_weight + self.quant(W-self.layer1_1_bn1_init_weight)
        # B=self.model.layer1[1].bn1.bias.data
        # self.model.layer1[1].bn1.bias.data=self.quant(B)
        # self.model.layer1[1].bn1.bias.data=self.layer1_1_bn1_init_bias + self.quant(B-self.layer1_1_bn1_init_bias)
        x=self.model.layer1[1].bn1(zero_relu(x11))
        shortcut=x11
        W=self.model.layer1[1].conv1.weight.data
        # self.model.layer1[1].conv1.weight.data=self.quant(W)
        self.model.layer1[1].conv1.weight.data=self.layer1_1_conv1_init_weight + self.quant(W-self.layer1_1_conv1_init_weight )
        x=self.model.layer1[1].conv1(x)
        x=x+shortcut
        identity=x
        # W=self.model.layer1[1].bn2.weight.data
        # self.model.layer1[1].bn2.weight.data=self.quant(W)
        # self.model.layer1[1].bn2.weight.data=self.layer1_1_bn2_init_weight + self.quant(W-self.layer1_1_bn2_init_weight)
        # B=self.model.layer1[1].bn2.bias.data
        # self.model.layer1[1].bn2.bias.data=self.quant(B)
        # self.model.layer1[1].bn2.bias.data=self.layer1_1_bn2_init_bias + self.quant(W-self.layer1_1_bn2_init_bias)
        x=self.model.layer1[1].bn2(zero_relu(x))
        W=self.model.layer1[1].conv2.weight.data
        # self.model.layer1[1].conv2.weight.data=self.quant(W)
        self.model.layer1[1].conv2.weight.data=self.layer1_1_conv2_init_weight + self.quant(W-self.layer1_1_conv2_init_weight )
        x=self.model.layer1[1].conv2(x)
        x20=x+identity
    
        # # # #layer 2
        # W=self.model.layer2[0].bn1.weight.data
        # self.model.layer2[0].bn1.weight.data=self.quant(W)
        # self.model.layer2[0].bn1.weight.data=self.layer2_0_bn1_init_weight + self.quant(W-self.layer2_0_bn1_init_weight)
        # B=self.model.layer2[0].bn1.bias.data
        # self.model.layer2[0].bn1.bias.data=self.quant(B)
        # self.model.layer2[0].bn1.bias.data=self.layer2_0_bn1_init_bias + self.quant(B-self.layer2_0_bn1_init_bias)
        x=self.model.layer2[0].bn1(zero_relu(x20))
        W=self.model.layer2[0].shortcut[0].weight.data
        # self.model.layer2[0].shortcut[0].weight.data=self.quant(W)
        self.model.layer2[0].shortcut[0].weight.data=self.layer2_0_shortcut_0_init_weight +  self.quant(W-self.layer2_0_shortcut_0_init_weight)
        shortcut=self.model.layer2[0].shortcut[0](x)
        shortcut=self.model.layer2[0].shortcut[1](shortcut)
        W=self.model.layer2[0].conv1.weight.data
        # self.model.layer2[0].conv1.weight.data=self.quant(W)
        self.model.layer2[0].conv1.weight.data=self.layer2_0_conv1_init_weight + self.quant(W-self.layer2_0_conv1_init_weight )
        x=self.model.layer2[0].conv1(x)
        x=x+shortcut
        identity=x
        # W=self.model.layer2[0].bn2.weight.data
        # self.model.layer2[0].bn2.weight.data=self.quant(W)
        # self.model.layer2[0].bn2.weight.data=self.layer2_0_bn2_init_weight + self.quant(W-self.layer2_0_bn2_init_weight)
        # B=self.model.layer2[0].bn2.bias.data
        # self.model.layer2[0].bn2.bias.data=self.quant(B)
        # self.model.layer2[0].bn2.bias.data=self.layer2_0_bn2_init_bias + self.quant(W-self.layer2_0_bn2_init_bias)
        x=self.model.layer2[0].bn2(zero_relu(x))
        W=self.model.layer2[0].conv2.weight.data
        # self.model.layer2[0].conv2.weight.data=self.quant(W)
        self.model.layer2[0].conv2.weight.data=self.layer2_0_conv2_init_weight + self.quant(W-self.layer2_0_conv2_init_weight )
        x=self.model.layer2[0].conv2(x)
        x21=x+identity
        # W=self.model.layer2[1].bn1.weight.data
        # self.model.layer2[1].bn1.weight.data=self.quant(W)
        # self.model.layer2[1].bn1.weight.data=self.layer2_bn1_init_weight + self.quant(W-self.layer2_1_bn1_init_weight)
        # B=self.model.layer2[1].bn1.bias.data
        # self.model.layer2[1].bn1.bias.data=self.quant(B)
        # self.model.layer2[1].bn1.bias.data=self.layer2_1_bn1_init_bias + self.quant(B-self.layer2_1_bn1_init_bias)
        x=self.model.layer2[1].bn1(zero_relu(x21))
        shortcut=x21
        W=self.model.layer2[1].conv1.weight.data
        # self.model.layer2[1].conv1.weight.data=self.quant(W)
        self.model.layer2[1].conv1.weight.data=self.layer2_1_conv1_init_weight + self.quant(W-self.layer2_1_conv1_init_weight )
        x=self.model.layer2[1].conv1(x)
        x=x+shortcut
        identity=x
        # W=self.model.layer2[1].bn2.weight.data
        # self.model.layer2[1].bn2.weight.data=self.quant(W)
        # self.model.layer2[1].bn2.weight.data=self.layer2_1_bn2_init_weight + self.quant(W-self.layer2_1_bn2_init_weight)
        # B=self.model.layer2[1].bn2.bias.data
        # self.model.layer2[1].bn2.bias.data=self.quant(B)
        # self.model.layer2[1].bn2.bias.data=self.layer2_1_bn2_init_bias + self.quant(W-self.layer2_1_bn2_init_bias)
        x=self.model.layer2[1].bn2(zero_relu(x))
        W=self.model.layer2[1].conv2.weight.data
        # self.model.layer2[1].conv2.weight.data=self.quant(W)
        self.model.layer2[1].conv2.weight.data=self.layer2_1_conv2_init_weight + self.quant(W-self.layer2_1_conv2_init_weight )
        x=self.model.layer2[1].conv2(x)
        x30=x+identity

        # # #layer 3
        # W=self.model.layer3[0].bn1.weight.data
        # self.model.layer3[0].bn1.weight.data=self.quant(W)
        # self.model.layer3[0].bn1.weight.data=self.layer3_0_bn1_init_weight + self.quant(W-self.layer3_0_bn1_init_weight)
        # B=self.model.layer3[0].bn1.bias.data
        # self.model.layer3[0].bn1.bias.data=self.quant(B)
        # self.model.layer3[0].bn1.bias.data=self.layer3_0_bn1_init_bias + self.quant(B-self.layer3_0_bn1_init_bias)
        x=self.model.layer3[0].bn1(zero_relu(x30))
        W=self.model.layer3[0].shortcut[0].weight.data
        # self.model.layer3[0].shortcut[0].weight.data=self.quant(W)
        self.model.layer3[0].shortcut[0].weight.data=self.layer3_0_shortcut_0_init_weight +  self.quant(W-self.layer3_0_shortcut_0_init_weight)
        shortcut=self.model.layer3[0].shortcut[0](x)
        shortcut=self.model.layer3[0].shortcut[1](shortcut)
        W=self.model.layer3[0].conv1.weight.data
        # self.model.layer3[0].conv1.weight.data=self.quant(W)
        self.model.layer3[0].conv1.weight.data=self.layer3_0_conv1_init_weight + self.quant(W-self.layer3_0_conv1_init_weight )
        x=self.model.layer3[0].conv1(x)
        x=x+shortcut
        identity=x
        # W=self.model.layer3[0].bn2.weight.data
        # self.model.layer3[0].bn2.weight.data=self.quant(W)
        # self.model.layer3[0].bn2.weight.data=self.layer3_0_bn2_init_weight + self.quant(W-self.layer3_0_bn2_init_weight)
        # B=self.model.layer3[0].bn2.bias.data
        # self.model.layer3[0].bn2.bias.data=self.quant(B)
        # self.model.layer3[0].bn2.bias.data=self.layer3_0_bn2_init_bias + self.quant(W-self.layer3_0_bn2_init_bias)
        x=self.model.layer3[0].bn2(zero_relu(x))
        W=self.model.layer3[0].conv2.weight.data
        # self.model.layer3[0].conv2.weight.data=self.quant(W)
        self.model.layer3[0].conv2.weight.data=self.layer3_0_conv2_init_weight + self.quant(W-self.layer3_0_conv2_init_weight )
        x=self.model.layer3[0].conv2(x)
        x31=x+identity

        # W=self.model.layer3[1].bn1.weight.data
        # self.model.layer3[1].bn1.weight.data=self.quant(W)
        # self.model.layer3[1].bn1.weight.data=self.layer3_bn1_init_weight + self.quant(W-self.layer3_1_bn1_init_weight)
        # B=self.model.layer3[1].bn1.bias.data
        # self.model.layer3[1].bn1.bias.data=self.quant(B)
        # self.model.layer3[1].bn1.bias.data=self.layer3_1_bn1_init_bias + self.quant(B-self.layer3_1_bn1_init_bias)
        x=self.model.layer3[1].bn1(zero_relu(x31))
        shortcut=x31
        W=self.model.layer3[1].conv1.weight.data
        # self.model.layer3[1].conv1.weight.data=self.quant(W)
        self.model.layer3[1].conv1.weight.data=self.layer3_1_conv1_init_weight + self.quant(W-self.layer3_1_conv1_init_weight )
        x=self.model.layer3[1].conv1(x)
        x=x+shortcut
        identity=x
        # W=self.model.layer3[1].bn2.weight.data
        # self.model.layer3[1].bn2.weight.data=self.quant(W)
        # self.model.layer3[1].bn2.weight.data=self.layer3_1_bn2_init_weight + self.quant(W-self.layer3_1_bn2_init_weight)
        # B=self.model.layer3[1].bn2.bias.data
        # self.model.layer3[1].bn2.bias.data=self.quant(B)
        # self.model.layer3[1].bn2.bias.data=self.layer3_1_bn2_init_bias + self.quant(W-self.layer3_1_bn2_init_bias)
        x=self.model.layer3[1].bn2(zero_relu(x))
        W=self.model.layer3[1].conv2.weight.data
        # self.model.layer3[1].conv2.weight.data=self.quant(W)
        self.model.layer3[1].conv2.weight.data=self.layer3_1_conv2_init_weight + self.quant(W-self.layer3_1_conv2_init_weight )
        x=self.model.layer3[1].conv2(x)
        x40=x+identity

        #layer 4
        # W=self.model.layer4[0].bn1.weight.data
        # self.model.layer4[0].bn1.weight.data=self.quant(W)
        # self.model.layer4[0].bn1.weight.data=self.layer4_0_bn1_init_weight + self.quant(W-self.layer4_0_bn1_init_weight)
        # B=self.model.layer4[0].bn1.bias.data
        # self.model.layer4[0].bn1.bias.data=self.quant(B)
        # self.model.layer4[0].bn1.bias.data=self.layer4_0_bn1_init_bias + self.quant(B-self.layer4_0_bn1_init_bias)
        x=self.model.layer4[0].bn1(zero_relu(x40))
        W=self.model.layer4[0].shortcut[0].weight.data
        # self.model.layer4[0].shortcut[0].weight.data=self.quant(W)
        self.model.layer4[0].shortcut[0].weight.data=self.layer4_0_shortcut_0_init_weight +  self.quant(W-self.layer4_0_shortcut_0_init_weight)
        shortcut=self.model.layer4[0].shortcut[0](x)
        shortcut=self.model.layer4[0].shortcut[1](shortcut)
        W=self.model.layer4[0].conv1.weight.data
        # self.model.layer4[0].conv1.weight.data=self.quant(W)
        self.model.layer4[0].conv1.weight.data=self.layer4_0_conv1_init_weight + self.quant(W-self.layer4_0_conv1_init_weight )
        x=self.model.layer4[0].conv1(x)
        x=x+shortcut
        identity=x
        # W=self.model.layer4[0].bn2.weight.data
        # self.model.layer4[0].bn2.weight.data=self.quant(W)
        # self.model.layer4[0].bn2.weight.data=self.layer4_0_bn2_init_weight + self.quant(W-self.layer4_0_bn2_init_weight)
        # B=self.model.layer4[0].bn2.bias.data
        # self.model.layer4[0].bn2.bias.data=self.quant(B)
        # self.model.layer4[0].bn2.bias.data=self.layer4_0_bn2_init_bias + self.quant(W-self.layer4_0_bn2_init_bias)
        x=self.model.layer4[0].bn2(zero_relu(x))
        W=self.model.layer4[0].conv2.weight.data
        # self.model.layer4[0].conv2.weight.data=self.quant(W)
        self.model.layer4[0].conv2.weight.data=self.layer4_0_conv2_init_weight + self.quant(W-self.layer4_0_conv2_init_weight )
        x=self.model.layer4[0].conv2(x)
        x41=x+identity

        # W=self.model.layer4[1].bn1.weight.data
        # self.model.layer4[1].bn1.weight.data=self.quant(W)
        # self.model.layer4[1].bn1.weight.data=self.layer4_bn1_init_weight + self.quant(W-self.layer4_1_bn1_init_weight)
        # B=self.model.layer4[1].bn1.bias.data
        # self.model.layer4[1].bn1.bias.data=self.quant(B)
        # self.model.layer4[1].bn1.bias.data=self.layer4_1_bn1_init_bias + self.quant(B-self.layer4_1_bn1_init_bias)
        x=self.model.layer4[1].bn1(zero_relu(x41))
        shortcut=x41
        W=self.model.layer4[1].conv1.weight.data
        # self.model.layer4[1].conv1.weight.data=self.quant(W)
        self.model.layer4[1].conv1.weight.data=self.layer4_1_conv1_init_weight + self.quant(W-self.layer4_1_conv1_init_weight )
        x=self.model.layer4[1].conv1(x)
        x=x+shortcut
        identity=x
        # W=self.model.layer4[1].bn2.weight.data
        # self.model.layer4[1].bn2.weight.data=self.quant(W)
        # self.model.layer4[1].bn2.weight.data=self.layer4_1_bn2_init_weight + self.quant(W-self.layer4_1_bn2_init_weight)
        # B=self.model.layer4[1].bn2.bias.data
        # self.model.layer4[1].bn2.bias.data=self.quant(B)
        # self.model.layer4[1].bn2.bias.data=self.layer4_1_bn2_init_bias + self.quant(W-self.layer4_1_bn2_init_bias)
        x=self.model.layer4[1].bn2(zero_relu(x))
        W=self.model.layer4[1].conv2.weight.data
        # self.model.layer4[1].conv2.weight.data=self.quant(W)
        self.model.layer4[1].conv2.weight.data=self.layer4_1_conv2_init_weight + self.quant(W-self.layer4_1_conv2_init_weight )
        x=self.model.layer4[1].conv2(x)
        x=x+identity


        x=zero_relu(x)
        # W=self.model.bn2.weight.data
        # self.model.bn2.weight.data=self.quant(W)
        # self.model.bn2.weight.data=self.bn2_init_weight+self.quant(W-self.bn2_init_weight)
        # B=self.model.bn2.bias.data
        # self.model.bn2.bias.data=self.quant(B)
        # self.model.bn2.bias.data=self.bn2_init_bias+self.quant(B-self.bn2_init_bias)
        x=self.model.bn2(x)

        # x=self.quant(x)
        x=F.avg_pool2d(x,4)
        # x=self.quant(x)
        x=x.view(x.size(0),-1)
        W=self.model.linear.weight.data
        # self.model.linear.weight.data=self.quant(W)
        self.model.linear.weight.data=self.linear_init_weight+self.quant(W-self.linear_init_weight)
        B=self.model.linear.bias.data
        # self.model.linear.bias.data=self.quant(B) 
        self.model.linear.bias.data=self.linear_init_bias+self.quant(B-self.linear_init_bias)
        x=self.model.linear(x)
        return x


# **Building quantized resnet model**

In [25]:
import copy
net_quant=copy.deepcopy(identity_resnet18())

sub identity init a conv1x1
sub identity init a conv1x1
sub identity init a conv1x1


In [26]:
#Quantization aware traing of resnet18
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
print('==> Building quantized model..')

quant_resnet=PreAct(model=net_quant)
quant_resnet=quant_resnet.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(quant_resnet.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


for epoch in range(start_epoch, start_epoch+100):
    train(epoch,quant_resnet)
    test(epoch,quant_resnet)
    scheduler.step()

==> Building quantized model..
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],

In [51]:
torch.cuda.current_device()

0

In [16]:
torch.cuda.device_count()

5

In [18]:
torch.cuda.get_device_name()

'NVIDIA A100-SXM4-80GB'

In [None]:
from prettytable import PrettyTable
def count_parameters(model):
  table=PrettyTable(['Modules','Parameters'])
  total_params=0
  for name,parameter in model.named_parameters():
    if not parameter.requires_grad: continue
    params=parameter.numel()
    table.add_row([name,params])
    total_params+=params
  print(table)
  print(f'Total Trainable Params:{total_params}')
  return total_params


In [None]:
count_parameters(net_quant)

+----------------------------+------------+
|          Modules           | Parameters |
+----------------------------+------------+
|        conv1.weight        |    1728    |
|         bn1.weight         |     64     |
|          bn1.bias          |     64     |
|         bn2.weight         |    512     |
|          bn2.bias          |    512     |
|    layer1.0.bn1.weight     |     64     |
|     layer1.0.bn1.bias      |     64     |
|   layer1.0.conv1.weight    |   36864    |
|    layer1.0.bn2.weight     |     64     |
|     layer1.0.bn2.bias      |     64     |
|   layer1.0.conv2.weight    |   36864    |
|    layer1.1.bn1.weight     |     64     |
|     layer1.1.bn1.bias      |     64     |
|   layer1.1.conv1.weight    |   36864    |
|    layer1.1.bn2.weight     |     64     |
|     layer1.1.bn2.bias      |     64     |
|   layer1.1.conv2.weight    |   36864    |
|    layer2.0.bn1.weight     |     64     |
|     layer2.0.bn1.bias      |     64     |
|   layer2.0.conv1.weight    |  

11172298

In [None]:
!python -m pip install torchsummary 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from torchsummary import summary
# model=torchvision.models.net_quant.cuda()
device ='cuda' if torch.cuda.is_available() else 'cpu'
# net_quant=net_quant.to(device)
summary(net_quant,input_size=(3,32,32),batch_size=-1)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
ChannelPaddingSkip-2           [-1, 64, 32, 32]               0
Hadamard_Transform-3           [-1, 64, 32, 32]               0
       BatchNorm2d-4           [-1, 64, 32, 32]             128
       BatchNorm2d-5           [-1, 64, 32, 32]             128
            Conv2d-6           [-1, 64, 32, 32]          36,864
       BatchNorm2d-7           [-1, 64, 32, 32]             128
            Conv2d-8           [-1, 64, 32, 32]          36,864
       PreActBlock-9           [-1, 64, 32, 32]               0
      BatchNorm2d-10           [-1, 64, 32, 32]             128
           Conv2d-11           [-1, 64, 32, 32]          36,864
      BatchNorm2d-12           [-1, 64, 32, 32]             128
           Conv2d-13           [-1, 64, 32, 32]          36,864
      PreActBlock-14           [-1, 64,

In [None]:
!python -m pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
import torchinfo
# model=torchvision.models.resnet18().cuda()
torchinfo.summary(net_quant,(3,32,32),batch_dim=0,col_names=('input_size','output_size','num_params','kernel_size','mult_adds'),verbose=0)

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
PreActResNet                                  [1, 3, 32, 32]            [1, 10]                   --                        --                        --
├─Conv2d: 1-1                                 [1, 3, 32, 32]            [1, 64, 32, 32]           1,728                     [3, 3]                    1,769,472
├─Sequential: 1-2                             [1, 3, 32, 32]            [1, 64, 32, 32]           --                        --                        --
│    └─ChannelPaddingSkip: 2-1                [1, 3, 32, 32]            [1, 64, 32, 32]           --                        --                        --
│    └─Hadamard_Transform: 2-2                [1, 64, 32, 32]           [1, 64, 32, 32]           (4,096)                   --                        --
├─BatchNorm2d: 1-3                            [1, 64, 32, 32]       

In [None]:
num_eval_batches = 1000

train_batch_size = 30
eval_batch_size = 50

saved_model_dir = '/content/drive/MyDrive/Extra/'
scripted_float_model_file = 'quant_resnet18.pth'
print("Size of baseline model")
print_size_of_model(quant_resnet)

top1, top5 = evaluate(quant_resnet, criterion, testloader, neval_batches=num_eval_batches)
print('Evaluation accuracy on %d images, %2.2f'%(num_eval_batches * eval_batch_size, top1.avg))
torch.jit.save(torch.jit.script(quant_resnet), saved_model_dir + scripted_float_model_file)

# **Toy expeiment**

In [None]:
# algorithm for learning mantissa and maximum range c
def sgd(params, lr):
    for param in params:  
        param.data -= lr * param.grad.data
        param.grad = None
def mse_loss(predictions, targets):
    return ((predictions - targets) ** 2).mean()

In [None]:
# Input=torch.randn(10**5,1)
Input=torch.zeros(10**5,1)

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.c = torch.nn.Parameter(torch.tensor([250.0]))
        self.m = torch.nn.Parameter(torch.tensor([3.0]))
        
    def forward(self, x):
        sign=torch.sign(x)
        e=7-self.m
        # b= 2**(e-1) #bias
        b=2**e-torch.log2(self.c)+torch.log2(2-2**(-self.m))-1 #bias
        # self.c=(2-2**(-self.m))*2**(2**e-b-1)  #maximum representable range
        p= torch.round(torch.log2(abs(x)))-self.m 
        p=torch.clamp(p,1-b-self.m,p.max())
        s=2**p
        output = torch.clamp(sign*s*torch.round(torch.abs(x)/s),-self.c,self.c)
        return output

        
model = MyModel()
loss_fn = mse_loss
optimizer = lambda params, lr: sgd(params, lr=lr)

lr = 0.1
for epoch in range(500):
    # Compute the predictions and loss
    y_pred = model(Input)
    loss = loss_fn(y_pred, Input)
    
    # Compute the gradients and update the model parameters
    loss.backward()
    optimizer(model.parameters(), lr=lr)
    
    # Print the loss every 10 epochs
    if epoch % 10 == 0:
        print('Epoch %d: loss=%.4f' % (epoch, loss.item()))
        # print(model.parameters())
        for  param in model.parameters():
          if param.requires_grad:
             print( param.data)


Epoch 0: loss=nan
tensor([250.])
tensor([nan])
Epoch 10: loss=nan
tensor([250.])
tensor([nan])
Epoch 20: loss=nan
tensor([250.])
tensor([nan])
Epoch 30: loss=nan
tensor([250.])
tensor([nan])
Epoch 40: loss=nan
tensor([250.])
tensor([nan])
Epoch 50: loss=nan
tensor([250.])
tensor([nan])
Epoch 60: loss=nan
tensor([250.])
tensor([nan])
Epoch 70: loss=nan
tensor([250.])
tensor([nan])
Epoch 80: loss=nan
tensor([250.])
tensor([nan])
Epoch 90: loss=nan
tensor([250.])
tensor([nan])
Epoch 100: loss=nan
tensor([250.])
tensor([nan])
Epoch 110: loss=nan
tensor([250.])
tensor([nan])
Epoch 120: loss=nan
tensor([250.])
tensor([nan])
Epoch 130: loss=nan
tensor([250.])
tensor([nan])
Epoch 140: loss=nan
tensor([250.])
tensor([nan])
Epoch 150: loss=nan
tensor([250.])
tensor([nan])
Epoch 160: loss=nan
tensor([250.])
tensor([nan])
Epoch 170: loss=nan
tensor([250.])
tensor([nan])
Epoch 180: loss=nan
tensor([250.])
tensor([nan])
Epoch 190: loss=nan
tensor([250.])
tensor([nan])
Epoch 200: loss=nan
tensor([250

# **Pertensor/Per channel quantization**

In [None]:
     
    def forward(input,m,e):
        sign=torch.sign(input) #sign bit
        print(sign)
        b= 2**(e-1)     #bias
        print(b)
        c=(2-2**(-m))*2**(2**e-b-1)  #maximum representable range i.e. dynamic range
        print(c)
        p= torch.round(torch.log2(abs(input)))-m 
        print(p)
        pmax=p.max()
        print(torch.tensor(1-b-m),pmax)
        if pmax<1-b-m:
          pmax=1-b-m
        p=torch.clamp(p,1-b-m,pmax)
        # if p<1-b-m:
        #   p=1-b-m
        print(p)
        s=2**p
        print(s)
        print(np.round(torch.abs(input)/s))
        output = torch.clamp(sign*s*torch.round(torch.abs(input)/s),-c,c)
        print(output)
        return output

In [None]:
x=torch.zeros(10)
# x=torch.tensor(0)
y=forward(x,4,3)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
4
15.5
tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])
tensor(-7) tensor(-inf)
tensor([-7., -7., -7., -7., -7., -7., -7., -7., -7., -7.])
tensor([0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [None]:
print(y)

tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])


In [None]:
print(torch.sign(torch.tensor(5)))

tensor(1)
