In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from calculate_mask import calculate_mask
from inverse_hessian import calc_inverse_hessian
from input_prehooks import put_input_hooks
from testing_module import calculate_perp

In [2]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_name = "facebook/opt-350m"
# model_name = "facebook/opt-1.3b"

#Load dataset
dataset = load_dataset('c4', 'en', streaming=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

In [3]:
# Calibrate model (get inputs to each layer with calibration data)

calibration_size=4
token_length=512
calibration_batch_size=2

EPSILON = 1e-8
B = 4
Bs = 2

# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train'])), total=calibration_size):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                with torch.no_grad():
                    encoded_input = tokenizer(batch_sentences, return_tensors="pt",
                                              padding="max_length", max_length=token_length,
                                              truncation=True).to(device=device)
                    model(**encoded_input, labels=encoded_input.input_ids)
                    batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break

# Sparsify Model

In [4]:
# function to get module name from parameter name
def get_module_name(param_name):
    if param_name[-5:] == ".bias":
        return param_name[:-5], "bias"
    elif param_name[-7:] == ".weight":
        return param_name[:-7], "weight"
    else:
        return None, None

In [5]:
from input_prehooks import get_feature_storage_name
import gc
layer_blacklist = ['model.decoder.embed_tokens.weight', 'model.decoder.embed_tokens.bias', 'model.decoder.embed_positions.weight']

# Using calibration data (inputs to each intermediate weight layer)
# Iterate through named parameters, calculate inverse hessian and calculate mask

SPARSENESS_LIST = [0.5]#0.1, 0.2, 0.3, 0.5, 0.7, 0.9
for i, SPARSENESS in enumerate(SPARSENESS_LIST):
    
    # Load model with pre-trained head
    model = OPTForCausalLM.from_pretrained(model_name, output_attentions=True,
                                           output_hidden_states=True).to(device=device) # type: ignore
    
    #storage_dir = f'tmp/{model_name}-{SPARSENESS}'
    
    # First, put in forward hooks
    # Don't store inputs, instead store hessians (less data)
    # Only store hessians once, as all models take the same hessians
    print('------------')
    if i == 0:
        feature_hessians = {}
        #put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
        put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
        split_model_calibration(model)
        torch.cuda.empty_cache()
    print('-----------')
    # make a dictionary to access module by name
    module_lookup_dict = {}
    for module_name, module_iter in model.named_modules():
        module_lookup_dict[module_name] = module_iter

    # without this
    param_lookup_dict = {}
    param_names = []
    for name, param in model.named_parameters():
        param_names.append(name)
        param_lookup_dict[name] = param
    
    model.eval()
    model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])
    with torch.no_grad():
        # for name in tqdm(param_names):
        for name in tqdm(param_names, total=len(param_names)):
            param = param_lookup_dict[name]

            # skip the embed layer
            if name in layer_blacklist:
                continue
            # skip norms which have 1 dimension
            if len(param.shape) < 2:
                continue

            module_name, param_type = get_module_name(name)

            # apply to weight layers
            if param_type == "weight":
                #print(f"Doing layer {name}")
                # get layer input from features, key is get_feature_storage_name(module_name)
                # get_feature_storage_name(module_name) stores k_proj, v_proj, q_proj together
                # since they are the same input
                
                layer_hessian = feature_hessians[get_feature_storage_name(module_name)].to(device=device)

                # calculate inverse hessian
                # check if input is flattened e.g. from 8,512,768 to 4096,768
                inv_hess = calc_inverse_hessian(layer_hessian, epsilon=EPSILON)

                # calculate mask
                mask = calculate_mask(W=param, H_inv=inv_hess, p=SPARSENESS, B=B, Bs=Bs)

                # get module from lookup dictionary by module name
                module = module_lookup_dict[module_name]
                # apply mask
                prune.custom_from_mask(module=module, name=param_type, mask=mask)
                prune.remove(module=module, name=param_type)
                gc.collect()
                torch.cuda.empty_cache()           
    pruned_model_name = f'opt-350m-test-{SPARSENESS}'

    torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')

------------


5it [00:03,  1.43it/s]                       


-----------


100%|██████████| 388/388 [00:56<00:00,  6.82it/s]


