In [31]:
import os
import torch
import os.path as osp
import torchvision
from scheme.EB import EarlyBird
from typing import *
import tqdm
import collections
import numpy as np

In [32]:
DIRS = ['../Early-Bird-Tickets/EB/', '../FreeTickets/results/']
DATASETS = ['cifar10']#, 'cifar100']
MODELS = ['vgg16']
SPARSITY = 0.8

In [33]:
structured_dirs = osp.join(DIRS[0], DATASETS[0], MODELS[0])
unstructured_dirs = osp.join(DIRS[1], DATASETS[0], MODELS[0])

In [34]:
struct_files = [file for file in os.listdir(structured_dirs) if 'ckpt' in file]
unstruct_files = [file for file in os.listdir(unstructured_dirs) if 'ckpt' in file]

In [35]:
struct_files.sort(key=lambda x: int(x.split('.')[0][4::]))
unstruct_files.sort(key=lambda x: int(x.split('_')[0][4::]))

In [36]:
paired_files = zip(struct_files, unstruct_files)

In [37]:
model = eval(f"torchvision.models.{MODELS[0]}")(num_classes=10).cuda()

In [38]:
from typing import *
from scheme.EB import EarlyBird
import torch
import torch.nn as nn
def generate_macro_mask(model_fn:Callable, s_f:str, us_f:str, n_class:int, sparsity:float=0.8, x_perc:float=0.5) -> torch.Tensor:
    """
    desc:
        generate a macro or structured mask (channel masking) 
        from struture and unstructure pruning.
    param:
        model_fn: model function
        s_f: structure file path
        us_f: unstructure file path
        n_class: number of class
        sparsity: degree of sparsity from 0.0 to 1.0
        x_perc: cross over percentage. If kernel is x_perc 
            pruned, it is considered off and vice versa.    return 
        mask: masking vector 
    """
    s_m = model_fn(num_classes=n_class).cuda()
    # print(s_f)
    # print(us_m)
    # print(torch.load(s_f, map_location='cpu')['state_dict'].keys())
    s_m.load_state_dict(torch.load(s_f, map_location='cpu')['state_dict'])
    us_m = model_fn(num_classes=n_class).cuda()
    us_m.load_state_dict(torch.load(us_f, map_location='cpu')['state_dict'])
    
    s_mask = EarlyBird(sparsity).pruning(s_m, sparsity) 
    us_mask = micro_to_macro(us_m)
    
    intersection = s_mask * (s_mask == us_mask).float()    
    return s_mask, us_mask, intersection

def micro_to_macro(model:torch.nn.Module, x_perc:float=0.55) -> torch.Tensor:
    """
    desc:
        convert unstructured pruning to structure pruning
    param:
        model: nn.Module 
        x_perc: cross over percentage. If kernel is x_perc 
        pruned, it is considered off and vice versa.
    return:
        mask: masking vector
    """
    total = 0
    sparse = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]
            
    mask = torch.zeros(total)
    offset = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            # m.weight (out, in, kH, kW)
            # binarize and sum total zero. 
            # new m.weight (out, in, value), and if value/(kH+kW) > x_perc, 
            # filter is on. if sum of all filters is > x_perf, channel is on. 
            # increase offset
            out_c, in_c, kH, kW = m.weight.shape
            boolean_weight = (m.weight != 0.0).view(out_c, in_c, -1).float().mean(dim=-1).gt(x_perc).float().sum(dim=-1).div(in_c).gt(x_perc).float()
            mask[offset:offset+out_c] = boolean_weight
            offset += out_c
        
    return mask 

def valid_masked_model(model:nn.Module, mask:torch.Tensor) -> bool:
    """
    desc:
        determine if the masking vector consitutes a valid model.
        A valid model is one such that at least one input channel 
        makes it to the final output channel.
    params:
        model: resnet/vgg model construct
        mask: a masking vector
    return:
        bool: Determine whether the pruned model from mask is a 
        valid model
    """
    offset = 0
    total = 0
    
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]
    assert mask.size(0) == total
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            out_c = m.weight.shape[0]
            if mask[offset:offset+out_c].sum() > 0:
                offset+=out_c
            else:
                print(offset, out_c)
                return False
    return True

def apply_marco_mask(model:nn.Module, mask:torch.Tensor) -> nn.Module:
    """
    desc: 
        applying model's mask on batchnorm layer.
    """
    total = 0
    offset = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]
    assert mask.shape[0] == total
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            out_c = m.weight.data.size(0)
            m.weight.data.mul_(mask[offset:offset+out_c])
            m.bias.data.mul_(mask[offset:offset+out_c])
            offset += out_c
    return model
    
    
    

