In [5]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from calculate_mask import calculate_mask
from inverse_hessian import inverse_hessian
from input_prehooks import put_input_hooks
from testing_module import calculate_perp

In [None]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_name = "facebook/opt-125m"

#Load dataset
dataset = load_dataset('c4', 'en', streaming=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

# Load generator
generator = pipeline('text-generation', model=model_name)

In [None]:
# Calibrate model (get inputs to each layer with calibration data)
calibration_size=128
token_length=1024
calibrate_on_cpu = False
calibration_batch_size=2
EPSILON = 1e-8
B = 128
Bs = 64
layer_blacklist = ['model.decoder.embed_tokens.weight', 'model.decoder.embed_tokens.bias',
'model.decoder.embed_positions.weight']

# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train']))):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                encoded_input = tokenizer(batch_sentences, return_tensors="pt", padding="max_length", max_length=token_length, truncation=True).to(device=device)
                with torch.no_grad():
                    model(**encoded_input, labels=encoded_input.input_ids)
                batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break
            
# function to get module name from parameter name
def get_module_name(param_name):
    if param_name[-5:] == ".bias":
        return param_name[:-5], "bias"
    elif param_name[-7:] == ".weight":
        return param_name[:-7], "weight"
    else:
        return None, None
    
for SPARSENESS in [0.2, 0.4, 0.6, 0.8]:
    print(f'On SPASENESS {SPARSENESS}')
    # Load model with pre-trained head
    model = OPTForCausalLM.from_pretrained(model_name, output_attentions=True,
                                           output_hidden_states=True).to(device=device) # type: ignore
    
    model.eval()
    # First, put in forward hooks
    features = {}
    put_input_hooks(model=model, features=features, feature_storage_device='cpu')
    split_model_calibration(model)
    # make a dictionary to access module by name
    module_lookup_dict = {}
    for module_name, module_iter in model.named_modules():
        module_lookup_dict[module_name] = module_iter
        
    #Iterate through named parameters, calculate inverse hessian and mask
    param_lookup_dict = {}
    param_names = []
    for name, param in model.named_parameters():
        param_names.append(name)
        param_lookup_dict[name] = param
    print(f'SPARSIFYING SPASENESS {SPARSENESS}')
    with torch.no_grad():
        for name in tqdm(param_names):
            param = param_lookup_dict[name]

            # skip the embed layer
            if name in layer_blacklist:
                continue

            # skip norms which have 1 dimension
            if len(param.shape) < 2:
                continue

            module_name, param_type = get_module_name(name)

            # apply to weight and bias layers
            if param_type == "weight" or param_type == "bias":
                # input to parameter
                layer_input = features[module_name].to(device=device)
                # calculate inverse hessian
                # check if input is flattened e.g. from 8,512,768 to 4096,768
                if len(layer_input.shape) == 2:
                    inv_hess = inverse_hessian(torch.transpose(layer_input, 0, 1), epsilon=EPSILON, 
                    flattened=True).to(device=device)

                else:
                    inv_hess = inverse_hessian(torch.transpose(layer_input, 1, 2), epsilon=EPSILON,
                    flattened=False).to(device=device)

                # calculate mask
                mask = calculate_mask(W=param, H_inv=inv_hess, p=SPARSENESS, B=B, Bs=Bs)

                # get module from lookup dictionary by module name
                module = module_lookup_dict[module_name]
                # apply mask
                prune.custom_from_mask(module=module, name=param_type, mask=mask)
                
        print(f'SAVING SPASENESS {SPARSENESS}')
        pruned_model_name = f'opt-125m-{SPARSENESS}'
        torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')