#ALQ

In [None]:
import math
import random
import torch
from torch.optim.optimizer import Optimizer, required


class ALQ_optimizer(Optimizer):

    """Implement ALQ optimizer.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float, optional): learning rate (default: 1e-3).
        betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)).
        eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8).
        weight_decay (float, optional): weight decay (L2 regularization) (default: 0).
        
    Reference:
        Adam optimizer by Pytorch:
        https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam
        On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """
    
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.):
        # Check the validity 
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(ALQ_optimizer, self).__init__(params, defaults)
        
    def __setstate__(self, state):
        super(ALQ_optimizer, self).__setstate__(state)
           
    def step(self, params_bin, mode, pruning_rate=None, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            # Check if this is a pruning step
            if pruning_rate is not None:
                importance_list = torch.tensor([])
            
            for i, (p_bin, p) in enumerate(zip(params_bin, group['params'])):
                if p.grad is None:
                    continue
                
                # Compute the gradient in both w domain and alpha domain
                grad = p.grad.data
                grad_alpha = p_bin.construct_grad_alpha(grad)
                state = self.state[p]
                
                # Initialize the state parameters in both w domain and alpha domain
                if len(state) == 0:
                    state['step_alpha'] = 0
                    state['exp_avg_alpha'] = torch.zeros_like(p_bin.alpha)
                    state['exp_avg_sq_alpha'] = torch.zeros_like(p_bin.alpha)
                    state['max_exp_avg_sq_alpha'] = torch.zeros_like(p_bin.alpha)
                    
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    state['max_exp_avg_sq'] = torch.zeros_like(p.data)

                if mode == 'coordinate':
                    # Update the state parameters in w domain
                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                    max_exp_avg_sq = state['max_exp_avg_sq']
                    beta1, beta2 = group['betas']
                    state['step'] += 1
                    # Decay the first and second moment running average coefficient
                    exp_avg.mul_(beta1).add_(1 - beta1, grad)
                    exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                    # Maintain the maximum of all second moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)

                    # Update the state parameters in alpha domain
                    exp_avg_alpha, exp_avg_sq_alpha = state['exp_avg_alpha'], state['exp_avg_sq_alpha']
                    max_exp_avg_sq_alpha = state['max_exp_avg_sq_alpha']
                    state['step_alpha'] += 1
                    # L2 regularization on coordinates (in alpha domain)
                    if group['weight_decay'] != 0:
                        grad_alpha = grad_alpha.add(p_bin.alpha, alpha=group['weight_decay'])
                    # Decay the first and second moment running average coefficient
                    exp_avg_alpha.mul_(beta1).add_(1 - beta1, grad_alpha)
                    exp_avg_sq_alpha.mul_(beta2).addcmul_(1 - beta2, grad_alpha, grad_alpha)
                    # Maintain the maximum of all second moment running avg. till now
                    torch.max(max_exp_avg_sq_alpha, exp_avg_sq_alpha, out=max_exp_avg_sq_alpha)
                    # Use the max. for normalizing running avg. of gradient
                    denom_alpha = max_exp_avg_sq_alpha.sqrt().add_(group['eps'])
                    bias_correction1 = 1 - beta1 ** state['step_alpha']
                    bias_correction2 = 1 - beta2 ** state['step_alpha']

                    # Compute the pseudo gradient and the pseudo diagonal Hessian 
                    pseudo_grad_alpha = (group['lr'] / bias_correction1) * exp_avg_alpha 
                    pseudo_hessian_alpha = denom_alpha.div(math.sqrt(bias_correction2))
                    
                    # Check if this is a pruning step
                    if pruning_rate is not None:
                        # Compute the integer used to determine the number of selected alpha's in this layer
                        float_tmp = p_bin.num_bin_filter.item()*pruning_rate[0]
                        int_tmp = int(float_tmp)
                        if random.random()<float_tmp-int_tmp:
                            int_tmp += 1 
                        # Sort the importance of binary filters (alpha's) in this layer and select Top-k% (int_tmp) unimportant ones
                        p_bin_importance_list = p_bin.sort_importance_bin_filter(pseudo_grad_alpha, pseudo_hessian_alpha, int_tmp) 
                        importance_list = torch.cat((importance_list,p_bin_importance_list), 0) 
                    else:
                        # Take one optimization step on coordinates
                        p_bin.alpha.add_(-pseudo_grad_alpha/pseudo_hessian_alpha)
                        # Reconstruct the weight tensor from the current quantization
                        p_bin.update_w_FP()
                        tmp_p = p.detach()
                        tmp_p.zero_().add_(p_bin.w_FP.data)
                                     
                elif mode == 'basis':
                    # Update the state parameters in w domain
                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                    max_exp_avg_sq = state['max_exp_avg_sq']
                    beta1, beta2 = group['betas']
                    state['step'] += 1
                    # Decay the first and second moment running average coefficient
                    exp_avg.mul_(beta1).add_(1 - beta1, grad)
                    exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                    # Maintain the maximum of all second moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                    bias_correction1 = 1 - beta1 ** state['step']
                    bias_correction2 = 1 - beta2 ** state['step']

                    # Compute the pseudo gradient and the pseudo diagonal Hessian 
                    pseudo_grad = (group['lr'] / bias_correction1) * exp_avg 
                    pseudo_hessian = denom.div(math.sqrt(bias_correction2))
                    # Take one optimization step on binary bases
                    p_bin.optimize_bin_basis(pseudo_grad, pseudo_hessian)
                    # Speed up with an optimization step on coordinates
                    p_bin.speedup(pseudo_grad, pseudo_hessian)
                    # Reconstruct the weight tensor from the current quantization
                    p_bin.update_w_FP()
                    tmp_p = p.detach()
                    tmp_p.zero_().add_(p_bin.w_FP.data)

                    # Update the state parameters in alpha domain (approximately)
                    state['step_alpha'] += 1
                    state['exp_avg_alpha'] = p_bin.construct_grad_alpha(exp_avg)
                    state['exp_avg_sq_alpha'] = p_bin.construct_hessian_alpha(exp_avg_sq)
                    state['max_exp_avg_sq_alpha'] = p_bin.construct_hessian_alpha(max_exp_avg_sq)
                    
                elif mode == 'ste':
                    # Update the state parameters in w domain
                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                    max_exp_avg_sq = state['max_exp_avg_sq']
                    beta1, beta2 = group['betas']
                    state['step'] += 1
                    # Decay the first and second moment running average coefficient
                    exp_avg.mul_(beta1).add_(1 - beta1, grad)
                    exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                    # Maintain the maximum of all second moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                    bias_correction1 = 1 - beta1 ** state['step']
                    bias_correction2 = 1 - beta2 ** state['step']

                    # Compute the pseudo gradient and the pseudo diagonal Hessian 
                    pseudo_grad = (group['lr'] / bias_correction1) * exp_avg 
                    pseudo_hessian = denom.div(math.sqrt(bias_correction2))
                    
                    # Take one optimization step on binary bases
                    p_bin.optimize_bin_basis(pseudo_grad, pseudo_hessian)
                    # Speed up with an optimization step on coordinates
                    p_bin.speedup(pseudo_grad, pseudo_hessian)
                    # Update the maintained full precision weights
                    p_bin.update_w_FP(-pseudo_grad/pseudo_hessian)
                    # Reconstruct the weight tensor from the current quantization
                    tmp_p = p.detach()
                    tmp_p.zero_().add_(p_bin.reconstruct_w())

                    # Update the state parameters in alpha domain (approximately)
                    state['step_alpha'] += 1
                    state['exp_avg_alpha'] = p_bin.construct_grad_alpha(exp_avg)
                    state['exp_avg_sq_alpha'] = p_bin.construct_hessian_alpha(exp_avg_sq)
                    state['max_exp_avg_sq_alpha'] = p_bin.construct_hessian_alpha(max_exp_avg_sq)
            
            # Check if this is a pruning step        
            if pruning_rate is not None:
                # Resort the importance of selected binary filters (alpha's) over all layers 
                sorted_ind = torch.argsort(importance_list[:,-1])
                # Compute the number of pruned alpha's in this iteration
                # Note that unlike the paper, M_p varies over iterations here, but this does not influence the pruning schedule. 
                M_p = int(sorted_ind.nelement()*pruning_rate[1])
                # Determine indexes of alpha's to be pruned
                ind_prune = sorted_ind[:M_p]
                list_prune = importance_list[ind_prune,:]
                # Prune alpha's in each layer and reconstruct the weight tensor
                for i, (p_bin, p) in enumerate(zip(params_bin, group['params'])):
                    p_bin.prune_alpha((torch.sort(list_prune[list_prune[:,0]==i,1])[0]).to(torch.int64))
                    p_bin.update_w_FP()
                    tmp_p = p.detach()
                    tmp_p.zero_().add_(p_bin.w_FP.data)
        return loss


#BinaryNet


In [None]:
import torch
#from myoptimizer import ALQ_optimizer


REL_NORM_THRES = 1e-6


def construct_bit_table(bit):
    """Construct a look-up-table to store bitwise values of all intergers given a bitwidth."""
    bit_table = -torch.ones((2**bit, bit), dtype=torch.int8)
    for i in range(1,2**bit):
        for j in range(bit):
            if (i & (1<<j)):
                bit_table[i,j] = 1
    return bit_table.to('cuda')


def binarize(input_t): 
    """Binarize input tensor."""
    dim = input_t.nelement()
    output_t = torch.ones(dim)
    output_t[input_t<0] = -1
    return output_t


def transform_bin_basis(w_vec, max_dim, rel_norm_thres=REL_NORM_THRES):
    """Transform a full precision weight vector into multi-bit form, i.e. binary bases and coordiantes."""
    # Reshape the coordinates vector in w domain
    crd_w = w_vec.detach().view(-1,1)
    # Get the dimensionality in w domain
    dim_w = crd_w.nelement()
    # Determine the max number of dimensionality in alpha domain
    if dim_w <= max_dim:
        max_dim_alpha = dim_w
    else:
        max_dim_alpha = max_dim
    # Initialize binary basis matrix in alpha domain
    bin_basis_alpha = torch.zeros((dim_w, max_dim_alpha))
    # Initialize coordinates vector in alpha domain
    crd_alpha = torch.zeros(max_dim_alpha) 
    res = crd_w.detach()
    res_L2Norm_square = torch.sum(torch.pow(res,2))
    ori_L2Norm_square = torch.sum(torch.pow(crd_w,2))  
    for i in range(max_dim_alpha):
        if res_L2Norm_square/ori_L2Norm_square < rel_norm_thres:
            break
        new_bin_basis = binarize(res.view(-1))
        bin_basis_alpha[:,i] = new_bin_basis 
        B_ = bin_basis_alpha[:,:i+1]
        # Find the optimal coordinates in the space spanned by B_ 
        alpha_ = torch.mm(torch.inverse(torch.mm(torch.t(B_),B_)),torch.mm(torch.t(B_),crd_w)) 
        # Compute the residual (orthogonal to the space spanned by B_)
        res = crd_w - torch.mm(B_, alpha_)
        crd_alpha[:i+1] = alpha_.view(-1)
        res_L2Norm_square = torch.sum(torch.pow(res,2))   
    ind_neg = crd_alpha < 0
    crd_alpha[ind_neg] = -crd_alpha[ind_neg]
    bin_basis_alpha[:,ind_neg] = -bin_basis_alpha[:,ind_neg]
    # Get the valid indexes 
    ind_valid = crd_alpha != 0
    # Get the valid dimensionality in alpha domain
    dim_alpha = torch.sum(ind_valid) 
    sorted_ind = torch.argsort(crd_alpha[ind_valid])
    if dim_alpha == 0:
        return [], [], 0
    else:
        return bin_basis_alpha[:,ind_valid][:,sorted_ind].to(torch.int8), crd_alpha[ind_valid][sorted_ind], dim_alpha


class ConvLayer_bin(object):
    """This class defines the multi-bit form of the weight tensor of a convolutional layer used in ALQ. 
    Arguments:
        w_ori (float tensor): the 4-dim pretrained weight tensor of a convolutional layer.
        ind_layer (int): the index of this layer in the network.
        structure (string): the structure used for grouping the weights in this layer, optional values: 'kernelwise', 'pixelwise', 'channelwise'.
        max_bit (int): the maximum bitwidth used in initialization.
    """
    def __init__(self, w_ori, ind_layer, structure, max_bit):
        # The layer type
        self.layer_type = 'conv'
        # The shape of the weight tensor of this layer
        self.tensor_shape = w_ori.size()
        # The maintained full precision weight tensor of this layer used in STE
        self.w_FP = w_ori.clone().to('cuda')
        # The index of this layer in the network
        self.ind_layer = ind_layer
        # The structure used for grouping the weights in this layer
        self.structure = structure
        # The maximum bitwidth used in initialization
        self.max_bit = max_bit
        # The binary bases, the coordinates, and the mask (only for parallel computing purposes) of each group
        self.B, self.alpha, self.mask = self.structured_sketch()
        # The total number of binary filters in this layer, namely the total number of (valid) alpha's
        self.num_bin_filter = torch.sum(self.mask)
        # The average bitwidth of this layer
        self.avg_bit = self.num_bin_filter.float()/(self.mask.size(0)*self.mask.size(1))
        # The total number of weights of this layer
        self.num_weight = self.w_FP.nelement()
        # The used look-up-table for bitwise values
        self.bit_table = construct_bit_table(self.max_bit)
        
    def structured_sketch(self):
        """Initialize the weight tensor using structured sketching. 
        Namely, structure the weights in groupwise, and quantize each group's weights in multi-bit form w.r.t. the reconstruction error.
        Return the binary bases, the coordinates, and the mask (only for parallel computing purposes) of each group. 
        """
        w_cpu = self.w_FP.to('cpu')
        if self.structure == 'kernelwise':
            B = torch.zeros((self.tensor_shape[0],self.tensor_shape[1],self.max_bit,self.tensor_shape[2]*self.tensor_shape[3])).to(torch.int8)
            alpha = torch.zeros((self.tensor_shape[0],self.tensor_shape[1],self.max_bit,1)).to(torch.float32)
            mask =  torch.zeros((self.tensor_shape[0],self.tensor_shape[1],self.max_bit,1)).to(torch.bool)
        elif self.structure == 'pixelwise':
            B = torch.zeros((self.tensor_shape[0],self.tensor_shape[2]*self.tensor_shape[3],self.max_bit,self.tensor_shape[1])).to(torch.int8)
            alpha = torch.zeros((self.tensor_shape[0],self.tensor_shape[2]*self.tensor_shape[3],self.max_bit,1)).to(torch.float32)
            mask =  torch.zeros((self.tensor_shape[0],self.tensor_shape[2]*self.tensor_shape[3],self.max_bit,1)).to(torch.bool)
        elif self.structure == 'channelwise':
            B = torch.zeros((self.tensor_shape[0],1,self.max_bit,self.tensor_shape[1]*self.tensor_shape[2]*self.tensor_shape[3])).to(torch.int8)
            alpha = torch.zeros((self.tensor_shape[0],1,self.max_bit,1)).to(torch.float32)
            mask =  torch.zeros((self.tensor_shape[0],1,self.max_bit,1)).to(torch.bool)
        for k in range(self.tensor_shape[0]):
            if self.structure == 'kernelwise':
                for q in range(self.tensor_shape[1]):
                    bin_basis, crd, dim = transform_bin_basis(w_cpu[k,q,:,:].view(-1), self.max_bit)
                    mask[k,q,:dim,0] = 1
                    B[k,q,:dim,:] = torch.t(bin_basis)
                    alpha[k,q,:dim,0] = crd
            elif self.structure == 'pixelwise':
                for h in range(self.tensor_shape[2]):
                    for w in range(self.tensor_shape[3]):
                        bin_basis, crd, dim = transform_bin_basis(w_cpu[k,:,h,w].view(-1), self.max_bit)
                        mask[k,h*self.tensor_shape[3]+w,:dim,0] = 1
                        B[k,h*self.tensor_shape[3]+w,:dim,:] = torch.t(bin_basis)
                        alpha[k,h*self.tensor_shape[3]+w,:dim,0] = crd
            if self.structure == 'channelwise':
                bin_basis, crd, dim = transform_bin_basis(w_cpu[k,:,:,:].view(-1), self.max_bit)
                mask[k,0,:dim,0] = 1
                B[k,0,:dim,:] = torch.t(bin_basis)
                alpha[k,0,:dim,0] = crd
        return B.to('cuda'), alpha.to('cuda'), mask.to('cuda')

    def reconstruct_w(self):
        """Reconstruct the weight tensor from the current quantization.
        Return the reconstructed weight tensor of this layer, i.e. \hat{w}.
        """
        w_bin = torch.sum(self.B.float()*(self.alpha*self.mask.float()),dim=2)
        if self.structure == 'kernelwise':
            return w_bin.reshape(self.tensor_shape)
        elif self.structure == 'pixelwise':
            return torch.transpose(w_bin,1,2).reshape(self.tensor_shape)
        elif self.structure == 'channelwise':
            return w_bin.reshape(self.tensor_shape)

    def update_w_FP(self, w_FP_new=None):
        """Update the full precision weight tensor.
        In STE with loss-aware optimization, w_FP is the maintained full precision weight tensor.
        In ALQ optimization, w_FP is used to store the reconstructed weight tensor from the current quantization. 
        """
        if w_FP_new is not None:
            self.w_FP.add_(w_FP_new)
        else:
            self.w_FP.zero_().add_(self.reconstruct_w())

    def construct_grad_alpha(self, grad_w):
        """Compute and return the gradient (or the first momentum) in alpha domain w.r.t the loss.
        """
        if self.structure == 'kernelwise':
            return torch.matmul(self.B.float(), grad_w.reshape((self.tensor_shape[0],self.tensor_shape[1],-1,1)))*self.mask.float()
        elif self.structure == 'pixelwise':
            return torch.matmul(self.B.float(), torch.transpose(grad_w.reshape((self.tensor_shape[0],self.tensor_shape[1],-1,1)), 1,2) )*self.mask.float()
        elif self.structure == 'channelwise':
            return torch.matmul(self.B.float(), grad_w.reshape((self.tensor_shape[0],1,-1,1)))*self.mask.float()

    def construct_hessian_alpha(self, diag_hessian_w):
        """Compute and return the diagonal Hessian (or the second momentum) in alpha domain w.r.t the loss.
        """
        if self.structure == 'kernelwise':
            diag_hessian = torch.matmul(self.B.float()*diag_hessian_w.reshape((self.tensor_shape[0],self.tensor_shape[1],1,-1)), torch.transpose(self.B,2,3).float())
            return torch.diagonal(diag_hessian,dim1=-2,dim2=-1).unsqueeze(-1)*self.mask.float()
        elif self.structure == 'pixelwise':
            diag_hessian = torch.matmul(self.B.float()*torch.transpose(diag_hessian_w.reshape((self.tensor_shape[0],self.tensor_shape[1],1,-1)), 1,3), torch.transpose(self.B,2,3).float())
            return torch.diagonal(diag_hessian,dim1=-2,dim2=-1).unsqueeze(-1)*self.mask.float()
        elif self.structure == 'channelwise':
            diag_hessian = torch.matmul(self.B.float()*diag_hessian_w.reshape((self.tensor_shape[0],1,1,-1)), torch.transpose(self.B,2,3).float())
            return torch.diagonal(diag_hessian,dim1=-2,dim2=-1).unsqueeze(-1)*self.mask.float()

    def sort_importance_bin_filter(self, grad_alpha, diag_hessian_alpha, num_top):
        """Compute and sort the importance of binary filters (alpha's) in this layer.
        The importance is defined by the modeled loss increment caused by pruning each individual alpha.
        Return the selected num_top alpha's with the least importance.
        """
        delta_loss_prune = -grad_alpha*self.alpha+0.5*torch.pow(self.alpha,2)*diag_hessian_alpha
        sorted_ind = torch.argsort(delta_loss_prune[self.mask].view(-1))
        top_importance_list = torch.tensor([[self.ind_layer, sorted_ind[i], delta_loss_prune.view(-1)[sorted_ind[i]]] for i in range(num_top)])  
        return top_importance_list
                
    def prune_alpha(self, ind_prune): 
        """Prune the cooresponding alpha's of this layer give the indexes.
        """
        num_bin_filter_ = torch.sum(self.mask)
        self.mask.view(-1)[self.mask.view(-1).nonzero().view(-1)[ind_prune]]=0   
        self.B *= self.mask.char()
        self.alpha *= self.mask.float()  
        self.num_bin_filter = torch.sum(self.mask)  
        self.avg_bit = self.num_bin_filter.float()/(self.mask.size(0)*self.mask.size(1))
        if num_bin_filter_-self.num_bin_filter != ind_prune.size(0):
            print('wrong pruning')
            return False
        return True
        
    def optimize_bin_basis(self, pseudo_grad, pseudo_hessian):
        """Take one optimization step on the binary bases of this layer while fixing coordinates.
        """
        # Compute the target weight tensor, i.e. the optimal point in w domain according to the quadratic model function 
        target_w = self.w_FP-pseudo_grad/pseudo_hessian
        if self.structure == 'kernelwise':
            all_disc_w = torch.matmul(self.bit_table.view((1,1,self.bit_table.size(0),self.bit_table.size(1))).float(),self.alpha)
            ind_opt = torch.argmin(torch.abs(target_w.view((self.tensor_shape[0],self.tensor_shape[1],1,-1)) - all_disc_w), dim=2)
            self.B = torch.transpose((self.bit_table[ind_opt.view(-1),:]).view(self.tensor_shape[0],self.tensor_shape[1],self.tensor_shape[2]*self.tensor_shape[3],self.max_bit), 2,3)
            self.B *= self.mask.char()
        elif self.structure == 'pixelwise':
            all_disc_w = torch.matmul(self.bit_table.view((1,1,self.bit_table.size(0),self.bit_table.size(1))).float(),self.alpha)
            ind_opt = torch.argmin(torch.abs(torch.transpose(target_w.view((self.tensor_shape[0],self.tensor_shape[1],1,-1)), 1,3) - all_disc_w), dim=2)
            self.B = torch.transpose((self.bit_table[ind_opt.view(-1),:]).view(self.tensor_shape[0],self.tensor_shape[2]*self.tensor_shape[3],self.tensor_shape[1],self.max_bit), 2,3)
            self.B *= self.mask.char()
        elif self.structure == 'channelwise':
            all_disc_w = torch.matmul(self.bit_table.view((1,1,self.bit_table.size(0),self.bit_table.size(1))).float(),self.alpha)
            ind_opt = torch.argmin(torch.abs(target_w.view((self.tensor_shape[0],1,1,-1)) - all_disc_w), dim=2)
            self.B = torch.transpose((self.bit_table[ind_opt.view(-1),:]).view(self.tensor_shape[0],1,self.tensor_shape[1]*self.tensor_shape[2]*self.tensor_shape[3],self.max_bit), 2,3)
            self.B *= self.mask.char()
        return True
            
    def speedup(self, pseudo_grad, pseudo_hessian):
        """Speed up the optimization on binary bases, i.e. take a following optimization step on coordinates while fixing binary bases. 
        """
        revised_grad_w = -pseudo_hessian*self.w_FP+pseudo_grad
        if self.structure == 'kernelwise':
            revised_hessian = torch.matmul(self.B.float()*pseudo_hessian.view((self.tensor_shape[0],self.tensor_shape[1],1,-1)),torch.transpose(self.B,2,3).float())
            revised_hessian += torch.diag_embed(1+1e-6-(self.mask.float().squeeze(-1))) 
            revised_grad = torch.matmul(self.B.float(),revised_grad_w.view((self.tensor_shape[0],self.tensor_shape[1],-1,1)))
            self.alpha = -torch.matmul(torch.inverse(revised_hessian),revised_grad)
        elif self.structure == 'pixelwise':
            revised_hessian = torch.matmul(self.B.float()*torch.transpose(pseudo_hessian.view((self.tensor_shape[0],self.tensor_shape[1],1,-1)),1,3),torch.transpose(self.B,2,3).float())
            revised_hessian += torch.diag_embed(1+1e-6-(self.mask.float().squeeze(-1)))
            revised_grad = torch.matmul(self.B.float(),torch.transpose(revised_grad_w.view((self.tensor_shape[0],self.tensor_shape[1],-1,1)),1,2))
            self.alpha = -torch.matmul(torch.inverse(revised_hessian),revised_grad)
        elif self.structure == 'channelwise':
            revised_hessian = torch.matmul(self.B.float()*pseudo_hessian.view((self.tensor_shape[0],1,1,-1)),torch.transpose(self.B,2,3).float())
            revised_hessian += torch.diag_embed(1+1e-6-(self.mask.float().squeeze(-1))) 
            revised_grad = torch.matmul(self.B.float(),revised_grad_w.view((self.tensor_shape[0],1,-1,1)))
            self.alpha = -torch.matmul(torch.inverse(revised_hessian),revised_grad)
        self.alpha *= self.mask.float()
        ind_neg = self.alpha<0
        self.alpha[ind_neg] *= -1
        self.B.contiguous().view(-1,self.B.size(-1))[ind_neg.view(-1),:] *= -1
        self.num_bin_filter = torch.sum(self.mask)
        self.avg_bit = self.num_bin_filter.float()/(self.mask.size(0)*self.mask.size(1))
        return True
 
    
class FCLayer_bin(object):
    """This class defines the multi-bit form of the weight tensor of a convolutional layer used in ALQ. 
    Arguments:
        w_ori (float tensor): the 4-dim pretrained weight tensor of a convolutional layer.
        ind_layer (int): the index of this layer in the network.
        structure (string): the structure used for grouping the weights in this layer, optional values: 'subchannelwise'.
        max_bit (int): the maximum bitwidth used in initialization.
    """
    def __init__(self, w_ori, ind_layer, structure, num_subchannel, max_bit):
        # The layer type
        self.layer_type = 'fc'
        # The shape of the weight tensor of this layer
        self.tensor_shape = w_ori.size()
        # The maintained full precision weight tensor of this layer used in STE
        self.w_FP = w_ori.clone().to('cuda')
        # The index of this layer in the network
        self.ind_layer = ind_layer
        # The structure used for grouping the weights in this layer
        self.structure = structure
        # The maximum bitwidth used in initialization
        self.max_bit = max_bit
        # The number of groups in each channel, i.e. the number of subchannels 
        self.num_subchannel = num_subchannel
        # The number of weights in each subchannel
        self.num_w_subc = int(self.tensor_shape[1]/self.num_subchannel)
        # The binary bases, the coordinates, and the mask (only for parallel computing purposes) of each group
        self.B, self.alpha, self.mask = self.structured_sketch()
        # The total number of binary filters in this layer, namely the total number of (valid) alpha's
        self.num_bin_filter = torch.sum(self.mask)
        # The average bitwidth of this layer
        self.avg_bit = self.num_bin_filter.float()/(self.mask.size(0)*self.mask.size(1))
        # The total number of weights of this layer
        self.num_weight = self.w_FP.nelement()
        # The used look-up-table for bitwise values
        self.bit_table = construct_bit_table(self.max_bit)
        
    def structured_sketch(self):
        """Initialize the weight tensor using structured sketching. 
        Namely, structure the weights in groupwise, and quantize each group's weights in multi-bit form w.r.t. the reconstruction error.
        Return the binary bases, the coordinates, and the mask (only for parallel computing purposes) of each group. 
        """
        w_cpu = self.w_FP.to('cpu')
        B = torch.zeros((self.tensor_shape[0],self.num_subchannel,self.max_bit,self.num_w_subc)).to(torch.int8)
        alpha = torch.zeros((self.tensor_shape[0],self.num_subchannel,self.max_bit,1)).to(torch.float32)
        mask =  torch.zeros((self.tensor_shape[0],self.num_subchannel,self.max_bit,1)).to(torch.bool)
        for k in range(self.tensor_shape[0]):
            for (q,i) in enumerate(range(0,self.tensor_shape[1],self.num_w_subc)):
                bin_basis, crd, dim = transform_bin_basis(w_cpu[k,i:i+self.num_w_subc].view(-1), self.max_bit)
                mask[k,q,:dim,0] = 1
                B[k,q,:dim,:] = torch.t(bin_basis)
                alpha[k,q,:dim,0] = crd
        return B.to('cuda'), alpha.to('cuda'), mask.to('cuda')

    def reconstruct_w(self):
        """Reconstruct the weight tensor from the current quantization.
        Return the reconstructed weight tensor of this layer, i.e. \hat{w}.
        """
        w_bin = torch.sum(self.B.float()*(self.alpha*self.mask.float()),dim=2)
        return w_bin.reshape(self.tensor_shape)
    
    def update_w_FP(self, w_FP_new=None):
        """Update the full precision weight tensor.
        In STE with loss-aware optimization, w_FP is the maintained full precision weight tensor.
        In ALQ optimization, w_FP is used to store the reconstructed weight tensor from the current quantization. 
        """
        if w_FP_new is not None:
            self.w_FP.add_(w_FP_new)
        else:
            self.w_FP.zero_().add_(self.reconstruct_w())        
    
    def construct_grad_alpha(self, grad_w):
        """Compute and return the gradient (or the first momentum) in alpha domain w.r.t the loss.
        """
        return torch.matmul(self.B.float(), grad_w.reshape((self.tensor_shape[0],self.num_subchannel,self.num_w_subc,1)))*self.mask.float()
        
    def construct_hessian_alpha(self, diag_hessian_w):
        """Compute and return the diagonal Hessian (or the second momentum) in alpha domain w.r.t the loss.
        """
        diag_hessian_alpha = torch.matmul(self.B.float()*diag_hessian_w.reshape((self.tensor_shape[0],self.num_subchannel,1,self.num_w_subc)), torch.transpose(self.B,2,3).float())
        return torch.diagonal(diag_hessian_alpha,dim1=-2,dim2=-1).unsqueeze(-1)*self.mask.float()
        
    def sort_importance_bin_filter(self, grad_alpha, diag_hessian_alpha, num_top):
        """Compute and sort the importance of binary filters (alpha's) in this layer.
        The importance is defined by the modeled loss increment caused by pruning each individual alpha.
        Return the selected num_top alpha's with the least importance.
        """
        delta_loss_prune = -grad_alpha*self.alpha+0.5*torch.pow(self.alpha,2)*diag_hessian_alpha
        sorted_ind = torch.argsort(delta_loss_prune[self.mask].view(-1))
        top_importance_list = torch.tensor([[self.ind_layer, sorted_ind[i], delta_loss_prune.view(-1)[sorted_ind[i]]] for i in range(num_top)])  
        return top_importance_list
                   
    def prune_alpha(self, ind_prune): 
        """Prune the cooresponding alpha's of this layer give the indexes.
        """
        num_bin_filter_ = torch.sum(self.mask)
        self.mask.view(-1)[self.mask.view(-1).nonzero().view(-1)[ind_prune]]=0   
        self.B *= self.mask.char()
        self.alpha *= self.mask.float()   
        self.num_bin_filter = torch.sum(self.mask) 
        self.avg_bit = self.num_bin_filter.float()/(self.mask.size(0)*self.mask.size(1))
        if num_bin_filter_-self.num_bin_filter != ind_prune.size(0):
            print('wrong pruning')
            return False
        return True    
                   
    def optimize_bin_basis(self, pseudo_grad, pseudo_hessian):
        """Take one optimization step on the binary bases of this layer while fixing coordinates.
        """
        # Compute the target weight tensor, i.e. the optimal point in w domain according to the quadratic model function 
        target_w = self.w_FP-pseudo_grad/pseudo_hessian
        all_disc_w = torch.matmul(self.bit_table.view((1,1,self.bit_table.size(0),self.bit_table.size(1))).float(),self.alpha)
        ind_opt = torch.argmin(torch.abs(target_w.view((self.tensor_shape[0],self.num_subchannel,1,-1)) - all_disc_w), dim=2)
        self.B = torch.transpose((self.bit_table[ind_opt[:],:]).view(self.tensor_shape[0],self.num_subchannel,self.num_w_subc,self.max_bit), 2,3)
        self.B *= self.mask.char()
        return True
                               
    def speedup(self, pseudo_grad, pseudo_hessian):
        """Speed up the optimization on binary bases, i.e. take a following optimization step on coordinates while fixing binary bases. 
        """
        revised_grad_w = -pseudo_hessian*self.w_FP+pseudo_grad
        revised_hessian = torch.matmul(self.B.float()*pseudo_hessian.view((self.tensor_shape[0],self.num_subchannel,1,-1)),torch.transpose(self.B,2,3).float())
        revised_hessian += torch.diag_embed(1+1e-6-(self.mask.float().squeeze(-1))) 
        revised_grad = torch.matmul(self.B.float(),revised_grad_w.view((self.tensor_shape[0],self.num_subchannel,-1,1)))
        self.alpha = -torch.matmul(torch.inverse(revised_hessian),revised_grad)
        self.alpha *= self.mask.float()
        ind_neg = self.alpha<0
        self.alpha[ind_neg] *= -1
        self.B.contiguous().view(-1,self.B.size(-1))[ind_neg.view(-1),:] *= -1
        self.num_bin_filter = torch.sum(self.mask)
        self.avg_bit = self.num_bin_filter.float()/(self.mask.size(0)*self.mask.size(1))
        return True

#Training

In [None]:
import torch
#from binarynet import ConvLayer_bin, FCLayer_bin


TOPK = (1,5)

def accuracy(output, target, correct_sum, topk=(1,)):
    """Compute the accuracy over the k top predictions for the specified values of k."""
    with torch.no_grad():
        maxk = max(topk)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        #pred = pred.t()
        y_pred_list=[]
        for y in pred:
            y_pred_list.append(y)
        #print("working till here id 1")
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        #print("working till here id 2")

        for (i,k) in enumerate(topk):
            #print("working till here id 3")
            correct_sum[i] += (correct[:k].reshape(-1).float().sum(0, keepdim=True)).item()
        #print(correct_sum)
        return 


def get_accuracy(net, train_loader, loss_func):
    """Get the training loss and training accuracy."""
    net.eval()
    with torch.no_grad():
        train_loss = 0.
        num_batches = 0
        correct_sum = [0. for i in range(len(TOPK))]
        total = 0
        for (inputs, labels) in train_loader:               
            inputs, labels = inputs.cuda(non_blocking=True), labels.cuda(non_blocking=True)
            outputs = net(inputs)
            loss = loss_func(outputs, labels)
            accuracy(outputs, labels, correct_sum, topk=TOPK)
            total += labels.size(0)
            train_loss += loss.data.item()
            num_batches += 1
        print('training loss: ', train_loss/num_batches)
        print('training accuracy: ', [ci/total for ci in correct_sum])


def train_fullprecision(net, train_loader, loss_func, optimizer, epoch):
    """Train the original full precision network for one epoch."""
    net.train()
    train_loss = 0.
    num_batches = 0
    correct_sum = [0. for i in range(len(TOPK))]
    total = 0
    for (inputs, labels) in train_loader:               
        inputs, labels = inputs.cuda(non_blocking=True), labels.cuda(non_blocking=True)
        optimizer.zero_grad()    
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()   
        optimizer.step()
        accuracy(outputs, labels, correct_sum, topk=TOPK)
        total += labels.size(0)
        train_loss += loss.data.item()
        num_batches += 1
    print("epoch: ", epoch, ", training loss: ", train_loss/num_batches)            
    print('training accuracy: ', [ci/total for ci in correct_sum])


def train_coordinate(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, epoch):
    """Train the coordinates for one epoch."""
    net.train()
    train_loss = 0.
    num_batches = 0
    correct_sum = [0. for i in range(len(TOPK))]
    total = 0
    for (inputs, labels) in train_loader:               
        inputs, labels = inputs.cuda(non_blocking=True), labels.cuda(non_blocking=True)
        optimizer_w.zero_grad()
        optimizer_b.zero_grad()
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()   
        optimizer_b.step()
        optimizer_w.step(parameters_w_bin, 'coordinate')
        accuracy(outputs, labels, correct_sum, topk=TOPK)
        total += labels.size(0)
        train_loss += loss.data.item()
        num_batches += 1    
    print("epoch: ", epoch, ", training loss: ", train_loss/num_batches) 
    print('training accuracy: ', [ci/total for ci in correct_sum])
 
  
def train_basis(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, epoch):
    """Train the binary bases (with speedup) for one epoch."""
    net.train()
    train_loss = 0.
    num_batches = 0
    correct_sum = [0. for i in range(len(TOPK))]
    total = 0
    for inputs, labels in train_loader:               
        inputs, labels = inputs.cuda(non_blocking=True), labels.cuda(non_blocking=True)
        optimizer_w.zero_grad()
        optimizer_b.zero_grad()
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()   
        optimizer_b.step()
        optimizer_w.step(parameters_w_bin, 'basis')
        accuracy(outputs, labels, correct_sum, topk=TOPK)
        total += labels.size(0)
        train_loss += loss.data.item()
        num_batches += 1   
    print("epoch: ", epoch, ", training loss: ", train_loss/num_batches)                
    print('training accuracy: ', [ci/total for ci in correct_sum])


def train_basis_STE(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, epoch):
    """Train the binary bases (with speedup) by STE for one epoch."""
    net.train()
    train_loss = 0.
    num_batches = 0
    correct_sum = [0. for i in range(len(TOPK))]
    total = 0
    for (inputs, labels) in train_loader:               
        inputs, labels = inputs.cuda(non_blocking=True), labels.cuda(non_blocking=True)
        optimizer_w.zero_grad()
        optimizer_b.zero_grad()
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()   
        optimizer_b.step()
        optimizer_w.step(parameters_w_bin, 'ste')
        accuracy(outputs, labels, correct_sum, topk=TOPK)
        total += labels.size(0)
        train_loss += loss.data.item()
        num_batches += 1    
    print("epoch: ", epoch, ", training loss: ", train_loss/num_batches)                
    print('training accuracy: ', [ci/total for ci in correct_sum])


def prune(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, pruning_rate, epoch):
    """Prune alpha for one epoch."""
    net.eval()
    train_loss = 0.
    num_batches = 0
    correct_sum = [0. for i in range(len(TOPK))]
    total = 0
    for (inputs, labels) in train_loader:               
        inputs, labels = inputs.cuda(non_blocking=True), labels.cuda(non_blocking=True)
        optimizer_w.zero_grad()
        optimizer_b.zero_grad()
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()   
        optimizer_b.step()
        optimizer_w.step(parameters_w_bin, 'coordinate', pruning_rate)
        accuracy(outputs, labels, correct_sum, topk=TOPK)
        train_loss += loss.data.item()
        num_batches += 1
        total += labels.size(0)
    print("epoch: ", epoch, ", pruning loss: ", train_loss/num_batches)                
    print('pruning accuracy: ', [ci/total for ci in correct_sum])
    num_weight_layer = 0.
    num_bit_layer = 0.
    print('currrent number of binary filters per layer: ')
    for p_w_bin in parameters_w_bin:
        print(p_w_bin.num_bin_filter)
    print('currrent average bitwidth per layer: ')
    for p_w_bin in parameters_w_bin:
        num_weight_layer += p_w_bin.num_weight
        num_bit_layer += p_w_bin.avg_bit*p_w_bin.num_weight
        print(p_w_bin.avg_bit)
    print('currrent average bitwidth: ', num_bit_layer/num_weight_layer)

 
def initialize(net, train_loader, loss_func, structure, num_subchannel, max_bit):
    """Initialize the weight tensors of all layers to multi-bit form using structured sketching. 
    Return the iterator over all weight parameters, the iterator over all other parameters, and the iterator over the multi-bit forms of all weight parameters.  
    """
    parameters_w = []
    parameters_b = []
    parameters_w_bin = []
    i = 0
    for name, param in net.named_parameters():
        # Only initialize weight tensors to multi-bit form
        if 'weight' in name and param.dim()>1:
            parameters_w.append(param)
            # Initialize fully connected layers (param.dim()==2)
            if 'fc' in name or 'classifier' in name:
                parameters_w_bin.append(FCLayer_bin(param.data, len(parameters_w)-1, structure[i], num_subchannel[i], max_bit[i]))  
                i += 1
                tmp_param = param.detach()
                tmp_param.zero_().add_(parameters_w_bin[-1].reconstruct_w())
            # Initialize convolutional layers (param.dim()==4)
            else:
                parameters_w_bin.append(ConvLayer_bin(param.data, len(parameters_w)-1, structure[i], max_bit[i]))    
                i += 1
                tmp_param = param.detach()
                tmp_param.zero_().add_(parameters_w_bin[-1].reconstruct_w())    
        # Maintain other parameters (e.g. bias, batch normalization) in full precision 
        else:
            parameters_b.append(param)
    net.eval()
    train_loss = 0.
    num_batches = 0
    correct_sum = [0. for i in range(len(TOPK))]
    total = 0
    for (inputs, labels) in train_loader:               
        inputs, labels = inputs.cuda(non_blocking=True), labels.cuda(non_blocking=True)
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        accuracy(outputs, labels, correct_sum, topk=TOPK)
        total += labels.size(0)
        train_loss += loss.data.item()
        num_batches += 1
    print('train loss: ', train_loss/num_batches)
    print('train accuracy: ', [ci/total for ci in correct_sum]) 
    num_weight_layer = 0.
    num_bit_layer = 0.
    print('currrent binary filter number per layer: ')
    for p_w_bin in parameters_w_bin:
        print(p_w_bin.num_bin_filter)
    print('currrent average bitwidth per layer: ')
    for p_w_bin in parameters_w_bin:
        num_weight_layer += p_w_bin.num_weight
        num_bit_layer += p_w_bin.avg_bit*p_w_bin.num_weight
        print(p_w_bin.avg_bit)
    print('currrent average bitwidth: ', num_bit_layer/num_weight_layer)
    return parameters_w, parameters_b, parameters_w_bin 
      
     
def validate(net, val_loader, loss_func):
    """Get the validation loss and validation accuracy."""
    net.eval()
    val_loss = 0.
    num_batches = 0
    correct_sum = [0. for i in range(len(TOPK))]
    total = 0
    with torch.no_grad():
        for (inputs, labels) in val_loader:
            inputs, labels = inputs.cuda(non_blocking=True), labels.cuda(non_blocking=True)
            outputs = net(inputs)
            loss = loss_func(outputs, labels)  
            accuracy(outputs, labels, correct_sum, topk=TOPK)
            total += labels.size(0)
            val_loss += loss.data.item()
            num_batches += 1 
        print('validation loss: ', val_loss/num_batches)
        print("validation accuracy: ", [ci/total for ci in correct_sum])
        return [ci/total for ci in correct_sum]


def test(net, test_loader, loss_func):
    """Get the test loss and test accuracy."""
    net.eval()
    test_loss = 0.
    num_batches = 0
    correct_sum = [0. for i in range(len(TOPK))]
    total = 0
    with torch.no_grad():
        for (inputs, labels) in test_loader:
            inputs, labels = inputs.cuda(non_blocking=True), labels.cuda(non_blocking=True)
            outputs = net(inputs)
            loss = loss_func(outputs, labels)  
            accuracy(outputs, labels, correct_sum, topk=TOPK)
            total += labels.size(0)
            test_loss += loss.data.item()
            num_batches += 1
        print("test loss: ", test_loss/num_batches)
        print("test accuracy: ", [ci/total for ci in correct_sum])
        

def save_model(file_name, net, optimizer_w, optimizer_b, parameters_w_bin):
    """Save the state dictionary of model and optimizers."""
    print('saving...')   
    torch.save({
        'net_state_dict': net.state_dict(),
        'optimizer_w_state_dict': optimizer_w.state_dict(),
        'optimizer_b_state_dict': optimizer_b.state_dict(),
        'parameters_w_bin': parameters_w_bin,
        }, file_name)


def save_model_ori(file_name, net, optimizer):
    """Save the state dictionary of model and optimizer for full precision training."""
    print('saving...')   
    torch.save({
        'net_state_dict': net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        }, file_name)

#VGG

In [None]:
import argparse

import torch
import math
from torchvision import datasets, transforms 

#from binarynet import ConvLayer_bin, FCLayer_bin
#from myoptimizer import ALQ_optimizer
#from train import get_accuracy, train_fullprecision, train_basis, train_basis_STE, train_coordinate, validate, test, prune, initialize, save_model, save_model_ori


# Defining the network (VGG_small)  
class VGG_small(torch.nn.Module):
    def __init__(self):
        super(VGG_small, self).__init__()
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(num_features=128, affine=True, momentum=0.1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.BatchNorm2d(num_features=128, affine=True, momentum=0.1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(num_features=256, affine=True, momentum=0.1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.BatchNorm2d(num_features=256, affine=True, momentum=0.1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(num_features=512, affine=True, momentum=0.1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.BatchNorm2d(num_features=512, affine=True, momentum=0.1),
            torch.nn.ReLU(inplace=True),
            )
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(512*4*4, 1024),
            torch.nn.BatchNorm1d(num_features=1024, affine=True, momentum=0.1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(1024, 1024),
            torch.nn.BatchNorm1d(num_features=1024, affine=True, momentum=0.1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(1024, 10),
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


**ON CIFAR 10**

In [None]:
import sys
sys.argv=['--PRETRAIN','--ALQ','--POSTTRAIN']
#del sys

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='./data',
                        help='CIFAR10 dataset directory')
    parser.add_argument('--val_size', type=int, default=5000,
                        help='the number of samples in validation dataset')
    parser.add_argument('--model_ori', type=str, default='./vgg_small_model_ori.pth', 
                        help='the file of the original full precision vgg_small model')
    parser.add_argument('--model', type=str, default='./vgg_small_model.pth', 
                        help='the file of the quantized vgg_small model')
    parser.add_argument('--PRETRAIN', action='store_true', 
                        help='train the original full precision vgg_small model')
    parser.add_argument('--ALQ', action='store_true',  
                        help='adaptive loss-aware quantize vgg_small model')
    parser.add_argument('--POSTTRAIN', action='store_true', 
                        help='posttrain the final quantized vgg_small model')
    parser.add_argument('--lr', type=float, default=1e-3,
                        help='learning rate')
    parser.add_argument('--R', type=int, default=1,
                        help='the number of outer iterations, also the number of pruning')
    parser.add_argument('--epoch_prune', type=int, default=1,
                        help='the number of epochs for pruning')
    parser.add_argument('--epoch_basis', type=int, default=1,
                        help='the number of epochs for optimizing bases')
    parser.add_argument('--ld_basis', type=float, default=0.8,
                        help='learning rate decay factor for optimizing bases')
    parser.add_argument('--epoch_coord', type=int, default=1,
                        help='the number of epochs for optimizing coordinates')
    parser.add_argument('--ld_coord', type=float, default=0.8,
                        help='learning rate decay factor for optimizing coordinates')
    parser.add_argument('--wd', type=float, default=0.,
                        help='weight decay')
    parser.add_argument('--pr', type=float, default=0.4,
                        help='the pruning ratio of alpha')
    parser.add_argument('--top_k', type=float, default=0.002,
                        help='the ratio of selected alpha in each layer for resorting')
    parser.add_argument('--structure', type=str, nargs='+', choices=['channelwise', 'kernelwise', 'pixelwise', 'subchannelwise'], 
                        default=['channelwise','pixelwise','pixelwise','pixelwise','pixelwise','pixelwise','subchannelwise','subchannelwise','subchannelwise'],
                        help='the structure-wise used in each layer')
    parser.add_argument('--subc', type=int, nargs='+', default=[0,0,0,0,0,0,16,2,2],
                        help='number of subchannels when using subchannelwise')
    parser.add_argument('--max_bit', type=int, nargs='+', default=[6,6,6,6,6,6,6,6,6],
                        help='the maximum bitwidth used in initialization')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='the number of training samples in each batch')
    args = parser.parse_args()
    
    torch.backends.cudnn.benchmark = True
    train_dataset_full = datasets.CIFAR10(args.data, train=True, download=True, transform=transforms.Compose([
                       
                        transforms.RandomCrop(32, padding=4),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                        ]))
                   
    test_dataset_full = datasets.CIFAR10(args.data, train=False, download=True, transform=transforms.Compose([
                       
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                        ]))
    val_dataset, train_dataset = torch.utils.data.random_split(train_dataset_full, [args.val_size, len(train_dataset_full)-args.val_size])
    num_training_sample = len(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=8) 
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=8)
    test_loader = torch.utils.data.DataLoader(test_dataset_full, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=8)


    
    print('pretraining...')
    net = VGG_small().cuda()
    loss_func = torch.nn.CrossEntropyLoss().cuda()
        
    optimizer = torch.optim.Adam(net.parameters(),lr=5e-2)
    get_accuracy(net, train_loader, loss_func)
    val_accuracy = validate(net, val_loader, loss_func)
    best_acc = val_accuracy[0]
    test(net, test_loader, loss_func)
    save_model_ori(args.model_ori, net, optimizer)
        
    for epoch in range(1):
            if epoch%30 == 0:
                optimizer.param_groups[0]['lr'] *= 0.2
            train_fullprecision(net, train_loader, loss_func, optimizer, epoch)
            val_accuracy = validate(net, val_loader, loss_func)
            if val_accuracy[0]>best_acc:
                best_acc = val_accuracy[0]
                test(net, test_loader, loss_func)
                save_model_ori(args.model_ori, net, optimizer) 
        
    # End of the for loop

    
    print('adaptive loss-aware quantization...')

    net = VGG_small().cuda()
    loss_func = torch.nn.CrossEntropyLoss().cuda() 

    print('loading pretrained full precision vgg_small model ...')
    checkpoint = torch.load(args.model_ori)
    net.load_state_dict(checkpoint['net_state_dict'])
    for name, param in net.named_parameters():
            print(name)
            print(param.size())   

    print('initialization (structured sketching)...')
    parameters_w, parameters_b, parameters_w_bin = initialize(net, train_loader, loss_func, args.structure, args.subc, args.max_bit)
    optimizer_b = torch.optim.Adam(parameters_b, weight_decay=args.wd) 
    optimizer_w = ALQ_optimizer(parameters_w, weight_decay=args.wd)
    val_accuracy = validate(net, val_loader, loss_func)
    best_acc = val_accuracy[0]
    test(net, test_loader, loss_func)
    save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)

    M_p = (args.pr/args.top_k)/(args.epoch_prune*math.ceil(num_training_sample/args.batch_size))

    for r in range(args.R):

            print('outer iteration: ', r)
            optimizer_b.param_groups[0]['lr'] = args.lr
            optimizer_w.param_groups[0]['lr'] = args.lr
            
            print('optimizing basis...')
            for q_epoch in range(args.epoch_basis):
                optimizer_b.param_groups[0]['lr'] *= args.ld_basis
                optimizer_w.param_groups[0]['lr'] *= args.ld_basis
                train_basis(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, q_epoch)
                val_accuracy = validate(net, val_loader, loss_func)
                if val_accuracy[0]>best_acc:
                    best_acc = val_accuracy[0]
                    test(net, test_loader, loss_func)
                    #save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)
            
            print('optimizing coordinates...')
            for p_epoch in range(args.epoch_coord):
                optimizer_b.param_groups[0]['lr'] *= args.ld_coord
                optimizer_w.param_groups[0]['lr'] *= args.ld_coord
                train_coordinate(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, p_epoch)
                val_accuracy = validate(net, val_loader, loss_func)
                if val_accuracy[0]>best_acc:
                    best_acc = val_accuracy[0]
                    test(net, test_loader, loss_func)
                    #save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)
                    
            print('pruning...')
            for t_epoch in range(args.epoch_prune):
                prune(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, [args.top_k, M_p], t_epoch)
                val_accuracy = validate(net, val_loader, loss_func)
                best_acc = val_accuracy[0]
                test(net, test_loader, loss_func)
                save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)


    
    print('posttraining...')
            
    net = VGG_small().cuda()
    loss_func = torch.nn.CrossEntropyLoss().cuda()

    parameters_w = []
    parameters_b = []
    for name, param in net.named_parameters():
            if 'weight' in name and param.dim()>1:
                parameters_w.append(param)
            else:
                parameters_b.append(param)

    optimizer_b = torch.optim.Adam(parameters_b, weight_decay=args.wd) 
    optimizer_w = ALQ_optimizer(parameters_w, weight_decay=args.wd)
        
    print('load quantized vgg_small model...')
    checkpoint = torch.load(args.model)
    net.load_state_dict(checkpoint['net_state_dict'])
    optimizer_w.load_state_dict(checkpoint['optimizer_w_state_dict'])
    optimizer_b.load_state_dict(checkpoint['optimizer_b_state_dict'])
    for state in optimizer_b.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()
    for state in optimizer_w.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()

    num_weight_layer = 0.
    num_bit_layer = 0.
    print('currrent binary filter number per layer: ')
    for p_w_bin in parameters_w_bin:
            print(p_w_bin.num_bin_filter)
    print('currrent average bitwidth per layer: ')
    for p_w_bin in parameters_w_bin:
            num_weight_layer += p_w_bin.num_weight
            num_bit_layer += p_w_bin.avg_bit*p_w_bin.num_weight
            print(p_w_bin.avg_bit)
    print('currrent average bitwidth: ', num_bit_layer/num_weight_layer)

    get_accuracy(net, train_loader, loss_func)
    val_accuracy = validate(net, val_loader, loss_func)
    best_acc = val_accuracy[0]
    test(net, test_loader, loss_func)
    optimizer_b.param_groups[0]['lr'] = args.lr
    optimizer_w.param_groups[0]['lr'] = args.lr
        
    print('optimizing basis with STE...')
    for epoch in range(1):
            optimizer_b.param_groups[0]['lr'] *= 0.95
            optimizer_w.param_groups[0]['lr'] *= 0.95
            train_basis_STE(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, epoch)
            val_accuracy = validate(net, val_loader, loss_func)
            if val_accuracy[0]>best_acc:
                best_acc = val_accuracy[0]
                test(net, test_loader, loss_func)
                save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)
        
    print('optimizing coordinates...')
    for epoch in range(2):
            optimizer_b.param_groups[0]['lr'] *= 0.9
            optimizer_w.param_groups[0]['lr'] *= 0.9
            train_coordinate(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, epoch)
            val_accuracy = validate(net, val_loader, loss_func)
            if val_accuracy[0]>best_acc:
                best_acc = val_accuracy[0]
                test(net, test_loader, loss_func)
                save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)


