In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from calculate_mask import calculate_mask
from inverse_hessian import calc_inverse_hessian
from input_prehooks import put_input_hooks
import gc

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_size = "opt-125m"

model_name = f"facebook/{model_size}"
# model_name = "facebook/opt-1.3b"

#Load dataset
dataset = load_dataset('c4', 'en', streaming=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

In [3]:
# Calibrate model (get inputs to each layer with calibration data)

calibration_size=4
token_length=512
calibration_batch_size=2

EPSILON = 1e-8
B = 4
Bs = 2

# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train'])), total=calibration_size):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                with torch.no_grad():
                    encoded_input = tokenizer(batch_sentences, return_tensors="pt",
                                              padding="max_length", max_length=token_length,
                                              truncation=True).to(device=device)
                    model(**encoded_input, labels=encoded_input.input_ids)
                    batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break

In [6]:
# SparseGPT fine tune loop
import torch

def get_prop_zeros(model):
    return torch.sum(model.module.get_decoder().layers[0].self_attn.k_proj.weight == 0) / (torch.numel(model.module.get_decoder().layers[0].self_attn.k_proj.weight))

# now, fine tune loop
# sparseness is defined as proportion of nonzeros (opposite of intuitive)
# sparseness_sequence = [.9, .8, .7, .6, .5, .4, .3, .2]
# sparseness_sequence = [.9, .5, .4]

# model is model to iteratively tune and prune, will do so in place
# model_size is for naming the save files (like opt-125m)
# sparseness sequence is sequence of sparsities (.8 sparseness = 20% proportion of zeros)
# training_data and tokenizer are for fine_tuning (should already be preprocessed with torch.format and stuff)
def iterative_sparsegpt_prune_tune(model, model_size, sparseness_sequence, feature_hessians, EPSILON, B, Bs, tokenizer, EPOCH_COUNT):
    # for sparseness_index in range(len(sparseness_sequence)):
    for sparseness in sparseness_sequence:
        sparsegpt_prune(model=model, model_name=model_size, feature_hessians=feature_hessians, SPARSENESS=sparseness, EPSILON=EPSILON, B=B, Bs=Bs)
        #del model
        #torch.cuda.empty_cache()
        #fine_tune(model_name=model_name, EPOCH_COUNT=EPOCH_COUNT, tokenizer=tokenizer)

        # deactivate masks
        #unmask_model(model=model)
        #torch.cuda.empty_cache()
        #print(f"proportion of zeros: {get_prop_zeros(model)}")

        #pruned_model_name = f'{model_size}-finetuned-{sparseness}'
        #torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')
        #del model
def iterative_cerebras_prune_tune(model, model_size, sparseness_sequence, training_data, tokenizer, EPOCH_COUNT):
    for sparseness in sparseness_sequence:

        mask_lowest(model=model, amount=1-sparseness)

        # activate masks
        mask_from_pruned(model=model)

        fine_tune(model=model, training_data=training_data, EPOCH_COUNT=EPOCH_COUNT, tokenizer=tokenizer)

        # deactivate masks
        unmask_model(model=model)

        print(f"proportion of zeros: {get_prop_zeros(model)}")

        pruned_model_name = f'{model_size}-cerebras-tune-and-prune-{sparseness}'
        torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')

In [7]:
# function to get module name from parameter name
def get_module_name(param_name):
    if param_name[-5:] == ".bias":
        return param_name[:-5], "bias"
    elif param_name[-7:] == ".weight":
        return param_name[:-7], "weight"
    
    elif param_name[-10:] == ".bias_orig":
        return param_name[:-10], "bias"
    elif param_name[-12:] == ".weight_orig":
        return param_name[:-12], "weight"
    else:
        return None, None

# load model without masks
def load_unmasked_model(existing_model, state_dict_path):
    existing_model.load_state_dict(torch.load(state_dict_path))

# prune 0s to a mask, to make training easier (ostensibly)
class ZeroPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    # default threshold is 0, prunes weights that are already 0 (for training)
    def __init__(self):
        pass

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) != 0

# apply pytorch mask in place of 0 weights to make backpropagation easier for training
default_opt_blacklist = ['model.decoder.embed_tokens', 'model.decoder.embed_positions']
def mask_from_pruned(model, module_blacklist=default_opt_blacklist):
    module_dict = {}
    for n, m in model.named_modules():
        module_dict[n] = m
    # print(module_dict.keys())
    
    parameter_list = []
    param_dict = {}
    for n, m in model.named_parameters():
        parameter_list.append(n)
        param_dict[n] = m
    # print(parameter_list)

    for n in parameter_list:
        module_name, param_type = get_module_name(n)

        # skip bias, embed, etc parameters
        if module_name in module_blacklist or module_name is None \
            or param_type is None or param_type!="weight":
            continue

        if len(param_dict[n].shape) < 2:
            continue

        ZeroPruning.apply(module=module_dict[module_name], name=param_type)
