# Либы

In [None]:
!pip install torchprofile
#!pip install torchsummary

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchprofile
  Downloading torchprofile-0.0.4-py3-none-any.whl (7.7 kB)
Installing collected packages: torchprofile
Successfully installed torchprofile-0.0.4


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function, Variable

import math
import numpy as np
import torch.utils.model_zoo as model_zoo
import torch.nn.init as init
from scipy.stats import ortho_group

import torch.utils
import torch.utils.data
from torchvision import datasets, transforms
from torch.cuda.amp import autocast as autocast
import torchvision.models as models
import os, argparse, logging,sys
import random
import time
import shutil
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.utils.prune as prune
from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score

from torchsummary import summary
from torchprofile import profile
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore", category=UserWarning)


device = torch.device('cuda:0')

In [None]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

cudnn.deterministic = True
cudnn.benchmark = False

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Utils

In [None]:
def top1_accuracy(targets, outputs):
    outputs_tensor = torch.from_numpy(outputs)
    targets_tensor = torch.from_numpy(targets)
    _, predicted = torch.max(outputs_tensor, 1)
    correct = torch.eq(predicted, targets_tensor).sum().item()
    accuracy = correct / targets_tensor.size(0)
    return accuracy

def top5_accuracy(targets, outputs):
    outputs_tensor = torch.from_numpy(outputs)
    targets_tensor = torch.from_numpy(targets)
    _, predicted = torch.topk(outputs_tensor, k=5, dim=1)
    correct = torch.eq(predicted, targets_tensor.view(-1, 1)).sum().item()
    accuracy = correct / targets_tensor.size(0)
    return accuracy

class StochasticDepth(nn.Module):
    def __init__(self, p):
        super(StochasticDepth, self).__init__()
        self.p = p

    def forward(self, x):
        if not self.training:
            return x
        # if torch.rand(1).item() > self.p:
        #     return x
        # return torch.zeros_like(x)
        if torch.rand(1).item() > self.p:
            return x.clone().detach()
        return torch.zeros_like(x).to(x.device)

def custom_pruning(model):
    if args.type_pruning == 'local':
        #local prune. просто зануляем веса
        for name, module in model.named_modules():
            # prune 20% of connections in all 2D-conv layers
            if isinstance(module, torch.nn.Conv2d):
                prune.l1_unstructured(module, name='weight', amount=0.2)
            # prune 40% of connections in all linear layers
            elif isinstance(module, torch.nn.Linear):
                prune.l1_unstructured(module, name='weight', amount=0.4)

    elif args.type_pruning == 'global':

        if args.architecture == 'cnn_model()' or args.bin_architecture == 'cnn_model_binary()':
            parameters_to_prune = (
                (model.conv1, 'weight'),
                (model.conv2, 'weight'),
                (model.conv3, 'weight'),
                (model.conv4, 'weight'),
                (model.fc1, 'weight'),
                (model.fc2, 'weight'),
            )

            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=0.2,
            )

        elif args.architecture == 'resnet20()' or args.bin_architecture == 'resnet20_binary()':
            return 0

        else:
            parameters_to_prune = [(module, "weight") for module in filter(lambda m: type(m) == torch.nn.Conv2d, model.modules())]
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=0.2,
            )

    else:
      print('error')

import torch
import torch.nn as nn

class LabelSmoothingCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing=0.0, num_classes=10):
        super(LabelSmoothingCrossEntropyLoss, self).__init__()
        self.smoothing = smoothing
        self.num_classes = num_classes
        self.confidence = 1.0 - smoothing

        if smoothing > 0:
            self.criterion = nn.KLDivLoss(reduction='batchmean')
        else:
            self.criterion = nn.CrossEntropyLoss()

    def forward(self, pred, target):
        one_hot = torch.full_like(pred, self.smoothing / (self.num_classes - 1))
        one_hot.scatter_(1, target.unsqueeze(1).long(), self.confidence)

        if self.smoothing > 0:
            pred = pred.log_softmax(dim=1)

        return self.criterion(pred, one_hot)

def get_data():
    if args.dataset == 'CIFAR10':
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

        transform = transforms.Compose([
            transforms.Resize(args.data_size),
            # transforms.RandomHorizontalFlip(),
            # transforms.RandomCrop(args.data_size, 4),
            transforms.RandAugment(num_ops=args.rand_augment, magnitude=10),
            transforms.ToTensor(),
            normalize])

    if args.dataset == 'MNIST':
        normalize = transforms.Normalize(mean=[0.1307], std=[0.3081])

        transform = transforms.Compose([
            transforms.Resize(args.data_size),
            transforms.RandAugment(num_ops=args.rand_augment, magnitude=10),
            transforms.ToTensor(),
            normalize
        ])

    train_loader = torch.utils.data.DataLoader(
        eval('datasets.' + args.dataset + '(root="../", train=True, transform=transform, download=True)'),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers)

    val_loader = torch.utils.data.DataLoader(
    eval('datasets.' + args.dataset + '(root="../", train=False, transform=transforms.Compose([transforms.Resize(args.data_size),transforms.ToTensor(),normalize,]), download=True)'),
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers)


    return train_loader, val_loader



