<a href="https://colab.research.google.com/github/adirmorgan/Private-Network-Inference/blob/main/SNL_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Following - "Selective Network Linearization for Efficient Private Inference"

https://arxiv.org/abs/2202.02340


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')


Mounted at /content/gdrive


In [None]:
!pip install torch torchvision
import torch
import torch.nn as nn
import torchvision
print(torch.__version__)
print(torchvision.__version__)
print(torch.cuda.is_available())


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

# Define the network model

In [None]:
import torch.nn.functional as F

class LearnableAlpha(nn.Module):
    def __init__(self, out_channel, layer_dim):
        super(LearnableAlpha, self).__init__()
        self.alphas = nn.Parameter(torch.ones(1, out_channel, layer_dim, layer_dim), requires_grad=True)

    def forward(self, x):
        out = F.relu(x) * self.alphas.expand_as(x) + (1-self.alphas.expand_as(x)) * x
        return out

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, layer_dim, stride=1):
      super(BasicBlock, self).__init__()
      self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                             stride=stride, padding=1, bias=False)
      self.bn1 = nn.BatchNorm2d(out_channels)
      self.prelu1 = LearnableAlpha(out_channels, layer_dim)
      self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                             stride=1, padding=1, bias=False)
      self.bn2 = nn.BatchNorm2d(out_channels)
      self.prelu2 = LearnableAlpha(out_channels, layer_dim)

      self.shortcut = nn.Sequential()
      # Sequential is a container to a NN, connecting all inputed layers sequentially.
      # Here there are no layers... So it seems just like an "empty object of NN type".
      # Seems just like a default option, the real "shortcut" is the following:

      if stride != 1 or in_channels != out_channels:
          self.shortcut = nn.Sequential(
              nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
              nn.BatchNorm2d(out_channels)
          )

        # In this case, the shortcut is similar to the "first part" of the block,
        # but the kernel size is different. Not sure why

    def forward(self, x):
      out = self.conv1(x)
      out = self.bn1(out)
      out = self.prelu1(out)
      out = self.conv2(out)
      out = self.bn2(out)
      out += self.shortcut(x)
      out = self.prelu2(out)
      return out

class ResNet18_SNL(nn.Module):
    def __init__(self, num_classes=10):
      super(ResNet18_SNL, self).__init__()
      self.in_channels = 64
      self.layer_dim = 32

      self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
      self.bn1 = nn.BatchNorm2d(64)
      self.relu = nn.ReLU(inplace=True)
      self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

      self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)    # 2 BasicBlocks in each ResNet layer
      self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
      self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
      self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)

      self.avgpool = nn.AdaptiveAvgPool2d((1, 1))    # Parameter is output size
      self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
      strides = [stride] + [1] * (num_blocks - 1)    # using stride value for the first layer, then later stride=1
      layers = []
      for stride in strides:
          self.layer_dim = self.layer_dim // 2 if stride == 2 else self.layer_dim
          layers.append(block(self.in_channels, out_channels, self.layer_dim, stride))
          self.in_channels = out_channels
      return nn.Sequential(*layers)
      # Each layer is a sequential cascade of the BasicBlocks sub-layer, I think this allows
      # implicitly to cascade the forward methods of the BasicBlocks

    def forward(self, x):
      out = self.conv1(x)
      out = self.bn1(out)
      out = self.layer1(out)
      out = self.layer2(out)
      out = self.layer3(out)
      out = self.layer4(out)

      out = self.avgpool(out)
      out = out.view(out.size(0), -1)
      out = self.fc(out)
      return out



In [None]:
model = ResNet18_SNL()
print(model)
# Print all learnable parameters in the model
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}")

ResNet18_SNL(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu1): LearnableAlpha()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu2): LearnableAlpha()
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affi

# Load Training data and split to train & val sets

In [None]:
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader, random_split

# Data Preparation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

train_size = int(0.8 * len(trainset))  # 80% for training
val_size = len(trainset) - train_size  # Remaining 20% for validation
train_dataset, val_dataset = random_split(trainset, [train_size, val_size])

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)


100%|██████████| 170M/170M [00:13<00:00, 13.0MB/s]


# Training process preliminaries


In [None]:
pretrain_params = {'epochs': 200,
                   'batch_size': 128,
                   'lr': 0.1,
                   'momentum': 0.9,
                   'weight_decay': 5e-4,  # l2 weights regularization factor
                   'lr_milestones': [100, 150],
                   'gamma': 0.1,  # LR is multiplied by gamma on schedule.
                   }