In [6]:
from save_pruned_model import load_into_model
loaded_model = OPTForCausalLM.from_pretrained(f'facebook/opt-350m', output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore
loaded_model = torch.nn.DataParallel(loaded_model, device_ids=[0,1,2,3])
load_into_model(loaded_model, f'pruned_models/opt-350m-test-0.5.pt')

loaded_model.eval()
_ = loaded_model(torch.randint(high=20, size=(1,10)))

100%|██████████| 388/388 [00:00<00:00, 14464.66it/s]


RuntimeError: Error(s) in loading state_dict for DataParallel:
	Missing key(s) in state_dict: "module.model.decoder.project_out.weight_orig", "module.model.decoder.project_out.weight_mask", "module.model.decoder.project_in.weight_orig", "module.model.decoder.project_in.weight_mask", "module.model.decoder.layers.0.self_attn.k_proj.weight_orig", "module.model.decoder.layers.0.self_attn.k_proj.weight_mask", "module.model.decoder.layers.0.self_attn.v_proj.weight_orig", "module.model.decoder.layers.0.self_attn.v_proj.weight_mask", "module.model.decoder.layers.0.self_attn.q_proj.weight_orig", "module.model.decoder.layers.0.self_attn.q_proj.weight_mask", "module.model.decoder.layers.0.self_attn.out_proj.weight_orig", "module.model.decoder.layers.0.self_attn.out_proj.weight_mask", "module.model.decoder.layers.0.fc1.weight_orig", "module.model.decoder.layers.0.fc1.weight_mask", "module.model.decoder.layers.0.fc2.weight_orig", "module.model.decoder.layers.0.fc2.weight_mask", "module.model.decoder.layers.1.self_attn.k_proj.weight_orig", "module.model.decoder.layers.1.self_attn.k_proj.weight_mask", "module.model.decoder.layers.1.self_attn.v_proj.weight_orig", "module.model.decoder.layers.1.self_attn.v_proj.weight_mask", "module.model.decoder.layers.1.self_attn.q_proj.weight_orig", "module.model.decoder.layers.1.self_attn.q_proj.weight_mask", "module.model.decoder.layers.1.self_attn.out_proj.weight_orig", "module.model.decoder.layers.1.self_attn.out_proj.weight_mask", "module.model.decoder.layers.1.fc1.weight_orig", "module.model.decoder.layers.1.fc1.weight_mask", "module.model.decoder.layers.1.fc2.weight_orig", "module.model.decoder.layers.1.fc2.weight_mask", "module.model.decoder.layers.2.self_attn.k_proj.weight_orig", "module.model.decoder.layers.2.self_attn.k_proj.weight_mask", "module.model.decoder.layers.2.self_attn.v_proj.weight_orig", "module.model.decoder.layers.2.self_attn.v_proj.weight_mask", "module.model.decoder.layers.2.self_attn.q_proj.weight_orig", "module.model.decoder.layers.2.self_attn.q_proj.weight_mask", "module.model.decoder.layers.2.self_attn.out_proj.weight_orig", "module.model.decoder.layers.2.self_attn.out_proj.weight_mask", "module.model.decoder.layers.2.fc1.weight_orig", "module.model.decoder.layers.2.fc1.weight_mask", "module.model.decoder.layers.2.fc2.weight_orig", "module.model.decoder.layers.2.fc2.weight_mask", "module.model.decoder.layers.3.self_attn.k_proj.weight_orig", "module.model.decoder.layers.3.self_attn.k_proj.weight_mask", "module.model.decoder.layers.3.self_attn.v_proj.weight_orig", "module.model.decoder.layers.3.self_attn.v_proj.weight_mask", "module.model.decoder.layers.3.self_attn.q_proj.weight_orig", "module.model.decoder.layers.3.self_attn.q_proj.weight_mask", "module.model.decoder.layers.3.self_attn.out_proj.weight_orig", "module.model.decoder.layers.3.self_attn.out_proj.weight_mask", "module.model.decoder.layers.3.fc1.weight_orig", "module.model.decoder.layers.3.fc1.weight_mask", "module.model.decoder.layers.3.fc2.weight_orig", "module.model.decoder.layers.3.fc2.weight_mask", "module.model.decoder.layers.4.self_attn.k_proj.weight_orig", "module.model.decoder.layers.4.self_attn.k_proj.weight_mask", "module.model.decoder.layers.4.self_attn.v_proj.weight_orig", "module.model.decoder.layers.4.self_attn.v_proj.weight_mask", "module.model.decoder.layers.4.self_attn.q_proj.weight_orig", "module.model.decoder.layers.4.self_attn.q_proj.weight_mask", "module.model.decoder.layers.4.self_attn.out_proj.weight_orig", "module.model.decoder.layers.4.self_attn.out_proj.weight_mask", "module.model.decoder.layers.4.fc1.weight_orig", "module.model.decoder.layers.4.fc1.weight_mask", "module.model.decoder.layers.4.fc2.weight_orig", "module.model.decoder.layers.4.fc2.weight_mask", "module.model.decoder.layers.5.self_attn.k_proj.weight_orig", "module.model.decoder.layers.5.self_attn.k_proj.weight_mask", "module.model.decoder.layers.5.self_attn.v_proj.weight_orig", "module.model.decoder.layers.5.self_attn.v_proj.weight_mask", "module.model.decoder.layers.5.self_attn.q_proj.weight_orig", "module.model.decoder.layers.5.self_attn.q_proj.weight_mask", "module.model.decoder.layers.5.self_attn.out_proj.weight_orig", "module.model.decoder.layers.5.self_attn.out_proj.weight_mask", "module.model.decoder.layers.5.fc1.weight_orig", "module.model.decoder.layers.5.fc1.weight_mask", "module.model.decoder.layers.5.fc2.weight_orig", "module.model.decoder.layers.5.fc2.weight_mask", "module.model.decoder.layers.6.self_attn.k_proj.weight_orig", "module.model.decoder.layers.6.self_attn.k_proj.weight_mask", "module.model.decoder.layers.6.self_attn.v_proj.weight_orig", "module.model.decoder.layers.6.self_attn.v_proj.weight_mask", "module.model.decoder.layers.6.self_attn.q_proj.weight_orig", "module.model.decoder.layers.6.self_attn.q_proj.weight_mask", "module.model.decoder.layers.6.self_attn.out_proj.weight_orig", "module.model.decoder.layers.6.self_attn.out_proj.weight_mask", "module.model.decoder.layers.6.fc1.weight_orig", "module.model.decoder.layers.6.fc1.weight_mask", "module.model.decoder.layers.6.fc2.weight_orig", "module.model.decoder.layers.6.fc2.weight_mask", "module.model.decoder.layers.7.self_attn.k_proj.weight_orig", "module.model.decoder.layers.7.self_attn.k_proj.weight_mask", "module.model.decoder.layers.7.self_attn.v_proj.weight_orig", "module.model.decoder.layers.7.self_attn.v_proj.weight_mask", "module.model.decoder.layers.7.self_attn.q_proj.weight_orig", "module.model.decoder.layers.7.self_attn.q_proj.weight_mask", "module.model.decoder.layers.7.self_attn.out_proj.weight_orig", "module.model.decoder.layers.7.self_attn.out_proj.weight_mask", "module.model.decoder.layers.7.fc1.weight_orig", "module.model.decoder.layers.7.fc1.weight_mask", "module.model.decoder.layers.7.fc2.weight_orig", "module.model.decoder.layers.7.fc2.weight_mask", "module.model.decoder.layers.8.self_attn.k_proj.weight_orig", "module.model.decoder.layers.8.self_attn.k_proj.weight_mask", "module.model.decoder.layers.8.self_attn.v_proj.weight_orig", "module.model.decoder.layers.8.self_attn.v_proj.weight_mask", "module.model.decoder.layers.8.self_attn.q_proj.weight_orig", "module.model.decoder.layers.8.self_attn.q_proj.weight_mask", "module.model.decoder.layers.8.self_attn.out_proj.weight_orig", "module.model.decoder.layers.8.self_attn.out_proj.weight_mask", "module.model.decoder.layers.8.fc1.weight_orig", "module.model.decoder.layers.8.fc1.weight_mask", "module.model.decoder.layers.8.fc2.weight_orig", "module.model.decoder.layers.8.fc2.weight_mask", "module.model.decoder.layers.9.self_attn.k_proj.weight_orig", "module.model.decoder.layers.9.self_attn.k_proj.weight_mask", "module.model.decoder.layers.9.self_attn.v_proj.weight_orig", "module.model.decoder.layers.9.self_attn.v_proj.weight_mask", "module.model.decoder.layers.9.self_attn.q_proj.weight_orig", "module.model.decoder.layers.9.self_attn.q_proj.weight_mask", "module.model.decoder.layers.9.self_attn.out_proj.weight_orig", "module.model.decoder.layers.9.self_attn.out_proj.weight_mask", "module.model.decoder.layers.9.fc1.weight_orig", "module.model.decoder.layers.9.fc1.weight_mask", "module.model.decoder.layers.9.fc2.weight_orig", "module.model.decoder.layers.9.fc2.weight_mask", "module.model.decoder.layers.10.self_attn.k_proj.weight_orig", "module.model.decoder.layers.10.self_attn.k_proj.weight_mask", "module.model.decoder.layers.10.self_attn.v_proj.weight_orig", "module.model.decoder.layers.10.self_attn.v_proj.weight_mask", "module.model.decoder.layers.10.self_attn.q_proj.weight_orig", "module.model.decoder.layers.10.self_attn.q_proj.weight_mask", "module.model.decoder.layers.10.self_attn.out_proj.weight_orig", "module.model.decoder.layers.10.self_attn.out_proj.weight_mask", "module.model.decoder.layers.10.fc1.weight_orig", "module.model.decoder.layers.10.fc1.weight_mask", "module.model.decoder.layers.10.fc2.weight_orig", "module.model.decoder.layers.10.fc2.weight_mask", "module.model.decoder.layers.11.self_attn.k_proj.weight_orig", "module.model.decoder.layers.11.self_attn.k_proj.weight_mask", "module.model.decoder.layers.11.self_attn.v_proj.weight_orig", "module.model.decoder.layers.11.self_attn.v_proj.weight_mask", "module.model.decoder.layers.11.self_attn.q_proj.weight_orig", "module.model.decoder.layers.11.self_attn.q_proj.weight_mask", "module.model.decoder.layers.11.self_attn.out_proj.weight_orig", "module.model.decoder.layers.11.self_attn.out_proj.weight_mask", "module.model.decoder.layers.11.fc1.weight_orig", "module.model.decoder.layers.11.fc1.weight_mask", "module.model.decoder.layers.11.fc2.weight_orig", "module.model.decoder.layers.11.fc2.weight_mask", "module.model.decoder.layers.12.self_attn.k_proj.weight_orig", "module.model.decoder.layers.12.self_attn.k_proj.weight_mask", "module.model.decoder.layers.12.self_attn.v_proj.weight_orig", "module.model.decoder.layers.12.self_attn.v_proj.weight_mask", "module.model.decoder.layers.12.self_attn.q_proj.weight_orig", "module.model.decoder.layers.12.self_attn.q_proj.weight_mask", "module.model.decoder.layers.12.self_attn.out_proj.weight_orig", "module.model.decoder.layers.12.self_attn.out_proj.weight_mask", "module.model.decoder.layers.12.fc1.weight_orig", "module.model.decoder.layers.12.fc1.weight_mask", "module.model.decoder.layers.12.fc2.weight_orig", "module.model.decoder.layers.12.fc2.weight_mask", "module.model.decoder.layers.13.self_attn.k_proj.weight_orig", "module.model.decoder.layers.13.self_attn.k_proj.weight_mask", "module.model.decoder.layers.13.self_attn.v_proj.weight_orig", "module.model.decoder.layers.13.self_attn.v_proj.weight_mask", "module.model.decoder.layers.13.self_attn.q_proj.weight_orig", "module.model.decoder.layers.13.self_attn.q_proj.weight_mask", "module.model.decoder.layers.13.self_attn.out_proj.weight_orig", "module.model.decoder.layers.13.self_attn.out_proj.weight_mask", "module.model.decoder.layers.13.fc1.weight_orig", "module.model.decoder.layers.13.fc1.weight_mask", "module.model.decoder.layers.13.fc2.weight_orig", "module.model.decoder.layers.13.fc2.weight_mask", "module.model.decoder.layers.14.self_attn.k_proj.weight_orig", "module.model.decoder.layers.14.self_attn.k_proj.weight_mask", "module.model.decoder.layers.14.self_attn.v_proj.weight_orig", "module.model.decoder.layers.14.self_attn.v_proj.weight_mask", "module.model.decoder.layers.14.self_attn.q_proj.weight_orig", "module.model.decoder.layers.14.self_attn.q_proj.weight_mask", "module.model.decoder.layers.14.self_attn.out_proj.weight_orig", "module.model.decoder.layers.14.self_attn.out_proj.weight_mask", "module.model.decoder.layers.14.fc1.weight_orig", "module.model.decoder.layers.14.fc1.weight_mask", "module.model.decoder.layers.14.fc2.weight_orig", "module.model.decoder.layers.14.fc2.weight_mask", "module.model.decoder.layers.15.self_attn.k_proj.weight_orig", "module.model.decoder.layers.15.self_attn.k_proj.weight_mask", "module.model.decoder.layers.15.self_attn.v_proj.weight_orig", "module.model.decoder.layers.15.self_attn.v_proj.weight_mask", "module.model.decoder.layers.15.self_attn.q_proj.weight_orig", "module.model.decoder.layers.15.self_attn.q_proj.weight_mask", "module.model.decoder.layers.15.self_attn.out_proj.weight_orig", "module.model.decoder.layers.15.self_attn.out_proj.weight_mask", "module.model.decoder.layers.15.fc1.weight_orig", "module.model.decoder.layers.15.fc1.weight_mask", "module.model.decoder.layers.15.fc2.weight_orig", "module.model.decoder.layers.15.fc2.weight_mask", "module.model.decoder.layers.16.self_attn.k_proj.weight_orig", "module.model.decoder.layers.16.self_attn.k_proj.weight_mask", "module.model.decoder.layers.16.self_attn.v_proj.weight_orig", "module.model.decoder.layers.16.self_attn.v_proj.weight_mask", "module.model.decoder.layers.16.self_attn.q_proj.weight_orig", "module.model.decoder.layers.16.self_attn.q_proj.weight_mask", "module.model.decoder.layers.16.self_attn.out_proj.weight_orig", "module.model.decoder.layers.16.self_attn.out_proj.weight_mask", "module.model.decoder.layers.16.fc1.weight_orig", "module.model.decoder.layers.16.fc1.weight_mask", "module.model.decoder.layers.16.fc2.weight_orig", "module.model.decoder.layers.16.fc2.weight_mask", "module.model.decoder.layers.17.self_attn.k_proj.weight_orig", "module.model.decoder.layers.17.self_attn.k_proj.weight_mask", "module.model.decoder.layers.17.self_attn.v_proj.weight_orig", "module.model.decoder.layers.17.self_attn.v_proj.weight_mask", "module.model.decoder.layers.17.self_attn.q_proj.weight_orig", "module.model.decoder.layers.17.self_attn.q_proj.weight_mask", "module.model.decoder.layers.17.self_attn.out_proj.weight_orig", "module.model.decoder.layers.17.self_attn.out_proj.weight_mask", "module.model.decoder.layers.17.fc1.weight_orig", "module.model.decoder.layers.17.fc1.weight_mask", "module.model.decoder.layers.17.fc2.weight_orig", "module.model.decoder.layers.17.fc2.weight_mask", "module.model.decoder.layers.18.self_attn.k_proj.weight_orig", "module.model.decoder.layers.18.self_attn.k_proj.weight_mask", "module.model.decoder.layers.18.self_attn.v_proj.weight_orig", "module.model.decoder.layers.18.self_attn.v_proj.weight_mask", "module.model.decoder.layers.18.self_attn.q_proj.weight_orig", "module.model.decoder.layers.18.self_attn.q_proj.weight_mask", "module.model.decoder.layers.18.self_attn.out_proj.weight_orig", "module.model.decoder.layers.18.self_attn.out_proj.weight_mask", "module.model.decoder.layers.18.fc1.weight_orig", "module.model.decoder.layers.18.fc1.weight_mask", "module.model.decoder.layers.18.fc2.weight_orig", "module.model.decoder.layers.18.fc2.weight_mask", "module.model.decoder.layers.19.self_attn.k_proj.weight_orig", "module.model.decoder.layers.19.self_attn.k_proj.weight_mask", "module.model.decoder.layers.19.self_attn.v_proj.weight_orig", "module.model.decoder.layers.19.self_attn.v_proj.weight_mask", "module.model.decoder.layers.19.self_attn.q_proj.weight_orig", "module.model.decoder.layers.19.self_attn.q_proj.weight_mask", "module.model.decoder.layers.19.self_attn.out_proj.weight_orig", "module.model.decoder.layers.19.self_attn.out_proj.weight_mask", "module.model.decoder.layers.19.fc1.weight_orig", "module.model.decoder.layers.19.fc1.weight_mask", "module.model.decoder.layers.19.fc2.weight_orig", "module.model.decoder.layers.19.fc2.weight_mask", "module.model.decoder.layers.20.self_attn.k_proj.weight_orig", "module.model.decoder.layers.20.self_attn.k_proj.weight_mask", "module.model.decoder.layers.20.self_attn.v_proj.weight_orig", "module.model.decoder.layers.20.self_attn.v_proj.weight_mask", "module.model.decoder.layers.20.self_attn.q_proj.weight_orig", "module.model.decoder.layers.20.self_attn.q_proj.weight_mask", "module.model.decoder.layers.20.self_attn.out_proj.weight_orig", "module.model.decoder.layers.20.self_attn.out_proj.weight_mask", "module.model.decoder.layers.20.fc1.weight_orig", "module.model.decoder.layers.20.fc1.weight_mask", "module.model.decoder.layers.20.fc2.weight_orig", "module.model.decoder.layers.20.fc2.weight_mask", "module.model.decoder.layers.21.self_attn.k_proj.weight_orig", "module.model.decoder.layers.21.self_attn.k_proj.weight_mask", "module.model.decoder.layers.21.self_attn.v_proj.weight_orig", "module.model.decoder.layers.21.self_attn.v_proj.weight_mask", "module.model.decoder.layers.21.self_attn.q_proj.weight_orig", "module.model.decoder.layers.21.self_attn.q_proj.weight_mask", "module.model.decoder.layers.21.self_attn.out_proj.weight_orig", "module.model.decoder.layers.21.self_attn.out_proj.weight_mask", "module.model.decoder.layers.21.fc1.weight_orig", "module.model.decoder.layers.21.fc1.weight_mask", "module.model.decoder.layers.21.fc2.weight_orig", "module.model.decoder.layers.21.fc2.weight_mask", "module.model.decoder.layers.22.self_attn.k_proj.weight_orig", "module.model.decoder.layers.22.self_attn.k_proj.weight_mask", "module.model.decoder.layers.22.self_attn.v_proj.weight_orig", "module.model.decoder.layers.22.self_attn.v_proj.weight_mask", "module.model.decoder.layers.22.self_attn.q_proj.weight_orig", "module.model.decoder.layers.22.self_attn.q_proj.weight_mask", "module.model.decoder.layers.22.self_attn.out_proj.weight_orig", "module.model.decoder.layers.22.self_attn.out_proj.weight_mask", "module.model.decoder.layers.22.fc1.weight_orig", "module.model.decoder.layers.22.fc1.weight_mask", "module.model.decoder.layers.22.fc2.weight_orig", "module.model.decoder.layers.22.fc2.weight_mask", "module.model.decoder.layers.23.self_attn.k_proj.weight_orig", "module.model.decoder.layers.23.self_attn.k_proj.weight_mask", "module.model.decoder.layers.23.self_attn.v_proj.weight_orig", "module.model.decoder.layers.23.self_attn.v_proj.weight_mask", "module.model.decoder.layers.23.self_attn.q_proj.weight_orig", "module.model.decoder.layers.23.self_attn.q_proj.weight_mask", "module.model.decoder.layers.23.self_attn.out_proj.weight_orig", "module.model.decoder.layers.23.self_attn.out_proj.weight_mask", "module.model.decoder.layers.23.fc1.weight_orig", "module.model.decoder.layers.23.fc1.weight_mask", "module.model.decoder.layers.23.fc2.weight_orig", "module.model.decoder.layers.23.fc2.weight_mask". 
	Unexpected key(s) in state_dict: "module.model.decoder.project_out.weight", "module.model.decoder.project_in.weight", "module.model.decoder.layers.0.self_attn.k_proj.weight", "module.model.decoder.layers.0.self_attn.v_proj.weight", "module.model.decoder.layers.0.self_attn.q_proj.weight", "module.model.decoder.layers.0.self_attn.out_proj.weight", "module.model.decoder.layers.0.fc1.weight", "module.model.decoder.layers.0.fc2.weight", "module.model.decoder.layers.1.self_attn.k_proj.weight", "module.model.decoder.layers.1.self_attn.v_proj.weight", "module.model.decoder.layers.1.self_attn.q_proj.weight", "module.model.decoder.layers.1.self_attn.out_proj.weight", "module.model.decoder.layers.1.fc1.weight", "module.model.decoder.layers.1.fc2.weight", "module.model.decoder.layers.2.self_attn.k_proj.weight", "module.model.decoder.layers.2.self_attn.v_proj.weight", "module.model.decoder.layers.2.self_attn.q_proj.weight", "module.model.decoder.layers.2.self_attn.out_proj.weight", "module.model.decoder.layers.2.fc1.weight", "module.model.decoder.layers.2.fc2.weight", "module.model.decoder.layers.3.self_attn.k_proj.weight", "module.model.decoder.layers.3.self_attn.v_proj.weight", "module.model.decoder.layers.3.self_attn.q_proj.weight", "module.model.decoder.layers.3.self_attn.out_proj.weight", "module.model.decoder.layers.3.fc1.weight", "module.model.decoder.layers.3.fc2.weight", "module.model.decoder.layers.4.self_attn.k_proj.weight", "module.model.decoder.layers.4.self_attn.v_proj.weight", "module.model.decoder.layers.4.self_attn.q_proj.weight", "module.model.decoder.layers.4.self_attn.out_proj.weight", "module.model.decoder.layers.4.fc1.weight", "module.model.decoder.layers.4.fc2.weight", "module.model.decoder.layers.5.self_attn.k_proj.weight", "module.model.decoder.layers.5.self_attn.v_proj.weight", "module.model.decoder.layers.5.self_attn.q_proj.weight", "module.model.decoder.layers.5.self_attn.out_proj.weight", "module.model.decoder.layers.5.fc1.weight", "module.model.decoder.layers.5.fc2.weight", "module.model.decoder.layers.6.self_attn.k_proj.weight", "module.model.decoder.layers.6.self_attn.v_proj.weight", "module.model.decoder.layers.6.self_attn.q_proj.weight", "module.model.decoder.layers.6.self_attn.out_proj.weight", "module.model.decoder.layers.6.fc1.weight", "module.model.decoder.layers.6.fc2.weight", "module.model.decoder.layers.7.self_attn.k_proj.weight", "module.model.decoder.layers.7.self_attn.v_proj.weight", "module.model.decoder.layers.7.self_attn.q_proj.weight", "module.model.decoder.layers.7.self_attn.out_proj.weight", "module.model.decoder.layers.7.fc1.weight", "module.model.decoder.layers.7.fc2.weight", "module.model.decoder.layers.8.self_attn.k_proj.weight", "module.model.decoder.layers.8.self_attn.v_proj.weight", "module.model.decoder.layers.8.self_attn.q_proj.weight", "module.model.decoder.layers.8.self_attn.out_proj.weight", "module.model.decoder.layers.8.fc1.weight", "module.model.decoder.layers.8.fc2.weight", "module.model.decoder.layers.9.self_attn.k_proj.weight", "module.model.decoder.layers.9.self_attn.v_proj.weight", "module.model.decoder.layers.9.self_attn.q_proj.weight", "module.model.decoder.layers.9.self_attn.out_proj.weight", "module.model.decoder.layers.9.fc1.weight", "module.model.decoder.layers.9.fc2.weight", "module.model.decoder.layers.10.self_attn.k_proj.weight", "module.model.decoder.layers.10.self_attn.v_proj.weight", "module.model.decoder.layers.10.self_attn.q_proj.weight", "module.model.decoder.layers.10.self_attn.out_proj.weight", "module.model.decoder.layers.10.fc1.weight", "module.model.decoder.layers.10.fc2.weight", "module.model.decoder.layers.11.self_attn.k_proj.weight", "module.model.decoder.layers.11.self_attn.v_proj.weight", "module.model.decoder.layers.11.self_attn.q_proj.weight", "module.model.decoder.layers.11.self_attn.out_proj.weight", "module.model.decoder.layers.11.fc1.weight", "module.model.decoder.layers.11.fc2.weight", "module.model.decoder.layers.12.self_attn.k_proj.weight", "module.model.decoder.layers.12.self_attn.v_proj.weight", "module.model.decoder.layers.12.self_attn.q_proj.weight", "module.model.decoder.layers.12.self_attn.out_proj.weight", "module.model.decoder.layers.12.fc1.weight", "module.model.decoder.layers.12.fc2.weight", "module.model.decoder.layers.13.self_attn.k_proj.weight", "module.model.decoder.layers.13.self_attn.v_proj.weight", "module.model.decoder.layers.13.self_attn.q_proj.weight", "module.model.decoder.layers.13.self_attn.out_proj.weight", "module.model.decoder.layers.13.fc1.weight", "module.model.decoder.layers.13.fc2.weight", "module.model.decoder.layers.14.self_attn.k_proj.weight", "module.model.decoder.layers.14.self_attn.v_proj.weight", "module.model.decoder.layers.14.self_attn.q_proj.weight", "module.model.decoder.layers.14.self_attn.out_proj.weight", "module.model.decoder.layers.14.fc1.weight", "module.model.decoder.layers.14.fc2.weight", "module.model.decoder.layers.15.self_attn.k_proj.weight", "module.model.decoder.layers.15.self_attn.v_proj.weight", "module.model.decoder.layers.15.self_attn.q_proj.weight", "module.model.decoder.layers.15.self_attn.out_proj.weight", "module.model.decoder.layers.15.fc1.weight", "module.model.decoder.layers.15.fc2.weight", "module.model.decoder.layers.16.self_attn.k_proj.weight", "module.model.decoder.layers.16.self_attn.v_proj.weight", "module.model.decoder.layers.16.self_attn.q_proj.weight", "module.model.decoder.layers.16.self_attn.out_proj.weight", "module.model.decoder.layers.16.fc1.weight", "module.model.decoder.layers.16.fc2.weight", "module.model.decoder.layers.17.self_attn.k_proj.weight", "module.model.decoder.layers.17.self_attn.v_proj.weight", "module.model.decoder.layers.17.self_attn.q_proj.weight", "module.model.decoder.layers.17.self_attn.out_proj.weight", "module.model.decoder.layers.17.fc1.weight", "module.model.decoder.layers.17.fc2.weight", "module.model.decoder.layers.18.self_attn.k_proj.weight", "module.model.decoder.layers.18.self_attn.v_proj.weight", "module.model.decoder.layers.18.self_attn.q_proj.weight", "module.model.decoder.layers.18.self_attn.out_proj.weight", "module.model.decoder.layers.18.fc1.weight", "module.model.decoder.layers.18.fc2.weight", "module.model.decoder.layers.19.self_attn.k_proj.weight", "module.model.decoder.layers.19.self_attn.v_proj.weight", "module.model.decoder.layers.19.self_attn.q_proj.weight", "module.model.decoder.layers.19.self_attn.out_proj.weight", "module.model.decoder.layers.19.fc1.weight", "module.model.decoder.layers.19.fc2.weight", "module.model.decoder.layers.20.self_attn.k_proj.weight", "module.model.decoder.layers.20.self_attn.v_proj.weight", "module.model.decoder.layers.20.self_attn.q_proj.weight", "module.model.decoder.layers.20.self_attn.out_proj.weight", "module.model.decoder.layers.20.fc1.weight", "module.model.decoder.layers.20.fc2.weight", "module.model.decoder.layers.21.self_attn.k_proj.weight", "module.model.decoder.layers.21.self_attn.v_proj.weight", "module.model.decoder.layers.21.self_attn.q_proj.weight", "module.model.decoder.layers.21.self_attn.out_proj.weight", "module.model.decoder.layers.21.fc1.weight", "module.model.decoder.layers.21.fc2.weight", "module.model.decoder.layers.22.self_attn.k_proj.weight", "module.model.decoder.layers.22.self_attn.v_proj.weight", "module.model.decoder.layers.22.self_attn.q_proj.weight", "module.model.decoder.layers.22.self_attn.out_proj.weight", "module.model.decoder.layers.22.fc1.weight", "module.model.decoder.layers.22.fc2.weight", "module.model.decoder.layers.23.self_attn.k_proj.weight", "module.model.decoder.layers.23.self_attn.v_proj.weight", "module.model.decoder.layers.23.self_attn.q_proj.weight", "module.model.decoder.layers.23.self_attn.out_proj.weight", "module.model.decoder.layers.23.fc1.weight", "module.model.decoder.layers.23.fc2.weight". 

In [9]:
def get_prop_zeros(model):
    return torch.sum(model.module.get_decoder().layers[0].self_attn.k_proj.weight == 0) / (torch.sum(model.module.get_decoder().layers[0].self_attn.k_proj.weight == 0) + torch.sum(model.module.get_decoder().layers[0].self_attn.k_proj.weight != 0))

print(get_prop_zeros(loaded_model))
print(get_prop_zeros(model))

tensor(0., device='cuda:0')
tensor(0.5005, device='cuda:0')


In [None]:
prune_list = []
layer_blacklist = ['', 'module','module.model','module.model.decoder',
                   'module.model.decoder.embed_tokens',
                   'module.model.decoder.embed_tokens',
                   'module.model.decoder.embed_positions']
for name, module in model.named_modules():
    # skip the embed layer or skip norms which have 1 dimension
    if name in layer_blacklist or 'norm' in name or not isinstance(module, torch.nn.Module):
        continue
    prune_list.append((module, 'weight'))

# prune to 0s
class ThresholdPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    # default threshold is 0, prunes weights that are already 0 (for training)
    def __init__(self, threshold=1e-8):
        self.threshold = threshold

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) >= self.threshold
    
prune.global_unstructured(prune_list,
                          pruning_method=ThresholdPruning)