In [1]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class MNISTNet(nn.Module):
#     """Feedfoward neural network with 1 hidden layer"""
#     def __init__(self):
#         super(MNISTNet, self).__init__()
        
#         self.fc1 = nn.Linear(28*28, 256)
#         self.fc2 = nn.Linear(256, 128)
#         self.fc3 = nn.Linear(128, 64)        
#         self.fc4 = nn.Linear(64, 10)
#         self.fc4.is_classifier = True
                
#     def forward(self, x):
#         x = x.view(x.shape[0], -1)
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = F.relu(self.fc3(x))        
#         x = self.fc4(x)
#         return F.log_softmax(x, dim=1)
    
# from torchsummary import summary
# model=MNISTNet()
# print("Model summary")
# print(summary(model, input_size=(1, 28, 28), batch_size=-1))

# print("Model details")
# for nm, params in model.named_parameters():
#     if "weight" in nm and "bn" not in nm and "linear" not in nm:
#         print(nm, params.data.shape)
    

In [2]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split

In [3]:
        
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)



class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)
    
    


In [4]:
def evaluate(model, val_loader):
    """Evaluate the model's performance on the validation set"""
    outputs = [model.validation_step(batch) for batch in val_loader]
#     print("outputs are ",outputs)
    return model.validation_epoch_end(outputs)
        
    
    
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))        

def correct(output, target, topk=(1,)):
    """Computes how many correct outputs with respect to targets

    Does NOT compute accuracy but just a raw amount of correct
    outputs given target labels. This is done for each value in
    topk. A value is considered correct if target is in the topk
    highest values of output.
    The values returned are upperbounded by the given batch size

    [description]

    Arguments:
        output {torch.Tensor} -- Output prediction of the model
        target {torch.Tensor} -- Target labels from data

    Keyword Arguments:
        topk {iterable} -- [Iterable of values of k to consider as correct] (default: {(1,)})

    Returns:
        List(int) -- Number of correct values for each topk
    """

    with torch.no_grad():
        maxk = max(topk)
        # Only need to do topk for highest k, reuse for the rest
        _, pred = output.topk(k=maxk, dim=1, largest=True, sorted=True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(torch.tensor(correct_k.item()))
        return res
    




In [5]:
class MNISTNet(nn.Module):
    """Feedfoward neural network with 1 hidden layer"""
    def __init__(self):
        super(MNISTNet, self).__init__()        
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)        
        self.fc4 = nn.Linear(64, 10)
        self.fc4.is_classifier = True
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))        
        x = self.fc4(x)
        return F.log_softmax(x, dim=1)
    
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        top_1, top_5 = correct(out, labels,topk=(1,5))
        
        top_1=top_1/batch[1].shape[0]
        top_5=top_5/batch[1].shape[0]
        return {'val_loss': loss, 'val_acc': acc, 'top_1': top_1, 'top_5': top_5}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies        
        batch_top_1s = [x['top_1'] for x in outputs]
        epoch_top_1 = torch.stack(batch_top_1s).mean()      # Combine top_1        
        batch_top_5s = [x['top_5'] for x in outputs]
        epoch_top_5 = torch.stack(batch_top_5s).mean()      # Combine top_5
        
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item(),
               'val_top_1': epoch_top_1.item(), 'val_top_5': epoch_top_5.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}, val_top_1: {:.4f}, val_top_5: {:.4f}".format(
                                epoch, result['val_loss'], result['val_acc'], 
                                result['val_top_1'], result['val_top_5']))


def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
    """Train the model using gradient descent"""
    print("At train")
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):
        # Training Phase 
        for batch in train_loader:
            loss = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # Validation phase
        result = evaluate(model, val_loader)        
        model.epoch_end(epoch, result)
        history.append(result)        
    return history


In [6]:
device = get_default_device()
data_transforms=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()])


dataset = MNIST(root='data/', download=True, transform=data_transforms)
# Define test dataset
test_dataset = MNIST(root='data/', train=False,transform=data_transforms)



val_size = 10000
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

batch_size=128

train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=1, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size*2, num_workers=1, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=256)

train_loader = DeviceDataLoader(train_loader, device)
val_loader = DeviceDataLoader(val_loader, device)
test_loader = DeviceDataLoader(test_loader, device)

    

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
Processing...


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Done!