SNL_training_params = {'epochs': 2000,
                      'batch_size': 128,
                      'lr': 1e-3,
                      'momentum': 0.9,
                      'weight_decay': 5e-4,  # l2 weights regularization factor
                      'lr_step_size': 30,  # How often to decrease learning by gamma
                      'lr_milestones': [80, 120],
                      'gamma': 0.1,  # LR is multiplied by gamma on schedule.
                      }

SNL_params = {'relu_budget': 10000,
              'relu_lin_threshold': 1e-2,
              'initial_lasso_weight': 1e-5,
              'lasso_weight_factor': 1.1
              }

fine_tuning_params = {'epochs': 100,
                      'batch_size': 128,
                      'lr': 1e-3,
                      'lr_step_size': 30,  # How often to decrease learning by gamma
                      'momentum': 0.9,
                      'weight_decay': 5e-4,  # l2 weights regularization factor
                      'gamma': 0.1,  # LR is multiplied by gamma on schedule.
                      }



In [None]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Optimizer, Adam
from torch.optim.lr_scheduler import StepLR, MultiStepLR
import time


In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [None]:
def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer,
          epoch: int, device, print_freq=100, display=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()
#     print("Entered training function")

    # switch to train mode
    model.train()

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs = inputs.to(device)
        targets = targets.to(device)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0 and display == True:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))

    return (losses.avg, top1.avg, top5.avg)

In [None]:
def test(loader: DataLoader, model: torch.nn.Module, criterion, device, print_freq, display=False):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to eval mode
    model.eval()

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(loader):
            # measure data loading time
            data_time.update(time.time() - end)

            inputs = inputs.to(device)
            targets = targets.to(device)

            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0 and display == True:
                print('Test : [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                    i, len(loader), batch_time=batch_time,
                    data_time=data_time, loss=losses, top1=top1, top5=top5))

        print(
            'Test Loss  ({loss.avg:.4f})\t'
            'Test Acc@1 ({top1.avg:.3f})\t'
            'Test Acc@5 ({top5.avg:.3f})'.format(
        loss=losses, top1=top1, top5=top5)
        )

        return (losses.avg, top1.avg, top5.avg)

# Pretraining phase

In [None]:
trainloader = DataLoader(
    train_dataset,  # Training dataset
    batch_size=pretrain_params['batch_size'],  # Number of samples per batch
    shuffle=True,  # Shuffle the training data
    num_workers=2  # Number of subprocesses for data loading
)

valloader = DataLoader(
    val_dataset,  # Validation dataset
    batch_size=pretrain_params['batch_size'],  # Number of samples per batch
    shuffle=False,  # No need to shuffle validation data
    num_workers=2  # Number of subprocesses for data loading
)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet18_SNL()

for name, param in model.named_parameters():
    if 'alpha' in name:
        param.requires_grad = False
        print(param.data)
model.to(device)


tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         ...,

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 

ResNet18_SNL(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu1): LearnableAlpha()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu2): LearnableAlpha()
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affi

In [None]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = SGD(model.parameters(), lr=pretrain_params['lr'], momentum=pretrain_params['momentum'], weight_decay=pretrain_params['weight_decay'])
scheduler = MultiStepLR(optimizer, milestones=pretrain_params['lr_milestones'], gamma=0.1)

best_top1 = 0

for epoch in range(pretrain_params['epochs']):

  # training
  train(trainloader, model, criterion, optimizer, epoch, device)

  # validation
  cur_step = (epoch+1) * len(trainloader)
  _, top1, _ = test(valloader, model, criterion, device, cur_step)
  scheduler.step()

  # save
  if best_top1 < top1:
      best_top1 = top1
      is_best = True
  else:
      is_best = False

  if is_best:
      model_path = '/content/gdrive/My Drive/SNL_results/ResNet_18_SNL_pretrained_10epochs.pth'
      torch.save(model.state_dict(), model_path)

  print("")

print("Best model's validation acc: {:.4%}".format(best_top1 / 100))


