In [1]:
# Specify GPU to be used-
# %env CUDA_DEVICE_ORDER = PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES = 0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os, pickle
from tqdm import tqdm
from tqdm import trange

  warn(f"Failed to load image Python extension: {e}")
Matplotlib is building the font cache; this may take a moment.


In [3]:
print(f"torch version: {torch.__version__}")


# Check if there are multiple devices (i.e., GPU cards)-
print(f"Number of GPU(s) available = {torch.cuda.device_count()}")

if torch.cuda.is_available():
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"Current GPU name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    print("PyTorch does not have access to GPU")

# Device configuration-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Available device is {device}')

torch version: 1.13.1
Number of GPU(s) available = 1
Current GPU: 0
Current GPU name: NVIDIA GeForce GTX TITAN X
Available device is cuda


In [4]:
# Specify hyper-parameters
batch_size = 4096
num_classes = 10

In [5]:
# MNIST dataset statistics:
# mean = tensor([0.1307]) & std dev = tensor([0.3081])
mean = np.array([0.1307])
std_dev = np.array([0.3081])

transforms_apply = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = mean, std = std_dev)
    ])

In [6]:
# MNIST dataset-
train_dataset = torchvision.datasets.MNIST(
        root = '/home/majumdar/Downloads/.data/', train = True,
        transform = transforms_apply, download = True
        )

test_dataset = torchvision.datasets.MNIST(
        root = '/home/majumdar/Downloads/.data/', train = False,
        transform = transforms_apply
        )

In [7]:
# Create dataloader-
train_loader = torch.utils.data.DataLoader(
        dataset = train_dataset, batch_size = batch_size,
        shuffle = True
        )

test_loader = torch.utils.data.DataLoader(
        dataset = test_dataset, batch_size = batch_size,
        shuffle = False
        )

In [8]:
len(train_dataset), len(test_dataset)

(60000, 10000)

