In [3]:
import torch

print(torch.__version__)

2.1.0+cu121


In [4]:
import torch.nn as nn
import torch.nn.functional as F
import time
import pdb
from os import getcwd
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms


In [5]:
# Loading CIFAR10 train and test dataset
def prep_dataloaders():

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([transforms.ToTensor()])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True,
                                            transform=train_transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                              shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True,
                                           transform=test_transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                             shuffle=False)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    return trainloader, testloader, classes
trainloader, testloader, classes = prep_dataloaders()

# Print the results
# print("Trainloader:", trainloader)
# print("Testloader:", testloader)
# print("Classes:", classes)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 41627500.22it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [6]:
# Defining the model to be used for training
# MaskedConv2d is a subclass of nn.Conv2d for a masking mechanism
class MaskedConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=False):
        super(MaskedConv2d, self).__init__(in_channels, out_channels,
                                           kernel_size, stride, padding,
                                           dilation, groups, bias)
        self.masked_channels = []
        self.mask_flag = False
        self.masks = None

    def forward(self, x):
        if self.mask_flag:
            self._expand_masks(x.size())
            weight = self.weight * self.masks
            return F.conv2d(x, weight, self.bias, self.stride, self.padding,
                            self.dilation, self.groups)
        else:
            return F.conv2d(x, self.weight, self.bias, self.stride,
                            self.padding, self.dilation, self.groups)

    def set_masked_channels(self, masked_channels):
        self.masked_channels = masked_channels
        self.mask_flag = len(masked_channels) > 0

    def get_masked_channels(self):
        return self.masked_channels

    def _expand_masks(self, input_size):
        if not self.masked_channels:
            self.masks = None
        else:
            batch_size, _, height, width = [int(input_size[i].item()) for i in range(4)]
            masks = torch.ones((len(self.masked_channels), batch_size, height, width), device=self.weight.device)
            self.masks = Variable(masks, requires_grad=False)

