In [0]:
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torchvision
from torchvision import datasets,transforms
import time
import math
import copy
import matplotlib.pyplot as plt
from collections import Counter
from collections import OrderedDict
dtype = (torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

In [0]:
if __name__ == '__main__':
    is_debug = True
    import pandas as pd
    trainframe = pd.DataFrame(columns =['epoch','loss','accuracy1','accuracy5'])
    testframe = pd.DataFrame(columns =['epoch','loss','accuracy1','accuracy5'])
    #User Definable
    dataset = "CIFAR10"
    start_epoch = 1
    epochs = 6
    lr = 0.01
    filename = "Sample_binresnet"
    model =AlexNet(10)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),lr = lr)
    
    is_sch_epoch = False
    is_sch_acc = False
    is_sch_valloss = False
    is_reg = False
    is_drive = False
    
    dataloaders, dataset_sizes = getDataset(dataset)
 
    criterion.type(dtype)
    model.type(dtype)
    
    model = train_model(model, criterion, optimizer,num_epochs=epochs,is_drive=is_drive)
    print(trainframe,testframe)

In [0]:
def getDataset(dataset,train_batch_size = 128, test_batch_size = 1000,train_data_shuffle = True, test_data_shuffle = False):
    if dataset == "MNIST":
        train_loader = torch.utils.data.DataLoader(
          torchvision.datasets.MNIST('C:\\Users\\ajana\\data', train=True, download=True,
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                          ])),
          batch_size=train_batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
          torchvision.datasets.MNIST('C:\\Users\\ajana\\data', train=False, transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                          ])),
          batch_size=test_batch_size, shuffle=True)
        dataloaders = {"train":train_loader,"val":test_loader}
        dataset_sizes = {"train":60000,"val":10000}
    if dataset == "CIFAR10":
        transform_train = transforms.Compose([
              transforms.RandomCrop(32, padding=4),
              transforms.RandomHorizontalFlip(),
              transforms.ToTensor(),
              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
          ])

        transform_test = transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
          ])

        trainset = torchvision.datasets.CIFAR10(
            root='C:\\Users\\ajana\\data', train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(
            trainset, batch_size=train_batch_size, shuffle=train_data_shuffle, num_workers=1)

        testset = torchvision.datasets.CIFAR10(
            root='C:\\Users\\ajana\\data', train=False, download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            testset, batch_size=test_batch_size, shuffle=test_data_shuffle, num_workers=1)
        dataloaders = {"train":train_loader,"val":test_loader}
        dataset_sizes = {"train":len(trainset),"val":len(testset)}
    if dataset == "CIFAR100":
        transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
        transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
        trainset = torchvision.datasets.CIFAR100(root = 'C:\\Users\\ajana\\data', train=True, transform=transform_train, download=True)
        train_loader =  torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=train_data_shuffle, num_workers=2)
        testset = datasets.CIFAR100(root = 'C:\\Users\\ajana\\data', train=False, transform=transform_test, download=True)
        test_loader =  torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=test_data_shuffle, num_workers=2)
        dataloaders = {"train":train_loader,"val":test_loader}
        dataset_sizes = {"train":len(trainset),"val":len(testset)}
    if is_debug:
        print("Loaded Train set of size : {} into loader with num of batches {}".format(dataset_sizes['train'],len(dataloaders['train'])))
        print("Loaded Test set of size : {} into loader with num of batches {}".format(dataset_sizes['val'],len(dataloaders['val'])))
    return dataloaders,dataset_sizes
def no_of_params(weights):
  weights = weights.detach().numpy()
  weights = weights.flatten()
  weights = weights.tolist()
  counter = Counter(weights)
  print(counter)
