In [1]:
import os

# Check if the notebook is running on Colab
if 'COLAB_GPU' in os.environ:
    # This block will run only in Google Colab
    IN_COLAB = True
    print("Running on Google Colab. Cloning the repository.")
else:
    # This block will run if not in Google Colab
    IN_COLAB = False
    print("Not running on Google Colab. Assuming local environment.")

if IN_COLAB:
    from google.colab import files, drive
    drive.mount('/content/drive')

DRIVE_PATH = '/content/drive/My Drive/resnet18_slt_models/init_signed_constant/' if IN_COLAB else './'

Running on Google Colab. Cloning the repository.
Mounted at /content/drive


In [1]:
%pip install wandb

Collecting wandb
  Downloading wandb-0.16.4-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting loss-landscapes
  Downloading loss_landscapes-3.0.6-py3-none-any.whl (72 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.0/72.0 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.4/195.4 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.41.0-py2.py3-none-any.whl (258 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.8/258.8 kB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle

In [3]:
!wandb login --relogin

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
Aborted!
Exception ignored in atexit callback: <function _Manager._atexit_setup.<locals>.<lambda> at 0x7c3f4989dab0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_manager.py", line 156, in <lambda>
    self._atexit_lambda = lambda: self._atexit_teardown()
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_manager.py", line 165, in _atexit_teardown
    self._teardown(exit_code)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_manager.py", line 176, in _teardown
    result = self._service.join()
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/service/service.py", line 263, in join
    ret = self.

In [2]:
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd
# import wandb

In [3]:
class StatisticsTracker:
    instance = None

    def __new__(cls):
        if cls.instance is None:
            cls.instance = super(StatisticsTracker, cls).__new__(cls)
            cls.zero_percentage = 0
            cls.one_percentage = 0
        return cls.instance

    def update(self, zeros, ones, total):
        self.zero_percentage = (zeros / total) * 100
        self.one_percentage = (ones / total) * 100

    def get_statistics(self):
        return self.zero_percentage, self.one_percentage


class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores, k):
        # Get the subnetwork by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1 - k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        # print("\n--------------")
        # print("J: ", j)
        # print("bottom: ", len(flat_out[idx[:j]]))
        # print("top: ", len(flat_out[idx[j:]]))
        # print("total: ", len(flat_out))
        # print("total2: ", len(flat_out[idx]))

        # Update statistics
        tracker = StatisticsTracker()
        zeros = (flat_out == 0).sum().item()
        ones = (flat_out == 1).sum().item()
        total = flat_out.numel()
        tracker.update(zeros, ones, total)

        return out

    @staticmethod
    def backward(ctx, g):
        # send the gradient g straight-through on the backward pass.
        return g, None

class NonAffineBatchNorm(nn.BatchNorm2d):
    def __init__(self, dim):
        super(NonAffineBatchNorm, self).__init__(dim, affine=False)

class SubnetConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))
        self.prune_rate = prune_rate

    def set_prune_rate(self, prune_rate):
        self.prune_rate = prune_rate

    @property
    def clamped_scores(self):
        return self.scores.abs()

    def forward(self, x):
        subnet = GetSubnet.apply(self.clamped_scores, self.prune_rate)
        w = self.weight * subnet

        # # Add check here
        # alive_weights = w.nonzero().size(0)  # Count non-zero elements in w
        # total_weights = w.numel()  # Total number of elements in w
        # percentage_alive = (alive_weights / total_weights) * 100
        # print(f"Percentage of alive weights: {percentage_alive:.2f}%")

        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

In [5]:
batch_size = 256
test_batch_size = 1000
epochs = 200
lr = 0.1
momentum = 0.9
weight_decay = 0.0005
nesterov=False
log_interval = 10
data_path = "data"
prune_rate = 0.3
# prune_rate=0.025
save_model = True
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
last_layer_dense = True
num_class = 10
nonlinearity = "relu"
# init = "kaiming_normal"
init = "signed_constant"
mode = "fan_in"
scale_fan = False
conv_type = SubnetConv
bn_type = NonAffineBatchNorm
first_layer_type = None
first_layer_dense = True
gpu = None

