<a href="https://colab.research.google.com/github/MuhammadUmairHaider/Neural-Network-Pruning-Through-Constrained-Reinforcement-Learning/blob/main/PyTorch_Compression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from utils import progress_bar
import os
from torch.autograd import Variable
import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


cuda
Files already downloaded and verified
Files already downloaded and verified


In [2]:
class Gates(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.size = size
        #filter = torch.ones(size)#(1,size,1,1)
        f = torch.from_numpy(np.ones(size,np.single))
        self.weight = nn.Parameter(f)
        self.epsilon = 0.1
    def forward(self, x):
        g = self.weight**2/(self.weight**2+self.epsilon)
        return x*torch.reshape(g,(1,self.size,1,1))#torch.reshape(self.filter,(1,self.size,1,1))
    def get_g(self):
        return self.weight**2/(self.weight**2+self.epsilon)
    def get_epsilon(self):
        return self.get_epsilon
    def set_epsilon(self, e):
        self.weight = e        

In [3]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           Gates(x),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

net = VGG('VGG16')
net = net.to(device)

In [4]:
gates_group = []
rem_group = []
for m in net.modules():
    if isinstance(m, Gates):
        gates_group.append(m.weight)
        # m.epsilon = 0.000001
        # print(m.epsilon)
    elif(isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear)):
        #print(m.out_channels)
        rem_group.append(m.weight)
        if m.bias is not None:
            rem_group.append(m.bias)

In [None]:
gates_group

In [5]:
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(
    [
        {"params": gates_group, "lr": 0.0001},
        {"params": rem_group, "lr": 0.001},
    ]
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [6]:
def g_val(w,e):
  print(w**2/(w**2+e),"\n")

In [7]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = Variable(inputs, requires_grad=False)
        targets = Variable(targets)
        net.zero_grad()
        outputs = net(inputs)
        reg_loss1 = None
        reg_loss2 = None
        for m in net.modules():
            if isinstance(m, Gates):
                #reg_loss += l1_crit(m.filter,torch.zeros_like(m.filter))/m.size
                if reg_loss1 is None:
                    reg_loss1 = m.weight.abs().sum()/(m.size*500)#torch.norm(m.weight, p=1)/m.size
                else:
                    reg_loss1 += m.weight.abs().sum()/(m.size*500)#torch.norm(m.weight, p=1)/m.size
            elif(isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear)):
                if reg_loss2 is None:
                    reg_loss2 = (torch.norm(m.weight, p=2))*0.05
                else:
                    reg_loss2 += (torch.norm(m.weight, p=2))*0.05
        
        loss1 = reg_loss1
        loss2 = criterion(outputs, targets)
        # print(loss1, "lll",loss2)
        loss = loss2 + loss1
        # print(loss1.item(),loss2.item())
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        if(batch_idx % 2000 == 0):
          print(100.*correct/total)
          for m in net.modules():
            if isinstance(m, Gates):
              print(m.weight[:27])
              g_val(m.weight[:27],m.epsilon)
              # break


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            if(batch_idx % 20 == 0):
              print(100.*correct/total)

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

In [10]:
start_epoch = 0
best_acc = 0
for epoch in range(start_epoch, start_epoch+350):
    train(epoch)
    test(epoch)
    for m in net.modules():
      if isinstance(m, Gates):
        m.epsilon = m.epsilon*0.96
        k = m

    print(k.epsilon)
    scheduler.step()


Epoch: 0
58.59375
tensor([0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527,
        0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527,
        0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527],
       device='cuda:0', grad_fn=<SliceBackward0>)
tensor([0.9044, 0.9044, 0.9044, 0.9044, 0.9044, 0.9044, 0.9044, 0.9044, 0.9044,
        0.9044, 0.9044, 0.9044, 0.9044, 0.9044, 0.9044, 0.9044, 0.9044, 0.9043,
        0.9044, 0.9044, 0.9044, 0.9043, 0.9044, 0.9043, 0.9043, 0.9044, 0.9044],
       device='cuda:0', grad_fn=<DivBackward0>) 

tensor([0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527,
        0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527,
        0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527, 0.9527],
       device='cuda:0', grad_fn=<SliceBackward0>)
tensor([0.9044, 0.9043, 0.9043, 0.9043, 0.9044, 0.9043, 0.9044, 0.9043, 0.9044,
        0.9044, 0.9044, 0.90

KeyboardInterrupt: ignored