# unmask model with 0s in place
def unmask_model(model, module_blacklist=default_opt_blacklist):
    module_dict = {}
    for n, m in model.named_modules():
        module_dict[n] = m
    # print(module_dict.keys())
    
    parameter_list = []
    param_dict = {}
    for n, m in model.named_parameters():
        parameter_list.append(n)
        param_dict[n] = m
    # print(parameter_list)

    for n in parameter_list:
        module_name, param_type = get_module_name(n)

        # skip bias, embed, etc parameters
        if module_name in module_blacklist or module_name is None \
            or param_type is None or param_type!="weight":
            continue

        if len(param_dict[n].shape) < 2:
            continue
            
        prune.remove(module=module_dict[module_name], name=param_type)
        torch.cuda.clear_cache()

# load model with masks
def load_masked_model(existing_model, state_dict_path):

    # first load like normal
    load_unmasked_model(existing_model, state_dict_path)
    
    # then reapply the (previously removed) masks
    mask_from_pruned(model=existing_model)

    # prune.global_unstructured(
    #     existing_model.parameters(), pruning_method=ThresholdPruning, threshold=0
    # )

In [8]:
from input_prehooks import get_feature_storage_name
import gc
from SparseGPT_pruning import sparsegpt_prune
layer_blacklist = ['model.decoder.embed_tokens.weight', 'model.decoder.embed_tokens.bias', 'model.decoder.embed_positions.weight']

# Using calibration data (inputs to each intermediate weight layer)
# Iterate through named parameters, calculate inverse hessian and calculate mask

SPARSENESS_LIST = [0.5]#0.1, 0.2, 0.3, 0.5, 0.7, 0.9
for i, SPARSENESS in enumerate(SPARSENESS_LIST):
    
    # Load model with pre-trained head
    model = OPTForCausalLM.from_pretrained(model_name, output_attentions=True,
                                           output_hidden_states=True).to(device=device) # type: ignore
    
    #storage_dir = f'tmp/{model_name}-{SPARSENESS}'
    
    # First, put in forward hooks
    # Don't store inputs, instead store hessians (less data)
    # Only store hessians once, as all models take the same hessians
    print('------------')
    if i == 0:
        feature_hessians = {}
        #put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
        put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
        split_model_calibration(model)
        torch.cuda.empty_cache()
    print('-----------')
    
    sparsegpt_prune(model=model, feature_hessians=feature_hessians, # type: ignore
    EPSILON=EPSILON, SPARSENESS=SPARSENESS, B=B, Bs=Bs)

    pruned_model_name = f'{model_size}-test-{SPARSENESS}'
    # make a dictionary to access module by name
    module_lookup_dict = {}
    for module_name, module_iter in model.named_modules():
        module_lookup_dict[module_name] = module_iter

    # without this
    param_lookup_dict = {}
    param_names = []
    for name, param in model.named_parameters():
        param_names.append(name)
        param_lookup_dict[name] = param
    
    model.eval()
    #model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])
    with torch.no_grad():
        # for name in tqdm(param_names):

    torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')

------------


5it [00:09,  1.88s/it]                       


-----------


100%|██████████| 196/196 [02:04<00:00,  1.58it/s]


In [7]:
from save_pruned_model import load_unmasked_model, load_masked_model

loaded_model = OPTForCausalLM.from_pretrained(f'facebook/{model_size}', output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore
# loaded_model = torch.nn.DataParallel(loaded_model, device_ids=[0,1,2,3])
load_unmasked_model(loaded_model, f'pruned_models/{model_size}-test-0.5.pt')

loaded_model_2 = OPTForCausalLM.from_pretrained(f'facebook/{model_size}', output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore
load_masked_model(loaded_model_2, f'pruned_models/{model_size}-test-0.5.pt')

def get_prop_zeros(model):
    return torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) / (torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) + torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight != 0))

print(get_prop_zeros(loaded_model))
print(get_prop_zeros(model))
print(get_prop_zeros(loaded_model_2))
# loaded_model.eval()
# _ = loaded_model(torch.randint(high=20, size=(1,10)))

tensor(0.5007)
tensor(0.5007)
tensor(0.5007)


In [None]:
prune_list = []
layer_blacklist = ['', 'module','module.model','module.model.decoder',
                   'module.model.decoder.embed_tokens',
                   'module.model.decoder.embed_tokens',
                   'module.model.decoder.embed_positions']
for name, module in model.named_modules():
    # skip the embed layer or skip norms which have 1 dimension
    if name in layer_blacklist or 'norm' in name or not isinstance(module, torch.nn.Module):
        continue
    prune_list.append((module, 'weight'))

# prune to 0s
class ThresholdPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    # default threshold is 0, prunes weights that are already 0 (for training)
    def __init__(self, threshold=1e-8):
        self.threshold = threshold

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) >= self.threshold
    
prune.global_unstructured(prune_list,
                          pruning_method=ThresholdPruning)

In [None]:
sparsegpt_prune()