In [11]:
import os
import torch
import os.path as osp
import torchvision
from scheme.EB import EarlyBird
from typing import *
import tqdm
import collections
import numpy as np

In [2]:
DIRS = ['../Early-Bird-Tickets/EB/', '../FreeTickets/results/']
DATASETS = ['cifar10', 'cifar100']
MODELS = ['resnet34']
SPARSITY = 0.8

In [3]:
structured_dirs = osp.join(DIRS[0], DATASETS[0], MODELS[0])
unstructured_dirs = osp.join(DIRS[1], DATASETS[0], MODELS[0])

In [4]:
struct_files = [file for file in os.listdir(structured_dirs) if 'ckpt' in file]
unstruct_files = [file for file in os.listdir(unstructured_dirs) if 'ckpt' in file]

In [5]:
struct_files.sort(key=lambda x: int(x.split('.')[0][4::]))
unstruct_files.sort(key=lambda x: int(x.split('_')[0][4::]))

In [6]:
paired_files = zip(struct_files, unstruct_files)

In [7]:
model = eval(f"torchvision.models.{MODELS[0]}")(num_classes=10).cuda()

In [8]:
from typing import *
from scheme.EB import EarlyBird
import torch
import torch.nn as nn
def generate_macro_mask(model_fn:Callable, s_f:str, us_f:str, n_class:int, sparsity:float=0.8, x_perc:float=0.5) -> torch.Tensor:
    """
    desc:
        generate a macro or structured mask (channel masking) 
        from struture and unstructure pruning.
    param:
        model_fn: model function
        s_f: structure file path
        us_f: unstructure file path
        n_class: number of class
        sparsity: degree of sparsity from 0.0 to 1.0
        x_perc: cross over percentage. If kernel is x_perc 
            pruned, it is considered off and vice versa.    return 
        mask: masking vector 
    """
    s_m = model_fn(num_classes=n_class).cuda()
    s_m.load_state_dict(torch.load(s_f, map_location='cpu')['state_dict'])
    us_m = model_fn(num_classes=n_class).cuda()
    us_m.load_state_dict(torch.load(us_f, map_location='cpu')['state_dict'])
    
    s_mask = EarlyBird(sparsity).pruning(s_m, sparsity) 
    us_mask = micro_to_macro(us_m)
    
    intersection = s_mask * (s_mask == us_mask).float()    
    return s_mask, us_mask, intersection

def micro_to_macro(model:torch.nn.Module, x_perc:float=0.55) -> torch.Tensor:
    """
    desc:
        convert unstructured pruning to structure pruning
    param:
        model: nn.Module 
        x_perc: cross over percentage. If kernel is x_perc 
        pruned, it is considered off and vice versa.
    return:
        mask: masking vector
    """
    total = 0
    sparse = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]
            
    mask = torch.zeros(total)
    offset = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            # m.weight (out, in, kH, kW)
            # binarize and sum total zero. 
            # new m.weight (out, in, value), and if value/(kH+kW) > x_perc, 
            # filter is on. if sum of all filters is > x_perf, channel is on. 
            # increase offset
            out_c, in_c, kH, kW = m.weight.shape
            boolean_weight = (m.weight != 0.0).view(out_c, in_c, -1).float().mean(dim=-1).gt(x_perc).float().sum(dim=-1).div(in_c).gt(x_perc).float()
            mask[offset:offset+out_c] = boolean_weight
            offset += out_c
        
    return mask 

def valid_masked_model(model:nn.Module, mask:torch.Tensor) -> bool:
    """
    desc:
        determine if the masking vector consitutes a valid model.
        A valid model is one such that at least one input channel 
        makes it to the final output channel.
    params:
        model: resnet/vgg model construct
        mask: a masking vector
    return:
        bool: Determine whether the pruned model from mask is a 
        valid model
    """
    offset = 0
    total = 0
    
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]
    assert mask.size(0) == total
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            out_c = m.weight.shape[0]
            if mask[offset:offset+out_c].sum() > 0:
                offset+=out_c
            else:
                print(offset, out_c)
                return False
    return True

def apply_marco_mask(model:nn.Module, mask:torch.Tensor) -> nn.Module:
    """
    desc: 
        applying model's mask on batchnorm layer.
    """
    total = 0
    offset = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]
    assert mask.shape[0] == total
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            out_c = m.weight.data.size(0)
            m.weight.data.mul_(mask[offset:offset+out_c])
            m.bias.data.mul_(mask[offset:offset+out_c])
            offset += out_c
    return model
    
    
    

In [9]:
eb = EarlyBird(SPARSITY)
model_fn = eval(f"torchvision.models.{MODELS[0]}")
output_dirs = '../intersected_macro_masks/cifar10/resnet34'
random_init = '../common_models/random_weights/resnet34_10.pth.tar'
os.makedirs(output_dirs, exist_ok=True)
paired_files = zip(struct_files, unstruct_files)
sizes = collections.defaultdict(list)
for i, (struct, unstruct) in tqdm.tqdm(enumerate(paired_files), total=len(struct_files)):
    s_f = osp.join(structured_dirs, struct)
    us_f = osp.join(unstructured_dirs, unstruct)
    s_mask, us_mask, macro_mask = generate_macro_mask(model_fn, s_f, us_f, 10, 0.8, 0.55)
    model = model_fn(num_classes=10)
    model.load_state_dict(torch.load(random_init))
    model = apply_marco_mask(model, macro_mask)
    torch.save(model.state_dict(), osp.join(output_dirs, f'iou_{i}.pth.tar'))
    model.load_state_dict(torch.load(random_init))
    model = apply_marco_mask(model, s_mask)
    torch.save(model.state_dict(), osp.join(output_dirs, f'struct_{i}.pth.tar'))
    model.load_state_dict(torch.load(random_init))
    model = apply_marco_mask(model, us_mask)
    torch.save(model.state_dict(), osp.join(output_dirs, f'unstruct_{i}.pth.tar'))
    sizes['s_mask'].append(s_mask.sum()/s_mask.size(0))
    sizes['us_mask'].append(us_mask.sum()/us_mask.size(0))
    sizes['macro_mask'].append(macro_mask.sum()/macro_mask.size(0))

    
#     print(f"s_mask: {s_mask.sum()/s_mask.size(0)}")
#     print(f"us_mask: {us_mask.sum()/us_mask.size(0)}")
#     print(f"marco_mask: {macro_mask.sum()/macro_mask.size(0)}")
    

    
    

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [03:24<00:00,  2.04s/it]


In [12]:
compatibility = np.zeros((len(sizes['s_mask']), 3))
for i in range(len(sizes['s_mask'])):
    compatibility[i][0] = sizes['s_mask'][i]
    compatibility[i][1] = sizes['us_mask'][i]
    compatibility[i][2] = sizes['macro_mask'][i]

In [15]:
path = 'resnet34_sparsity_ratio.txt'
np.savetxt(path, compatibility, delimiter=',')