def save_results(bin, args_list, train_loss_history, train_acc_history, train_recall_history, train_f1_history, train_top1_history, train_top5_history, valid_results, min_train_time, max_train_time, inf):

    path_to_save = args.path_to_save
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)

    if bin == False:
        path_to_file = 'ResultsNonBin.txt'
    elif bin == True:
        path_to_file = 'ResultsBin.txt'


    with open(path_to_save + '/' + path_to_file, 'a') as file:
        file.write('Общая информация о эксперименте \n')
        file.write(' || '.join([f'{arg}: {value}' for arg, value in args_list]))
        file.write('\n\n{0}'.format(inf))
        file.write('\n\nМинимальное и максимальное время на обучение эпохи в секунданх: {0} ; {1} \n\n'.format(min_train_time, max_train_time))
        file.write('Epoch No.    Loss     Acc     Recall     F1     Top1     Top5 \n')
        for i in range(0, len(train_loss_history)):
            if i<9:
                file.write(f"    {i+1}       {train_loss_history[i]}  {train_acc_history[i]}   {train_recall_history[i]}   {train_f1_history[i]}   {train_top1_history[i]}   {train_top5_history[i]}\n")
            else:
                file.write(f"   {i+1}       {train_loss_history[i]}  {train_acc_history[i]}   {train_recall_history[i]}   {train_f1_history[i]}   {train_top1_history[i]}   {train_top5_history[i]}\n")
        file.write('\n  Valid     {0}  {1}   {2}   {3}   {4}   {5}\n'.format(valid_results[0], valid_results[1], valid_results[2], valid_results[3], valid_results[4], valid_results[5]))
        file.write('---------------------')
        file.write('\n')

# Binary Modules

In [None]:
def cpt_tk(epoch):
    "compute t&k in back-propagation"
    T_min, T_max = torch.tensor(1e-2).float(), torch.tensor(1e1).float()
    Tmin, Tmax = torch.log10(T_min), torch.log10(T_max)
    t = torch.tensor([torch.pow(torch.tensor(10.), Tmin + (Tmax - Tmin) / args.epochs * epoch)]).float()
    k = max(1/t,torch.tensor(1.)).float()
    return t, k

In [None]:
#-----------------------IEE------------------------------------------
class OwnQuantize_a(Function):
    @staticmethod
    def forward(ctx, input, k, t):
        ctx.save_for_backward(input, k, t)
        out = torch.sign(input)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, k, t = ctx.saved_tensors
        k = torch.tensor(1.).to(input.device)
        t = max(t, torch.tensor(1.).to(input.device))
        # grad_input = k * (1.4*t - torch.abs(t**2 * input))
        grad_input = k * (3*torch.sqrt(t**2/3) - torch.abs(t ** 2 * input*3)/2)
        grad_input = grad_input.clamp(min=0) * grad_output.clone()
        return grad_input, None, None
class OwnQuantize(Function):
    @staticmethod
    def forward(ctx, input, k, t):
        ctx.save_for_backward(input, k, t)
        out = torch.sign(input)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, k, t = ctx.saved_tensors
        # grad_input = k * (1.4*t - torch.abs(t**2 * input))
        grad_input = k * (3*torch.sqrt(t**2/3) - torch.abs(t ** 2 * input*3)/2)
        grad_input = grad_input.clamp(min=0) * grad_output.clone()
        return grad_input, None, None

In [None]:
class BinarizeConv2d(nn.Conv2d):

    def __init__(self, *kargs, **kwargs):
        super(BinarizeConv2d, self).__init__(*kargs, **kwargs)
        self.k = torch.tensor([10.]).float()
        self.t = torch.tensor([0.1]).float()

        w = self.weight

        sw = w.abs().view(w.size(0), -1).mean(-1).float().view(w.size(0), 1, 1).detach()
        self.alpha = nn.Parameter(sw.to(device), requires_grad=True)

    def forward(self, input):
        a = input
        w = self.weight
        w0 = w - w.mean([1, 2, 3], keepdim=True)
        w1 = w0 / torch.sqrt(w0.var([1, 2, 3], keepdim=True) + 1e-5)
        if self.training:
            a0 = a / torch.sqrt(a.var([1, 2, 3], keepdim=True) + 1e-5)
        else:
            a0 = a

        #* binarize
        bw = OwnQuantize().apply(w1,self.k.to(w.device),self.t.to(w.device))

        ba = OwnQuantize_a().apply(a0,self.k.to(w.device),self.t.to(w.device))

        #* 1bit conv
        output = F.conv2d(ba, bw, self.bias, self.stride, self.padding,
                          self.dilation, self.groups)
        #* scaling factor
        output = output * self.alpha
        return output