In [9]:
class LeNet300(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Define layers-
        self.fc1 = nn.Linear(in_features = 28 * 28 * 1, out_features = 300)
        self.fc2 = nn.Linear(in_features = 300, out_features = 100)
        self.output = nn.Linear(in_features = 100, out_features = 10)
        
        # self.weights_initialization()
    
    
    def forward(self, x):
        out = F.leaky_relu(self.fc1(x))
        out = F.leaky_relu(self.fc2(out))
        return self.output(out)   
                

In [10]:
# Initialize an instance of LeNet-300-100 dense neural network-
model = LeNet300().to(device)

In [11]:
'''
def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        # nn.init.kaiming_uniform_(m.weight.data, nonlinearity = 'relu')
        nn.init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 1)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 1)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 1)
        
    return None
'''

@torch.no_grad()
def init_weights(m):
    # print(m)
    if type(m) == nn.Conv2d:
        nn.init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            m.bias.fill_(1.0)
    elif type(m) == nn.Linear:
        nn.init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            m.bias.fill_(1.0)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.fill_(1.0)
        if m.bias is not None:
            m.bias.fill_(1.0)
        
    return None


In [12]:
model.apply(init_weights)

LeNet300(
  (fc1): Linear(in_features=784, out_features=300, bias=True)
  (fc2): Linear(in_features=300, out_features=100, bias=True)
  (output): Linear(in_features=100, out_features=10, bias=True)
)

In [13]:
torch.save(model.state_dict(), "LeNet300_random_params.pth")
# torch.save(model.state_dict(), "LeNet5_random_params.pth")

In [14]:
'''
def count_trainable_params(model):
    # Count number of layer-wise parameters and total parameters-
    tot_params = 0
    for param in model.parameters():
        # print(f"layer.shape = {param.shape} has {param.nelement()} parameters")
        tot_params += param.nelement()

    return tot_params

def surviving_params(model):
    tot_params = 0
    for param in model.parameters():
        surviving_params = torch.count_nonzero(param).item()
        tot_params += surviving_params
    
    return tot_params
'''

def count_surviving_params(model):
    tot_params = 0
    for param in model.parameters():
        layer_nonzero_param = torch.count_nonzero(param).item()
        tot_params += layer_nonzero_param
        # print(f"{param.size()} has {layer_nonzero_param} params")
    
    return tot_params


In [15]:
orig_tot_params = count_surviving_params(model = model)

In [16]:
print(f"LeNet-300-100 has {orig_tot_params} non-zero params")

LeNet-300-100 has 266610 non-zero params


In [17]:
# Define loss function and optimizer-
loss = nn.CrossEntropyLoss()

'''
optimizer = torch.optim.SGD(
    params = model.parameters(), lr = 0.001,
    momentum = 0.9, weight_decay = 5e-4
)
'''
optimizer = torch.optim.Adam(
    params = model.parameters(), lr = 0.001
)

In [18]:
images, labels = next(iter(train_loader))
images = images.view(-1, 28 * 28 * 1)
images = images.to(device)
labels = labels.to(device)

pred = model(images)

# Compute loss-
J = loss(pred, labels)
            
# Empty accumulated gradients-
optimizer.zero_grad()
            
# Perform backprop-
J.backward()

# J.item()

# Preserve sparsity while training by zeroing-out gradients before weight update-
for param in model.parameters():
    wts = param.clone().detach()
    gradient_t = param.grad.clone().detach()
    gradient_t = torch.where(wts == 0., 0., gradient_t)
    param.grad = gradient_t
    
# Update trainable parameters-
optimizer.step()

In [23]:
# Sanity check-
for param in model.parameters():
    if len(param.size()) == 2:
        print(f"wts = {torch.count_nonzero(param)} &"
              f" grads = {torch.count_nonzero(param.grad)}"
             )

wts = 235200 & grads = 235200
wts = 30000 & grads = 30000
wts = 1000 & grads = 1000


In [24]:
def train_model_progress(model, train_loader, train_dataset):
    '''
    Function to perform one epoch of training by using 'train_loader'.
    Returns loss and number of correct predictions for this epoch.
    '''
    running_loss = 0.0
    running_corrects = 0.0
    
    model.train()
    
    with tqdm(train_loader, unit = 'batch') as tepoch:
        for images, labels in tepoch:
            tepoch.set_description(f"Training: ")
            
            images = images.view(-1, 28 * 28 * 1)
            
            images = images.to(device)
            labels = labels.to(device)
            
            # Get model predictions-
            outputs = model(images)
            
            # Compute loss-
            J = loss(outputs, labels)
            
            # Empty accumulated gradients-
            optimizer.zero_grad()
            
            # Perform backprop-
            J.backward()
            
            for param in model.parameters():
                wts = param.clone().detach()
                gradient_t = param.grad.clone().detach()
                gradient_t = torch.where(wts == 0., 0., gradient_t)
                param.grad = gradient_t
            
            # Update parameters-
            optimizer.step()
            
            '''
            global step
            optimizer.param_groups[0]['lr'] = custom_lr_scheduler.get_lr(step)
            step += 1
            '''
            
            # Compute model's performance statistics-
            running_loss += J.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            running_corrects += torch.sum(predicted == labels.data)
            
            tepoch.set_postfix(
                loss = running_loss / len(train_dataset),
                accuracy = (running_corrects.double().cpu().numpy() / len(train_dataset)) * 100
            )
            
    
    train_loss = running_loss / len(train_dataset)
    train_acc = (running_corrects.double() / len(train_dataset)) * 100

    return train_loss, train_acc.cpu().numpy()


In [25]:
def test_model_progress(model, test_loader, test_dataset):
    total = 0.0
    correct = 0.0
    running_loss_val = 0.0
    
    # Set model to evaluation mode-
    model.eval()

    with torch.no_grad():
        with tqdm(test_loader, unit = 'batch') as tepoch:
            for images, labels in tepoch:
                tepoch.set_description(f"Validation: ")
                
                images = images.view(-1, 28 * 28 * 1)
                
                images = images.to(device)
                labels = labels.to(device)
            
                # Predict using trained model-
                outputs = model(images)
                _, y_pred = torch.max(outputs, 1)
                
                # Compute validation loss-
                J_val = loss(outputs, labels)
                
                running_loss_val += J_val.item() * labels.size(0)
    
                # Total number of labels-
                total += labels.size(0)

                # Total number of correct predictions-
                correct += (y_pred == labels).sum()
                
                tepoch.set_postfix(
                    val_loss = running_loss_val / len(test_dataset),
                    val_acc = 100 * (correct.cpu().numpy() / total)
                )
            
        
    # return (running_loss_val, correct, total)
    val_loss = running_loss_val / len(test_dataset)
    val_acc = (correct / total) * 100

    return val_loss, val_acc.cpu().numpy()
 

In [26]:
def train_until_convergence(
    model, train_loader,
    test_loader, train_dataset,
    test_dataset, num_epochs
):
    
    best_test_acc = 0
    
    # Python3 dict to contain training metrics-
    training_history_lr_scheduler = {}
    
    for epoch in range(1, num_epochs + 1):
        loss_train, acc_train = train_model_progress(
            model = model, train_loader = train_loader,
            train_dataset = train_dataset
        )
        
        loss_test, acc_test = test_model_progress(
            model = model, test_loader = test_loader,
            test_dataset = test_dataset
        )
        
        curr_lr = optimizer.param_groups[0]['lr']
        
        print(f"{epoch}; loss = {loss_train:.4f}, test loss = {loss_test:.4f}"
              f", acc = {acc_train:.2f}%, test acc = {acc_test:.2f}%,"
              f" non-0 params = {count_surviving_params(model)}"
              # f" & lr = {curr_lr:.6f}"
             )
        
        # Compute sparsity-
        sparsity = ((orig_tot_params - count_surviving_params(model)) / orig_tot_params) * 100
                   
        training_history_lr_scheduler[epoch + 1] = {
            'loss': loss_train, 'acc': acc_train,
            'test_loss': loss_test, 'test_acc': acc_test,
            # 'lr': curr_lr
        }

        # Save best weights achieved until now-
        if (acc_test > best_test_acc):    
            # update 'best_val_loss' variable to lowest loss encountered so far-
            best_test_acc = acc_test

            print(f"Saving model with highest test acc = {acc_test:.2f}%\n")

            # Save trained model with 'best' validation accuracy-
            torch.save(model.state_dict(), f"LeNet300_bestmodel_{sparsity:.2f}.pth")
            # torch.save(model.state_dict(), f"LeNet5_bestmodel_{sparsity:.2f}.pth")
        
        
    return training_history_lr_scheduler


In [27]:
def prune_globally(model, pruning_percentile = 20):
    # Python 3 list to hold layer-wise weights-
    pruned_weights = []
    
    for param in model.parameters():
        wts = np.copy(param.detach().cpu().numpy())
        pruned_weights.append(wts)
    
    del param, wts
    
    # Flatten all numpy arrays-
    pruned_weights_flattened = [layer.flatten() for layer in pruned_weights]

    threshold = np.percentile(a = abs(np.concatenate(pruned_weights_flattened)), q = pruning_percentile)
    # print("\nFor p = {0:.2f}% of weights to be pruned, threshold = {1:.4f}\n".format(p, threshold))
    
    # Prune conv and dense layers-
    # bias and batch-norm is NOT pruned.
    for layer in pruned_weights:
        if len(layer.shape) == 4:
            layer[abs(layer) < threshold] = 0
        elif len(layer.shape) == 2:
            layer[abs(layer) < threshold] = 0
    
    
    i = 0
    model_d = dict()

    for name, params in model.named_parameters():
        if pruned_weights[i].shape == params.shape:
            model_d[name] = torch.from_numpy(pruned_weights[i])

        i += 1
        
        
    state_d = model.state_dict()

    for layer_name in model_d:
        # if pruned_model.state_dict().get(layer_name) is not None:
        if state_d.get(layer_name) is not None:
            # print(layer_name)
            state_d[layer_name] = model_d.get(layer_name)

    model.load_state_dict(state_d)
        
    return None



In [28]:
def compute_pruning_percentage_iterative_rounds(orig_tot_params):
    """
    Compute percentage of parameters to prune globally in
    each iterative pruning round.

    Current hardcoded pruning percentage = 20%.
    """

    surviving_params = orig_tot_params

    sparsity_percentage = []
    while surviving_params >= 0.005 * orig_tot_params:
        surviving_params = 0.8 * surviving_params
        sparsity_percentage.append((orig_tot_params - surviving_params) / orig_tot_params * 100)

    sparsity_percentage = np.asarray(sparsity_percentage)
    sparsity_percentage = np.round(a = sparsity_percentage, decimals = 3)

    return sparsity_percentage


In [29]:
sparsity_percentage = compute_pruning_percentage_iterative_rounds(orig_tot_params)
print(f"number of iterative pruning rounds = {sparsity_percentage.size}")

number of iterative pruning rounds = 24


In [30]:
sparsity_percentage

array([20.   , 36.   , 48.8  , 59.04 , 67.232, 73.786, 79.028, 83.223,
       86.578, 89.263, 91.41 , 93.128, 94.502, 95.602, 96.482, 97.185,
       97.748, 98.199, 98.559, 98.847, 99.078, 99.262, 99.41 , 99.528])

In [31]:
# Define number of training epochs-
num_training_epochs = 20

In [32]:
# Train until convergence randomly initialized model-
history_pr = train_until_convergence(
    model = model, train_loader = train_loader,
    test_loader = test_loader, train_dataset = train_dataset,
    test_dataset = test_dataset, num_epochs = num_training_epochs
)

Training: : 100%|█████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:11<00:00,  1.34batch/s, accuracy=60.7, loss=1.39]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.70batch/s, val_acc=82.1, val_loss=0.582]


