In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
bert_config_file ='teacher_models/config.json'
# specify your pruned model
ckpt_file='pruned_models/pd-sst2-05/lr3e20_s_bs32_0.4_pf1_IS0.998_Reg3e-1_E192/gs42080.pt'

In [None]:
device='cuda'
import torch
import os
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION']='python'

from modeling_prunebert import BertModel as PrunedBertModel
from modeling_prunebert import BertForSequenceClassification
from modeling_prunebert import set_head_cuts
from transformers import BertConfig
from textpruner import summary,inference_time
from textpruner import TransformerPruner
from textpruner.extentions.pruner import FineGrainedPruner

config = BertConfig.from_json_file(bert_config_file)
config.proj_size = 192

state_dict = torch.load(ckpt_file,map_location=device)

In [None]:
# restore weights
state_dict_items=list(state_dict.items())
for k,v in state_dict_items:
    if k.endswith('_mask'):
        state_dict[k[:-5]] = state_dict[k] * state_dict[k[:-5]+'_orig']
keys =  [k for k in state_dict.keys() if k.endswith('_orig')]
for k in keys:
    del state_dict[k]

In [None]:
model = BertForSequenceClassification.from_pretrained(None,config=config,state_dict=state_dict)
model.to(device)
model.eval();

In [None]:
ffn_mask_list = [state_dict[f'bert.encoder.layer.{i}.output.dense.weight_mask'][0] for i in range(12)]
ffn_mask = torch.stack(ffn_mask_list)
qk_mask_list = [state_dict[f'bert.encoder.layer.{i}.attention.self.query.bias_mask'] for i in range(12)]
vo_mask_list = [state_dict[f'bert.encoder.layer.{i}.attention.self.value.bias_mask'] for i in range(12)]
qk_head_size_list = [t.reshape(12,64).sum(-1) for t in qk_mask_list]
vo_head_size_list = [t.reshape(12,64).sum(-1) for t in vo_mask_list]

# make qk_mask and vo_mask consistent
def make_qk_vo_consistency(qk_mask_list,vo_mask_list):
    new_qk_mask_list = []
    new_vo_mask_list = []
    assert len(qk_mask_list)==len(vo_mask_list)
    for qk_mask, vo_mask in zip(qk_mask_list, vo_mask_list):
        if vo_mask.sum()==0: #important for empty MHA
            new_qk_mask = []
            new_vo_mask = []
        else:
            new_qk_mask = []
            new_vo_mask = []
            qk_head_mask = qk_mask.reshape(12,64)
            vo_head_mask = vo_mask.reshape(12,64)
            for i,(qk_head, vo_head) in enumerate(zip(qk_head_mask, vo_head_mask)):
                if vo_head.sum()==0 and qk_head.sum()==0 :
                    continue
                else:
                    new_qk_mask.append(qk_head.clone())
                    new_vo_mask.append(vo_head.clone())
            new_qk_mask = torch.stack(new_qk_mask)
            new_vo_mask = torch.stack(new_vo_mask)
        new_qk_mask_list.append(new_qk_mask)
        new_vo_mask_list.append(new_vo_mask)
    return new_qk_mask_list,new_vo_mask_list

consistent_qk_mask_list,consistent_vo_mask_list =  make_qk_vo_consistency(qk_mask_list,vo_mask_list)
consistent_qk_head_size_list = [t.reshape(-1,64).sum(-1).int() if isinstance(t,torch.Tensor) else t for t in consistent_qk_mask_list ]
consistent_vo_head_size_list = [t.reshape(-1,64).sum(-1).int() if isinstance(t,torch.Tensor) else t for t in consistent_vo_mask_list ]

qk_head_cuts_list = [torch.tensor([0]+list(t)).cumsum(-1) for t in consistent_qk_head_size_list]
vo_head_cuts_list = [torch.tensor([0]+list(t)).cumsum(-1) for t in consistent_vo_head_size_list]