def binarize(quant_mode = 'det',ste = "ste_backward",**kwargs):
    class sign(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            ctx.save_for_backward(input)
            if quant_mode=='det':
                input = input.sign()
            else:
                input = input.add_(1).div_(2).add_(torch.rand(input.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)
            #with torch.no_grad():
                #no_of_params(input)
            return input
        @staticmethod
        def backward(ctx, grad_output):
            input, = ctx.saved_tensors
            if ste == "ste_backward":
                grad_input = ste_backward(input, grad_output)
                return grad_input
            if ste == "tanh_backward":
                grad_input = tanh_backward(input, grad_output)
                return grad_input
            if ste == "htanh_backward":
                grad_input = htanh_backward(input, grad_output)
                return grad_input
            if ste == "swish_backward":
                beta = kwargs.get('beta')
                grad_input,beta = swish_backward(input, grad_output, beta)
                return grad_input,beta
    return sign().apply

class weight_quantize_fn(nn.Module):
  def __init__(self):
    super(weight_quantize_fn, self).__init__()
    self.Binarize = binarize()

  def forward(self, x):
      #E = torch.mean(torch.abs(x)).detach()
      weight_q = self.Binarize(x)
      return weight_q

def ste_backward(input, grad_output):
  grad_input = grad_output.clone()
  grad_input[input.ge(1)] = 0
  grad_input[input.le(-1)] = 0
  return grad_input, None

def htanh_backward(input, grad_output):
  grad_input = grad_output.clone()
  grad_input[input.ge(1/3)] = 0
  grad_input[input.le(-1/3)] = 0
  return 3.0 * grad_input, None

def tanh_backward(input, grad_output):
  input2 = 2 * input 
  z = 1 - input2.tanh()**2
  grad_input = grad_output.clone() * z
  return grad_input,None

def soft_swish(x, beta):
  loss = beta * (2 - beta * x * torch.tanh(beta * x / 2)) \
          / (1 + (torch.exp(beta * x) + torch.exp(-beta * x)) / 2)
  return loss

def swish_backward(input, grad_output, beta):
  grad_input = grad_output.clone()
  z = soft_swish(input, beta)
  grad_input = grad_output.clone() * z
  return grad_input, beta

def keep_elements_dict(vals, kwargs):
    t = list(kwargs.keys())
    for k in t:
        if k not in vals:
            del kwargs[k]       
class BinConv2d(nn.Conv2d):
    def __init__(self,in_channels,out_channels,kernel_size,full_precision = False,**kwargs):
        self.scale = kwargs.get('scale', False)
        self.compute_scale = kwargs.get('compute_scale', 'topk')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.full_precision = full_precision
        self.binarize_input = kwargs.get('binarize_input',True)
        keep_elements_dict(['stride', 'padding', 'kernel_size','padding_mode','dilation','groups'],kwargs)
        self.stride = kwargs.get('stride',1)
        self.padding = kwargs.get('padding',0)
        self.padding_mode = kwargs.get('padding_mode','zeros')
        self.dilation = kwargs.get('dilation',1)
        self.groups = kwargs.get('groups',1)
        self.bias = False
        super(BinConv2d,self).__init__(self.in_channels,self.out_channels,self.kernel_size,**kwargs)
        self.Binarize = weight_quantize_fn()
    def forward(self,input):
        if not self.full_precision:
            if self.scale and not hasattr(self, 'alpha'):
                if self.compute_scale == 'mean':
                    with torch.no_grad():
                        tk = torch.mean(self.weight.view(self.weight.shape[0], -1), dim=1)
                elif self.compute_scale == 'topk':
                    with torch.no_grad():
                        p = 0.5
                        k = torch.ceil((1-p) * torch.prod(torch.tensor(self.weight.shape[1:]), dtype=torch.float))
                        tk = torch.topk(torch.abs(self.weight.view(self.weight.shape[0], -1)), int(k), dim=1)[0][:, -1]
                elif self.compute_scale != "":
                    with torch.no_grad():
                        val = float(self.compute_scale)
                        tk = val * torch.ones(self.weight.shape[0], requires_grad=True)
                else:
                    raise NotImplementedError("compute_scale not implemented")
                tk = tk.to(self.weight.device)

                self.alpha = nn.Parameter(tk[:, None], requires_grad=True)
            self.Modweight = self.Binarize(self.weight)
            if self.binarize_input:
                input = self.Binarize(input)
        else:
            self.Modweight = self.weight
        self.out = nn.functional.conv2d(input,self.Modweight,None,padding=self.padding,stride=self.stride)
        return self.out
class BinLinear(nn.Linear):
    def __init__(self,in_channels,out_channels,full_precision = False,**kwargs):
        self.scale = kwargs.get('scale', False)
        self.compute_scale = kwargs.get('compute_scale', 'topk')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.full_precision = full_precision
        self.bias = False
        super(BinLinear,self).__init__(self.in_channels,self.out_channels)
        self.Binarize = weight_quantize_fn()
    def forward(self,input):
        if not self.full_precision:
            if self.scale and not hasattr(self, 'alpha'):
                if self.compute_scale == 'mean':
                    with torch.no_grad():
                        tk = torch.mean(self.weight.view(self.weight.shape[0], -1), dim=1)
                elif self.compute_scale == 'topk':
                    with torch.no_grad():
                        p = 0.5
                        k = torch.ceil((1-p) * torch.prod(torch.tensor(self.weight.shape[1:]), dtype=torch.float))
                        tk = torch.topk(torch.abs(self.weight.view(self.weight.shape[0], -1)), int(k), dim=1)[0][:, -1]
                elif self.compute_scale != "":
                    with torch.no_grad():
                        val = float(self.compute_scale)
                        tk = val * torch.ones(self.weight.shape[0], requires_grad=True)
                else:
                    raise NotImplementedError("compute_scale not implemented")
                tk = tk.to(self.weight.device)

                self.alpha = nn.Parameter(tk[:, None], requires_grad=True)
            self.Modweight = self.Binarize(self.weight)
            input = self.Binarize(input)
        else:
            self.Modweight = self.weight
        self.out = nn.functional.linear(input,self.Modweight,None)
        return self.out
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from collections import OrderedDict
class BaseModel(nn.Module):
    """An abstract class representing a model architecture.
    Any model definition should subclass `BaseModel`.
    """
    def __init__(self):
        super().__init__()
        self.size = 0
        self.kaming_uniform()
        self.batch_norm_init()

    @property
    def num_params(self):
        return sum(param.numel() for param in self.parameters())

    def num_trainable_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)  # Trainable parameters
    def size_of_params(self):             
        for name,param in self.named_parameters():
            if ((("Conv2d" in name and "Bin" in name) or ("Linear" in name and "Bin" in name)) and "classifier" not in name) and "bias" not in name:
                print(name)
                self.size = self.size + param.numel()
        print("-"*50)
        print("Cumulative size of all binary layers is : {} MB".format(self.size/8000000))
        print("-"*50)
        for name,param in self.named_parameters():
            if (("Conv2d" in name and "Bin" not in name) or "classifier" in name or ("Linear" in name and "Bin" not in name)) and "bias" not in name and "BN" not in name:
                print(name)
                self.size = self.size + param.numel()*32
        print("Cumulative size of all layers without bn is : {} MB".format(self.size/8000000))
        print("-"*50)
        for name,param in self.named_parameters():
            if "BN"  in name and "bias" not in name:
                print(name)
                self.size = self.size + param.numel()*32
        print("Cumulative size of all layers with bn is : {} MB".format(self.size/8000000))
        print("-"*50) 
        for name,param in self.named_parameters():
            if "Act"  in name and "bias" not in name:
                print(name)
                self.size = self.size + param.numel()
        print("Cumulative size of all layers  : {} MB".format(self.size/8000000))
        print("-"*50)  
        self.size = self.size / 8000000
    
    def xavier_init(self):
      # default xavier init
      print("Xavier Init")
      for m in self.modules():
          if isinstance(m, (BinConv2d, BinLinear)):
              nn.init.xavier_uniform(m.weight)
        
    def he_et_al_init(self):
      # he initialization
      print("He et al Init")
      for m in self.modules():
          if isinstance(m, (BinConv2d, BinLinear)):
              nn.init.kaiming_normal_(m.weight, mode='fan_in')
    
    def kaming_uniform(self):
      # he initialization
      print("Kaming Uniform Init")
      for m in self.modules():
        if isinstance(m, (BinConv2d, BinLinear)):
            nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))

    def orthogonal_init(self):
      print("Orthogonal Init")
      for m in self.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.orthogonal(m.weight)

    def batch_norm_init(self):
      print("BatchNorm Init made weight as 1 and bias as 0")
      for m in self.modules():
        if isinstance(m, nn.BatchNorm2d):
            nn.init.constant(m.weight, 1)
            nn.init.constant(m.bias, 0)

    def binary_parameters(self):
        for name, layer in self.named_parameters():
            if "Bin" in name and "Conv2d" in name and "bias" not in name:
                print(name)
                yield layer

    def non_binary_parameters(self):
        for name, layer in self.named_parameters():
            if (("BN" in name)  or ("Conv2d in name" and "Bin" not in name)) and ("bias" not in name):
                yield layer
    def forward(self, x):
        raise NotImplementedError
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.infl_ratio=1
        self.Linear_1 = BinLinear(784, 2048*self.infl_ratio,True)
        self.Act1 = nn.Hardtanh()
        self.BN1 = nn.BatchNorm1d(2048*self.infl_ratio)
        self.BinLinear_2 = BinLinear(2048*self.infl_ratio, 2048*self.infl_ratio)
        self.Act2 = nn.Hardtanh()
        self.BN2 = nn.BatchNorm1d(2048*self.infl_ratio)
        self.BinLinear_3 = BinLinear(2048*self.infl_ratio, 2048*self.infl_ratio)
        self.Act3 = nn.Hardtanh()
        self.BN3 = nn.BatchNorm1d(2048*self.infl_ratio)
        self.Linear_4 = BinLinear(2048*self.infl_ratio, 10,True)
        self.BN4 = nn.BatchNorm1d(10)
        self.logsoftmax=nn.LogSoftmax()
        self.drop=nn.Dropout(0.5)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.Linear_1(x)
        x = self.BN1(x)
        x = self.Act1(x)
        x = self.BinLinear_2(x)
        x = self.BN2(x)
        x = self.Act2(x)
        x = self.BinLinear_3(x)
        x = self.BN3(x)
        x = self.Act3(x)
        x = self.Linear_4(x)
        x = self.BN4(x)
        out = self.logsoftmax(x)
        return out