1; loss = 1.3937, test loss = 0.5815, acc = 60.69%, test acc = 82.11%, non-0 params = 266610
Saving model with highest test acc = 82.11%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:11<00:00,  1.35batch/s, accuracy=85.4, loss=0.478]
Validation: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.59batch/s, val_acc=89, val_loss=0.377]


2; loss = 0.4776, test loss = 0.3775, acc = 85.41%, test acc = 89.04%, non-0 params = 266610
Saving model with highest test acc = 89.04%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.38batch/s, accuracy=89.7, loss=0.346]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.63batch/s, val_acc=91.3, val_loss=0.305]


3; loss = 0.3461, test loss = 0.3052, acc = 89.75%, test acc = 91.28%, non-0 params = 266610
Saving model with highest test acc = 91.28%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.38batch/s, accuracy=91.7, loss=0.286]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.65batch/s, val_acc=92.3, val_loss=0.268]


4; loss = 0.2857, test loss = 0.2679, acc = 91.67%, test acc = 92.33%, non-0 params = 266610
Saving model with highest test acc = 92.33%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.37batch/s, accuracy=92.8, loss=0.249]
Validation: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.69batch/s, val_acc=93, val_loss=0.242]


5; loss = 0.2488, test loss = 0.2420, acc = 92.75%, test acc = 92.97%, non-0 params = 266610
Saving model with highest test acc = 92.97%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:11<00:00,  1.31batch/s, accuracy=93.6, loss=0.222]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.70batch/s, val_acc=93.7, val_loss=0.218]