In [None]:
model=MNISTNet()
if torch.cuda.is_available():
    model=model.cuda()    
history = [evaluate(model, val_loader)]
epochs=20
lr=0.01
history=fit(epochs, lr, model, train_loader, val_loader)





At train
Epoch [0], val_loss: 2.2577, val_acc: 0.3871, val_top_1: 0.3871, val_top_5: 0.6913
Epoch [1], val_loss: 1.9926, val_acc: 0.4592, val_top_1: 0.4592, val_top_5: 0.9223
Epoch [2], val_loss: 1.0168, val_acc: 0.7391, val_top_1: 0.7391, val_top_5: 0.9803
Epoch [3], val_loss: 0.6339, val_acc: 0.8202, val_top_1: 0.8202, val_top_5: 0.9854
Epoch [4], val_loss: 0.5141, val_acc: 0.8483, val_top_1: 0.8483, val_top_5: 0.9903
Epoch [5], val_loss: 0.4567, val_acc: 0.8661, val_top_1: 0.8661, val_top_5: 0.9921
Epoch [6], val_loss: 0.4198, val_acc: 0.8781, val_top_1: 0.8781, val_top_5: 0.9929
Epoch [7], val_loss: 0.3959, val_acc: 0.8843, val_top_1: 0.8843, val_top_5: 0.9935
Epoch [8], val_loss: 0.3716, val_acc: 0.8926, val_top_1: 0.8926, val_top_5: 0.9947


In [None]:
res = evaluate(model, test_loader)
print("Accuracy of model ",res["val_acc"])


In [None]:
model_state_path="model_state/mod_untitled.pt"
torch.save(model.state_dict(), model_state_path)


### Show weights

In [None]:
print("Model details")
for nm, params in model.named_parameters():
    if "weight" in nm and "bn" not in nm and "linear" not in nm:
        print(nm, "\n",params.data)
        print(params.data.shape)


In [None]:
all_weights=[]
for nm, params in model.named_parameters():
    if "weight" in nm and "bn" not in nm and "linear" not in nm:
        all_weights.extend(torch.flatten(params.data.detach()))
        
import plotly.graph_objects as go
x = all_weights

fig = go.Figure(data=[go.Histogram(x=x,nbinsx=5000)])
fig.update_layout(
    title="Histogram of weight magnitudes for entire neural network",
    xaxis_title="Magnitudes",
    yaxis_title="Frequency",
)

fig.show()

In [None]:
if torch.cuda.is_available():
    model.load_state_dict(torch.load(model_state_path))
else:
    model.load_state_dict(torch.load(model_state_path,map_location=torch.device('cpu')))
print("Model re-loaded")

### Basic Pruning

In [None]:
prune_rate=2.25
# calculate threshold
all_weights=[]
for nm, params in model.named_parameters():
    if "weight" in nm and "bn" not in nm and "linear" not in nm:
        wts=params.data
        l=list(wts.flatten().detach().numpy())
        all_weights.extend(l)


all_weights=torch.tensor(all_weights)
all_weights=all_weights.flatten()
abs_var=torch.std(torch.abs(all_weights))
threshold=abs_var*prune_rate
print("Threshold is",threshold)
print("Max weight = ",max(all_weights), "\nMin weight =",min(all_weights))

In [None]:
with torch.no_grad():
    for nm, params in model.named_parameters():
        if "weight" in nm and "bn" not in nm and "linear" not in nm:
            params.data[torch.abs(params.data)<threshold]=0


In [None]:
# for nm, params in model.named_parameters():
#     if "weight" in nm and "bn" not in nm and "linear" not in nm:
#         print(nm, "\n", params.data)
#         print(params.data.shape)
#         print("*"*20)

In [None]:
all_weights=[]
for nm, params in model.named_parameters():
    if "weight" in nm and "bn" not in nm and "linear" not in nm:
        all_weights.extend(torch.flatten(params.data.detach()))
        
import plotly.graph_objects as go
x = all_weights

fig = go.Figure(data=[go.Histogram(x=x,nbinsx=5000)])
fig.update_layout(
    title="Histogram of weight magnitudes for entire neural network",
    xaxis_title="Magnitudes",
    yaxis_title="Frequency",
)