Epoch: [0][0/313]	Time 1.698 (1.698)	Data 0.269 (0.269)	Loss 2.4033 (2.4033)	Acc@1 6.250 (6.250)	Acc@5 50.000 (50.000)
Epoch: [0][100/313]	Time 0.124 (0.128)	Data 0.002 (0.005)	Loss 1.7899 (2.5436)	Acc@1 28.906 (19.624)	Acc@5 89.062 (71.419)
Epoch: [0][200/313]	Time 0.112 (0.120)	Data 0.002 (0.004)	Loss 1.8270 (2.2022)	Acc@1 35.938 (24.351)	Acc@5 86.719 (77.903)
Epoch: [0][300/313]	Time 0.113 (0.118)	Data 0.002 (0.004)	Loss 1.5967 (2.0304)	Acc@1 46.094 (28.782)	Acc@5 90.625 (81.227)
Test Loss  (1.7328)	Test Acc@1 (37.110)	Test Acc@5 (87.360)

Epoch: [1][0/313]	Time 0.383 (0.383)	Data 0.319 (0.319)	Loss 1.5365 (1.5365)	Acc@1 42.188 (42.188)	Acc@5 91.406 (91.406)
Epoch: [1][100/313]	Time 0.119 (0.118)	Data 0.012 (0.006)	Loss 1.5324 (1.5983)	Acc@1 46.875 (40.811)	Acc@5 90.625 (90.138)
Epoch: [1][200/313]	Time 0.114 (0.117)	Data 0.002 (0.005)	Loss 1.6092 (1.5439)	Acc@1 42.969 (43.132)	Acc@5 89.844 (91.002)
Epoch: [1][300/313]	Time 0.117 (0.117)	Data 0.002 (0.005)	Loss 1.4653 (1.4943)	Acc@1

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

model_path = '/content/gdrive/My Drive/ResNet_18_full.pth'
torch.save(model.state_dict(), model_path)


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
baseline_model_path = '/content/gdrive/My Drive/SNL_results/ResNet_18_SNL_pretrained_70epochs.pth'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

baseline_model = ResNet18_SNL()
baseline_model.load_state_dict(torch.load(baseline_model_path, weights_only=True, map_location=device))
baseline_model.to(device=device)

ResNet18_SNL(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu1): LearnableAlpha()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu2): LearnableAlpha()
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affi

# Simultaneus training and linearization


In [None]:
def relu_counting(net, threshold):
    relu_count = 0
    for name, param in net.named_parameters():
        if 'alpha' in name:
            boolean_list = param.data > threshold
            relu_count += (boolean_list == 1).sum()
    return relu_count

In [None]:
def simultaneus_training_lasso_loss(loader, model, criterion, optimizer, lasso_coef):
  losses = AverageMeter()

  # switch to train mode
  model.train()

  for i, (inputs, targets) in enumerate(loader):
      inputs = inputs.to(device)
      targets = targets.to(device)

      reg_loss = 0
      for name, param in model.named_parameters():
          if 'alpha' in name:
              reg_loss += torch.norm(param, p=1)

      # compute output
      outputs = model(inputs)
      cur_relu_count = relu_counting(model, SNL_params['relu_lin_threshold'])

      # print(f"net loss: {criterion(outputs, targets)}")
      # print(f"reg loss: {reg_loss}")
      # print(f"cur relu count: {cur_relu_count}")
      # print(f"total loss: {criterion(outputs, targets) + lasso_coef * reg_loss}")

      loss = criterion(outputs, targets) + lasso_coef * reg_loss

      losses.update(loss.item(), inputs.size(0))

      # compute gradient and do SGD step
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

  return losses.avg

In [None]:
import copy
model = copy.deepcopy(baseline_model)

# Enable alpha training
for name, param in model.named_parameters():
    if 'alpha' in name:
        param.requires_grad = True

model.to(device)


ResNet18_SNL(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu1): LearnableAlpha()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu2): LearnableAlpha()
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affi

In [None]:
print(f"Initial Number of ReLUs: {relu_counting(model, SNL_params['relu_lin_threshold'])}")

Initial Number of ReLUs: 491520


In [None]:
trainloader = DataLoader(
    train_dataset,  # Training dataset
    batch_size=SNL_training_params['batch_size'],  # Number of samples per batch
    shuffle=True,  # Shuffle the training data
    num_workers=2  # Number of subprocesses for data loading
)

valloader = DataLoader(
    val_dataset,  # Validation dataset
    batch_size=SNL_training_params['batch_size'],  # Number of samples per batch
    shuffle=False,  # No need to shuffle validation data
    num_workers=2  # Number of subprocesses for data loading
)


In [None]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = Adam(model.parameters(), lr=SNL_training_params['lr'])

init_relu_count = relu_counting(model, SNL_params['relu_lin_threshold'])
lowest_relu_count = init_relu_count

lasso_weight = SNL_params['initial_lasso_weight']