6; loss = 0.2221, test loss = 0.2182, acc = 93.61%, test acc = 93.66%, non-0 params = 266610
Saving model with highest test acc = 93.66%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=94.3, loss=0.201]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.73batch/s, val_acc=94.1, val_loss=0.201]


7; loss = 0.2006, test loss = 0.2011, acc = 94.26%, test acc = 94.06%, non-0 params = 266610
Saving model with highest test acc = 94.06%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.37batch/s, accuracy=94.8, loss=0.183]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.70batch/s, val_acc=94.5, val_loss=0.187]


8; loss = 0.1828, test loss = 0.1869, acc = 94.80%, test acc = 94.51%, non-0 params = 266610
Saving model with highest test acc = 94.51%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=95.3, loss=0.167]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.61batch/s, val_acc=94.8, val_loss=0.175]


9; loss = 0.1675, test loss = 0.1752, acc = 95.26%, test acc = 94.82%, non-0 params = 266610
Saving model with highest test acc = 94.82%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.38batch/s, accuracy=95.6, loss=0.155]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.67batch/s, val_acc=95.1, val_loss=0.164]


10; loss = 0.1548, test loss = 0.1641, acc = 95.64%, test acc = 95.12%, non-0 params = 266610
Saving model with highest test acc = 95.12%



Training: : 100%|██████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.38batch/s, accuracy=96, loss=0.144]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.64batch/s, val_acc=95.4, val_loss=0.155]


11; loss = 0.1435, test loss = 0.1546, acc = 95.96%, test acc = 95.38%, non-0 params = 266610
Saving model with highest test acc = 95.38%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.40batch/s, accuracy=96.2, loss=0.133]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.74batch/s, val_acc=95.7, val_loss=0.146]