cfg_bin = {
    'VGG11': ['F', 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': ['F', 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': ['F', 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': ['F', 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG_BNN_ReLU(BaseModel):
    def __init__(self, vgg_name, nclass, img_width=32):
        super(VGG_BNN_ReLU, self).__init__()
        self.img_width = img_width
        self.nclass = nclass
        self.features = self._make_layers(cfg_bin[vgg_name])
        self.classifier = self._make_classifier()

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out # return None, to make it compatible with VGG_noise

    def _make_layers(self, cfg):
        layers = []
        iMax = 1
        iBn = 1
        iConv = 1
        iAct = 1
        in_channels = 3
        width = self.img_width
        for x in cfg:
            if x == 'M':
                layers.append(("Max_{}".format(iMax),nn.MaxPool2d(kernel_size=2, stride=2)))
                iMax = iMax + 1
                width = width // 2
            elif x == 'F':
                x = 64
                layers.append(("Conv2d_{}".format(iConv),nn.Conv2d(in_channels, x, kernel_size=3, padding=1)))
                layers.append(("BN_{}".format(iBn),nn.BatchNorm2d(x)))
                layers.append(("Act_{}".format(iAct),nn.ReLU(inplace=True)))
                iConv = iConv + 1
                iBn = iBn + 1
                iAct = iAct + 1
                in_channels = 64
            else:
                layers.append(("BinConv2d_{}".format(iConv),BinConv2d(in_channels, x, kernel_size=3, padding=1)))
                layers.append(("BN_{}".format(iBn),nn.BatchNorm2d(x)))
                layers.append(("Act_{}".format(iAct),nn.ReLU(inplace=True)))
                iConv = iConv + 1
                iBn = iBn + 1
                iAct = iAct + 1
                in_channels = x
        self.iBn = iBn
        self.iConv = iConv
        self.iAct = iAct
        layers.append(("Dropout_1",nn.Dropout(0.5)))
        layer_ord = OrderedDict(layers)
        return nn.Sequential(layer_ord)
    def _make_classifier(self):
        layers = []
        layers.append(("Linear_{}".format(self.iConv),nn.Linear(512, self.nclass)))
        layers.append(("BN_{}".format(self.iBn),nn.BatchNorm1d(self.nclass)))
        layers.append(("Softmax",nn.LogSoftmax()))
        layer_ord = OrderedDict(layers)
        return nn.Sequential(layer_ord)
    
class VGG_BNN_PReLU(BaseModel):
    def __init__(self, vgg_name, nclass, img_width=32):
        super(VGG_BNN_PReLU, self).__init__()
        self.img_width = img_width
        self.nclass = nclass
        self.features = self._make_layers(cfg_bin[vgg_name])
        self.classifier = self._make_classifier()

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out # return None, to make it compatible with VGG_noise

    def _make_layers(self, cfg):
        layers = []
        iMax = 1
        iBn = 1
        iConv = 1
        iAct = 1
        in_channels = 3
        width = self.img_width
        for x in cfg:
            if x == 'M':
                layers.append(("Max_{}".format(iMax),nn.MaxPool2d(kernel_size=2, stride=2)))
                iMax = iMax + 1
                width = width // 2
            elif x == 'F':
                x = 64
                layers.append(("Conv2d_{}".format(iConv),nn.Conv2d(in_channels, x, kernel_size=3, padding=1)))
                layers.append(("BN_{}".format(iBn),nn.BatchNorm2d(x)))
                layers.append(("Act_{}".format(iAct),nn.PReLU()))
                iConv = iConv + 1
                iBn = iBn + 1
                iAct = iAct + 1
                in_channels = 64
            else:
                layers.append(("BinConv2d_{}".format(iConv),BinConv2d(in_channels, x, kernel_size=3, padding=1)))
                layers.append(("BN_{}".format(iBn),nn.BatchNorm2d(x)))
                layers.append(("Act_{}".format(iAct),nn.PReLU()))
                iConv = iConv + 1
                iBn = iBn + 1
                iAct = iAct + 1
                in_channels = x
        self.iBn = iBn
        self.iConv = iConv
        self.iAct = iAct
        layers.append(("Dropout_1",nn.Dropout(0.5)))
        layer_ord = OrderedDict(layers)
        return nn.Sequential(layer_ord)
    def _make_classifier(self):
        layers = []
        layers.append(("Linear_{}".format(self.iConv),nn.Linear(512, self.nclass)))
        layers.append(("BN_{}".format(self.iBn),nn.BatchNorm1d(self.nclass)))
        layers.append(("Softmax",nn.LogSoftmax()))
        layer_ord = OrderedDict(layers)
        return nn.Sequential(layer_ord)
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
class VGG_ReLU(BaseModel):
    def __init__(self, vgg_name, nclass, img_width=32):
        super(VGG_ReLU, self).__init__()
        self.img_width = img_width
        self.nclass = nclass
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier =  self._make_classifier()

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out # return None, to make it compatible with VGG_noise
    
    def _make_layers(self, cfg):
        layers = []
        iMax = 1
        iBn = 1
        iConv = 1
        iAct = 1
        in_channels = 3
        width = self.img_width
        for x in cfg:
            if x == 'M':
                layers.append(("Max_{}".format(iMax),nn.MaxPool2d(kernel_size=2, stride=2)))
                iMax = iMax + 1
                width = width // 2
            else:
                layers.append(("Conv2d_{}".format(iConv),nn.Conv2d(in_channels, x, kernel_size=3, padding=1)))
                layers.append(("BN_{}".format(iBn),nn.BatchNorm2d(x)))
                layers.append(("Act_{}".format(iAct),nn.ReLU(inplace=True)))
                iConv = iConv + 1
                iBn = iBn + 1
                iAct = iAct + 1
                in_channels = x
        self.iBn = iBn
        self.iConv = iConv
        self.iAct = iAct
        layers.append(("Dropout_1",nn.Dropout(0.5)))
        layer_ord = OrderedDict(layers)
        return nn.Sequential(layer_ord)
    def _make_classifier(self):
        layers = []
        layers.append(("Linear_{}".format(self.iConv),nn.Linear(512, self.nclass)))
        layers.append(("BN_{}".format(self.iBn),nn.BatchNorm1d(self.nclass)))
        layers.append(("Softmax",nn.LogSoftmax()))
        layer_ord = OrderedDict(layers)
        return nn.Sequential(layer_ord)
#CNN ALexnet (While initializing give num_classes )
class AlexNet(BaseModel):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features =  nn.ModuleDict({
            "Conv2d_1":nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            #"BN_1":nn.BatchNorm2d(64),
            "Act_1":nn.ReLU(inplace=True),
            "Max_1":nn.MaxPool2d(kernel_size=3, stride=2),
            "Conv2d_2":nn.Conv2d(64, 192, kernel_size=5, padding=2),
            #"BN_2":nn.BatchNorm2d(192),
            "Act_2":nn.ReLU(inplace=True),
            "Max_2":nn.MaxPool2d(kernel_size=3, stride=2),
            "Conv2d_3":nn.Conv2d(192, 384, kernel_size=3, padding=1),
            #"BN_3":nn.BatchNorm2d(384),
            "Act_3":nn.ReLU(inplace=True),
            "Conv2d_4":nn.Conv2d(384, 256, kernel_size=3, padding=1),
            #"BN_4":nn.BatchNorm2d(64),
            "Act_4":nn.ReLU(inplace=True),
            "Conv2d_5":nn.Conv2d(256, 256, kernel_size=3, padding=2),
            #"BN_5":nn.BatchNorm2d(64),
            "Act_5":nn.ReLU(inplace=True),
            "Max_5":nn.MaxPool2d(kernel_size=3, stride=2)
        })

        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

        self.fullyconnected = nn.ModuleDict({
            "Pool":nn.AdaptiveAvgPool2d((6, 6)),
            "drop_6":nn.Dropout(),
            "Linear_6":nn.Linear(256 * 6 * 6, 4096),
            #"BN_6":nn.BatchNorm1d(4096),
            "Act_6":nn.ReLU(inplace=True),
            "drop_7":nn.Dropout(),
            "Linear_7":nn.Linear(4096, 4096),
            #"BN_7":nn.BatchNorm1d(4096),
            "Act_7":nn.ReLU(inplace=True),
            "Linear_8":nn.Linear(4096, num_classes),
            #"BN_8":nn.BatchNorm1d(num_classes),
            #"Softmax":nn.LogSoftmax()
        })

    def forward(self, x):
        x = self.features['Conv2d_1'](x)
        x = self.features['Act_1'](x)
        x = self.features['Max_1'](x)
        x = self.features['Conv2d_2'](x)
        x = self.features['Act_2'](x)
        x = self.features['Max_2'](x)
        x = self.features['Conv2d_3'](x)
        x = self.features['Act_3'](x)
        x = self.features['Conv2d_4'](x)
        x = self.features['Act_4'](x)
        x = self.features['Conv2d_5'](x)
        x = self.features['Act_5'](x)
        x = self.features['Max_5'](x)
        x = self.avgpool(x)
        x = x.view(-1, 256 * 6 * 6)
        x = self.fullyconnected['Linear_6'](x)
        x = self.fullyconnected['Act_6'](x)
        x = self.fullyconnected['Linear_7'](x)
        x = self.fullyconnected['Act_7'](x)
        x = self.fullyconnected['Linear_8'](x)
        return x
#BNN ALexnet (While initializing give num_classes )
#BNN ALexnet (While initializing give num_classes )
class BinAlexNet(BaseModel):
    def __init__(self, num_classes=1000):
        super(BinAlexNet, self).__init__()
        self.features =  nn.ModuleDict({
            "Conv2d_1":nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            "BN_1":nn.BatchNorm2d(64),
            "Act_1":nn.ReLU(inplace=True),
            "Max_1":nn.MaxPool2d(kernel_size=3, stride=2),
            "BinConv2d_2":BinConv2d(64, 192, kernel_size=5, padding=2),
            "BN_2":nn.BatchNorm2d(192),
            "Act_2":nn.ReLU(inplace=True),
            "Max_2":nn.MaxPool2d(kernel_size=3, stride=2),
            "BinConv2d_3":BinConv2d(192, 384, kernel_size=3, padding=1),
            "BN_3":nn.BatchNorm2d(384),
            "Act_3":nn.ReLU(inplace=True),
            "BinConv2d_4":BinConv2d(384, 256, kernel_size=3, padding=1),
            "BN_4":nn.BatchNorm2d(256),
            "Act_4":nn.ReLU(inplace=True),
            "BinConv2d_5":BinConv2d(256, 256, kernel_size=3, padding=2),
            "BN_5":nn.BatchNorm2d(256),
            "Act_5":nn.ReLU(inplace=True),
            "Max_5":nn.MaxPool2d(kernel_size=3, stride=2,)
            
        })

        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

        self.fullyconnected = nn.ModuleDict({
            #"drop_2:"nn.Dropout(),
            "BinLinear_6":BinLinear(256 * 6 * 6, 4096),
            "BN_6":nn.BatchNorm1d(4096),
            "Act_6":nn.ReLU(inplace=True),
            #"drop":nn.Dropout(),
            "BinLinear_7":BinLinear(4096, 4096),
            "BN_7":nn.BatchNorm1d(4096),
            "Act_7":nn.ReLU(inplace=True),
            "Linear_8":nn.Linear(4096, num_classes),
            "BN_8":nn.BatchNorm1d(num_classes),
            "Softmax":nn.LogSoftmax()
        })

    def forward(self, x):
        x = self.features['Conv2d_1'](x)
        x = self.features['BN_1'](x)
        x = self.features['Act_1'](x)
        x = self.features['Max_1'](x)
        x = self.features['BinConv2d_2'](x)
        x = self.features['BN_2'](x)
        x = self.features['Act_2'](x)
        x = self.features['Max_2'](x)
        x = self.features['BinConv2d_3'](x)
        x = self.features['BN_3'](x)
        x = self.features['Act_3'](x)
        x = self.features['BinConv2d_4'](x)
        x = self.features['BN_4'](x)
        x = self.features['Act_4'](x)
        x = self.features['BinConv2d_5'](x)
        x = self.features['BN_5'](x)
        x = self.features['Act_5'](x)
        x = self.features['Max_5'](x)
        x = self.avgpool(x)
        x = x.view(-1, 256 * 6 * 6)
        x = self.fullyconnected['BinLinear_6'](x)
        x = self.fullyconnected['BN_6'](x)
        x = self.fullyconnected['Act_6'](x)
        x = self.fullyconnected['BinLinear_7'](x)
        x = self.fullyconnected['BN_7'](x)
        x = self.fullyconnected['Act_7'](x)
        x = self.fullyconnected['Linear_8'](x)
        x = self.fullyconnected['BN_8'](x)
        x = self.fullyconnected['Softmax'](x)
        return x

#CNN net
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.Conv2d_1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.BN1 = nn.BatchNorm2d(planes)
        self.Conv2d_2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            layers = []
            layers.append(("Conv2d_3",nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)))
            layers.append(("BN_3",nn.BatchNorm2d(self.expansion*planes)))
            layer_ord = OrderedDict(layers)
            self.shortcut = nn.Sequential(layer_ord)

    def forward(self, x):
        out = F.relu(self.BN1(self.Conv2d_1(x)))
        out = self.BN2(self.Conv2d_2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.Conv2d_1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.BN1 = nn.BatchNorm2d(planes)
        self.Conv2d_2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.BN2 = nn.BatchNorm2d(planes)
        self.Conv2d_3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.BN3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            layers = []
            layers.append(("Conv2d_4",nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)))
            layers.append(("BN_4",nn.BatchNorm2d(self.expansion*planes)))
            layer_ord = OrderedDict(layers)
            self.shortcut = nn.Sequential(layer_ord)

    def forward(self, x):
        out = F.relu(self.BN1(self.Conv2d_1(x)))
        out = F.relu(self.BN2(self.Conv2d_2(out)))
        out = self.BN3(self.Conv2d_3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(BaseModel):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.Conv2d_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.Linear_2 = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.BN1(self.Conv2d_1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.Linear_2(out)
        return out

def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])

def test():
    net = ResNet18()
    y = net(torch.randn(1,3,32,32))
    print(y.size())
#BNN net
class BinBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BinBasicBlock, self).__init__()
        self.BinConv2d_1 = BinConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.BN1 = nn.BatchNorm2d(planes)
        self.BinConv2d_2 = BinConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            layers = []
            layers.append(("BinConv2d_3",BinConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)))
            layers.append(("BN_3",nn.BatchNorm2d(self.expansion*planes)))
            layer_ord = OrderedDict(layers)
            self.shortcut = nn.Sequential(layer_ord)

    def forward(self, x):
        out = F.relu(self.BN1(self.BinConv2d_1(x)))
        out = self.BN2(self.BinConv2d_2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class BinBottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(BinBottleneck, self).__init__()
        self.BinConv2d_1 = BinConv2d(in_planes, planes, kernel_size=1, bias=False)
        self.BN1 = nn.BatchNorm2d(planes)
        self.BinConv2d_2 = BinConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.BN2 = nn.BatchNorm2d(planes)
        self.BinConv2d_3 = BinConv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.BN3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            layers = []
            layers.append(("BinConv2d_4",BinConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)))
            layers.append(("BN_4",nn.BatchNorm2d(self.expansion*planes)))
            layer_ord = OrderedDict(layers)
            self.shortcut = nn.Sequential(layer_ord)

    def forward(self, x):
        out = F.relu(self.BN1(self.BinConv2d_1(x)))
        out = F.relu(self.BN2(self.BinConv2d_2(out)))
        out = self.BN3(self.BinConv2d_3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class BinResNet(BaseModel):
    def __init__(self, block, num_blocks, num_classes=10):
        super(BinResNet, self).__init__()
        self.in_planes = 64

        self.Conv2d_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.Linear_2 = nn.Linear(512*block.expansion, num_classes)
        self.BN2 = nn.BatchNorm1d(num_classes)
        self.softmax = nn.LogSoftmax()

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.BN1(self.Conv2d_1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.Linear_2(out)
        out = self.BN2(out)
        out = self.softmax(out)
        return out

def BinResNet18():
    return BinResNet(BinBasicBlock, [2,2,2,2])

def BinResNet34():
    return BinResNet(BinBasicBlock, [3,4,6,3])

def BinResNet50():
    return BinResNet(BinBottleneck, [3,4,6,3])

def BinResNet101():
    return BinResNet(BinBottleneck, [3,4,23,3])

def BinResNet152():
    return BinResNet(BinBottleneck, [3,8,36,3])

def Bintest():
    net = BinResNet18()
    y = net(torch.randn(1,3,32,32))
    print(y.size())

def filt_alpha_weight(model):
    alpha = []
    weights = []
    layers = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            if ('alpha' in name):
                n_split = name.split('alpha')
                nn = n_split[0]
                if (nn not in layers.keys()):
                    layers[nn] = []
                alpha.append(param)

    for name, param in model.named_parameters():
        if param.requires_grad:
            if ('weight' in name) and "BN" not in name:
                n_split = name.split('weight')
                nn = n_split[0]
                if (nn in list(layers.keys())):
                    weights.append(param)

    return alpha, weights
import math
def get_reg(model,reg_type,epoch,lr,eta = 1):
    with torch.enable_grad():
        l = eta * lr * math.log(epoch)
        alpha, weights = filt_alpha_weight(model)
        reg = 0
        if reg_type == 'l1':
            for a, w in zip(alpha, weights):
                reg += torch.sum(torch.abs(torch.abs(w) - a))
        elif reg_type == 'l2':
            for a, w in zip(alpha, weights):
                reg += torch.sum(torch.abs(torch.mul(w, w) - a))
        elif reg_type == 'l2_2':
            for a, w in zip(alpha, weights):
                reg += torch.sum(torch.mul(w,w) - a)
        elif reg_type == 'tmr1':
            for a, w in zip(alpha, weights):
                reg += torch.abs(torch.sum(torch.mul(w,w) - a))
        elif reg_type == 'tmr2':
            for a, w in zip(alpha, weights):
                reg += torch.sum(torch.mul(w,w) - a)**2
        else:
            reg = 0
    return l*reg
import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0,filename = 'checkpoint.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), filename)
        self.val_loss_min = val_loss
def accuracy(output, target, topk=(1,)):
  """Computes the accuracy over the k top predictions for the specified values of k"""
  with torch.no_grad():
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res    

class AverageMeter(object):
  """Computes and stores the average and current value"""
  def __init__(self, name, fmt=':f'):
      self.name = name
      self.fmt = fmt
      self.reset()

  def reset(self):
      self.val = 0
      self.avg = 0
      self.sum = 0
      self.count = 0

  def update(self, val, n=1):
      self.val = val
      self.sum += val * n
      self.count += n
      self.avg = self.sum / self.count

  def __str__(self):
      fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
      return fmtstr.format(**self.__dict__)
        
class ProgressMeter(object):
  def __init__(self, num_batches, meters, prefix=""):
      self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
      self.meters = meters
      self.prefix = prefix

  def display(self, batch):
      entries = [self.prefix + self.batch_fmtstr.format(batch)]
      entries += [str(meter) for meter in self.meters]
      print('\t'.join(entries))

  def _get_batch_fmtstr(self, num_batches):
      num_digits = len(str(num_batches // 1))
      fmt = '{:' + str(num_digits) + 'd}'
      return '[' + fmt + '/' + fmt.format(num_batches) + ']'  
def train_model(model, criterion, optimizer,num_epochs=25,**kwargs):    
    since = time.time()
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    fig = plt.figure(figsize=(12, 8))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)
    is_early_stopping = kwargs.get('is_early_stopping',True)
    is_sch_epoch = kwargs.get('is_sch_epoch',False)
    is_sch_acc = kwargs.get('is_sch_acc',False)
    is_sch_valloss = kwargs.get('is_sch_valloss',False)
    is_reg = kwargs.get('is_reg',False)
    is_drive = kwargs.get('is_drive',True)
    printfreq = kwargs.get('printfreq',100)
    reg = 0
    if is_early_stopping :
        early_stopping = EarlyStopping(patience=40, verbose=True,filename = filename)
    if is_sch_epoch:
        scheduler_epoch = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,160,220,300], gamma=0.1)
    if is_sch_valloss:
        scheduler_valloss = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor = 0.1,patience=15, verbose=True)
    if is_sch_acc:
        scheduler_acc = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='max',factor = 0.1,patience=15, verbose=True)
    if is_reg:
        reg_type = kwargs.get("reg_type",None)
        
    for epoch in range(num_epochs):
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            losses = AverageMeter('Loss', ':.4e')
            top1 = AverageMeter('Acc@1', ':6.2f')
            top5 = AverageMeter('Acc@5', ':6.2f')
            progress = ProgressMeter(dataset_sizes[phase],[losses, top1, top5],prefix='{}: '.format(phase))
            if phase == 'train':
                if is_debug: print("Train")
                model.train()  # Set model to training mode
            else:
                if is_debug: print("Validation")
                model.eval()   # Set model to evaluate mode
            batch_itr = 1
            add_param = []
            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs.type(dtype))

                    if is_reg and phase == 'train':
                        reg = get_reg(model,reg_type,epoch,lr=lr,eta = 1)
                        if is_debug: print("Reg loss is {}".format(reg))
                    else:
                        reg = 0
                    loss = criterion(outputs, labels.to(device)) + reg

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                acc1, acc5 = accuracy(outputs, labels.to(device), topk=(1, 5))
                losses.update(loss.detach().item(),inputs.size(0))
                top1.update(acc1[0], inputs.size(0))
                top5.update(acc5[0], inputs.size(0))
                if batch_itr % printfreq == 0:
                    progress.display(i)
            add_param.append(int(epoch))
            add_param.append(losses.avg)
            add_param.append((top1.avg).tolist())
            add_param.append((top5.avg).tolist())
            if phase == 'train':
                trainframe.loc[epoch] = add_param
            if phase == 'val':
                testframe.loc[epoch] = add_param  
            # deep copy the model
            if phase == 'val' and top1.avg > best_acc:
                best_acc = top1.avg
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val' and is_early_stopping:
                val_loss = losses.avg
                early_stopping(val_loss, model)
                if early_stopping.early_stop:
                    print("Early stopping")
                    break
            print("{} epoch : {}, accuray: {},Loss : {} ".format(phase,int(epoch),top1.avg,losses.avg))
        if is_sch_epoch:
            scheduler_epoch.step(epoch)
        if is_sch_valloss:
            val_loss = losses.avg
            scheduler_valloss.step(val_loss)
        if is_sch_acc:
            acc1 = top1.avg
            scheduler_acc.step(acc1)
        if is_drive:
            trainframe.to_excel("drive/My Drive/Model Results/{}_train.xlsx".format(filename))
            testframe.to_excel("drive/My Drive/Model Results/{}_test.xlsx".format(filename)) 
        else:
            trainframe.to_excel("{}_train.xlsx".format(filename))
            testframe.to_excel("{}_test.xlsx".format(filename)) 
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model