In [1]:
from transformers import BertModel, BertConfig
from copy import deepcopy
import torch
from typing import List, Dict
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import math
import os
import torch
import torch.nn.functional as F

In [37]:
class ClassificationDataBase:
    """ Data base class API: loads, preprocesses, and tokenizes datasets to 
        be ready for the BERT model.
        Required attributes:
         - num_classes: number of classes in a dataset
         - datasets: dictionary with 'train' and 'val' tokenized datasets;
    """
    def __init__(self, **kwargs):
        self.num_classes: int
        self.datasets: Dict[str, Dataset] = {}  
        
        
class IMDb(ClassificationDataBase):
    def __init__(self, device, backbone_name, **kwargs):
        super().__init__()
        loading_kwargs = {'path': 'imdb', 
                          'train_filename': 'train.tsv',
                          'dev_filename': 'dev.tsv',
                          'header': 0,
                          'index_col': 0}
        self.device = device
        self._prepare_datasets(loading_kwargs, device, backbone_name)
    
    def _prepare_datasets(self, loading_kwargs, device, backbone_name):
        dataframes = load_dataframes(loading_kwargs)
        self.num_classes = len(dataframes['train']['label'].unique())
        data_X, data_y = {}, {}
        for split in dataframes.keys():
            data_X[split] = dataframes[split].drop('label', axis=1, inplace=False)
            data_X[split] = data_X[split].values.tolist()
            data_y[split] = dataframes[split]['label'].values
            data_y[split] = torch.LongTensor(data_y[split]).to(device)
            data_X[split] = get_tokens(data_X[split], backbone_name=backbone_name)
            data_X[split] = {k: v.to(device) for k,v in data_X[split].items()}
            self.datasets[split] = TokenizedDataset(data_X[split], data_y[split])
        

class TokenizedDataset(Dataset):
    def __init__(self, X, y):
        super().__init__()
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.y)
            
    def __getitem__(self,idx):
        return {key: self.X[key][idx] for key in self.X.keys()}, self.y[idx]


def get_all_subclasses(cls):
    all_subclasses = []
    for subclass in cls.__subclasses__():
        all_subclasses.append(subclass)
        all_subclasses.extend(get_all_subclasses(subclass))
    return all_subclasses