In [39]:
eb = EarlyBird(SPARSITY)
model_fn = eval(f"torchvision.models.{MODELS[0]}")
output_dirs = '../intersected_macro_masks/cifar10/vgg16'
random_init = '../common_models/random_weights/vgg16_10.pth.tar'
os.makedirs(output_dirs, exist_ok=True)
paired_files = zip(struct_files, unstruct_files)
sizes = collections.defaultdict(list)
for i, (struct, unstruct) in tqdm.tqdm(enumerate(paired_files), total=len(struct_files)):
    s_f = osp.join(structured_dirs, struct)
    us_f = osp.join(unstructured_dirs, unstruct)
    s_mask, us_mask, macro_mask = generate_macro_mask(model_fn, s_f, us_f, 10, 0.8, 0.55)
    model = model_fn(num_classes=10)
    model.load_state_dict(torch.load(random_init))
    model = apply_marco_mask(model, macro_mask)
    torch.save(model.state_dict(), osp.join(output_dirs, f'iou_{i}.pth.tar'))
    model.load_state_dict(torch.load(random_init))
    model = apply_marco_mask(model, s_mask)
    torch.save(model.state_dict(), osp.join(output_dirs, f'struct_{i}.pth.tar'))
    model.load_state_dict(torch.load(random_init))
    model = apply_marco_mask(model, us_mask)
    torch.save(model.state_dict(), osp.join(output_dirs, f'unstruct_{i}.pth.tar'))
    sizes['s_mask'].append(s_mask.sum()/s_mask.size(0))
    sizes['us_mask'].append(us_mask.sum()/us_mask.size(0))
    sizes['macro_mask'].append(macro_mask.sum()/macro_mask.size(0))

    
#     print(f"s_mask: {s_mask.sum()/s_mask.size(0)}")
#     print(f"us_mask: {us_mask.sum()/us_mask.size(0)}")
#     print(f"marco_mask: {macro_mask.sum()/macro_mask.size(0)}")
    

    
    

  0%|          | 0/100 [00:01<?, ?it/s]


RuntimeError: Error(s) in loading state_dict for VGG:
	Missing key(s) in state_dict: "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.12.weight", "features.12.bias", "features.19.weight", "features.19.bias", "features.26.weight", "features.26.bias". 
	Unexpected key(s) in state_dict: "features.31.weight", "features.31.bias", "features.31.running_mean", "features.31.running_var", "features.31.num_batches_tracked", "features.34.weight", "features.34.bias", "features.35.weight", "features.35.bias", "features.35.running_mean", "features.35.running_var", "features.35.num_batches_tracked", "features.37.weight", "features.37.bias", "features.38.weight", "features.38.bias", "features.38.running_mean", "features.38.running_var", "features.38.num_batches_tracked", "features.40.weight", "features.40.bias", "features.41.weight", "features.41.bias", "features.41.running_mean", "features.41.running_var", "features.41.num_batches_tracked", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.1.num_batches_tracked", "features.3.weight", "features.3.bias", "features.4.weight", "features.4.bias", "features.4.running_mean", "features.4.running_var", "features.4.num_batches_tracked", "features.8.weight", "features.8.bias", "features.8.running_mean", "features.8.running_var", "features.8.num_batches_tracked", "features.11.weight", "features.11.bias", "features.11.running_mean", "features.11.running_var", "features.11.num_batches_tracked", "features.15.weight", "features.15.bias", "features.15.running_mean", "features.15.running_var", "features.15.num_batches_tracked", "features.18.weight", "features.18.bias", "features.18.running_mean", "features.18.running_var", "features.18.num_batches_tracked", "features.20.weight", "features.20.bias", "features.21.running_mean", "features.21.running_var", "features.21.num_batches_tracked", "features.25.weight", "features.25.bias", "features.25.running_mean", "features.25.running_var", "features.25.num_batches_tracked", "features.27.weight", "features.27.bias", "features.28.running_mean", "features.28.running_var", "features.28.num_batches_tracked", "features.30.weight", "features.30.bias". 
	size mismatch for features.7.weight: copying a param with shape torch.Size([128, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for features.10.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
	size mismatch for features.10.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for features.14.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for features.17.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 256, 3, 3]).
	size mismatch for features.17.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for features.21.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for features.21.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for features.24.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for features.28.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).

In [None]:
compatibility = np.zeros((len(sizes['s_mask']), 3))
for i in range(len(sizes['s_mask'])):
    compatibility[i][0] = sizes['s_mask'][i]
    compatibility[i][1] = sizes['us_mask'][i]
    compatibility[i][2] = sizes['macro_mask'][i]

In [None]:
path = 'vgg16_sparsity_ratio.txt'
np.savetxt(path, compatibility, delimiter=',')