In [1]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)


import os
import yaml
import argparse
import numpy as np
from pathlib import Path
from models import *
from experiment import VAEXperiment
import torch.backends.cudnn as cudnn
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from dataset import VAEDataset
from pytorch_lightning.plugins import DDPPlugin

import matplotlib.pyplot as plt


# In[ ]:


# # Plot some training images
# real_batch = next(iter(train_loader))
# plt.figure(figsize=(8,8))
# plt.axis("off")
# plt.title("Training Images")
# plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))


# In[ ]:




# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 10

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1


# In[ ]:


model_nm="VanillaVAE"
args_filename="configs/vae.yaml"
with open(args_filename, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
        
        
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
chk_path="logs/"+model_nm+"/version_2/checkpoints/last.ckpt"

checkpoint = torch.load(chk_path,map_location=torch.device(device))
        
    
data = VAEDataset(**config["data_params"])
data.setup()
tloader=data.test_dataloader()
train_loader=data.train_dataloader()    


Random Seed:  999


In [None]:
import copy

def create_mask_from_mean_wt(model,mean_weight_description,prune_rate):
    mask_whole_model=[]
    for nm, params in model.named_parameters():
        if "weight" in nm and "bn" not in nm and "linear" not in nm:
#             print(nm)
            mask_layer=torch.ones(params.shape)    
            mean_wt_layer=mean_weight_description[nm]
            wts_this_layer=[]
            wts=mean_weight_description[nm]
            abs_var=torch.std(wts.flatten())
            threshold=abs_var*prune_rate
            
            these_wts=copy.deepcopy(params.data)
            these_wts=these_wts.flatten()
            mask_layer=mask_layer.flatten()
            
            for i in range(these_wts.shape[0]):
                if torch.abs(these_wts[i])<threshold:
                    mask_layer[i]=0
            mask_layer=torch.reshape(mask_layer,params.data.shape)
            print(nm,params.shape,mask_layer.shape,abs_var,threshold)
            mask_whole_model.append(mask_layer)
            
    return mask_whole_model
            
    
def get_weighted_mean(state_dicts,keyy,importance_vector):
    sum_val=0
    for i in range(len(importance_vector)):
        importance=importance_vector[i]
        wt_vals=state_dicts[i][keyy]
        importance_wt_vals=importance*wt_vals
        sum_val+=importance_wt_vals
    return sum_val



def apply_mask_model(model,list_mask_whole_model,layer_to_prune=None):
    mask_layer_count=0
    for nm, params in model.named_parameters():        
        if "weight" in nm and "bn" not in nm and "linear" not in nm:
#             print(mask_layer_count,layer_to_prune)
            if layer_to_prune is not None:
                if mask_layer_count>layer_to_prune:
#                     print(mask_layer_count,layer_to_prune,"returning model")
                    return model
            
            
            mask_layer=list_mask_whole_model[mask_layer_count]
            with torch.no_grad():
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#                 print("pruning layer",mask_layer_count)
                mask_layer=mask_layer.to(device)    
                params.data=params.data*mask_layer            
            mask_layer_count+=1
    return model


def nonzero(tensor):

    return np.sum(tensor != 0.0)


def model_size(model, as_bits=False):
    
    

    total_params = 0
    nonzero_params = 0
    for tensor in model.parameters():
        t = np.prod(tensor.shape)
        nz = nonzero(tensor.detach().cpu().numpy())
        if as_bits:
            bits = dtype2bits[tensor.dtype]
            t *= bits
            nz *= bits
        total_params += t
        nonzero_params += nz
    return int(total_params), int(nonzero_params)    

In [89]:
model_nm="VanillaVAE"
args_filename="configs/vae.yaml"
with open(args_filename, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
        
model = vae_models[config['model_params']['name']](**config['model_params'])

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
chk_path="logs/"+model_nm+"/version_2/checkpoints/last.ckpt"

checkpoint = torch.load(chk_path,map_location=torch.device(device))


for nm,params in model.named_parameters():
#     print(nm,params.shape)
#     print("model."+nm in checkpoint["state_dict"])
    keyy="model."+nm 
    params.data=checkpoint["state_dict"][keyy]

In [90]:
checkpoint["state_dict"]['model.encoder.0.0.weight']

tensor([[[[-5.7147e-01, -5.4276e-02,  4.2360e-01],
          [-5.5104e-01,  1.1627e-01,  4.9879e-01],
          [-7.2046e-01,  1.8088e-01,  2.3496e-01]],

         [[-4.8046e-01, -6.0694e-02,  4.9852e-01],
          [-7.2104e-01,  2.4263e-01,  3.9929e-01],
          [-2.5442e-01,  2.8318e-01,  7.1152e-01]],

         [[-3.8183e-01, -1.0914e-01,  2.6411e-01],
          [-4.9824e-01, -1.4876e-01,  5.2428e-01],
          [-3.0802e-01, -1.7951e-01,  3.9325e-01]]],


        [[[ 3.4050e-01,  7.1979e-02,  3.8273e-01],
          [ 7.8096e-02, -1.8944e-01, -1.4746e-01],
          [-1.1103e-01, -6.0413e-01, -3.3009e-01]],

         [[ 4.6322e-01,  1.7154e-01,  4.7905e-01],
          [ 3.0857e-01,  1.9879e-01,  1.2830e-01],
          [ 1.3608e-01, -3.9847e-01, -2.5415e-01]],

         [[ 3.0442e-01, -1.6816e-01,  1.8796e-02],
          [-1.9579e-01, -6.0260e-01, -2.4356e-01],
          [-6.1299e-02, -1.1389e+00, -6.5029e-01]]],


        [[[ 7.3294e-01,  8.4180e-01,  1.3514e+00],
          [ 2.3

In [91]:
state_dicts=[]
epoch_names=["last.ckpt","epoch=2-step=7631.ckpt"]

for epoch_name in epoch_names:
    chk_path="logs/"+model_nm+"/version_2/checkpoints/"+epoch_name
    print(chk_path)
    checkpoint = torch.load(chk_path,map_location=torch.device(device))
    state_dict=checkpoint["state_dict"]
    state_dicts.append(state_dict)

logs/VanillaVAE/version_2/checkpoints/last.ckpt
logs/VanillaVAE/version_2/checkpoints/epoch=2-step=7631.ckpt


In [92]:
state_dicts[0].keys()

odict_keys(['model.encoder.0.0.weight', 'model.encoder.0.0.bias', 'model.encoder.0.1.weight', 'model.encoder.0.1.bias', 'model.encoder.0.1.running_mean', 'model.encoder.0.1.running_var', 'model.encoder.0.1.num_batches_tracked', 'model.encoder.1.0.weight', 'model.encoder.1.0.bias', 'model.encoder.1.1.weight', 'model.encoder.1.1.bias', 'model.encoder.1.1.running_mean', 'model.encoder.1.1.running_var', 'model.encoder.1.1.num_batches_tracked', 'model.encoder.2.0.weight', 'model.encoder.2.0.bias', 'model.encoder.2.1.weight', 'model.encoder.2.1.bias', 'model.encoder.2.1.running_mean', 'model.encoder.2.1.running_var', 'model.encoder.2.1.num_batches_tracked', 'model.encoder.3.0.weight', 'model.encoder.3.0.bias', 'model.encoder.3.1.weight', 'model.encoder.3.1.bias', 'model.encoder.3.1.running_mean', 'model.encoder.3.1.running_var', 'model.encoder.3.1.num_batches_tracked', 'model.encoder.4.0.weight', 'model.encoder.4.0.bias', 'model.encoder.4.1.weight', 'model.encoder.4.1.bias', 'model.encoder.4

In [94]:
importance_vector=[0.8,0.2]
evol_wts={}
for nm,params in model.named_parameters():
    if "weight" in nm and "bn" not in nm and "linear" not in nm:
        print(nm,params.shape)
        keyy="model."+nm         
#         print(state_dicts[0][keyy].shape,state_dicts[1][keyy].shape)
#         print(state_dicts[0][keyy][0],state_dicts[1][keyy][0])        
        new_param_values=get_weighted_mean(state_dicts,keyy,importance_vector)
#         print(new_param_values[0])
        evol_wts[nm]=new_param_values

encoder.0.0.weight torch.Size([32, 3, 3, 3])
encoder.0.1.weight torch.Size([32])
encoder.1.0.weight torch.Size([64, 32, 3, 3])
encoder.1.1.weight torch.Size([64])
encoder.2.0.weight torch.Size([128, 64, 3, 3])
encoder.2.1.weight torch.Size([128])
encoder.3.0.weight torch.Size([256, 128, 3, 3])
encoder.3.1.weight torch.Size([256])
encoder.4.0.weight torch.Size([512, 256, 3, 3])
encoder.4.1.weight torch.Size([512])
fc_mu.weight torch.Size([128, 2048])
fc_var.weight torch.Size([128, 2048])
decoder_input.weight torch.Size([2048, 128])
decoder.0.0.weight torch.Size([512, 256, 3, 3])
decoder.0.1.weight torch.Size([256])
decoder.1.0.weight torch.Size([256, 128, 3, 3])
decoder.1.1.weight torch.Size([128])
decoder.2.0.weight torch.Size([128, 64, 3, 3])
decoder.2.1.weight torch.Size([64])
decoder.3.0.weight torch.Size([64, 32, 3, 3])
decoder.3.1.weight torch.Size([32])
final_layer.0.weight torch.Size([32, 32, 3, 3])
final_layer.1.weight torch.Size([32])
final_layer.3.weight torch.Size([3, 32, 3,

In [95]:
evol_wts.keys()

dict_keys(['encoder.0.0.weight', 'encoder.0.1.weight', 'encoder.1.0.weight', 'encoder.1.1.weight', 'encoder.2.0.weight', 'encoder.2.1.weight', 'encoder.3.0.weight', 'encoder.3.1.weight', 'encoder.4.0.weight', 'encoder.4.1.weight', 'fc_mu.weight', 'fc_var.weight', 'decoder_input.weight', 'decoder.0.0.weight', 'decoder.0.1.weight', 'decoder.1.0.weight', 'decoder.1.1.weight', 'decoder.2.0.weight', 'decoder.2.1.weight', 'decoder.3.0.weight', 'decoder.3.1.weight', 'final_layer.0.weight', 'final_layer.1.weight', 'final_layer.3.weight'])

In [97]:
# evol_wts
prune_rate=0.05
list_mask_val=create_mask_from_mean_wt(model,evol_wts,prune_rate)

encoder.0.0.weight torch.Size([32, 3, 3, 3]) torch.Size([32, 3, 3, 3]) tensor(0.3757) tensor(0.0188)
encoder.0.1.weight torch.Size([32]) torch.Size([32]) tensor(0.2073) tensor(0.0104)
encoder.1.0.weight torch.Size([64, 32, 3, 3]) torch.Size([64, 32, 3, 3]) tensor(0.2704) tensor(0.0135)
encoder.1.1.weight torch.Size([64]) torch.Size([64]) tensor(0.2865) tensor(0.0143)
encoder.2.0.weight torch.Size([128, 64, 3, 3]) torch.Size([128, 64, 3, 3]) tensor(0.2718) tensor(0.0136)
encoder.2.1.weight torch.Size([128]) torch.Size([128]) tensor(0.3943) tensor(0.0197)
encoder.3.0.weight torch.Size([256, 128, 3, 3]) torch.Size([256, 128, 3, 3]) tensor(0.2684) tensor(0.0134)
encoder.3.1.weight torch.Size([256]) torch.Size([256]) tensor(0.3330) tensor(0.0166)
encoder.4.0.weight torch.Size([512, 256, 3, 3]) torch.Size([512, 256, 3, 3]) tensor(0.2204) tensor(0.0110)
encoder.4.1.weight torch.Size([512]) torch.Size([512]) tensor(0.0931) tensor(0.0047)
fc_mu.weight torch.Size([128, 2048]) torch.Size([128, 20

In [98]:
pruned_model=apply_mask_model(model,list_mask_val)




total_size,nz_size=model_size(model)
compression=(total_size-nz_size)/total_size
print("compression is ",compression)

compression is  0.05535987972475864


compression is  0.4593346005914718