def set_seeds(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


def load_dataframes(loading_kwargs):
    path = loading_kwargs['path']
    train_filename = loading_kwargs['train_filename']
    dev_filename = loading_kwargs['dev_filename']
    header = loading_kwargs['header']
    index_col = loading_kwargs['index_col']
    assert '.' in train_filename, f"unrecognized file format {train_filename}"
    extension = train_filename.split('.')[-1]
    if extension == 'tsv':
        delimiter = '\t'
    elif extension == 'csv':
        delimiter = ','
    else:
        raise ValueError(f"unrecognized file format {extension}")
    dataframes = {}
    for split,split_filename in zip(['train','dev'], [train_filename, dev_filename]):
        filename = os.path.join(path, split_filename)
        dataframes[split] = pd.read_csv(
                filename,
                delimiter=delimiter,
                header=header,
                index_col=index_col,
                engine="python",
                error_bad_lines=False,
                warn_bad_lines=False)
        new_columns = [f"sentence_{i}" for i in range(len(dataframes[split].columns)-1)]+["label"]
        dataframes[split].columns = new_columns
    return dataframes


def get_tokens(data_X: List[List[str]], backbone_name: str) -> dict:
    tokenizer =  BertTokenizer.from_pretrained(backbone_name)
    if len(data_X[0])==1:
        data_X = [X[0] for X in data_X]
    data_X = tokenizer.batch_encode_plus(
                data_X,
                add_special_tokens=True,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt")
    return data_X
    

def stabilized_forward(linearized_encoder,
            min_=1e-1,
            max_=1e5,
            scaler=1e3,
            max_attempts=10):
    def forward(x):
        for i in range(len(linearized_encoder.layer)):
            attempts = 0
            prelim_x = linearized_encoder.layer[i](x)[0]
            while torch.mean(torch.abs(prelim_x))>max_ and attempts<max_attempts:
                for m in linearized_encoder.layer[i].modules():
                    if isinstance(m, torch.nn.Linear):
                        m.weight.data = m.weight.data/scaler
                prelim_x = linearized_encoder.layer[i](x)[0]
                attempts+=1
            while torch.mean(torch.abs(prelim_x))<min_ and attempts<max_attempts:
                for m in linearized_encoder.layer[i].modules():
                    if isinstance(m, torch.nn.Linear):
                        m.weight.data = scaler*m.weight.data
                prelim_x = linearized_encoder.layer[i](x)[0]
                attempts+=1
            assert attempts<max_attempts
            x = prelim_x
        return x
    return forward


def max_pooling(token_embeddings, attention_mask):
    input_mask_expanded=(attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float())
    token_embeddings[input_mask_expanded == 0] = -1e9
    return torch.max(token_embeddings, -2).values


def mean_pooling(token_embeddings, attention_mask):
    input_mask_expanded=(attention_mask.unsqueeze(-1).float())
    sum_embeddings=torch.sum(token_embeddings * input_mask_expanded, -2)
    sum_mask=torch.clamp(input_mask_expanded.sum(-2), min=1e-9)
    return sum_embeddings / sum_mask


def pool(x, encoded_input, pool_type):
    if pool_type in ['avg', 'mean']:
        x = mean_pooling(x['last_hidden_state'], encoded_input['attention_mask'])
    elif pool_type == 'max':
        x = max_pooling(x['last_hidden_state'], encoded_input['attention_mask'])
    elif pool_type in ['cls', 'first']:
        x = x['last_hidden_state'][:,0]
    elif pool_type == 'pooler_output':
        x = x['pooler_output']
    else:
        raise ValueError(f'unknown pool type {pool_type}.')
    return x


class BertSelfAttentionLinearized(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = torch.nn.Linear(config.hidden_size, self.all_head_size)
        self.key = torch.nn.Linear(config.hidden_size, self.all_head_size)
        self.value = torch.nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = torch.nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = torch.nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
        elif is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        # attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = attention_scores

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs


class ClassificationModelBase(torch.nn.Module):
    """ The base model API
        Required attributes:
         - masks: dictionary storing lists of binary masks for
           each stage ('embeddings', 'encoder', 'classifier')
    """

    def __init__(self, **kwargs):
        super().__init__()
        self.masks: Dict[str, List[torch.Tensor]]

    def forward(self, X: Dict) -> torch.Tensor:
        """ Forward pass of the model
            Args:
             - X: dictionary of batched tokenized documents
            Returns: class logits
        """
        raise NotImplementedError(f"override 'forward' method")
    

class PrunerBase:
    """ Pruner base class
    """

    def prune(self, model: ClassificationModelBase,
                    target_sparsity: float,
                    pruning_type: str, **kwargs) -> List[torch.Tensor]:
        """ Prune model
            Args:
             - model: model to prune;
             - target_saprsity: desired sparsity of the model;
             - pruning_type: 'effective' or 'direct.
            Returns:
             - list of binary tensors (masks)
        """
        raise NotImplementedError("override prune method")
        
        
def random_pruning(sparsities, shapes, pruning_type='direct'):
    masks = []
    for shape, sparsity in zip(shapes, sparsities):
        count = int(torch.prod(shape))
        mask = torch.ones(count)
        idx_to_prune = np.random.choice(range(count),
                size=int(sparsity*count),
                replace=False)
        mask[idx_to_prune] = 0.
        shape = tuple(int(dim) for dim in shape)
        mask = mask.reshape(shape)
        masks.append(mask)
    return masks


def score_pruning(target_sparsity, scores, pruning_type):
    scores_flatten = np.concatenate([score.reshape(-1) for score in scores])
    threshold = np.quantile(scores_flatten, target_sparsity)
    masks = [(score>threshold).float() for score in scores]
    return masks


class SNIP(PrunerBase):
    def __init__(self):
        super().__init__()
    
    def prune(self, model, target_sparsity, pruning_type, sample, batch_size=256, **kwargs):
        model.eval()
        weights = []
        for module in model.get_prunable_modules:
            weights.append(model.weight.data.detach().cpu())
        batch_size = min(batch_size, len(sample))
        dataloader = DataLoader(sample, batch_size=batch_size)
        gradients = []
        for i,(X,y) in enumerate(dataloader):
            output = model(X)
            loss = F.cross_entropy(output, y)
            loss.backward(retain_graph=False)
            for module in model.get_prunable_models:
                curr_layer_grad = module.weight.grad.detach().cpu()
                if i==0:
                    gradients.append(curr_layer_grad)
                else:
                    sum_ = gradients[curr_layer_idx]*i
                    avg_ = (sum_+curr_layer_grad)/(i+1)
                    gradients[curr_layer_idx] = avg_
                module.weight.grad = None
        scores = [gradient*weight for gradient,weight in zip(gradients, weights)]
        masks = score_pruning(target_sparsity, scores, pruning_type)
        return masks


class Magnitude(PrunerBase):
    def __init__(self):
        super().__init__()

    def prune(self, model, target_sparsity, pruning_type, **kwargs):
        if pruning_type=='direct':
            scores = self.scores(model)
            return score_pruning(target_sparsity, scores)
        
    def scores(self, model):
        scores = []
        for module in model.get_prunable_modules:
            scores.append(torch.abs(module.weight.data.detach()).numpy())
        return scores

    
class RandomUniform(PrunerBase):
    def __init__(self):
        super().__init__()

    def prune(self, model, target_sparsity, pruning_type, **kwargs):
        shapes = []
        for module in model.get_prunable_modules:
            shape = list(module.weight.data.detach().detach().numpy().shape)
            shapes.append(torch.Tensor(shape))
        if pruning_type=='direct':
            sparsities = self.quotas(target_sparsity, shapes)
            return random_pruning(sparsities, shapes)
    
    def quotas(self, target_sparsity, shapes):
        return [target_sparsity]*len(shapes)
    
    
class RandomIGQ(PrunerBase):
    def __init__(self, tolerance=1e20):
        super().__init__()

    def prune(self, model, target_sparsity, pruning_type, **kwargs):
        shapes = []
        for module in model.get_prunable_modules:
            shape = list(module.weight.data.detach().detach().numpy().shape)
            shapes.append(torch.Tensor(shape))
        if pruning_type=='direct':
            sparsities = self.quotas(target_sparsity, shapes)
            return random_pruning(sparsities, shapes)

    def _bs_force_igq(self, areas, Lengths, target_sparsity, tolerance,f_low,f_high, depth):
        lengths_low=[Length/(f_low/area+1) for Length,area in zip(Lengths,areas)]
        overall_sparsity_low=1-sum(lengths_low)/sum(Lengths)
        if abs(overall_sparsity_low-target_sparsity)<tolerance or depth<0:
            return [1-length/Length for length,Length in zip(lengths_low,Lengths)]
        lengths_high=[Length/(f_high/area+1) for Length,area in zip(Lengths,areas)]
        overall_sparsity_high=1-sum(lengths_high)/sum(Lengths)
        if abs(overall_sparsity_high-target_sparsity)<tolerance or depth<0:
            return [1-length/Length for length,Length in zip(lengths_high,Lengths)]
        force=float(f_low+f_high)/2
        lengths=[Length/(force/area+1) for Length,area in zip(Lengths,areas)]
        overall_sparsity=1-sum(lengths)/sum(Lengths)
        f_low=force if overall_sparsity<target_sparsity else f_low
        f_high=force if overall_sparsity>target_sparsity else f_high
        return self._bs_force_igq(areas,Lengths,target_sparsity,tolerance,f_low,f_high, depth-1)

    def quotas(self, target_sparsity, shapes):
        counts=[torch.prod(shape) for shape in shapes]
        tolerance=100./sum(counts)
        areas=[1./count for count in counts]
        Lengths=[count for count in counts]
        return self._bs_force_igq(areas,Lengths,target_sparsity,tolerance,0,1e30, 1000)


    def effective_masks(self, **kwargs) -> Dict[str, List[torch.Tensor]]:
        """ Return effective masks in a dictionary by module:
            'embeddings', 'encoder', 'classifier'
        """
        raise NotImplementedError(f"override 'linearize' method")
        

class LinearizedEmbeddings(torch.nn.Module):
    def __init__(self, bert_embeddings):
        super().__init__()
        self.shapes = {1: bert_embeddings.word_embeddings.weight.data.size(),
                2: bert_embeddings.position_embeddings.weight.data.size(),
                3: bert_embeddings.token_type_embeddings.weight.data.size()}
        self.total_length = sum(shape[0] for shape in self.shapes.values())
        bert_word_embeddings = bert_embeddings.word_embeddings.weight.data.detach().clone().t()
        self.word_embeddings = torch.nn.Linear(self.shapes[1][0],
                self.shapes[1][1], bias=False)
        self.word_embeddings.weight.data = bert_word_embeddings.requires_grad_(True)
        bert_position_embeddings = bert_embeddings.position_embeddings.weight.data.detach().clone().t()
        self.position_embeddings = torch.nn.Linear(self.shapes[2][0],
                self.shapes[2][1], bias=False)
        self.position_embeddings.weight.data = bert_position_embeddings.requires_grad_(True)
        bert_token_type_embeddings = bert_embeddings.token_type_embeddings.weight.data.detach().clone().t()
        self.token_type_embeddings = torch.nn.Linear(self.shapes[3][0],
                self.shapes[3][1], bias=False)
        self.token_type_embeddings.weight.data = bert_token_type_embeddings.requires_grad_(True)
        
    def forward(self, X):
        X = torch.squeeze(X)
        X1 = X[:, sum(self.shapes[i][0] for i in range(1,1)):sum(self.shapes[i][0] for i in range(1,2))]
        X2 = X[:, sum(self.shapes[i][0] for i in range(1,2)):sum(self.shapes[i][0] for i in range(1,3))]
        X3 = X[:, sum(self.shapes[i][0] for i in range(1,3)):sum(self.shapes[i][0] for i in range(1,4))]
        word_out = self.word_embeddings(X1)
        position_out = self.position_embeddings(X2)
        token_type_out = self.token_type_embeddings(X3)
        return word_out + position_out + token_type_out


class LinearizedBERTClassifier(ClassificationModelBase):
    def __init__(self, reference_model, stabilize=False):
        super().__init__()
        self.device = reference_model.device
        self.stabilize = stabilize
        config = BertConfig(hidden_act=lambda x: x,
                hidden_dropout_prob=0,
                attention_probs_dropout_prob=0)
        linearized_encoder = BertModel(config).encoder
        for layer in linearized_encoder.layer:
            layer.attention.self = BertSelfAttentionLinearized(config)
        if stabilize:
            linearized_encoder.forward = stabilized_forward(linearized_encoder)
        identity = torch.nn.Identity()
        weights = []
        for m in reference_model.module_list[1].modules():
            if isinstance(m, torch.nn.Linear):
                weights.append(m.weight.data.detach().clone())
        idx = 0
        for m in linearized_encoder.modules():
            if isinstance(m, torch.nn.Linear):
                m.weight.data = weights[idx].requires_grad_(True)
                idx+=1
            if isinstance(m, torch.nn.LayerNorm):
                m.forward = identity.forward
        linearized_embeddings = LinearizedEmbeddings(reference_model.module_list[0])
        self.total_length = linearized_embeddings.total_length
        self.modules_list = torch.nn.ModuleList([
                linearized_embeddings.to(self.device),
                linearized_encoder.to(self.device),
                #deepcopy(reference_model.modules_dict.classifier).to(self.device)])
        for m in self.modules():
            if hasattr(m, 'weight'):
                m.weight.data = torch.abs(m.weight.data)
                m.weight.grad = None
        
    def forward(self, inp):
        x = self.module_list[0](inp)[None, :, :]
        x = self.module_list[1](x)
        mask = {'attention_mask': torch.ones(512)}
        if self.stabilize:
            x = {'last_hidden_state': x}
        x = pool(x, mask, 'mean')
        #x = self.modules_dict['classifier'](x)
        return x
    

def layer_sparsity(mask: torch.Tensor):
    return 1-torch.sum(mask>0)/np.prod(mask.numpy().shape)


def model_sparsity(masks: List[torch.Tensor]):
    counts = [np.prod(mask.numpy().shape) for mask in masks]
    sparsities = [layer_sparsity(mask) for mask in masks]
    active_parameters = 0
    for count,sparsity in zip(counts, sparsities):
        active_parameters+=count*(1-sparsity)
    return 1-active_parameters/sum(counts)


class BackboneBERT(ClassificationModelBase):
    def __init__(self, backbone_name,
                    pool_type,
                    num_classes,
                    device,
                    **kwargs):
        super().__init__()
        self.pool_type = pool_type
        bert = BertModel.from_pretrained(backbone_name)
        classifier = torch.nn.Linear(768, num_classes, bias=True)
        self.module_list = torch.nn.ModuleList([
                bert.embeddings.to(device),
                bert.encoder.to(device)])
        self.prunable_modules = {torch.nn.Embedding, torch.nn.Linear}
        self.create_masks()
        self.device = device
        
    def get_prunable_modules(self):
        prunable_modules = []
        for module in self.module_list:
            for m in module.modules():
                if isinstance(m, torch.nn.Embedding):
                    prunable_modules.append(m)
                if isinstance(m, torch.nn.Linear):
                    prunable_modules.append(m)
        return prunable_modules
        
    def create_masks(self):
        self.masks = []
        for module in self.get_prunable_modules():
            self.masks.append(torch.ones(module.weight.data.size()))
    
    def update_masks(self, masks):
        self.masks = masks

    def apply_masks(self):
        for i,module in enumerate(self.get_prunable_modules):
            masked_weight = torch.mul(m.weight.data, self.masks[i])
            m.weight.data = masked_weight
    
    def forward(self, inp):
        self.apply_masks()
        attention_mask = inp['attention_mask'][:, None, None, :]
        x = self.module_list[0](inp['input_ids'])
        x = self.module_list[1](x, attention_mask=attention_mask)
        x = pool(x, inp, self.pool_type)
        #x = self.modules_dict['classifier'](x)
        return x

    def linearize(self):
        self.apply_masks()
        return LinearizedBERTClassifier(self, stabilize=True)
    
    @property
    def effective_masks(self):
        self.apply_masks()
        linearized = self.linearize()
        full_length = linearized.total_length
        X = torch.ones((512,full_length))
        output = torch.sum(linearized(X))
        output.backward()
        effective_masks = []
        for module in self.get_prunable_modules:
            scores = torch.abs(torch.mul(m.weight.grad,m.weight.data))
            effective_masks.append((scores>0).int())
        del linearized
        return effective_masks

In [6]:
from torch.utils.data import Subset
imdb_dataset = IMDb(torch.device('cpu'), 'bert-base-uncased')
sample = Subset(imdb_dataset.datasets['train'], range(4))

In [43]:
model = BERTClassifier('bert-base-uncased', 'mean', 2, torch.device('cpu'))
pruner = SNIP()
masks=pruner.prune(model, 0.999, 'direct', sample)
model.update_masks(masks)
effective_masks = model.effective_masks()
print(f'direct sparsity: {model_sparsity(masks):.9f}')
print(f'effective sparsity: {model_sparsity(effective_masks):.9f}')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


direct sparsity: 0.999000013
effective sparsity: 0.999262094


In [44]:
for i,mask in enumerate(masks):
    print(i,1-torch.sum(mask)/np.prod(mask.numpy().shape))


0 tensor(0.9999)
1 tensor(0.9995)
2 tensor(0.8483)
3 tensor(0.9997)
4 tensor(0.9998)
5 tensor(0.9982)
6 tensor(0.9974)
7 tensor(0.9992)
8 tensor(0.9991)
9 tensor(0.9991)
10 tensor(0.9995)
11 tensor(0.9984)
12 tensor(0.9974)
13 tensor(0.9990)
14 tensor(0.9996)
15 tensor(0.9981)
16 tensor(0.9981)
17 tensor(0.9987)
18 tensor(0.9986)
19 tensor(0.9993)
20 tensor(0.9997)
21 tensor(0.9990)
22 tensor(0.9979)
23 tensor(0.9978)
24 tensor(0.9980)
25 tensor(0.9994)
26 tensor(0.9999)
27 tensor(0.9997)
28 tensor(0.9996)
29 tensor(0.9972)
30 tensor(0.9981)
31 tensor(0.9995)
32 tensor(0.9998)
33 tensor(0.9997)
34 tensor(0.9996)
35 tensor(0.9988)
36 tensor(0.9990)
37 tensor(0.9995)
38 tensor(0.9998)
39 tensor(0.9996)
40 tensor(0.9993)
41 tensor(0.9981)
42 tensor(0.9990)
43 tensor(0.9994)
44 tensor(0.9998)
45 tensor(0.9998)
46 tensor(0.9996)
47 tensor(0.9981)
48 tensor(0.9984)
49 tensor(0.9992)
50 tensor(0.9996)
51 tensor(0.9996)
52 tensor(0.9993)
53 tensor(0.9939)
54 tensor(0.9938)
55 tensor(0.9991)
56