# Implementation of Accurate Binary Convolution Layer
[Original Paper](https://arxiv.org/abs/1711.11294)

The inspiration for this network is the use of Deep Neural Networks for real-time object recognition. Currently available **Convolution Layers** require large amount of computation power at runtime and that hinders the use of very deep networks in embedded systems or ASICs. Xiaofan Lin, Cong Zhao, and Wei Pan presented a way to convert Convolution Layers to **Binary Convolution Layers** for faster realtime computation.

The inspiration for this network is the use of Deep Neural Networks for real-time object recognition. Currently available **Convolution Layers** require large amount of computation power at runtime and that hinders the use of very deep networks in embedded systems or ASICs. Xiaofan Lin, Cong Zhao, and Wei Pan presented a way to convert Convolution Layers to **Binary Convolution Layers** for faster realtime computation.

The inspiration for this network is the use of Deep Neural Networks for real-time object recognition. Currently available **Convolution Layers** require large amount of computation power at runtime and that hinders the use of very deep networks in embedded systems or ASICs. Xiaofan Lin, Cong Zhao, and Wei Pan presented a way to convert Convolution Layers to **Binary Convolution Layers** for faster realtime computation.

We'll need mean and standard deviation of the complete convolution filters

In [4]:
import torch.nn as nn
import torch



class ABCConv2d(nn.Module):
    def __init__(self, input_channels, output_channels,
                 kernel_size=-1, stride=-1, padding=-1, groups=1, dropout=0.0,
                 linear=False, base_number=3):
        super(ABCConv2d, self).__init__()
        assert base_number == 3 or base_number == 1, "support base_number == 3 or base_number == 1 "
        self.layer_type = 'ABC_Conv2d'
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dropout_ratio = dropout
        self.base_number = base_number
        if dropout != 0:
            self.dropout = nn.Dropout(dropout)
        self.linear = linear
        if not self.linear:
            self.bn = nn.BatchNorm2d(input_channels, eps=1e-4, momentum=0.1, affine=True)
            if self.base_number == 1:
                self.bases_conv2d_1 = nn.Conv2d(input_channels, output_channels,
                                                kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
            else:

                self.bases_conv2d_1 = nn.Conv2d(input_channels, output_channels,
                                                kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
                self.bases_conv2d_2 = nn.Conv2d(input_channels, output_channels,
                                                kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
                self.bases_conv2d_3 = nn.Conv2d(input_channels, output_channels,
                                                kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)

        else:
            self.bn = nn.BatchNorm1d(input_channels, eps=1e-4, momentum=0.1, affine=True)
            if self.base_number == 1:
                self.bases_linear_1 = nn.Linear(input_channels, output_channels)
            else:
                self.bases_linear_1 = nn.Linear(input_channels, output_channels)
                self.bases_linear_2 = nn.Linear(input_channels, output_channels)
                self.bases_linear_3 = nn.Linear(input_channels, output_channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.bn(x)
        # x = BinActive()(x)
        if self.dropout_ratio != 0:
            x = self.dropout(x)
        if self.base_number == 1:
            if not self.linear:
                x = self.bases_conv2d_1(x)
            else:
                x = self.bases_linear_1(x)
        else:
            if not self.linear:
                x = self.bases_conv2d_1(x) + self.bases_conv2d_2(x) + self.bases_conv2d_3(x)
            else:
                x = self.bases_linear_1(x) + self.bases_linear_2(x) + self.bases_linear_3(x)
        x = self.relu(x)
        return x


BinConv2d = ABCConv2d


class AlexNet(nn.Module):

    def __init__(self, num_classes=10, base_number=3):
        super(AlexNet, self).__init__()
        self.num_classes = num_classes
        self.base_number = base_number
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=0),
            nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=True),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            BinConv2d(64, 192, kernel_size=5, stride=1, padding=2, groups=1, base_number=self.base_number),
            nn.MaxPool2d(kernel_size=3, stride=2),
            BinConv2d(192, 384, kernel_size=3, stride=1, padding=1, base_number=self.base_number),
            BinConv2d(384, 256, kernel_size=3, stride=1, padding=1, groups=1, base_number=self.base_number),
            BinConv2d(256, 256, kernel_size=3, stride=1, padding=1, groups=1, base_number=self.base_number),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            BinConv2d(256 * 6 * 6, 4096, linear=True, base_number=self.base_number),
            BinConv2d(4096, 4096, dropout=0.1, linear=True, base_number=self.base_number),
            nn.BatchNorm1d(4096, eps=1e-3, momentum=0.1, affine=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

    def my_model_loader(self, state_dict, strict=True):
        own_state = self.state_dict()
        # map fp model to ABC-Net
        load_map = \
            {
                'features.0.weight': 'features.0.weight',
                'features.0.bias': 'features.0.bias',
                'features.4.bases_conv2d_1.weight': 'features.3.weight',
                'features.4.bases_conv2d_1.bias': 'features.3.bias',
                'features.6.bases_conv2d_1.weight': 'features.6.weight',
                'features.6.bases_conv2d_1.bias': 'features.6.bias',
                'features.7.bases_conv2d_1.weight': 'features.8.weight',
                'features.7.bases_conv2d_1.bias': 'features.8.bias',
                'features.8.bases_conv2d_1.weight': 'features.10.weight',
                'features.8.bases_conv2d_1.bias': 'features.10.bias',
                'classifier.0.bases_linear_1.weight': 'classifier.1.weight',
                'classifier.0.bases_linear_1.bias': 'classifier.1.bias',
                'classifier.1.bases_linear_1.weight': 'classifier.4.weight',
                'classifier.1.bases_linear_1.bias': 'classifier.4.bias',
                'classifier.4.weight': 'classifier.6.weight',
                'classifier.4.bias': 'classifier.6.bias',
            }

        for k, v in load_map.items():
            own_state[k].copy_(state_dict[v].data)


def alexnet(pretrained=False, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = AlexNet(**kwargs)
    if pretrained:
        model_path = 'model_list/alexnet_fp_pretrained.pth'
        pretrained_model = torch.load(model_path)
        model.my_model_loader(pretrained_model)
    return model


We need to spread the standard deviation by the number of filters being used as in the original paper
$\mu_i= -1 + (i - 1)\frac{2}{\mathbf{M} - 1}$

Now, we can get the values of $\mathbf{B_{i}s}$

In [5]:
import torch.nn as nn
import torch
import numpy
from sklearn.linear_model import LinearRegression
import platform


class BinOp():
    def __init__(self, model):
        self.base_number = model.base_number
        # count the number of Conv2d and Linear
        count_targets = 0
        for m in model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                count_targets = count_targets + 1

        start_range = 1
        end_range = count_targets - 2
        self.bin_range = numpy.linspace(start_range,
                                        end_range, end_range - start_range + 1) \
            .astype('int').tolist()
        self.num_of_params = len(self.bin_range)
        self.saved_params = []
        self.target_params = []
        self.target_modules = []
        self.alphas = []
        index = -1
        for m in model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                index = index + 1
                if index in self.bin_range:
                    tmp = m.weight.data.clone()
                    self.saved_params.append(tmp)
                    self.target_modules.append(m.weight)

        for index_conv in range(int(self.num_of_params / self.base_number)):
            self.alphas.append(torch.zeros(self.base_number))

    def binarization(self):
        # self.meancenterConvParams()
        self.clampConvParams()
        self.save_params()
        self.binarizeConvParams()

    def clampConvParams(self):
        for index in range(int(self.num_of_params / self.base_number)):
            self.target_modules[index * self.base_number].data=torch.clamp(
                                                                     self.target_modules[
                                                                         index * self.base_number].data,-1.0, 1.0)

    def save_params(self):
        for index in range(int(self.num_of_params / self.base_number)):
            self.saved_params[index * self.base_number].copy_(self.target_modules[index * self.base_number].cpu().data)

    def ABC_binarizeConvParams(self):
        for index_conv in range(int(self.num_of_params / self.base_number)):
            n_vec = self.target_modules[index_conv * self.base_number].data.nelement()
            k_size = self.target_modules[index_conv * self.base_number].data.size()

            W = self.target_modules[index_conv * self.base_number].data.view(n_vec)

            W_neg_mean = W.mean(dim=0, keepdim=True).neg().expand(n_vec)
            W_std = W.std(dim=0, keepdim=True).expand(n_vec)
            if self.base_number == 1:
                B = W.add(W_neg_mean).sign().view(1, n_vec)
            if self.base_number == 3:
                t1 = W.add(W_neg_mean).add(W_std.mul(-1)).sign().view(1, n_vec)
                t2 = W.add(W_neg_mean).sign().view(1, n_vec)
                t3 = W.add(W_neg_mean).add(W_std).sign().view(1, n_vec)
                B = torch.cat((t1, t2, t3))
            # for base in range(self.base_number):
            #     u_i=-1 + base * 2 / (self.base_number-1)
            #     t=W.add(W_neg_mean).add(W_std.mul(u_i)).sign()
            #     if base==0:
            #         B=t.view(1,n_vec)
            #     else:
            #         B=torch.cat((B,t.view(1,n_vec)))
            LRM = LinearRegression()
            LRM.fit(B.t().cpu(), W.cpu())
            # alpha = torch.from_numpy(LRM.coef_)
            if platform.system() == "Windows":
                alpha = torch.Tensor(LRM.coef_)
            else:
                alpha = torch.Tensor(LRM.coef_).cpu().cuda()

            self.alphas[index_conv].copy_(alpha)
            for base in range(self.base_number):
                self.target_modules[index_conv * self.base_number + base].data.copy_(
                    B[base].mul(alpha[base]).view(k_size))

    def ABC_updateBinaryGradWeight(self):
        # original version:
        for index_conv in range(int(self.num_of_params / self.base_number)):
            if self.base_number == 1:
                pass
            if self.base_number == 3:
                # explanation of dW=dB*alpha^2:
                # dB=d(L)/d(alpha*B)=1/alpha*d(L)/d(B)
                alpha_dB1 = self.target_modules[index_conv * self.base_number].grad.data. \
                    mul(self.alphas[index_conv][0] * self.alphas[index_conv][0])
                alpha_dB2 = self.target_modules[index_conv * self.base_number + 1].grad.data. \
                    mul(self.alphas[index_conv][1] * self.alphas[index_conv][1])
                alpha_dB3 = self.target_modules[index_conv * self.base_number + 2].grad.data. \
                    mul(self.alphas[index_conv][2] * self.alphas[index_conv][2])

                dW = alpha_dB1.add(alpha_dB2).add(alpha_dB3)
                # attach STE to single base OR the sum of them?
                W = self.target_modules[index_conv * self.base_number].data
                dW[W.lt(-1)] = 0
                dW[W.gt(1)] = 0
                dW.mul(1e+9)
                self.target_modules[index_conv * self.base_number].grad.data.copy_(dW)

    binarizeConvParams = ABC_binarizeConvParams
    updateBinaryGradWeight = ABC_updateBinaryGradWeight

    def restore(self):
        for index in range(int(self.num_of_params / self.base_number)):
            self.target_modules[index * self.base_number].data.copy_(self.saved_params[index * self.base_number])

    
    #

#### Calculating alphas
Now, we can calculate alphas using the *binary filters* and *convolution filters* by minimizing the *squared difference*
$\|\mathbf{W}-\mathbf{B}\alpha\|^2$

### Creating ApproxConv using the binary filters
$\mathbf{O}=\sum\limits_{m=1}^M\alpha_m\operatorname{Conv}(\mathbf{B}_m, \mathbf{A})$

As in mentioned in the paper, it is better to train the network first with simple Convolution networks and then convert the filters into the binary filters, allowing original filters to be trained.

### Multiple binary activations and bitwise convolution
Now, convolution can be achieved using just the summation operations by using the ApproxConv layers. But the paper suggests something even better. We can even bypass the summation through bitwise operations only, if the input to the convolution layer is also binarized.
For that the authors suggests that an input can be binarized (creating multiple inputs) by shifting the inputs and binarizing them.

First, the input is clipped between 0. and 1. using multiple shift parameters $\nu$, learnable by the network  
$\operatorname{h_{\nu}}(x)=\operatorname{clip}(x + \nu, 0, 1)$  
  
Then using the following function it is binarized  
$\operatorname{H_{\nu}}(\mathbf{R})=2\mathbb{I}_{\operatorname{h_{\nu}}(\mathbf{R})\geq0.5}-1$

The above function can be implemented as  
$\operatorname{H_{\nu}}(\mathbf{R})=\operatorname{sign}(\mathbf{R} - 0.5)$

Now, after calculating the **ApproxConv** over each separated input, their weighted summation can be taken using trainable paramters $\beta s$

## Testing
Let's just test our network using MNIST

In [None]:
import argparse
import os
import shutil
import time
import sys
import gc
import platform
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision
import sys
sys.argv=['']
del sys


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

parser.add_argument('--data', metavar='DATA_PATH', default='./data/',
                    help='path to imagenet data (default: ./data/)')

parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                    help='number of data loading workers (default: 8)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--base_number', default=3, type=int,
                    metavar='N', help='base_number (default: 3)')
parser.add_argument('--epochs', default=5, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.90, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float,
                    metavar='W', help='weight decay (default: 1e-5)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    default=False, help='use pre-trained model')
parser.add_argument('--nocuda', dest='nocuda', action='store_true',
                    help='running on no cuda')
best_prec1 = 0

# define global bin_op
bin_op = None

# define optimizer
optimizer = None


def main():
    global args, best_prec1
    args = parser.parse_args()

    if platform.system() == "Windows":
        args.nocuda = True
    else:
        args.nocuda = False

    # create model
    
    model = alexnet(pretrained=args.pretrained, base_number=args.base_number)
    input_size = 227
    model.features = torch.nn.DataParallel(model.features)
    if not args.nocuda:
        # set the seed
        torch.manual_seed(1)
        torch.cuda.manual_seed(1)
        model.cuda()
        # define loss function (criterion) and optimizer
        criterion = nn.CrossEntropyLoss().cuda()
        # Set benchmark
        cudnn.benchmark = True
    else:
        criterion = nn.CrossEntropyLoss()

    global optimizer
    optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                 weight_decay=args.weight_decay)
    # random initialization
    if not args.pretrained:
        for m in model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                c = float(m.weight.data[0].nelement())
                m.weight.data = m.weight.data.normal_(0, 1.0 / c)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data = m.weight.data.zero_().add(1.0)
    else:
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.weight.data = m.weight.data.zero_().add(1.0)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            # original saved file with DataParallel
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            print(checkpoint)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            del checkpoint
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    transform_val = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    traindir = os.path.join(args.data, 'ILSVRC2012_img_train')
    valdir = os.path.join(args.data, 'ILSVRC2012_img_val')
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
    val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
    if not args.nocuda:
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)

        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
    else:
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers)

        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers)

    print(model)

    # define the binarization operator
    global bin_op
    bin_op =BinOp(model)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, is_best)