In [6]:
# train_loader = torch.utils.data.DataLoader(datasets.MNIST(os.path.join(data_path, 'mnist'), train=True, download=True,
#                                                           transform=transforms.Compose([transforms.ToTensor(),
#                                                                                         transforms.Normalize((0.1307,), (0.3081,))])),
#                                            batch_size=batch_size, shuffle=True)

# test_loader = torch.utils.data.DataLoader(datasets.MNIST(os.path.join(data_path, 'mnist'), train=False,
#                                                          transform=transforms.Compose([transforms.ToTensor(),
#                                                                                        transforms.Normalize((0.1307,), (0.3081,))])),
#                                           batch_size=test_batch_size, shuffle=True)


normalize = transforms.Normalize(
            mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]
        )


train_set = datasets.CIFAR10(root=os.path.join(data_path, 'cifar10'),
                             train=True,
                             download=True,
                             transform=transforms.Compose(
                                                        [
                                                            transforms.RandomCrop(32, padding=4),
                                                            transforms.RandomHorizontalFlip(),
                                                            transforms.ToTensor(),
                                                            normalize,
                                                        ]
                                                    ))
test_set = datasets.CIFAR10(root=os.path.join(data_path, 'cifar10'),
                            train=False,
                            download=True,
                            transform=transforms.Compose([transforms.ToTensor(), normalize]))

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_batch_size, shuffle=False, drop_last=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:07<00:00, 24256503.96it/s]


Extracting data/cifar10/cifar-10-python.tar.gz to data/cifar10
Files already downloaded and verified


In [7]:
class Builder(object):
    def __init__(self, conv_layer, bn_layer, first_layer=None):
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.first_layer = first_layer or conv_layer


    def conv(self, kernel_size, in_planes, out_planes, stride=1, first_layer=False):
        conv_layer = self.first_layer if first_layer else self.conv_layer

        if first_layer:
            print(f"==> Building first layer with {str(self.first_layer)}")

        if kernel_size == 3:
            conv = conv_layer(
                in_planes,
                out_planes,
                kernel_size=3,
                stride=stride,
                padding=1,
                bias=False,
            )
        elif kernel_size == 1:
            conv = conv_layer(
                in_planes, out_planes, kernel_size=1, stride=stride, bias=False
            )
        elif kernel_size == 5:
            conv = conv_layer(
                in_planes,
                out_planes,
                kernel_size=5,
                stride=stride,
                padding=2,
                bias=False,
            )
        elif kernel_size == 7:
            conv = conv_layer(
                in_planes,
                out_planes,
                kernel_size=7,
                stride=stride,
                padding=3,
                bias=False,
            )
        else:
            return None

        self._init_conv(conv)

        return conv


    def conv3x3(self, in_planes, out_planes, stride=1, first_layer=False):
        """3x3 convolution with padding"""
        c = self.conv(3, in_planes, out_planes, stride=stride, first_layer=first_layer)
        return c

    def conv1x1(self, in_planes, out_planes, stride=1, first_layer=False):
        """1x1 convolution with padding"""
        c = self.conv(1, in_planes, out_planes, stride=stride, first_layer=first_layer)
        return c

    def conv7x7(self, in_planes, out_planes, stride=1, first_layer=False):
        """7x7 convolution with padding"""
        c = self.conv(7, in_planes, out_planes, stride=stride, first_layer=first_layer)
        return c

    def conv5x5(self, in_planes, out_planes, stride=1, first_layer=False):
        """5x5 convolution with padding"""
        c = self.conv(5, in_planes, out_planes, stride=stride, first_layer=first_layer)
        return c

    def batchnorm(self, planes, last_bn=False, first_layer=False):
        return self.bn_layer(planes)

    def activation(self):
        if nonlinearity == "relu":
            return (lambda: nn.ReLU(inplace=True))()
        else:
            raise ValueError(f"{nonlinearity} is not an initialization option!")

    def _init_conv(self, conv):
        if init == "signed_constant":

            fan = nn.init._calculate_correct_fan(conv.weight, mode)
            if scale_fan:
                fan = fan * (1 - prune_rate)
            gain = nn.init.calculate_gain(nonlinearity)
            std = gain / math.sqrt(fan)
            conv.weight.data = conv.weight.data.sign() * std

        elif init == "unsigned_constant":

            fan = nn.init._calculate_correct_fan(conv.weight, mode)
            if scale_fan:
                fan = fan * (1 - prune_rate)

            gain = nn.init.calculate_gain(nonlinearity)
            std = gain / math.sqrt(fan)
            conv.weight.data = torch.ones_like(conv.weight.data) * std

        elif init == "kaiming_normal":

            if scale_fan:
                fan = nn.init._calculate_correct_fan(conv.weight, mode)
                fan = fan * (1 - prune_rate)
                gain = nn.init.calculate_gain(nonlinearity)
                std = gain / math.sqrt(fan)
                with torch.no_grad():
                    conv.weight.data.normal_(0, std)
            else:
                nn.init.kaiming_normal_(
                    conv.weight, mode=mode, nonlinearity=nonlinearity
                )

        elif init == "kaiming_uniform":
            nn.init.kaiming_uniform_(
                conv.weight, mode=mode, nonlinearity=nonlinearity
            )
        elif init == "xavier_normal":
            nn.init.xavier_normal_(conv.weight)
        elif init == "xavier_constant":

            fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(conv.weight)
            std = math.sqrt(2.0 / float(fan_in + fan_out))
            conv.weight.data = conv.weight.data.sign() * std

        elif init == "standard":

            nn.init.kaiming_uniform_(conv.weight, a=math.sqrt(5))

        else:
            raise ValueError(f"{init} is not an initialization option!")