class channel_w(nn.Module):
    def __init__(self,p):
        super(channel_w, self).__init__()
        self.w1 = torch.nn.Parameter(torch.rand(1)*0.001, requires_grad=True)

    def forward(self,x):
        output = self.w1 * x
        return output
# class OwnBinaryConv(nn.Module):
#     def __init__(self,in_ch,out_ch,kernel_size=3,stride=1,padding=1):
#         super(OwnBinaryConv, self).__init__()
#         self.shift1 = nn.Parameter(torch.zeros(1,in_ch,1,1), requires_grad=True)
#         self.shift2 = nn.Parameter(torch.zeros(1, in_ch, 1, 1), requires_grad=True)
#         self.conv = BinarizeConv2d(in_ch,out_ch,kernel_size=kernel_size,stride=stride,padding=padding)
#         self.scale = channel_w(out_ch)

#     def forward(self,x):
#         x1 = x + self.shift1.expand_as(x)
#         x2 = x + self.shift2.expand_as(x)

#         out1 = self.conv(x1)
#         out2 = self.conv(x2)

#         out = out1+self.scale(out2)

#         return out

class OwnBinaryConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, groups=1, bias = False):
        super(OwnBinaryConv, self).__init__()
        self.K = args.K
        self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(1, in_ch, 1, 1), requires_grad=True) for _ in range(self.K)])
        self.conv = BinarizeConv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=bias)
        self.scale = channel_w(out_ch)
        self.weight = self.conv.weight

    def forward(self, x):
        out = None
        for i in range(self.K):
            xi = x + self.shifts[i].expand_as(x)
            outi = self.conv(xi)
            if out is None:
                out = outi
            else:
                out = out + self.scale(outi)
        return out

# Models

## Бинаризация архитектур

In [None]:
def replace_conv2d(module, conv):
    for name, child in module.named_children():
        if isinstance(child, nn.Conv2d):
            conv_str = f"({child.in_channels}, {child.out_channels}, kernel_size={child.kernel_size}, stride={child.stride}, padding={child.padding})"
            module.__setattr__(name, eval(conv + conv_str))
        else:
            replace_conv2d(child, conv)

def replace_relu(module):
    for name, child in module.named_children():
        if isinstance(child, nn.ReLU):
            relu_str = f"(inplace={child.inplace})"
            module.__setattr__(name, eval("nn.Hardtanh" + relu_str))
        else:
            replace_relu(child)

## ResNet18



