In [None]:
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(root='data',train=True,download=True,transform=ToTensor())
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.3MB/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 494kB/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.33MB/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 2.89MB/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [None]:
from torch.utils.data  import DataLoader
import torch


loaders={
    'train' : torch.utils.data.DataLoader(train_data,batch_size=128,shuffle=True),
    'test'  : torch.utils.data.DataLoader(test_data,batch_size=4096)
}
loaders


{'train': <torch.utils.data.dataloader.DataLoader at 0x79f8df625bd0>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x79f8df625720>}

In [None]:
import torch.nn as nn

class Model(nn.Module):

  def __init__(self):
    super().__init__()
    self.classifier = nn.Sequential(
        nn.Linear(28*28,1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024,512),
        nn.ReLU(inplace=True),
        nn.Linear(512,256),
        nn.ReLU(inplace=True),
        nn.Linear(256,10)
    )

  def forward(self,x):
    x = torch.flatten(x,1)
    x = self.classifier(x)
    return x

In [None]:
import torch.nn.init as init

def init_weights(module):

  if isinstance(module,nn.Linear):
    init.xavier_normal_(module.weight.data)
    init.normal_(module.bias.data)
  else:
    ValueError



In [None]:
import copy
model = Model()
model.apply(init_weights)
initial_dict = copy.deepcopy(model.state_dict())
optimizer = torch.optim.AdamW(model.parameters(),lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [None]:
def mask_maker(model):

  mask = [None] * sum(1 for name,param in model.named_parameters() if "weight" in name)
  print(mask)
  layer=0
  for name,param in model.named_parameters():
    if "weight" in name:
      tensor = param.data
      mask[layer] = torch.ones_like(tensor)
      layer+=1
  return mask

mask = mask_maker(model)

[None, None, None, None]


In [None]:
def prune_percentile(percent,mask):
  layer = 0
  for name,param in model.named_parameters():
    if 'weight' in name:
      tensor = param.data
      torch_nonzero = torch.nonzero(tensor,as_tuple=True)
      alive = tensor[torch_nonzero]
      percentile_value = torch.quantile(abs(alive), percent).item()
      new_mask = torch.from_numpy(np.where(abs(tensor) < percentile_value, 0, mask[layer]))
      mask[layer] = new_mask
      layer += 1
  return mask


In [None]:
def total_nodes(model):
  total = 0
  for name,param in model.named_parameters():
    if "weight" in name:
      total += torch.count_nonzero(param.data)
  return total

original_nodes = total_nodes(model)
print("Total Nodes:",original_nodes)

Total Nodes: tensor(1460736)


In [None]:
def reset_to_original_init(model,mask,inital_dict):
  layer = 0
  for name,param in model.named_parameters():
    if "weight" in name:
      param.data = initial_dict[name] * mask[layer]
      layer += 1
    if "bias" in name:
      param.data = initial_dict[name]


In [None]:
def reset_mask(mask):
  for step in range(len(mask)):
    new_mask = torch.ones_like(mask[step])
    mask[step] = new_mask

In [None]:
def full_reset(model,mask,initial_dict):
  reset_mask(mask)
  reset_to_original_init(model,mask,initial_dict)

In [None]:
full_reset(model,mask,initial_dict)

In [None]:
from torch.autograd import Variable

def train_prune(model,loaders,loss_func):

   EPS = 1e-6
   size = len(loaders['train'].dataset)
   for batch_dix,(imgs,targets) in enumerate(loaders['train']):
    optimizer.zero_grad()
    pred = model(imgs)
    train_loss = loss_func(pred,targets)
    train_loss.backward()

    for name,param in model.named_parameters():
      if "weight" in name:
        tensor = param.data
        grad_tensor = param.grad.data
        grad_tensor = torch.where(tensor<EPS,0,grad_tensor)
        param.grad.data = grad_tensor
    optimizer.step()

    if batch_dix % 100 == 0:
      loss,current = train_loss.item(),batch_dix*len(imgs)
      print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")



In [None]:
def test(model,loaders,loss_func):
  test_dataloader = loaders['test']
  size = len(test_dataloader.dataset)
  num_batches = len(test_dataloader)
  test_loss,correct = 0,0

  with torch.no_grad():
    for imgs,targets in test_dataloader:
      pred = model(imgs)
      test_loss += loss_func(pred,targets).item()
      correct += (pred.argmax(1) == targets).type(torch.float).sum().item()
  test_loss /= num_batches

  correct /= size
  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
test(model,loaders,loss_func)
nodes = total_nodes(model)
print(f"Number of nodes:{nodes}")

Test Error: 
 Accuracy: 9.7%, Avg loss: 3.574221 

Number of nodes:1460736


In [None]:
epochs = 10
prune_percent = 0.5
iterations = 1
import numpy as np


def iterative_prune_train(model,mask,loss_func,iterations,percent):
  for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    model.train()
    for t in range(iterations):
      print(f"Iteration {t+1}\n-------------------------------")
      train_prune(model,loaders,loss_func)
      test(model,loaders,loss_func)
    mask = prune_percentile(percent,mask)
    reset_to_original_init(model,mask,initial_dict)
    print(f"\n--- Pruning Level [{epoch+1}/{epochs}]: ---")

iterative_prune_train(model, mask, loss_func, iterations, prune_percent)




Epoch 1
-------------------------------
Iteration 1
-------------------------------
loss: 0.001891 [    0/60000]
loss: 0.004102 [12800/60000]
loss: 0.002325 [25600/60000]
loss: 0.000788 [38400/60000]
loss: 0.002691 [51200/60000]
Test Error: 
 Accuracy: 98.2%, Avg loss: 0.063005 


--- Pruning Level [1/20]: ---
Epoch 2
-------------------------------
Iteration 1
-------------------------------
loss: 2.280112 [    0/60000]
loss: 0.226206 [12800/60000]
loss: 0.291351 [25600/60000]
loss: 0.294636 [38400/60000]
loss: 0.365459 [51200/60000]
Test Error: 
 Accuracy: 94.8%, Avg loss: 0.162615 


--- Pruning Level [2/20]: ---
Epoch 3
-------------------------------
Iteration 1
-------------------------------
loss: 2.223068 [    0/60000]
loss: 0.319968 [12800/60000]
loss: 0.244551 [25600/60000]
loss: 0.230020 [38400/60000]
loss: 0.166420 [51200/60000]
Test Error: 
 Accuracy: 95.2%, Avg loss: 0.154853 


--- Pruning Level [3/20]: ---
Epoch 4
-------------------------------
Iteration 1
------------

In [None]:
# loss_func = nn.CrossEntropyLoss()
# full_reset(model, mask, initial_dict)
test(model, loaders, loss_func)
nodes = total_nodes(model)
# print(f"Accuracy: {acc:.3f}")
print(f"Number of nodes: {nodes}")
print(f"Percent of nodes left: {(nodes / original_nodes):.3f}")

Test Error: 
 Accuracy: 11.3%, Avg loss: 2.345188 

Number of nodes: 2989
Percent of nodes left: 0.002


In [None]:
full_reset(model, mask, initial_dict)

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
epochs = 100
model.train()
for t in range(epochs) :
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(loaders["train"], model, loss_func, optimizer)
    test(model, loaders, loss_func)
print("Done.")

Epoch 1
-------------------------------
loss: 2.924125  [    0/60000]
loss: 0.212223  [12800/60000]
loss: 0.284626  [25600/60000]
loss: 0.047097  [38400/60000]
loss: 0.114272  [51200/60000]
Test Error: 
 Accuracy: 96.6%, Avg loss: 0.098209 

Epoch 2
-------------------------------
loss: 0.086463  [    0/60000]
loss: 0.063238  [12800/60000]
loss: 0.144997  [25600/60000]
loss: 0.062582  [38400/60000]
loss: 0.057242  [51200/60000]
Test Error: 
 Accuracy: 97.5%, Avg loss: 0.076414 

Epoch 3
-------------------------------
loss: 0.013616  [    0/60000]
loss: 0.035442  [12800/60000]
loss: 0.030279  [25600/60000]
loss: 0.011119  [38400/60000]
loss: 0.035215  [51200/60000]
Test Error: 
 Accuracy: 97.6%, Avg loss: 0.075114 

Epoch 4
-------------------------------
loss: 0.030664  [    0/60000]
loss: 0.010130  [12800/60000]
loss: 0.023699  [25600/60000]
loss: 0.015874  [38400/60000]
loss: 0.101203  [51200/60000]
Test Error: 
 Accuracy: 97.8%, Avg loss: 0.068019 

Epoch 5
------------------------