def get_builder():

    print("==> Conv Type: {}".format(conv_type))
    print("==> BN Type: {}".format(bn_type))

    #conv_layer = getattr(utils.conv_type, conv_type)
    #bn_layer = getattr(utils.bn_type, bn_type)
    conv_layer = conv_type
    bn_layer = bn_type

    if first_layer_type is not None:
        first_layer = getattr(conv_type, first_layer_type)
        print(f"==> First Layer Type: {first_layer_type}")
    else:
        first_layer = None

    builder = Builder(conv_layer=conv_layer, bn_layer=bn_layer, first_layer=first_layer)
    #builder = Builder(conv_layer=None, bn_layer=None, first_layer=None)

    return builder

In [8]:
class BasicBlockCIFAR(nn.Module):
    expansion = 1

    def __init__(self, builder, in_planes, planes, stride=1):
        super(BasicBlockCIFAR, self).__init__()
        self.conv1 = builder.conv3x3(in_planes, planes, stride=stride)
        self.bn1 = builder.batchnorm(planes)
        self.conv2 = builder.conv3x3(planes, planes, stride=1)
        self.bn2 = builder.batchnorm(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                builder.conv1x1(in_planes, self.expansion * planes, stride=stride),
                builder.batchnorm(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class BottleneckCIFAR(nn.Module):
    expansion = 4

    def __init__(self, builder, in_planes, planes, stride=1):
        super(BottleneckCIFAR, self).__init__()
        self.conv1 = builder.conv1x1(in_planes, planes)
        self.bn1 = builder.batchnorm(planes)
        self.conv2 = builder.conv3x3(planes, planes, stride=stride)
        self.bn2 = builder.batchnorm(planes)
        self.conv3 = builder.conv1x1(planes, self.expansion * planes)
        self.bn3 = builder.batchnorm(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                builder.conv1x1(in_planes, self.expansion * planes, stride=stride),
                builder.batchnorm(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)

        return out


class ResNetCIFAR(nn.Module):
    def __init__(self, builder, block, num_blocks):
        super(ResNetCIFAR, self).__init__()
        self.in_planes = 64
        self.builder = builder

        self.conv1 = builder.conv3x3(3, 64, stride=1, first_layer=True)
        self.bn1 = builder.batchnorm(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)

        if last_layer_dense:
            self.fc = nn.Conv2d(512 * block.expansion, 10, 1)
        else:
            self.fc = builder.conv1x1(512 * block.expansion, 10)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.builder, self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = self.fc(out)
        return out.flatten(1)

In [9]:
import abc
from torch.utils.tensorboard import SummaryWriter

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch, tqdm_writer=True):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        if not tqdm_writer:
            print("\t".join(entries))
        else:
            tqdm.tqdm.write("\t".join(entries))

    def write_to_tensorboard(
        self, writer: SummaryWriter, prefix="train", global_step=None
    ):
        for meter in self.meters:
            avg = meter.avg
            val = meter.val
            if meter.write_val:
                writer.add_scalar(
                    f"{prefix}/{meter.name}_val", val, global_step=global_step
                )

            if meter.write_avg:
                writer.add_scalar(
                    f"{prefix}/{meter.name}_avg", avg, global_step=global_step
                )

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"

class Meter(object):
    @abc.abstractmethod
    def __init__(self, name, fmt=":f"):
        pass

    @abc.abstractmethod
    def reset(self):
        pass

    @abc.abstractmethod
    def update(self, val, n=1):
        pass

    @abc.abstractmethod
    def __str__(self):
        pass

class AverageMeter(Meter):
    """ Computes and stores the average and current value """

    def __init__(self, name, fmt=":f", write_val=True, write_avg=True):
        self.name = name
        self.fmt = fmt
        self.reset()

        self.write_val = write_val
        self.write_avg = write_avg

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)

In [10]:
import copy

def get_initial_weights(model):
    initial_weights = {}
    for name, param in model.named_parameters():
        if 'weight' in name:  # This checks if the parameter is a weight
            initial_weights[name] = copy.deepcopy(param.data)
    return initial_weights

def get_initial_scores(model):
    initial_scores = {}
    for name, param in model.named_parameters():
        if 'scores' in name:  # This checks if the parameter is a score
            initial_scores[name] = copy.deepcopy(param.data)
    return initial_scores

def check_weights_change(initial_weights, model):
    counter = 0
    for name, param in model.named_parameters():
        if 'weight' in name:  # Again, checking for weights
            initial_weight = initial_weights[name]
            final_weight = param.data
            # Calculate the absolute difference and check if any values are non-zero
            difference = (initial_weight - final_weight).abs().sum().item()
            if difference > 0:
                counter += 1
                # print(f"Weight '{name}' has changed by {difference}.")

    if counter == 0:
        print("No weights changed.")
    else:
        print("Weihts changed!")

# def check_weights_change(initial_weights, model):
#     total_weights = 0
#     changed_weights = 0

#     for name, param in model.named_parameters():
#         if 'weight' in name:  # Focus on weight parameters
#             initial_weight = initial_weights[name]
#             final_weight = param.data

#             # Count the total number of weights
#             total_weights += param.numel()

#             # Calculate the absolute difference
#             difference = (initial_weight - final_weight).abs()

#             # Count how many weights have changed (difference > 0)
#             changed_weights += difference.gt(0).sum().item()

#     if total_weights == 0:
#         print("No weights to check.")
#         return

#     # Calculate the percentage of changed weights
#     percentage_changed = (changed_weights / total_weights) * 100

#     print(f"{changed_weights} out of {total_weights} weights changed ({percentage_changed:.2f}%).")

#     # Optionally, you can return this percentage if you want to use it outside the function
#     return percentage_changed


def check_scores_change(initial_scores, model):
    counter = 0
    for name, param in model.named_parameters():
        if 'scores' in name:  # Checking for scores
            initial_score = initial_scores[name]
            final_score = param.data
            # Calculate the absolute difference and check if any values are non-zero
            difference = (initial_score - final_score).abs().sum().item()
            if difference > 0:
                counter += 1

    if counter == 0:
        print("No scores changed.")
    else:
        print("Scores changed!")

In [11]:
def freeze_model_weights(model):
    print("=> Freezing model weights")

    for n, m in model.named_modules():
        if hasattr(m, "weight") and m.weight is not None:
            print(f"==> No gradient to {n}.weight")
            m.weight.requires_grad = False
            if m.weight.grad is not None:
                print(f"==> Setting gradient of {n}.weight to None")
                m.weight.grad = None

            if hasattr(m, "bias") and m.bias is not None:
                print(f"==> No gradient to {n}.bias")
                m.bias.requires_grad = False

                if m.bias.grad is not None:
                    print(f"==> Setting gradient of {n}.bias to None")
                    m.bias.grad = None


def freeze_model_subnet(model):
    print("=> Freezing model subnet")

    for n, m in model.named_modules():
        if hasattr(m, "scores"):
            m.scores.requires_grad = False
            print(f"==> No gradient to {n}.scores")
            if m.scores.grad is not None:
                print(f"==> Setting gradient of {n}.scores to None")
                m.scores.grad = None


def unfreeze_model_weights(model):
    print("=> Unfreezing model weights")

    for n, m in model.named_modules():
        if hasattr(m, "weight") and m.weight is not None:
            print(f"==> Gradient to {n}.weight")
            m.weight.requires_grad = True
            if hasattr(m, "bias") and m.bias is not None:
                print(f"==> Gradient to {n}.bias")
                m.bias.requires_grad = True


def unfreeze_model_subnet(model):
    print("=> Unfreezing model subnet")

    for n, m in model.named_modules():
        if hasattr(m, "scores"):
            print(f"==> Gradient to {n}.scores")
            m.scores.requires_grad = True


def set_model_prune_rate(model, prune_rate):
    print(f"==> Setting prune rate of network to {prune_rate}")

    for n, m in model.named_modules():
        if hasattr(m, "set_prune_rate"):
            m.set_prune_rate(prune_rate)
            print(f"==> Setting prune rate of {n} to {prune_rate}")


def get_optimizer(model):
    parameters = list(model.named_parameters())
    bn_params = [v for n, v in parameters if ("bn" in n) and v.requires_grad]
    rest_params = [v for n, v in parameters if ("bn" not in n) and v.requires_grad]
    optimizer = torch.optim.SGD(
        [
            {
                "params": bn_params,
                "weight_decay": weight_decay,
            },
            {"params": rest_params, "weight_decay": weight_decay},
        ],
        lr,
        momentum=momentum,
        weight_decay=weight_decay,
        nesterov=nesterov,
    )

    return optimizer

In [12]:
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        running_loss += loss.item()

    avg_training_loss = running_loss / len(train_loader)
    # wandb.log({"epoch": epoch+1, "Training Loss": avg_training_loss})


def test(model, device, criterion, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target)
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    # wandb.log({"epoch": epoch+1, "Validation Loss": test_loss})


def modifier(prune_rate, epoch, model):
    if epoch == 0:
        set_model_prune_rate(model, prune_rate=0.0)
        freeze_model_subnet(model)
        unfreeze_model_weights(model)
        return prune_rate

    elif epoch == 6:
        set_model_prune_rate(model, prune_rate=prune_rate)
        unfreeze_model_subnet(model)
        freeze_model_weights(model)
        return prune_rate

    # elif epoch % 5 == 0 and epoch < 50 and epoch > 6:
    #     print("Changed prune rate to ", prune_rate-0.1)
    #     set_model_prune_rate(model, prune_rate=prune_rate-0.1)
    #     return prune_rate-0.1

    # elif epoch == 55:
    #     print("Changing prunning rate to 2.5%")
    #     set_model_prune_rate(model, prune_rate=0.025)
    #     return 0.025

    return prune_rate



In [13]:
import math

prune_rate = 0.3
# checkpoint_path = os.path.join(DRIVE_PATH, "resnet18_slt_epoch_99_30.pth")
# checkpoint = torch.load(checkpoint_path, map_location=device)
# model = ResNetCIFAR(get_builder(), BasicBlockCIFAR, [2,2,2,2])
# model.load_state_dict(checkpoint['model_state_dict'])

model = ResNetCIFAR(get_builder(), BasicBlockCIFAR, [2,2,2,2])
model.to(device)
initial_weights = get_initial_weights(model)
initial_scores = get_initial_scores(model)

set_model_prune_rate(model, prune_rate=prune_rate)
freeze_model_weights(model)


# optimizer = optim.SGD([p for p in model.parameters() if p.requires_grad],
#                       lr=lr,
#                       momentum=momentum,
#                       weight_decay=weight_decay)

optimizer = get_optimizer(model)
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

criterion = nn.CrossEntropyLoss().to(device)

scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

curr_prune_rate = prune_rate
for epoch in range(0, epochs):
    print("\n-------------------------------------")
    check_weights_change(initial_weights, model)
    check_scores_change(initial_scores, model)

    curr_prune_rate = modifier(curr_prune_rate, epoch, model)
    train(model, device, train_loader, optimizer, criterion, epoch)

    # compute the percentage of weights used in the forward pass
    tracker = StatisticsTracker()
    zero_percentage, one_percentage = tracker.get_statistics()
    print(f"Percentage of weights set to zero: {zero_percentage}%")
    print(f"Percentage of weights set to one: {one_percentage}%")

    test(model, device, criterion, test_loader)
    scheduler.step()

    one_percent_formated = math.floor(one_percentage)
    if epoch > 6:
        model_save_path = os.path.join(DRIVE_PATH, f"resnet18_slt_epoch_{epoch}_{one_percent_formated}_final.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        }, model_save_path)

        print(f"Model saved to {model_save_path} after epoch {epoch}")



==> Conv Type: <class '__main__.SubnetConv'>
==> BN Type: <class '__main__.NonAffineBatchNorm'>
==> Building first layer with <class '__main__.SubnetConv'>
==> Setting prune rate of network to 0.3
==> Setting prune rate of conv1 to 0.3
==> Setting prune rate of layer1.0.conv1 to 0.3
==> Setting prune rate of layer1.0.conv2 to 0.3
==> Setting prune rate of layer1.1.conv1 to 0.3
==> Setting prune rate of layer1.1.conv2 to 0.3
==> Setting prune rate of layer2.0.conv1 to 0.3
==> Setting prune rate of layer2.0.conv2 to 0.3
==> Setting prune rate of layer2.0.shortcut.0 to 0.3
==> Setting prune rate of layer2.1.conv1 to 0.3
==> Setting prune rate of layer2.1.conv2 to 0.3
==> Setting prune rate of layer3.0.conv1 to 0.3
==> Setting prune rate of layer3.0.conv2 to 0.3
==> Setting prune rate of layer3.0.shortcut.0 to 0.3
==> Setting prune rate of layer3.1.conv1 to 0.3
==> Setting prune rate of layer3.1.conv2 to 0.3
==> Setting prune rate of layer4.0.conv1 to 0.3
==> Setting prune rate of layer4.0

In [13]:
check_weights_change(initial_weights, model)

Weihts changed!


In [14]:
check_scores_change(initial_scores, model)

No scores changed.


In [15]:
# Move model to the specified device
model.to(device)
model.eval()  # Set the model to evaluation mode

correct = 0
total = 0
with torch.no_grad():  # No need to track gradients
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')


Accuracy of the network on the 10000 test images: 91.76%