for epoch in range(SNL_training_params['epochs']):
    # Simultaneous tarining of w and alpha with KD loss.
    train_loss = simultaneus_training_lasso_loss(trainloader, model, criterion,
                                                 optimizer, lasso_weight)

    # validation
    cur_step = (epoch+1) * len(trainloader)
    _, top1, _ = test(valloader, model, criterion, device, cur_step)

    # counting ReLU in the neural network by using threshold.
    cur_relu_count = relu_counting(model, SNL_params['relu_lin_threshold'])
    print(f"Epoch: {epoch}, Test Accuracy: {top1:.4f}, ReLU Count: {cur_relu_count}, Lasso weight: {lasso_weight:.6f}")


    # Lasso weight increment
    if cur_relu_count < lowest_relu_count:
        lowest_relu_count = cur_relu_count

    elif cur_relu_count >= lowest_relu_count and epoch >= 5:
        lasso_weight *= SNL_params['lasso_weight_factor']

    if cur_relu_count <= SNL_params['relu_budget']:
        print(f"Achieved relu budget after epoch {epoch}")
        break

print(f"After SNL Algorithm, the current ReLU Count: {cur_relu_count}, relative count: {cur_relu_count / init_relu_count:.6f}")



Test Loss  (0.5574)	Test Acc@1 (83.160)	Test Acc@5 (99.080)
Epoch: 0, Test Accuracy: 83.1600, ReLU Count: 491520, Lasso weight: 0.000010
Test Loss  (0.5403)	Test Acc@1 (84.040)	Test Acc@5 (99.270)
Epoch: 1, Test Accuracy: 84.0400, ReLU Count: 394543, Lasso weight: 0.000010
Test Loss  (0.5922)	Test Acc@1 (83.940)	Test Acc@5 (98.910)
Epoch: 2, Test Accuracy: 83.9400, ReLU Count: 356833, Lasso weight: 0.000010
Test Loss  (0.6059)	Test Acc@1 (83.980)	Test Acc@5 (98.840)
Epoch: 3, Test Accuracy: 83.9800, ReLU Count: 328283, Lasso weight: 0.000010
Test Loss  (0.6766)	Test Acc@1 (83.370)	Test Acc@5 (98.800)
Epoch: 4, Test Accuracy: 83.3700, ReLU Count: 302649, Lasso weight: 0.000010
Test Loss  (0.6646)	Test Acc@1 (83.460)	Test Acc@5 (98.930)
Epoch: 5, Test Accuracy: 83.4600, ReLU Count: 279142, Lasso weight: 0.000010
Test Loss  (0.6854)	Test Acc@1 (83.570)	Test Acc@5 (98.990)
Epoch: 6, Test Accuracy: 83.5700, ReLU Count: 258417, Lasso weight: 0.000010
Test Loss  (0.6855)	Test Acc@1 (83.880)	T

In [None]:
model_path = '/content/gdrive/My Drive/SNL_results/ResNet_18_linearized_10K.pth'
torch.save(model.state_dict(), model_path)


# Fine-tuning training phase


In [None]:
linearized_model_path = '/content/gdrive/My Drive/SNL_results/ResNet_18_linearized_10K.pth'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

linearized_model = ResNet18_SNL()
linearized_model.load_state_dict(torch.load(linearized_model_path, weights_only=True, map_location=device))
linearized_model.to(device=device)

ResNet18_SNL(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu1): LearnableAlpha()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu2): LearnableAlpha()
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affi

In [None]:
import copy

rounded_model = copy.deepcopy(linearized_model)
rounded_model.to(device)

# Round and Freeze alpha params
for name, param in rounded_model.named_parameters():
    if 'alpha' in name:
        param.data = (param.data > SNL_params['relu_lin_threshold']).float()
        param.requires_grad = False


In [None]:
trainloader = DataLoader(
    train_dataset,  # Training dataset
    batch_size=fine_tuning_params['batch_size'],  # Number of samples per batch
    shuffle=True,  # Shuffle the training data
    num_workers=2  # Number of subprocesses for data loading
)

valloader = DataLoader(
    val_dataset,  # Validation dataset
    batch_size=fine_tuning_params['batch_size'],  # Number of samples per batch
    shuffle=False,  # No need to shuffle validation data
    num_workers=2  # Number of subprocesses for data loading
)


In [None]:
finetuned_model = copy.deepcopy(rounded_model)
finetuned_model.to(device)

frozen, unfrozen = [], []
for name, param in finetuned_model.named_parameters():
    if param.requires_grad:
        unfrozen.append(name)
    else:
        frozen.append(name)