fig.show()

In [None]:
all_weights=[]
for nm, params in model.named_parameters():
    if "weight" in nm and "bn" not in nm and "linear" not in nm:
        all_weights.extend(torch.flatten(params.data.detach()))
        
import plotly.graph_objects as go
x = list(filter(lambda a: a != 0, all_weights))



fig = go.Figure(data=[go.Histogram(x=x,nbinsx=5000)])
fig.update_layout(
    title="Histogram of weight magnitudes for entire neural network",
    xaxis_title="Magnitudes",
    yaxis_title="Frequency",
)

fig.show()

In [None]:
def nonzero(tensor):
    return np.sum(tensor != 0.0)

def model_size(model, as_bits=False):
    total_params = 0
    nonzero_params = 0
    for tensor in model.parameters():
        t = np.prod(tensor.shape)
        nz = nonzero(tensor.detach().cpu().numpy())
        if as_bits:
            bits = dtype2bits[tensor.dtype]
            t *= bits
            nz *= bits
        total_params += t
        nonzero_params += nz
    return int(total_params), int(nonzero_params)

total_size,nz_size=model_size(model)
compression=(total_size-nz_size)/total_size
res = evaluate(model, test_loader)
print("Compression = ",compression)
print("Accuracy of model ",res["val_acc"])


In [None]:
# for nm, params in model.named_parameters():
#     if "weight" in nm and "bn" not in nm and "linear" not in nm:
#         print(nm, "\n", params.data)
#         print(params.data.shape)
#         print("*"*20)

### Pruning with fine-tuning

In [None]:
def prune_model_get_mask(model,prune_rate):
    
    mask_whole_model=[]
    all_weights=[]
    for nm, params in model.named_parameters():
        if "weight" in nm and "bn" not in nm and "linear" not in nm:
            wts=params.data
            l=list(wts.flatten().detach().numpy())
            all_weights.extend(l)


    all_weights=torch.tensor(all_weights)
    all_weights=all_weights.flatten()
    abs_var=torch.std(torch.abs(all_weights))
    threshold=abs_var*prune_rate
    
    for nm, params in model.named_parameters():
        if "weight" in nm and "bn" not in nm and "linear" not in nm:
            mask_layer=torch.ones(params.shape)            
            num_components=params.shape[0]
            for index_component in range(num_components):
                values=params[index_component]            
                re_shaped_values=values.flatten()                
                mask_vals = (torch.abs(re_shaped_values)>threshold).float()
                mask_vals=mask_vals.reshape(values.shape)
                mask_layer[index_component]=mask_vals
            mask_whole_model.append(mask_layer)
    return mask_whole_model

if torch.cuda.is_available():
    model.load_state_dict(torch.load(model_state_path))
else:
    model.load_state_dict(torch.load(model_state_path,map_location=torch.device('cpu')))

print("Model re-loaded")
mask_whole_model=prune_model_get_mask(model,prune_rate) 
print("Created mask\n",mask_whole_model)

In [None]:
def apply_mask_model(model,list_mask_whole_model):
    mask_layer_count=0
    for nm, params in model.named_parameters():
        if "weight" in nm and "bn" not in nm and "linear" not in nm:
            mask_layer=list_mask_whole_model[mask_layer_count]
            with torch.no_grad():
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                mask_layer=mask_layer.to(device)
                params.data=params.data*mask_layer            
            mask_layer_count+=1
    
def fit_prune(epochs, lr, model, train_loader, val_loader, 
        opt_func=torch.optim.SGD,
        mask_whole_model=mask_whole_model
       ):
    """Train the model using gradient descent"""
    print("At fine tuning (prune + train)")
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):
        # Training Phase 
        for batch in train_loader:
            loss = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # prune here
            apply_mask_model(model,mask_whole_model)
        # Validation phase
        result = evaluate(model, val_loader)        
        model.epoch_end(epoch, result)
        history.append(result)        
    return history

epochs=5
lr=0.01
history=fit_prune(epochs, lr, model, train_loader, val_loader,
           mask_whole_model=mask_whole_model)
total_size,nz_size=model_size(model)
compression=(total_size-nz_size)/total_size
res = evaluate(model, test_loader)
print("Compression = ",compression)
print("Accuracy of model ",res["val_acc"])