In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        if args.binary == False:
            self.cv1 = 'nn.Conv2d'
            self.cv2 = 'nn.Conv2d'
            self.activat = 'F.relu'
        elif args.binary == True:
            self.cv1 = 'OwnBinaryConv'
            self.cv2 = 'BinarizeConv2d'
            self.activat = 'F.hardtanh'

        self.conv1 = eval(self.cv1 + '(in_planes, planes, kernel_size=3, stride=stride, padding=1)')
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = eval(self.cv1 + '(planes, planes, kernel_size=3, stride=1, padding=1)')
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                eval(self.cv2 + '(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)'),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = eval(self.activat + '(self.bn1(self.conv1(x)))')
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = eval(self.activat + '(out)')
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_channel, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = num_channel[0]

        self.conv1 = nn.Conv2d(3, num_channel[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_channel[0])
        self.layer1 = self._make_layer(block, num_channel[0], num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, num_channel[1], num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, num_channel[2], num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, num_channel[3], num_blocks[3], stride=2)
        self.linear = nn.Linear(num_channel[3]*block.expansion, num_classes)
        self.bn2 = nn.BatchNorm1d(num_channel[3]*block.expansion)

        self.prob = args.dropout
        self.dropout = nn.Dropout(p=self.prob)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.bn1(self.conv1(x))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.bn2(out)
        out = self.dropout(out)
        out = self.linear(out)
        return out

def resnet18(**kwargs):
    return ResNet(BasicBlock, [2,2,2,2],[64,128,256,512],**kwargs)
# def resnet18_binary(**kwargs):
#     return ResNet(BasicBlock, [2,2,2,2],[64,128,256,512],**kwargs)


## MobileNetV2

In [None]:
# class MobileNetV2(nn.Module):
#     def __init__(self, num_classes=10):
#         super(MobileNetV2, self).__init__()
#         self.features = nn.Sequential(
#             nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
#             nn.BatchNorm2d(32),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, groups=32, bias=False),
#             nn.BatchNorm2d(32),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(64),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, groups=64, bias=False),
#             nn.BatchNorm2d(64),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(64, 128, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(128),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, groups=128, bias=False),
#             nn.BatchNorm2d(128),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(128),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, groups=128, bias=False),
#             nn.BatchNorm2d(256),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(256),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, groups=256, bias=False),
#             nn.BatchNorm2d(256),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, groups=512, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, groups=512, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, groups=512, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, groups=512, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, groups=512, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(inplace=True),
#             nn.AdaptiveAvgPool2d(1)
#         )
#         self.classifier = nn.Sequential(
#             nn.Dropout(args.dropout),
#             nn.Linear(512, num_classes)
#         )

#     def forward(self, x):
#         x = self.features(x)
#         x = x.view(x.size(0), -1)
#         x = self.classifier(x)
#         return x


In [None]:
# def replace_conv2d(module, conv):
#     for name, child in module.named_children():
#         if isinstance(child, nn.Conv2d):
#             conv_str = f"({child.in_channels}, {child.out_channels}, kernel_size={child.kernel_size}, stride={child.stride}, padding={child.padding})"
#             module.__setattr__(name, eval(conv + conv_str))
#         else:
#             replace_conv2d(child, conv)
#         if isinstance(child, nn.Sequential):
#             for sub_name, sub_child in child.named_children():
#                 if isinstance(sub_child, nn.Conv2d):
#                     conv_str = f"({sub_child.in_channels}, {sub_child.out_channels}, kernel_size={sub_child.kernel_size}, stride={sub_child.stride}, padding={sub_child.padding})"
#                     child.__setattr__(sub_name, eval(conv + conv_str))
#                 else:
#                     replace_conv2d(sub_child, conv)

# def replace_relu(module):
#     for name, child in module.named_children():
#         if isinstance(child, nn.ReLU):
#             relu_str = f"(inplace={child.inplace})"
#             module.__setattr__(name, eval("nn.Hardtanh" + relu_str))
#         else:
#             replace_relu(child)
#         if isinstance(child, nn.Sequential):
#             for sub_name, sub_child in child.named_children():
#                 if isinstance(sub_child, nn.ReLU):
#                     relu_str = f"(inplace={sub_child.inplace})"
#                     child.__setattr__(sub_name, eval("nn.Hardtanh" + relu_str))
#                 else:
#                     replace_relu(sub_child)


# def conv_dw(in_channels, out_channels, stride):
#     return nn.Sequential(
#         nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False),
#         nn.BatchNorm2d(in_channels),
#         nn.ReLU(inplace=True),
#         nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
#         nn.BatchNorm2d(out_channels),
#         nn.ReLU(inplace=True)
#     )

# class MobileNetV2(nn.Module):
#     def __init__(self, num_classes=10):
#         super(MobileNetV2, self).__init__()
#         self.in_channels = 32

#         self.features = nn.Sequential(
#             nn.Conv2d(3, self.in_channels, kernel_size=3, stride=2, padding=1, bias=False),
#             nn.BatchNorm2d(self.in_channels),
#             nn.ReLU(inplace=True),

#             conv_dw(self.in_channels, 64, stride=1),
#             conv_dw(64, 128, stride=2),
#             conv_dw(128, 128, stride=1),
#             conv_dw(128, 256, stride=2),
#             conv_dw(256, 256, stride=1),
#             conv_dw(256, 512, stride=2),

#             nn.Sequential(
#                 *[conv_dw(512, 512, stride=1) for _ in range(5)]
#             ),

#             conv_dw(512, 1024, stride=2),
#             conv_dw(1024, 1024, stride=1)
#         )

#         self.avgpool = nn.AdaptiveAvgPool2d(1)
#         self.classifier = nn.Linear(1024, num_classes)

#     def forward(self, x):
#         x = self.features(x)
#         x = self.avgpool(x)
#         x = x.view(x.size(0), -1)
#         x = self.classifier(x)
#         return x


In [None]:

# # Определение блока сверточных слоев в MobileNetV2
# def conv_bn(in_channels, out_channels, stride):
#     return nn.Sequential(
#         nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
#         nn.BatchNorm2d(out_channels),
#         nn.ReLU(inplace=True)
#     )

# # Определение блока Bottleneck в MobileNetV2
# class Bottleneck(nn.Module):
#     def __init__(self, in_channels, out_channels, stride, expansion):
#         super(Bottleneck, self).__init__()
#         self.stride = stride
#         mid_channels = in_channels * expansion

#         self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, padding=0, bias=False)
#         self.bn1 = nn.BatchNorm2d(mid_channels)

#         self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=mid_channels, bias=False)
#         self.bn2 = nn.BatchNorm2d(mid_channels)

#         self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
#         self.bn3 = nn.BatchNorm2d(out_channels)

#         self.relu = nn.ReLU(inplace=True)
#         self.shortcut = nn.Sequential()
#         if stride == 1 and in_channels != out_channels:
#             self.shortcut = nn.Sequential(
#                 nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
#                 nn.BatchNorm2d(out_channels)
#             )

#     def forward(self, x):
#         identity = x

#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.relu(out)

#         out = self.conv2(out)
#         out = self.bn2(out)
#         out = self.relu(out)

#         out = self.conv3(out)
#         out = self.bn3(out)

#         out += self.shortcut(identity)
#         out = self.relu(out)

#         return out

# # Определение архитектуры MobileNetV2
# class MobileNetV2(nn.Module):
#     def __init__(self, num_classes=10):
#         super(MobileNetV2, self).__init__()

#         self.features = nn.Sequential(
#             conv_bn(3, 32, stride=2),
#             Bottleneck(32, 16, stride=1, expansion=1),

#             Bottleneck(16, 24, stride=2, expansion=6),
#             Bottleneck(24, 24, stride=1, expansion=6),

#             Bottleneck(24, 32, stride=2, expansion=6),
#             Bottleneck(32, 32, stride=1, expansion=6),
#             Bottleneck(32, 32, stride=1, expansion=6),

#             Bottleneck(32, 64, stride=2, expansion=6),
#             Bottleneck(64, 64, stride=1, expansion=6),
#             Bottleneck(64, 64, stride=1, expansion=6),
#             Bottleneck(64, 64, stride=1, expansion=6),

#             Bottleneck(64, 96, stride=1, expansion=6),
#             Bottleneck(96, 96, stride=1, expansion=6),
#             Bottleneck(96, 96, stride=1, expansion=6),

#             Bottleneck(96, 160, stride=2, expansion=6),
#             Bottleneck(160, 160, stride=1, expansion=6),
#             Bottleneck(160, 160, stride=1, expansion=6),

#             Bottleneck(160, 320, stride=1, expansion=6),
#         )

#         self.conv1 = conv_bn(320, 1280, stride=1)
#         self.avgpool = nn.AdaptiveAvgPool2d(1)
#         self.classifier = nn.Linear(1280, num_classes)

#     def forward(self, x):
#         x = self.features(x)
#         x = self.conv1(x)
#         x = self.avgpool(x)
#         x = x.view(x.size(0), -1)
#         x = self.classifier(x)
#         return x

# # Создание экземпляра модели MobileNetV2
# model = MobileNetV2()

import torch
import torch.nn as nn


def conv_bn(in_channels, out_channels, stride):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )


class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expansion):
        super(Bottleneck, self).__init__()
        self.stride = stride
        mid_channels = in_channels * expansion

        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_channels)

        self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=mid_channels, bias=False)
        self.bn2 = nn.BatchNorm2d(mid_channels)

        self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)
        self.shortcut = nn.Sequential()
        if stride == 1 and in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.stride == 1:
            out += self.shortcut(identity)
        out = self.relu(out)

        return out


class MobileNetV2(nn.Module):
    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()

        self.features = nn.Sequential(
            conv_bn(3, 32, stride=1),
            Bottleneck(32, 16, stride=1, expansion=1),

            Bottleneck(16, 24, stride=1, expansion=6),
            Bottleneck(24, 24, stride=1, expansion=6),

            Bottleneck(24, 32, stride=2, expansion=6),
            Bottleneck(32, 32, stride=1, expansion=6),
            Bottleneck(32, 32, stride=1, expansion=6),

            Bottleneck(32, 64, stride=2, expansion=6),
            Bottleneck(64, 64, stride=1, expansion=6),
            Bottleneck(64, 64, stride=1, expansion=6),
            Bottleneck(64, 64, stride=1, expansion=6),

            Bottleneck(64, 96, stride=1, expansion=6),
            Bottleneck(96, 96, stride=1, expansion=6),
            Bottleneck(96, 96, stride=1, expansion=6),

            Bottleneck(96, 160, stride=2, expansion=6),
            Bottleneck(160, 160, stride=1, expansion=6),
            Bottleneck(160, 160, stride=1, expansion=6),

            Bottleneck(160, 320, stride=1, expansion=6),
        )

        self.conv1 = conv_bn(320, 1280, stride=1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(p=args.dropout)
        self.classifier = nn.Linear(1280, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.conv1(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.classifier(x)
        return x


In [None]:
def mobilenetv2():
    if args.binary == False:
        return MobileNetV2()
        # model = models.mobilenet_v2(pretrained=False)
        # num_classes = 10
        # model.classifier[1] = torch.nn.Linear(1280, num_classes)
        # return model

    elif args.binary == True:
        model = MobileNetV2()
        replace_conv2d(model, 'OwnBinaryConv')
        replace_relu(model)
        return model


In [None]:
# print(args.binary)
# print(mobilenetv2())

# Train

In [None]:
def train(model, train_loader, loss_f, optimizer, epoch):
    model.train()
    preds_all = []
    targets_all = []

    if args.pruning == True and epoch == int(args.epochs/2):
      custom_pruning(model)


    for i, (input, target) in enumerate(train_loader):

        target = target.to(device)
        input_var = input.to(device)
        target_var = target.to(device)

        # compute output
        preds = model(input_var)

        loss = loss_f(preds, target_var)

        # use this loss for any training statistics
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        preds = preds.float()
        loss = loss.float()
        preds_all.append(torch.softmax(preds.detach(), dim=1).cpu().numpy())
        targets_all.append(target.cpu().numpy())

        # measure accuracy, recall, f1-score, auc and record loss
        # y_true = target.cpu().numpy()
        # y_pred = preds.cpu().argmax(axis=1).numpy()
        # acc = accuracy_score(y_true, y_pred)
        # recall = recall_score(y_true, y_pred)
        # f1 = f1_score(y_true, y_pred)
        #auc = roc_auc_score(y_true, preds[:, 1].detach().cpu().numpy(), multi_class='ovr')

    # recall_history.append(recall)
    # f1_history.append(f1)
    # #auc_history.append(auc)
    # loss_history.append(loss)
    # accuracy_history.append(acc)

    # concatenate predictions and targets for all batches
    preds_all = np.concatenate(preds_all, axis=0)
    targets_all = np.concatenate(targets_all, axis=0)

    # calculate metrics for all data
    loss = '{:.4f}'.format(loss.item())
    acc_all = '{:.4f}'.format(accuracy_score(targets_all, preds_all.argmax(axis=1)).item())
    recall_all = '{:.4f}'.format(recall_score(targets_all, preds_all.argmax(axis=1), average ='macro').item())
    f1_all = '{:.4f}'.format(f1_score(targets_all, preds_all.argmax(axis=1), average ='macro').item())
    top1_all = '{:.4f}'.format(top1_accuracy(targets_all, preds_all))
    top5_all = '{:.4f}'.format(top1_accuracy(targets_all, preds_all))

    #auc_all = roc_auc_score(targets_all, preds_all[:, 1])

    #return loss_history, accuracy_history, recall_history, f1_history, recall_all, f1_all
    return loss, acc_all, recall_all, f1_all, top1_all, top5_all

# Validation

In [None]:
def validate(model, val_loader, loss_f):
    model.eval()
    preds_all = []
    targets_all = []

    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.to(device)
            input_var = input.to(device)
            target_var = target.to(device)

            preds = model(input_var)
            loss = loss_f(preds, target_var)

            preds = preds.float()
            loss = loss.float()
            preds_all.append(torch.softmax(preds.detach(), dim=1).cpu().numpy())
            targets_all.append(target.cpu().numpy())

            # measure accuracy, recall, f1-score, auc and record loss
            # y_true = target.cpu().numpy()
            # y_pred = preds.cpu().argmax(axis=1).numpy()
            # acc = accuracy_score(y_true, y_pred)
            # recall = recall_score(y_true, y_pred, average='macro')
            # f1 = f1_score(y_true, y_pred, average='macro')
            #auc = roc_auc_score(y_true, preds[:, 1].detach().cpu().numpy(), multi_class='ovr')

    # recall_history.append(recall)
    # f1_history.append(f1)
    # #auc_history.append(auc)
    # loss_history.append(loss)
    # accuracy_history.append(acc)

    # concatenate predictions and targets for all batches
    preds_all = np.concatenate(preds_all, axis=0)
    targets_all = np.concatenate(targets_all, axis=0)

    # calculate metrics for all data
    loss = '{:.4f}'.format(loss.item())
    acc_all = '{:.4f}'.format(accuracy_score(targets_all, preds_all.argmax(axis=1)).item())
    recall_all = '{:.4f}'.format(recall_score(targets_all, preds_all.argmax(axis=1), average ='macro').item())
    f1_all = '{:.4f}'.format(f1_score(targets_all, preds_all.argmax(axis=1), average ='macro').item())
    top1_all = '{:.4f}'.format(top1_accuracy(targets_all, preds_all))
    top5_all = '{:.4f}'.format(top1_accuracy(targets_all, preds_all))
    #auc_all = roc_auc_score(targets_all, preds_all[:, 1])

    return loss, acc_all, recall_all, f1_all, top1_all, top5_all

# Дистилляция

In [None]:
def distill(teacher_outputs, student_outputs, temperature):
    teacher_softmax = nn.functional.softmax(teacher_outputs / temperature, dim=1)
    student_softmax = nn.functional.softmax(student_outputs / temperature, dim=1)
    return nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax()(student_outputs / temperature), teacher_softmax)


def destill2():
    # Define the transform for CIFAR-10 dataset
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load the CIFAR-10 dataset
    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

    # Define the teacher model
    teacher_net = models.resnet18(pretrained=True)
    num_ftrs = teacher_net.fc.in_features
    teacher_net.fc = nn.Linear(num_ftrs, 10)
    teacher_net.to(device)

    # Define the student model
    student_net = resnet18().to(device)
    #print(student_net)

    # Define the loss function for both teacher and student models
    criterion = LabelSmoothingCrossEntropyLoss(smoothing=args.label_smoothing)

    # Define the optimizer for both teacher and student models
    # teacher_optimizer = optim.Adam(teacher_net.parameters(), lr=0.001)
    # student_optimizer = optim.Adam(student_net.parameters(), lr=0.001)
    teacher_optimizer =         optimizer = torch.optim.SGD(
            [{'params': teacher_net.parameters(), 'initial_lr': args.lr}], lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
    student_optimizer =         optimizer = torch.optim.SGD(
            [{'params': student_net.parameters(), 'initial_lr': args.lr}], lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)

    lr_scheduler_teacher = torch.optim.lr_scheduler.CosineAnnealingLR(teacher_optimizer, args.epochs, eta_min = 0, last_epoch=-1)
    lr_scheduler_student = torch.optim.lr_scheduler.CosineAnnealingLR(student_optimizer, args.epochs, eta_min = 0, last_epoch=-1)
    # Define the temperature parameter for the softmax function
    temperature = 3.0

    # Define the number of epochs
    num_epochs = 50

    # Train the student model using distillation
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            teacher_net.eval()
            student_optimizer.zero_grad()
            teacher_outputs = teacher_net(inputs)
            student_outputs = student_net(inputs)
            distillation_loss = distill(teacher_outputs, student_outputs, temperature)
            classification_loss = criterion(student_outputs, labels)
            loss = distillation_loss + classification_loss
            loss.backward()
            student_optimizer.step()
            running_loss += loss.item()

            lr_scheduler_teacher.step()
            lr_scheduler_student.step()

        print('Epoch %d, Student Loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))

    # Evaluate the student model on the test dataset
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = student_net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the student network on the 10000 test images: %d %%' % (100 * correct / total))

81 - с такими же оптимазерами, лоссом и расписанием, как и на всех экспериментах, 80 имнут

# Парсим арги

In [None]:
parser = argparse.ArgumentParser(description='Исследование бинарные сетей в PyTorch')

parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
                    help='number of data loading workers')
parser.add_argument('--epochs', default=50, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=128, type=int,
                    metavar='N', help='mini-batch size')
parser.add_argument('--optimizer', default='sgd', type=str,
                    help='which optimizer to use')


parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay')
parser.add_argument('--label-smoothing', '--ls', default=0, type=float,
                    metavar='LS', help='label smoothing')
parser.add_argument('--dropout', default=0.5, type=float,
                    help='dropout')
parser.add_argument('--stoch_depth', default=0, type=float,
                    help='dropout')
parser.add_argument('--rand_augment', default=2, type=int,
                    help='dropout')

parser.add_argument('--binary', default=False, type=bool,
                    help='which type of model to use (default: binary is active)')
parser.add_argument('--architecture', '--arch', default='mobilenetv2()',
                    help='which non bin arch to use (default: cnn_model)')
parser.add_argument('--K', default=1, type=int,
                    help='how many binarize func to use')


parser.add_argument('--pruning', default=False, type=bool,
                    help='')
parser.add_argument('--type_pruning', default='global', type=str,
                    help='')
parser.add_argument('--distillation', default=False, type=bool,
                    )
parser.add_argument('--data_size', default=32, type=int,
                    help='Разрешение изображений в датасете для небинарки. Можно варьировать, если хотим ')
parser.add_argument('--bin_data_size', default=32, type=int,
                    help='Разрешение изображений в датасете для бинарки. Можно варьировать, если хотим ')


parser.add_argument('--dataset', '--ds', default='CIFAR10', type=str,
                    help='which dataset to use')
parser.add_argument('--save_res', default = True, type=bool,
                    help='save the results after the launch. (default: while research is underway - False)')
parser.add_argument('--path_to_save', default = '/content/drive/MyDrive/Experiment_Results', type=str,
                    help='s')

_StoreAction(option_strings=['--path_to_save'], dest='path_to_save', nargs=None, const=None, default='/content/drive/MyDrive/Experiment_Results', type=<class 'str'>, choices=None, required=False, help='s', metavar=None)

# Main

In [None]:
def main():
    global args

    args, unknown = parser.parse_known_args()
    args_list = [(arg, getattr(args, arg)) for arg in vars(args)]

    if args.distillation == True:
        start = time.time()
        destill2()
        end = time.time()
        print((end-start) / 60)
    else:
        train_loss_history = []
        train_acc_history = []
        train_recall_history = []
        train_f1_history = []
        train_top1_history = []
        train_top5_history = []
        valid_results = []

        times = []

        data_size = args.data_size
        model = eval(args.architecture)

        model.to(device)

        #criterion_ce = nn.CrossEntropyLoss(smoothing = args.label_smoothing).to(device)
        criterion_ce = LabelSmoothingCrossEntropyLoss(smoothing=args.label_smoothing)


        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(
                [{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam(
                [{'params': model.parameters, 'initial_lr': args.lr,'weight_decay': args.weight_decay}], lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)

        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min = 0, last_epoch=-1)

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])

        train_loader, val_loader = get_data()

        conv_modules = []
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d):
                #print(module)
                conv_modules.append(module)

        for epoch in range(args.start_epoch, args.epochs):

            if args.binary == True:
                t, k = cpt_tk(epoch)
                for module in conv_modules:
                    module.k = k.to(device)
                    module.t = t.to(device)

            start_train_epoch = time.time()

            train_loss, train_acc, train_recall, train_f1, train_top1, train_top5 = train(model, train_loader, criterion_ce, optimizer, epoch)
            train_loss_history.append(train_loss)
            train_acc_history.append(train_acc)
            train_recall_history.append(train_recall)
            train_f1_history.append(train_f1)
            train_top1_history.append(train_top1)
            train_top5_history.append(train_top5)

            end_train_epoch = time.time()
            train_time = end_train_epoch - start_train_epoch
            times.append(train_time)

            print('Epoch No. {0}. Loss: {1}, Acc: {2}, Recall: {3}, F1: {4}, Top1: {5}, Top5: {6} '.format(epoch+1, train_loss, train_acc, train_recall, train_f1, train_top1, train_top5))
            lr_scheduler.step()


        valid_results.extend(validate(model, val_loader, criterion_ce))
        print('\nValid results is: Loss: {0}, Acc: {1}, Recall: {2}, F1: {3}, Top1: {4}, Top5: {5} \n'.format(valid_results[0], valid_results[1], valid_results[2], valid_results[3], valid_results[4], valid_results[5]))


        max_train_time = '{:.4f}'.format(max(times))
        min_train_time = '{:.4f}'.format(min(times))


        summary(model, input_size=(3, 32, 32))



        inf = 'Обучение MobileNetV2 на 50и эпохах из библиотеки + LS + WD + RA + DP'
        if args.save_res == True:
            save_results(args.binary, args_list, train_loss_history, train_acc_history, train_recall_history, train_f1_history, train_top1_history, train_top5_history, valid_results, min_train_time, max_train_time, inf)


# Пуск

In [None]:
if __name__ == '__main__' :
    main()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 34860189.45it/s]


Extracting ../cifar-10-python.tar.gz to ../
Files already downloaded and verified
Epoch No. 1. Loss: 1.7813, Acc: 0.2726, Recall: 0.2726, F1: 0.2661, Top1: 0.2726, Top5: 0.2726 
Epoch No. 2. Loss: 1.5829, Acc: 0.4370, Recall: 0.4370, F1: 0.4316, Top1: 0.4370, Top5: 0.4370 
Epoch No. 3. Loss: 0.9720, Acc: 0.5329, Recall: 0.5329, F1: 0.5293, Top1: 0.5329, Top5: 0.5329 
Epoch No. 4. Loss: 1.0274, Acc: 0.6043, Recall: 0.6043, F1: 0.6030, Top1: 0.6043, Top5: 0.6043 
Epoch No. 5. Loss: 0.9878, Acc: 0.6623, Recall: 0.6623, F1: 0.6613, Top1: 0.6623, Top5: 0.6623 
Epoch No. 6. Loss: 0.7551, Acc: 0.6983, Recall: 0.6983, F1: 0.6975, Top1: 0.6983, Top5: 0.6983 
Epoch No. 7. Loss: 0.7424, Acc: 0.7307, Recall: 0.7307, F1: 0.7300, Top1: 0.7307, Top5: 0.7307 
Epoch No. 8. Loss: 0.5454, Acc: 0.7510, Recall: 0.7510, F1: 0.7505, Top1: 0.7510, Top5: 0.7510 
Epoch No. 9. Loss: 0.7195, Acc: 0.7656, Recall: 0.7656, F1: 0.7652, Top1: 0.7656, Top5: 0.7656 
Epoch No. 10. Loss: 0.6192, Acc: 0.7860, Recall: 0.786