frozen_str = '\n'.join(frozen)
print(f"---Frozen Layers---\n{frozen_str}")
unfrozen_str = '\n'.join(unfrozen)
print(f"---Unfrozen Layers---\n{unfrozen_str}")

---Frozen Layers---
layer1.0.prelu1.alphas
layer1.0.prelu2.alphas
layer1.1.prelu1.alphas
layer1.1.prelu2.alphas
layer2.0.prelu1.alphas
layer2.0.prelu2.alphas
layer2.1.prelu1.alphas
layer2.1.prelu2.alphas
layer3.0.prelu1.alphas
layer3.0.prelu2.alphas
layer3.1.prelu1.alphas
layer3.1.prelu2.alphas
layer4.0.prelu1.alphas
layer4.0.prelu2.alphas
layer4.1.prelu1.alphas
layer4.1.prelu2.alphas
---Unfrozen Layers---
conv1.weight
bn1.weight
bn1.bias
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.shortcut.0.weight
layer2.0.shortcut.1.weight
layer2.0.shortcut.1.bias
layer2.1.conv1.weight
layer2.1.bn1.weight
layer2.1.bn1.bias
layer2.1.conv2.weight
layer2.1.bn2.weight
layer2.1.bn2.bia

In [None]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = SGD(finetuned_model.parameters(), lr=fine_tuning_params['lr'], momentum=fine_tuning_params['momentum'], weight_decay=fine_tuning_params['weight_decay'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, fine_tuning_params['epochs'])

best_top1 = 0

for epoch in range(fine_tuning_params['epochs']):
  # training
  train(trainloader, finetuned_model, criterion, optimizer, epoch, device)

  # validation
  cur_step = (epoch+1) * len(trainloader)
  _, top1, _ = test(valloader, finetuned_model, criterion, device, cur_step)
  scheduler.step()

  # save
  if best_top1 < top1:
      best_top1 = top1
      is_best = True
  else:
      is_best = False

  if is_best:
      model_path = '/content/gdrive/My Drive/SNL_results/ResNet_18_SNL_linearized_10K_finetuned.pth'
      torch.save(finetuned_model.state_dict(), model_path)

  print("")

print("Best model's validation acc: {:.4%}".format(best_top1 / 100))

Epoch: [0][0/313]	Time 0.272 (0.272)	Data 0.156 (0.156)	Loss 0.0891 (0.0891)	Acc@1 97.656 (97.656)	Acc@5 100.000 (100.000)
Epoch: [0][100/313]	Time 0.124 (0.127)	Data 0.002 (0.005)	Loss 0.1417 (0.1902)	Acc@1 97.656 (94.462)	Acc@5 99.219 (99.714)
Epoch: [0][200/313]	Time 0.127 (0.127)	Data 0.002 (0.004)	Loss 0.2055 (0.1870)	Acc@1 94.531 (94.652)	Acc@5 100.000 (99.724)
Epoch: [0][300/313]	Time 0.128 (0.127)	Data 0.002 (0.004)	Loss 0.0765 (0.1922)	Acc@1 97.656 (94.544)	Acc@5 100.000 (99.702)
Test Loss  (0.2008)	Test Acc@1 (94.400)	Test Acc@5 (99.770)

Epoch: [1][0/313]	Time 0.221 (0.221)	Data 0.162 (0.162)	Loss 0.1651 (0.1651)	Acc@1 92.969 (92.969)	Acc@5 99.219 (99.219)
Epoch: [1][100/313]	Time 0.128 (0.129)	Data 0.002 (0.005)	Loss 0.2929 (0.1871)	Acc@1 90.625 (94.740)	Acc@5 99.219 (99.691)
Epoch: [1][200/313]	Time 0.130 (0.129)	Data 0.002 (0.004)	Loss 0.1322 (0.1829)	Acc@1 95.312 (94.819)	Acc@5 100.000 (99.708)
Epoch: [1][300/313]	Time 0.129 (0.129)	Data 0.002 (0.004)	Loss 0.1508 (0.1865

KeyboardInterrupt: 

# Modified Lasso Training

In [None]:
def simultaneus_training_lasso_loss_modified(loader, model, criterion, optimizer, lasso_coef):
  losses = AverageMeter()

  # switch to train mode
  model.train()

  for i, (inputs, targets) in enumerate(loader):
      inputs = inputs.to(device)
      targets = targets.to(device)

      reg_loss = 0
      for name, param in model.named_parameters():
          if 'alpha' in name:
              reg_loss += torch.norm(param, p=1)

      # compute output
      outputs = model(inputs)
      cur_relu_count = relu_counting(model, SNL_params['relu_lin_threshold'])

      # print(f"net loss: {criterion(outputs, targets)}")
      # print(f"reg loss: {reg_loss}")
      # print(f"cur relu count: {cur_relu_count}")
      # print(f"total loss: {criterion(outputs, targets) + lasso_coef * reg_loss}")

      loss = criterion(outputs, targets) + lasso_coef * reg_loss

      losses.update(loss.item(), inputs.size(0))

      # compute gradient and do SGD step
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

  return losses.avg

#Testing

In [None]:
finetuned_model_path = '/content/gdrive/My Drive/SNL_results/ResNet_18_SNL_linearized_10K_finetuned.pth'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

finetuned_model = ResNet18_SNL()
finetuned_model.load_state_dict(torch.load(finetuned_model_path, weights_only=True, map_location=device))
finetuned_model.to(device=device)

ResNet18_SNL(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu1): LearnableAlpha()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu2): LearnableAlpha()
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affi

In [None]:
# Testing
test_dataset = torchvision.datasets.CIFAR10(
    root='./data',  # Directory to save/load the dataset
    train=False,  # Load training set
    download=True,  # Download dataset if not already present
    transform=transform  # Apply the defined transforms
)

# Step 3: Create a DataLoader
testloader = DataLoader(
    test_dataset,  # Dataset to load
    batch_size=32,  # Number of samples per batch
    shuffle=True,  # Shuffle the data
    num_workers=2  # Number of subprocesses for data loading
)

for images, labels in testloader:
    print(f"Batch size: {images.size(0)}, Image shape: {images.size()[1:]}, Labels: {labels}")
    break  # Just to check one batch

Files already downloaded and verified
Batch size: 32, Image shape: torch.Size([3, 32, 32]), Labels: tensor([5, 9, 4, 7, 3, 1, 5, 5, 6, 9, 4, 5, 4, 6, 6, 8, 4, 0, 0, 0, 5, 3, 4, 5,
        0, 8, 5, 7, 9, 3, 6, 8])


In [None]:
def eval_model(loader, model, criterion, run_name):
  model.eval()

  with torch.no_grad():
      total = 0
      correct = 0
      running_loss = 0.0

      for data in loader:  # Iterate over the test data loader
          inputs, labels = data[0].to(device), data[1].to(device)

          outputs = model(inputs)  # Forward pass
          loss = criterion(outputs, labels)  # Compute loss

          running_loss += loss.item()  # Accumulate loss

          # Get predictions (class with the highest probability)
          _, predicted = torch.max(outputs, 1)

          # Update the total number of samples and correct predictions
          total += labels.size(0)
          correct += (predicted == labels).sum().item()

      # Calculate average loss and accuracy
      avg_loss = running_loss / len(loader)
      acc = 100 * correct / total

  print(f'{run_name} Loss: {avg_loss:.4f}, {run_name} Accuracy: {acc:.2f}%')

criterion = nn.CrossEntropyLoss().to(device)

# print("pretrained model")
# eval_model(trainloader, baseline_model, criterion, 'train')
# eval_model(valloader, baseline_model, criterion, 'val')
# eval_model(testloader, baseline_model, criterion, 'test')

# print("linearized model")
# eval_model(trainloader, linearized_model, criterion, 'train')
# eval_model(valloader, linearized_model, criterion, 'val')
# eval_model(testloader, linearized_model, criterion, 'test')

# print("rounded model")
# eval_model(trainloader, rounded_model, criterion, 'train')
# eval_model(valloader, rounded_model, criterion, 'val')
# eval_model(testloader, rounded_model, criterion, 'test')

print("finetuned model")
eval_model(trainloader, finetuned_model, criterion, 'train')
eval_model(valloader, finetuned_model, criterion, 'val')
eval_model(testloader, finetuned_model, criterion, 'test')

finetuned model
train Loss: 0.1391, train Accuracy: 96.25%
val Loss: 0.1768, val Accuracy: 94.89%
test Loss: 0.7478, test Accuracy: 79.47%


Notice:


*   Performance improvement method - knowledge distillation, was used in the paper and not here. Hence I accept a difference in performance of several percents accuracy to their favor.
*   In addition, in most steps I cut the training (much) earlier then required in the paper.
*   Some hyperparameters are also inaccurate