In [None]:
def show_masks(state_dict):
    ffn_mask_list = torch.stack([state_dict[f'bert.encoder.layer.{i}.output.dense.weight_mask'][0] for i in range(12)]).int()
    qk_mask_list = torch.stack([state_dict[f'bert.encoder.layer.{i}.attention.self.query.bias_mask'] for i in range(12)]).int()
    vo_mask_list = torch.stack([state_dict[f'bert.encoder.layer.{i}.attention.self.value.bias_mask'] for i in range(12)]).int()
    qk_head_size_list = [t.reshape(12,64).sum(-1) for t in qk_mask_list]
    #qk_head_size_list = [t[t>0] for t in qk_head_size_list]
    vo_head_size_list = vo_mask_list.reshape(12,12,64).sum(-1)
    #vo_head_size_list = [t[t>0] for t in vo_head_size_list]
    print("=====VO=====")
    for i in range(12):
        print(f"{i}: {[i for i in vo_head_size_list[i].tolist() if i >0]}, {vo_head_size_list[i].sum().item()}, {(vo_head_size_list[i]>0).sum().item()}")
    print("Total number of heads:",(vo_head_size_list>0).sum().item())
    print("Total number of MHA layer:",(vo_head_size_list.sum(-1)>0).sum().item())
    
    print("=====FFN=====")
    print(f"FFN size/12: {ffn_mask_list.sum(-1).tolist()} {(ffn_mask_list).sum().item()/12:.1f}")
    print("Total number of FFN layers:",(ffn_mask_list.sum(-1)>0).sum().item())
show_masks(state_dict)

In [None]:
inputs = torch.randint(low=0,high=10000,size=(128,512),device=device)
with torch.no_grad():
    mean,std = inference_time(model,[inputs])
    print(mean,std)
    print(summary(model))
    original_outputs = model(inputs)

In [None]:
# Remove weights where mask==1
pruner = TransformerPruner(model)
pruner.prune(ffn_mask=ffn_mask, save_model=False)
pruner =FineGrainedPruner(model)
pruner.prune(QK_mask_list=qk_mask_list,VO_mask_list=vo_mask_list,save_model=False)

In [None]:
# Remove empty FFN layers and empty MHA layers

from torch import nn
import types
def feed_forward_chunk_for_empty_ffn(self, attention_output):
        layer_output = self.output(attention_output)
        return layer_output

def output_forward(self, input_tensor):
        return self.LayerNorm(self.dense.bias + input_tensor)

def attetion_forward_for_empty_attention(self,
                                        hidden_states,
                                        attention_mask=None,
                                        head_mask=None,
                                        encoder_hidden_states=None,
                                        encoder_attention_mask=None,
                                        past_key_value=None,
                                        output_attentions=False):
    hidden_states = self.output.LayerNorm(self.output.dense.bias + hidden_states)
    return (hidden_states,)

def transform(model: nn.Module,always_ffn=False, always_mha=False):
    base_model = model.base_model
    bert_layers = base_model.encoder.layer
    for layer in bert_layers:
        output = layer.output
        if always_ffn or output.dense.weight.numel()==0: #empty ffn
            print("replace ffn")
            layer.feed_forward_chunk = types.MethodType(feed_forward_chunk_for_empty_ffn,layer)
            layer.output.forward = types.MethodType(output_forward,layer.output)
        attention_output = layer.attention.output
        if always_mha or attention_output.dense.weight.numel()==0: #empty attention
            print("replace mha")
            layer.attention.forward = types.MethodType(attetion_forward_for_empty_attention,layer.attention)

transform(model)

In [None]:
set_head_cuts(model,qk_head_cuts_list,vo_head_cuts_list)

In [None]:
model.eval()
with torch.no_grad():
    pruned_outputs = model(inputs)

In [None]:
# calcuate the discrepency between unpruned and pruned models
torch.max((pruned_outputs.logits-original_outputs.logits).abs())

In [None]:
# show model size
print(summary(model))

In [None]:
# inference time

inputs = torch.randint(low=0,high=10000,size=(128,512),device=device)
with torch.no_grad():
    mean,std = inference_time(model,[inputs])

print("Mean: ", mean)
print("Std: ", std)