12; loss = 0.1334, test loss = 0.1458, acc = 96.20%, test acc = 95.68%, non-0 params = 266610
Saving model with highest test acc = 95.68%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.36batch/s, accuracy=96.5, loss=0.124]
Validation: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.72batch/s, val_acc=96, val_loss=0.139]


13; loss = 0.1241, test loss = 0.1386, acc = 96.51%, test acc = 96.00%, non-0 params = 266610
Saving model with highest test acc = 96.00%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.40batch/s, accuracy=96.8, loss=0.116]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.73batch/s, val_acc=96.1, val_loss=0.132]


14; loss = 0.1159, test loss = 0.1320, acc = 96.82%, test acc = 96.11%, non-0 params = 266610
Saving model with highest test acc = 96.11%



Training: : 100%|██████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.38batch/s, accuracy=97, loss=0.108]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.65batch/s, val_acc=96.3, val_loss=0.126]


15; loss = 0.1081, test loss = 0.1264, acc = 97.02%, test acc = 96.29%, non-0 params = 266610
Saving model with highest test acc = 96.29%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:11<00:00,  1.36batch/s, accuracy=97.2, loss=0.101]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.71batch/s, val_acc=96.5, val_loss=0.121]


16; loss = 0.1013, test loss = 0.1210, acc = 97.21%, test acc = 96.47%, non-0 params = 266610
Saving model with highest test acc = 96.47%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=97.4, loss=0.095]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.71batch/s, val_acc=96.5, val_loss=0.117]


17; loss = 0.0950, test loss = 0.1167, acc = 97.40%, test acc = 96.50%, non-0 params = 266610
Saving model with highest test acc = 96.50%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.40batch/s, accuracy=97.6, loss=0.0895]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.70batch/s, val_acc=96.7, val_loss=0.112]


18; loss = 0.0895, test loss = 0.1120, acc = 97.56%, test acc = 96.70%, non-0 params = 266610
Saving model with highest test acc = 96.70%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=97.7, loss=0.0843]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.71batch/s, val_acc=96.8, val_loss=0.109]


19; loss = 0.0843, test loss = 0.1085, acc = 97.69%, test acc = 96.82%, non-0 params = 266610
Saving model with highest test acc = 96.82%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:11<00:00,  1.35batch/s, accuracy=97.9, loss=0.0791]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.71batch/s, val_acc=96.9, val_loss=0.107]

20; loss = 0.0791, test loss = 0.1067, acc = 97.86%, test acc = 96.94%, non-0 params = 266610
Saving model with highest test acc = 96.94%






In [33]:
count_surviving_params(model)

266610

In [34]:
sparsity_percentage[0]

20.0

In [35]:
prune_globally(model, pruning_percentile = sparsity_percentage[0])

In [36]:
count_surviving_params(model)

213288

In [37]:
# Train until convergence pruned initialized model-
history_pr = train_until_convergence(
    model = model, train_loader = train_loader,
    test_loader = test_loader, train_dataset = train_dataset,
    test_dataset = test_dataset, num_epochs = 20
)

Training: : 100%|█████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=98, loss=0.0751]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.70batch/s, val_acc=96.9, val_loss=0.103]


1; loss = 0.0751, test loss = 0.1033, acc = 97.96%, test acc = 96.91%, non-0 params = 266610
Saving model with highest test acc = 96.91%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.41batch/s, accuracy=98.1, loss=0.0706]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.70batch/s, val_acc=97.1, val_loss=0.0996]


2; loss = 0.0706, test loss = 0.0996, acc = 98.12%, test acc = 97.06%, non-0 params = 266610
Saving model with highest test acc = 97.06%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.38batch/s, accuracy=98.2, loss=0.0666]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.71batch/s, val_acc=97.1, val_loss=0.0967]


3; loss = 0.0666, test loss = 0.0967, acc = 98.23%, test acc = 97.14%, non-0 params = 266610
Saving model with highest test acc = 97.14%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.38batch/s, accuracy=98.4, loss=0.0618]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.69batch/s, val_acc=97.2, val_loss=0.0939]


4; loss = 0.0618, test loss = 0.0939, acc = 98.43%, test acc = 97.17%, non-0 params = 266610
Saving model with highest test acc = 97.17%



Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.41batch/s, accuracy=98.5, loss=0.058]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.73batch/s, val_acc=97.2, val_loss=0.0925]