def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        if not args.nocuda:
            target = target.cuda(non_blocking=True)
            input_var = torch.autograd.Variable(input).cuda()
        else:
            input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        # process the weights including binarization
        bin_op.binarization()

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data, input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()

        # restore weights
        bin_op.restore()
        bin_op.updateBinaryGradWeight()

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 :
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))

        # because the training process is too slow
        if i % 100 == 99:
            save_checkpoint({
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, False, filename="checkpoint_every_100_batches.pth.tar")
        gc.collect()


def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    bin_op.binarization()
    for i, (input, target) in enumerate(val_loader):

        if not args.nocuda:
            target = target.cuda(non_blocking=True)
            input_var = torch.autograd.Variable(input, volatile=True).cuda()
            target_var = torch.autograd.Variable(target, volatile=True)
        else:
            input_var = torch.autograd.Variable(input, volatile=True)
            target_var = torch.autograd.Variable(target, volatile=True)
        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data, input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        #if i % args.print_freq == 0 :
        print('Test: [{0}/{1}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                i, len(val_loader), loss=losses,
                top1=top1, top5=top5))
    bin_op.restore()

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 25 epochs"""
    #lr = args.lr * (0.1 ** (epoch // 25))   
    lr=0.2 
    print('Learning rate:', lr)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    # print(pred)
    # print(target)
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


if __name__ == '__main__':
    main()


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
AlexNet(
  (features): DataParallel(
    (module): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4))
      (1): BatchNorm2d(64, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): ABCConv2d(
        (bn): BatchNorm2d(64, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
        (bases_conv2d_1): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (bases_conv2d_2): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (bases_conv2d_3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (relu): ReLU(inplace=True)
      )
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): ABCConv2d(
        (bn): BatchNorm2d(192, 

  cpuset_checked))


Epoch: [0][0/196]	Time 28.929 (28.929)	Data 13.284 (13.284)	Loss 2.3029 (2.3029)	Prec@1 13.281 (13.281)	Prec@5 48.438 (48.438)


myfolder.py