Files already downloaded and verified
Files already downloaded and verified
pretraining...
training loss:  2.302714389833537
training accuracy:  [0.10011111111111111, 0.4992]
validation loss:  2.3025242924690246
validation accuracy:  [0.099, 0.5072]
test loss:  2.3027364573901212
test accuracy:  [0.1, 0.5]
saving...
epoch:  0 , training loss:  1.818170333111828
training accuracy:  [0.35755555555555557, 0.8587111111111111]
validation loss:  1.4910687416791917
validation accuracy:  [0.4854, 0.912]
test loss:  1.4532442455050312
test accuracy:  [0.4854, 0.9223]
saving...
adaptive loss-aware quantization...
loading pretrained full precision vgg_small model ...
features.0.weight
torch.Size([128, 3, 3, 3])
features.0.bias
torch.Size([128])
features.1.weight
torch.Size([128])
features.1.bias
torch.Size([128])
features.3.weight
torch.Size([128, 128, 3, 3])
features.3.bias
torch.Size([128])
features.5.weight
torch.Size([128])
features.5.bias
torch.Size([128])
features.7.weight
torch.Size([256, 

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.)


epoch:  0 , training loss:  1.2342727302827619
training accuracy:  [0.5473777777777777, 0.9494888888888889]
validation loss:  1.1565889060497283
validation accuracy:  [0.5714, 0.9518]
test loss:  1.142047124572947
test accuracy:  [0.5825, 0.9571]
optimizing coordinates...
epoch:  0 , training loss:  1.1216780194504694
training accuracy:  [0.5911777777777778, 0.9598888888888889]
validation loss:  1.0328823938965797
validation accuracy:  [0.6204, 0.9674]
test loss:  1.000333789783188
test accuracy:  [0.6341, 0.967]
pruning...
epoch:  0 , pruning loss:  0.9933832032098011
pruning accuracy:  [0.6419555555555555, 0.9682444444444445]
currrent number of binary filters per layer: 
tensor(765, device='cuda:0')
tensor(6663, device='cuda:0')
tensor(11697, device='cuda:0')
tensor(10944, device='cuda:0')
tensor(19762, device='cuda:0')
tensor(18916, device='cuda:0')
tensor(57010, device='cuda:0')
tensor(9198, device='cuda:0')
tensor(119, device='cuda:0')
currrent average bitwidth per layer: 
tensor(

copy

In [None]:
import sys
sys.argv=['--PRETRAIN','--ALQ','--POSTTRAIN']
#del sys

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='./data',
                        help='CIFAR10 dataset directory')
    parser.add_argument('--val_size', type=int, default=5000,
                        help='the number of samples in validation dataset')
    parser.add_argument('--model_ori', type=str, default='./vgg_small_model_ori.pth', 
                        help='the file of the original full precision vgg_small model')
    parser.add_argument('--model', type=str, default='./vgg_small_model.pth', 
                        help='the file of the quantized vgg_small model')
    parser.add_argument('--PRETRAIN', action='store_true', 
                        help='train the original full precision vgg_small model')
    parser.add_argument('--ALQ', action='store_true',  
                        help='adaptive loss-aware quantize vgg_small model')
    parser.add_argument('--POSTTRAIN', action='store_true', 
                        help='posttrain the final quantized vgg_small model')
    parser.add_argument('--lr', type=float, default=1e-3,
                        help='learning rate')
    parser.add_argument('--R', type=int, default=1,
                        help='the number of outer iterations, also the number of pruning')
    parser.add_argument('--epoch_prune', type=int, default=1,
                        help='the number of epochs for pruning')
    parser.add_argument('--epoch_basis', type=int, default=1,
                        help='the number of epochs for optimizing bases')
    parser.add_argument('--ld_basis', type=float, default=0.8,
                        help='learning rate decay factor for optimizing bases')
    parser.add_argument('--epoch_coord', type=int, default=1,
                        help='the number of epochs for optimizing coordinates')
    parser.add_argument('--ld_coord', type=float, default=0.8,
                        help='learning rate decay factor for optimizing coordinates')
    parser.add_argument('--wd', type=float, default=0.,
                        help='weight decay')
    parser.add_argument('--pr', type=float, default=0.4,
                        help='the pruning ratio of alpha')
    parser.add_argument('--top_k', type=float, default=0.002,
                        help='the ratio of selected alpha in each layer for resorting')
    parser.add_argument('--structure', type=str, nargs='+', choices=['channelwise', 'kernelwise', 'pixelwise', 'subchannelwise'], 
                        default=['channelwise','pixelwise','pixelwise','pixelwise','pixelwise','pixelwise','subchannelwise','subchannelwise','subchannelwise'],
                        help='the structure-wise used in each layer')
    parser.add_argument('--subc', type=int, nargs='+', default=[0,0,0,0,0,0,16,2,2],
                        help='number of subchannels when using subchannelwise')
    parser.add_argument('--max_bit', type=int, nargs='+', default=[6,6,6,6,6,6,6,6,6],
                        help='the maximum bitwidth used in initialization')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='the number of training samples in each batch')
    args = parser.parse_args()
    
    torch.backends.cudnn.benchmark = True
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from IPython import embed
import os