In [7]:
#  Neural network architecture
class CustomNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomNet, self).__init__()
        self.conv1_1 = MaskedConv2d(3, 64, 3, padding=1)
        self.conv2_1 = MaskedConv2d(64, 128, 3, padding=1)
        self.conv3_1 = MaskedConv2d(128, 256, 3, padding=1)

        self.fc1 = nn.Linear(4096, 4096)
        self.fc2 = nn.Linear(4096, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        out = F.relu(self.conv1_1(x))
        out = F.max_pool2d(out, 2)

        out = F.relu(self.conv2_1(out))
        out = F.max_pool2d(out, 2)

        out = F.relu(self.conv3_1(out))
        out = F.max_pool2d(out, 2)

        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        return self.softmax(out)

In [8]:
def train(model, optimizer, criterion, trainloader):
    """A single training iteration"""
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Batch loop
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        # A single optimization step
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # Update loss and accuracy info
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print("Train: Loss: %.3f, Acc: %.3f (%d/%d)" % (train_loss / (batch_idx + 1), correct / total * 100., correct, total))

    return correct / total


def test(model, optimizer, criterion, testloader):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Batch loop
    for batch_idx, (inputs, targets) in enumerate(testloader):
        inputs, targets = inputs.to(device), targets.to(device)

        # A single test iteration
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        test_loss += loss.item()

        # Update loss and accuracy info
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print("Test: Loss: %.3f, Acc: %.3f (%d/%d)" % (test_loss / (batch_idx + 1), correct / total * 100., correct, total))

    return correct / total





In [19]:

# Set up model, optimizer and loss
model = CustomNet(len(classes))

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# set learning rate = 0.001
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

# print(criterion)

# scale (multiply) the learning rate of an optimizer for adjusting the learning rate during training
def scale_lr(optimizer, scale):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= scale


In [21]:
best_train_acc = 0
best_test_acc = 0
# Number of training epochs for pretraining
num_pretrain_epochs = 10


# Loop over training epochs
for epoch in range(num_pretrain_epochs):

    print("========== epoch %d" % (epoch))

    # Update learning rate of optimizers regularly
    if epoch % 50 == 0:
          scale_lr(optimizer, 0.1)

    # Train
    tic = time.time()
    train_acc = train(model, optimizer, criterion, trainloader)
    print("Train Time: %.3f" % (time.time() - tic))
    if train_acc > best_train_acc:
          best_train_acc = train_acc

    # Evaluate
    tic = time.time()
    test_acc = test(model, optimizer, criterion, trainloader)
    print("Test Time: %.3f" % (time.time() - tic))
    if test_acc > best_test_acc:
            best_test_acc = test_acc

print("Best Training Accuracy: %.3f%%" % (best_train_acc * 100.))
print("Best Test Accuracy: %.3f%%" % (best_test_acc * 100.))

# Save model and optimizer as checkpoint
save_dir = getcwd() + "/saved_model.pth"
save_data = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
torch.save(save_data, save_dir)

Train: Loss: 1.857, Acc: 61.460 (30730/50000)
Train Time: 21.929
Test: Loss: 1.853, Acc: 61.762 (30881/50000)
Test Time: 16.482
Train: Loss: 1.851, Acc: 61.894 (30947/50000)
Train Time: 21.761
Test: Loss: 1.851, Acc: 61.958 (30979/50000)
Test Time: 16.474
Train: Loss: 1.851, Acc: 61.988 (30994/50000)
Train Time: 21.605
Test: Loss: 1.849, Acc: 62.308 (31154/50000)
Test Time: 16.359
Train: Loss: 1.850, Acc: 62.150 (31075/50000)
Train Time: 21.632
Test: Loss: 1.847, Acc: 62.402 (31201/50000)
Test Time: 16.540
Train: Loss: 1.848, Acc: 62.454 (31227/50000)
Train Time: 21.570
Test: Loss: 1.846, Acc: 62.542 (31271/50000)
Test Time: 16.123
Train: Loss: 1.846, Acc: 62.552 (31276/50000)
Train Time: 21.432
Test: Loss: 1.845, Acc: 62.682 (31341/50000)
Test Time: 16.877
Train: Loss: 1.845, Acc: 62.810 (31405/50000)
Train Time: 21.653
Test: Loss: 1.848, Acc: 62.366 (31183/50000)
Test Time: 16.322
Train: Loss: 1.846, Acc: 62.588 (31294/50000)
Train Time: 21.636
Test: Loss: 1.843, Acc: 63.048 (31524/5

In [42]:

class Params():
    def __init__(self):

        # Number of training epochs for sparse learning
        self.num_sparse_train_epochs = 10

        # Learning rate for optimizer
        self.learning_rate = 0.001

        # Hyperparameters for structured sparsity learning
        self.ssl_hyperparams = {
            "wgt_decay": 5e-4,
            "lambda_n": 5e-2,
            "lambda_c": 5e-2,
            "lambda_s": 5e-2,
        }

        # Threshold below which a weight value should be counted as too low
        self.threshold = 1e-5


In [54]:
# SSL can be applied on the pretrained model

def group_lasso(param_group):
    return torch.sum(param_group ** 2)


def cross_entropy_loss(outputs, targets):
    """Cross-entropy loss"""
    ce_loss_func = nn.CrossEntropyLoss()
    return ce_loss_func(outputs, targets)


def filter_and_channel_wise_ssl_loss(model, outputs, targets, params):
    """
    Penalizing unimportant filters and channels

    Params:
        - pretrained model
        - output tensor generated by the forward step of model
        - target output tensor corresponding to the input (for CE loss)
        - hyperparameters which contain entry for weight decay, lambda_n
        (for filter-wise group LASSO), and lambda_c (for channel-wise group LASSO)

    Return:
        - Loss of "Learning Structured Sparsity in Deep Neural Networks"
    """

    # Compute cross-entropy loss
    ce_loss = cross_entropy_loss(outputs, targets)

    # Loss accumulators
    wgt_l2_norm = torch.Tensor([0.])
    filter_wise_loss = torch.Tensor([0.])
    channel_wise_loss = torch.Tensor([0.])

# Check for GPU availability and move tensors accordingly
    if torch.cuda.is_available():
        wgt_l2_norm = wgt_l2_norm.cuda()
        filter_wise_loss = filter_wise_loss.cuda()
        channel_wise_loss = channel_wise_loss.cuda()

    # Coefficient hyperparams
    hyperparams = params.ssl_hyperparams
    wgt_decay = hyperparams["wgt_decay"]
    lambda_n = hyperparams["lambda_n"]
    lambda_c = hyperparams["lambda_c"]

    # Iterate over every layer
    params = list(model.parameters())
    for param in params:
        # L2 norm over entire parameters
        wgt_l2_norm += torch.norm(param)

        # Ignore linear or bias parameters
        if len(param.size()) != 4:
            continue

        num_filters, num_channels = param.size()[0], param.size()[1]

        # Group LASSO over filters of current layer
        for filter_idx in range(num_filters):
            filter_wise_loss += group_lasso(param[filter_idx, :, :, :])

        # Group LASSO over channels of current layer
        for channel_idx in range(num_channels):
            channel_wise_loss += group_lasso(param[:, channel_idx, :, :])

    return ce_loss + (wgt_decay * wgt_l2_norm) + (lambda_n * \
           filter_wise_loss)+ (lambda_c * channel_wise_loss)


def prep_dataloaders():
    """Loads CIFAR10 train and test dataset"""
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([transforms.ToTensor()])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True,
                                            transform=train_transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                              shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True,
                                           transform=test_transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                             shuffle=False)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    return trainloader, testloader, classes



def train_ssl(model, optimizer, ssl_loss_func, trainloader, params):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
    # Check for GPU availability and move tensors accordingly
     if torch.cuda.is_available():
        inputs = inputs.cuda()
        targets = targets.cuda()

    optimizer.zero_grad()
    outputs = model(inputs)
    loss = ssl_loss_func(model, outputs, targets, params)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()


    print("Train: Loss: %.3f, Acc: %.3f (%d/%d)" % (train_loss / \
          (batch_idx + 1), correct / total * 100., correct, total))

    return correct / total


def test(model, optimizer, criterion, testloader):
   model.eval()
   test_loss = 0
   correct = 0
   total = 0

   for batch_idx, (inputs, targets) in enumerate(testloader):
    # Check for GPU availability and move tensors accordingly
    if torch.cuda.is_available():
        inputs = inputs.cuda()
        targets = targets.cuda()

    outputs = model(inputs)
    loss = criterion(outputs, targets)
    test_loss += loss.item()

    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()


    print("Test: Loss: %.3f, Acc: %.3f (%d/%d)" % (test_loss / \
          (batch_idx + 1), correct / total * 100., correct, total))

    return correct / total


def count_sparse_wgt(model, threshold):
    weight_cnt = 0
    sparse_weight_cnt = 0
    with torch.no_grad():
        for param_key in model.state_dict():
            param_tensor = model.state_dict()[param_key]
            dims = 1
            for dim in list(param_tensor.size()):
                dims *= dim
            weight_cnt += dims
            sparse_weight_cnt += torch.sum(param_tensor < threshold).item()
    return weight_cnt, sparse_weight_cnt


def count_sparse_wgt_by_layer(model, threshold):
    wgt_cnts = []
    sparse_wgt_cnts = []
    with torch.no_grad():
        for param_key in model.state_dict():
            param_tensor = model.state_dict()[param_key]
            dims = 1
            for dim in list(param_tensor.size()):
                dims *= dim
            wgt_cnts.append((param_key, dims))
            sparse_wgt_cnt_layer = torch.sum(param_tensor < threshold).item()
            sparse_wgt_cnts.append((param_key, sparse_wgt_cnt_layer))
    return wgt_cnts, sparse_wgt_cnts


def count_sparse_wgt_by_filter(model, threshold):
    sparse_wgt_cnts = []
    with torch.no_grad():
        for param_key in model.state_dict():
            param_tensor = model.state_dict()[param_key]
            if len(param_tensor.size()) != 4:
                sparse_wgt_cnts.append((param_key, None))
                continue
            num_filters = param_tensor.size()[0]
            sparse_wgt_cnts_by_filter = []
            for filter_idx in range(num_filters):
                cnt = torch.sum(param_tensor[filter_idx, :, :, :] < \
                                threshold).item()
                sparse_wgt_cnts_by_filter.append(cnt)
            sparse_wgt_cnts.append((param_key, sparse_wgt_cnts_by_filter))
    return sparse_wgt_cnts


def count_sparse_wgt_by_channel(model, threshold):
    sparse_wgt_cnts = []
    with torch.no_grad():
        for param_key in model.state_dict():
            param_tensor = model.state_dict()[param_key]
            if len(param_tensor.size()) != 4:
                sparse_wgt_cnts.append((param_key, None))
                continue
            num_channels = param_tensor.size()[1]
            sparse_wgt_cnts_by_channel = []
            for channel_idx in range(num_channels):
                cnt = torch.sum(param_tensor[:, channel_idx, :, :] < \
                                threshold).item()
                sparse_wgt_cnts_by_channel.append(cnt)
            sparse_wgt_cnts.append((param_key, sparse_wgt_cnts_by_channel))
    return sparse_wgt_cnts


def print_sparse_weights(model, threshold):
    wgt_cnt, sparse_wgt_cnt = count_sparse_wgt(model, threshold)
    print("\nTotal sparse weights: %.3f (%d/%d)" % (100. * sparse_wgt_cnt / \
          wgt_cnt, sparse_wgt_cnt, wgt_cnt))

    wgt_cnts, sparse_wgt_cnts = count_sparse_wgt_by_layer(model, threshold)
    print("\nSparse weight by layer")
    for idx in range(len(wgt_cnts)):
        layer_name = wgt_cnts[idx][0]
        wgt_cnt = wgt_cnts[idx][1]
        sparse_wgt_cnt = sparse_wgt_cnts[idx][1]
        print("Layer: {}, {} ({}/{})".format(layer_name, sparse_wgt_cnt / \
              wgt_cnt, sparse_wgt_cnt, wgt_cnt))

    sparse_wgt_cnts = count_sparse_wgt_by_filter(model, threshold)
    print("\nSparse weight by filter")
    for idx in range(len(sparse_wgt_cnts)):
        layer_name = sparse_wgt_cnts[idx][0]
        wgts_filters = sparse_wgt_cnts[idx][1]
        print("Layer: {}, {}".format(layer_name, wgts_filters))



params = Params()

trainloader, testloader, classes = prep_dataloaders()
model = CustomNet(len(classes))
if torch.cuda.is_available():
    model = model.cuda()
# optimizer = torch.optim.SGD(model.parameters(), lr=params.learning_rate, momentum=0.9,
#                 weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=params.learning_rate)
criterion = nn.CrossEntropyLoss()

checkpoint = torch.load(getcwd() + "/saved_model.pth")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

print_sparse_weights(model, params.threshold)

best_train_acc = 0
best_test_acc = 0

for epoch in range(params.num_sparse_train_epochs):
        print("========== epoch %d" % (epoch))
        # if epoch % 50 == 0:
        #     scale_lr(optimizer, 0.1)


        ssl_loss_func = filter_and_channel_wise_ssl_loss


        tic = time.time()
        train_acc = train_ssl(model, optimizer, ssl_loss_func, trainloader,
                              params)
        print("Train Time: %.3f" % (time.time() - tic))
        if train_acc > best_train_acc:
            best_train_acc = train_acc

        tic = time.time()
        test_acc = test(model, optimizer, criterion, trainloader)
        print("Test Time: %.3f" % (time.time() - tic))
        if test_acc > best_test_acc:
            best_test_acc = test_acc

print("Best Training Accuracy: %.3f%%" % (best_train_acc * 100.))
print("Best Test Accuracy: %.3f%%" % (best_test_acc * 100.))

print_sparse_weights(model, params.threshold)
test_acc = test(model, optimizer, criterion, trainloader)
print("Final test accuracy: {}".format(test_acc * 100.))

Files already downloaded and verified
Files already downloaded and verified

Total sparse weights: 74.103 (12740248/17192650)

Sparse weight by layer
Layer: conv1_1.weight, 0.47858796296296297 (827/1728)
Layer: conv2_1.weight, 0.5255940755208334 (38751/73728)
Layer: conv3_1.weight, 0.5858629014756944 (172778/294912)
Layer: fc1.weight, 0.7453654408454895 (12505157/16777216)
Layer: fc1.bias, 0.685302734375 (2807/4096)
Layer: fc2.weight, 0.486376953125 (19922/40960)
Layer: fc2.bias, 0.6 (6/10)

Sparse weight by filter
Layer: conv1_1.weight, [12, 13, 14, 7, 12, 12, 15, 10, 14, 11, 13, 14, 14, 12, 13, 13, 14, 15, 12, 12, 14, 13, 13, 13, 14, 14, 15, 13, 14, 11, 12, 15, 12, 5, 13, 16, 12, 13, 13, 14, 14, 12, 10, 11, 27, 14, 14, 11, 12, 15, 13, 14, 7, 13, 11, 15, 12, 13, 14, 12, 10, 16, 13, 13]
Layer: conv2_1.weight, [296, 576, 332, 379, 309, 319, 279, 214, 384, 345, 358, 309, 355, 277, 287, 274, 296, 282, 248, 316, 285, 297, 223, 297, 248, 326, 316, 377, 230, 264, 254, 290, 316, 309, 328, 273