5; loss = 0.0580, test loss = 0.0925, acc = 98.50%, test acc = 97.24%, non-0 params = 266610
Saving model with highest test acc = 97.24%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=98.6, loss=0.0546]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.72batch/s, val_acc=97.3, val_loss=0.0904]


6; loss = 0.0546, test loss = 0.0904, acc = 98.61%, test acc = 97.33%, non-0 params = 266610
Saving model with highest test acc = 97.33%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=98.7, loss=0.0515]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.73batch/s, val_acc=97.3, val_loss=0.0889]


7; loss = 0.0515, test loss = 0.0889, acc = 98.69%, test acc = 97.32%, non-0 params = 266610


Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=98.8, loss=0.0484]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.70batch/s, val_acc=97.3, val_loss=0.0875]


8; loss = 0.0484, test loss = 0.0875, acc = 98.76%, test acc = 97.32%, non-0 params = 266610


Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.38batch/s, accuracy=98.8, loss=0.0461]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.72batch/s, val_acc=97.4, val_loss=0.0862]


9; loss = 0.0461, test loss = 0.0862, acc = 98.84%, test acc = 97.38%, non-0 params = 266610
Saving model with highest test acc = 97.38%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.40batch/s, accuracy=98.9, loss=0.0433]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.72batch/s, val_acc=97.5, val_loss=0.085]


10; loss = 0.0433, test loss = 0.0850, acc = 98.91%, test acc = 97.49%, non-0 params = 266610
Saving model with highest test acc = 97.49%



Training: : 100%|█████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=99, loss=0.0408]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.66batch/s, val_acc=97.5, val_loss=0.0834]


11; loss = 0.0408, test loss = 0.0834, acc = 99.01%, test acc = 97.45%, non-0 params = 266610


Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:11<00:00,  1.32batch/s, accuracy=99.1, loss=0.0387]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.65batch/s, val_acc=97.5, val_loss=0.083]


12; loss = 0.0387, test loss = 0.0830, acc = 99.07%, test acc = 97.55%, non-0 params = 266610
Saving model with highest test acc = 97.55%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:11<00:00,  1.35batch/s, accuracy=99.2, loss=0.0366]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.71batch/s, val_acc=97.5, val_loss=0.0817]


13; loss = 0.0366, test loss = 0.0817, acc = 99.15%, test acc = 97.49%, non-0 params = 266610


Training: : 100%|████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.41batch/s, accuracy=99.2, loss=0.035]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.74batch/s, val_acc=97.5, val_loss=0.0802]


14; loss = 0.0350, test loss = 0.0802, acc = 99.19%, test acc = 97.53%, non-0 params = 266610


Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.41batch/s, accuracy=99.3, loss=0.0325]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.73batch/s, val_acc=97.5, val_loss=0.0795]


15; loss = 0.0325, test loss = 0.0795, acc = 99.26%, test acc = 97.53%, non-0 params = 266610


Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:11<00:00,  1.36batch/s, accuracy=99.3, loss=0.0304]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.58batch/s, val_acc=97.5, val_loss=0.0789]


16; loss = 0.0304, test loss = 0.0789, acc = 99.34%, test acc = 97.55%, non-0 params = 266610


Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=99.4, loss=0.0287]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.70batch/s, val_acc=97.6, val_loss=0.0787]


17; loss = 0.0287, test loss = 0.0787, acc = 99.38%, test acc = 97.58%, non-0 params = 266610
Saving model with highest test acc = 97.58%



Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.39batch/s, accuracy=99.4, loss=0.0279]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.70batch/s, val_acc=97.6, val_loss=0.0785]


18; loss = 0.0279, test loss = 0.0785, acc = 99.39%, test acc = 97.58%, non-0 params = 266610


Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.41batch/s, accuracy=99.4, loss=0.0258]
Validation: : 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.72batch/s, val_acc=97.6, val_loss=0.079]


19; loss = 0.0258, test loss = 0.0790, acc = 99.44%, test acc = 97.57%, non-0 params = 266610


Training: : 100%|███████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:10<00:00,  1.41batch/s, accuracy=99.5, loss=0.0252]
Validation: : 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.70batch/s, val_acc=97.7, val_loss=0.0785]

20; loss = 0.0252, test loss = 0.0785, acc = 99.46%, test acc = 97.68%, non-0 params = 266610
Saving model with highest test acc = 97.68%






In [38]:
count_surviving_params(model)

266610