def get(batch_size, data_root='/mnt/local0/public_dataset/pytorch/', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'stl10-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building STL10 data loader with {} workers".format(num_workers))
    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.STL10(
                root=data_root, split='train', download=True,
                transform=transforms.Compose([
                    transforms.Pad(4),
                    transforms.RandomCrop(96),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)

    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.STL10(
                root=data_root, split='test', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)

    ds = ds[0] if len(ds) == 1 else ds
    return ds

if __name__ == '__main__':
    train_ds, test_ds = get(200, num_workers=1)
    for data, target in train_ds:
        print("~~")


    
    print('pretraining...')
    net = VGG_small().cuda()
    loss_func = torch.nn.CrossEntropyLoss().cuda()
        
    optimizer = torch.optim.Adam(net.parameters(),lr=5e-2)
    get_accuracy(net, train_loader, loss_func)
    val_accuracy = validate(net, val_loader, loss_func)
    best_acc = val_accuracy[0]
    test(net, test_loader, loss_func)
    save_model_ori(args.model_ori, net, optimizer)
        
    for epoch in range(1):
            if epoch%30 == 0:
                optimizer.param_groups[0]['lr'] *= 0.2
            train_fullprecision(net, train_loader, loss_func, optimizer, epoch)
            val_accuracy = validate(net, val_loader, loss_func)
            if val_accuracy[0]>best_acc:
                best_acc = val_accuracy[0]
                test(net, test_loader, loss_func)
                save_model_ori(args.model_ori, net, optimizer) 
        
    # End of the for loop

    
    print('adaptive loss-aware quantization...')

    net = VGG_small().cuda()
    loss_func = torch.nn.CrossEntropyLoss().cuda() 

    print('loading pretrained full precision vgg_small model ...')
    checkpoint = torch.load(args.model_ori)
    net.load_state_dict(checkpoint['net_state_dict'])
    for name, param in net.named_parameters():
            print(name)
            print(param.size())   

    print('initialization (structured sketching)...')
    parameters_w, parameters_b, parameters_w_bin = initialize(net, train_loader, loss_func, args.structure, args.subc, args.max_bit)
    optimizer_b = torch.optim.Adam(parameters_b, weight_decay=args.wd) 
    optimizer_w = ALQ_optimizer(parameters_w, weight_decay=args.wd)
    val_accuracy = validate(net, val_loader, loss_func)
    best_acc = val_accuracy[0]
    test(net, test_loader, loss_func)
    save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)

    M_p = (args.pr/args.top_k)/(1+args.epoch_prune*math.ceil(num_training_sample/args.batch_size))

    for r in range(args.R):

            print('outer iteration: ', r)
            optimizer_b.param_groups[0]['lr'] = args.lr
            optimizer_w.param_groups[0]['lr'] = args.lr
            
            print('optimizing basis...')
            for q_epoch in range(args.epoch_basis):
                optimizer_b.param_groups[0]['lr'] *= args.ld_basis
                optimizer_w.param_groups[0]['lr'] *= args.ld_basis
                train_basis(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, q_epoch)
                val_accuracy = validate(net, val_loader, loss_func)
                if val_accuracy[0]>best_acc:
                    best_acc = val_accuracy[0]
                    test(net, test_loader, loss_func)
                    #save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)
            
            print('optimizing coordinates...')
            for p_epoch in range(args.epoch_coord):
                optimizer_b.param_groups[0]['lr'] *= args.ld_coord
                optimizer_w.param_groups[0]['lr'] *= args.ld_coord
                train_coordinate(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, p_epoch)
                val_accuracy = validate(net, val_loader, loss_func)
                if val_accuracy[0]>best_acc:
                    best_acc = val_accuracy[0]
                    test(net, test_loader, loss_func)
                    #save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)
                    
            print('pruning...')
            for t_epoch in range(args.epoch_prune):
                prune(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, [args.top_k, M_p], t_epoch)
                val_accuracy = validate(net, val_loader, loss_func)
                best_acc = val_accuracy[0]
                test(net, test_loader, loss_func)
                save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)


    
    print('posttraining...')
            
    net = VGG_small().cuda()
    loss_func = torch.nn.CrossEntropyLoss().cuda()

    parameters_w = []
    parameters_b = []
    for name, param in net.named_parameters():
            if 'weight' in name and param.dim()>1:
                parameters_w.append(param)
            else:
                parameters_b.append(param)

    optimizer_b = torch.optim.Adam(parameters_b, weight_decay=args.wd) 
    optimizer_w = ALQ_optimizer(parameters_w, weight_decay=args.wd)
        
    print('load quantized vgg_small model...')
    checkpoint = torch.load(args.model)
    net.load_state_dict(checkpoint['net_state_dict'])
    optimizer_w.load_state_dict(checkpoint['optimizer_w_state_dict'])
    optimizer_b.load_state_dict(checkpoint['optimizer_b_state_dict'])
    for state in optimizer_b.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()
    for state in optimizer_w.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()

    num_weight_layer = 0.
    num_bit_layer = 0.
    print('currrent binary filter number per layer: ')
    for p_w_bin in parameters_w_bin:
            print(p_w_bin.num_bin_filter)
    print('currrent average bitwidth per layer: ')
    for p_w_bin in parameters_w_bin:
            num_weight_layer += p_w_bin.num_weight
            num_bit_layer += p_w_bin.avg_bit*p_w_bin.num_weight
            print(p_w_bin.avg_bit)
    print('currrent average bitwidth: ', num_bit_layer/num_weight_layer)

    get_accuracy(net, train_loader, loss_func)
    val_accuracy = validate(net, val_loader, loss_func)
    best_acc = val_accuracy[0]
    test(net, test_loader, loss_func)
    optimizer_b.param_groups[0]['lr'] = args.lr
    optimizer_w.param_groups[0]['lr'] = args.lr
        
    print('optimizing basis with STE...')
    for epoch in range(1):
            optimizer_b.param_groups[0]['lr'] *= 0.95
            optimizer_w.param_groups[0]['lr'] *= 0.95
            train_basis_STE(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, epoch)
            val_accuracy = validate(net, val_loader, loss_func)
            if val_accuracy[0]>best_acc:
                best_acc = val_accuracy[0]
                test(net, test_loader, loss_func)
                save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)
        
    print('optimizing coordinates...')
    for epoch in range(2):
            optimizer_b.param_groups[0]['lr'] *= 0.9
            optimizer_w.param_groups[0]['lr'] *= 0.9
            train_coordinate(net, train_loader, loss_func, optimizer_w, optimizer_b, parameters_w_bin, epoch)
            val_accuracy = validate(net, val_loader, loss_func)
            if val_accuracy[0]>best_acc:
                best_acc = val_accuracy[0]
                test(net, test_loader, loss_func)
                save_model(args.model, net, optimizer_w, optimizer_b, parameters_w_bin)


Building STL10 data loader with 1 workers
Files already downloaded and verified
Files already downloaded and verified
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
~~
pretraining...
training loss:  2.302663889798251
training accuracy:  [0.09262222222222222, 0.4994]
validation loss:  2.302672189474106
validation accuracy:  [0.0978, 0.5044]
test loss:  2.302629494968849
test accuracy:  [0.094, 0.5]
saving...
epoch:  0 , training loss:  1.832166663286361
training accuracy:  [0.35486666666666666, 0.8521111111111112]
validation loss:  1.5255082100629807
validation accuracy:  [0.4416, 0.915]
test loss:  1.4480182656758949
test accuracy:  [0.4677, 0.9266]
saving...
adaptive loss-aware quantization...
loading pretrained full precision vgg_small model ...
features.0.weight
torch.Size([128, 3, 3, 3])
features.0.bias
torch.Size([128])
features.1.weight
torch.Size([128])
features.1.bias
torch.Size([128])
features.3.weight
torch.Size([128, 128, 3, 3])
features.3.bias
torch