In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from calculate_mask import calculate_mask
from inverse_hessian import calc_inverse_hessian
from input_prehooks import put_input_hooks
from testing_module import calculate_perp

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_size = "opt-125m"

model_name = f"facebook/{model_size}"
# model_name = "facebook/opt-1.3b"

#Load dataset
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', streaming=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

In [3]:
# Calibrate model (get inputs to each layer with calibration data)

calibration_size=4
token_length=512
calibration_batch_size=2

EPSILON = 1e-8
B = 4
Bs = 2

# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train'])), total=calibration_size):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                with torch.no_grad():
                    encoded_input = tokenizer(batch_sentences, return_tensors="pt",
                                              padding="max_length", max_length=token_length,
                                              truncation=True).to(device=device)
                    model(**encoded_input, labels=encoded_input.input_ids)
                    batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break

# Sparsify Model

In [4]:
# function to get module name from parameter name
def get_module_name(param_name):
    if param_name[-5:] == ".bias":
        return param_name[:-5], "bias"
    elif param_name[-7:] == ".weight":
        return param_name[:-7], "weight"
    else:
        return None, None

In [5]:
import torch as torch

'''
W is weights matrix for one layer
H_inv is inverse hessian for one layer
p is proportion of weights we're pruning (top p%)
B is lazy block size (low B helps to reduce memory use)
Bs is inverse of how often to make masks (e.g. when Bs is 4, make new masks with specified sparseness for every 4 columns)
'''

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def prop_zeros(tens):
    return torch.sum(tens == 0) / torch.numel(tens)

# calculate new mask of previously masked model
def calculate_mask(
    W,
    H_inv,
    p,
    B,
    Bs,
    ):

    # Get the number of rows and columns in W
    (d_row, d_col) = W.shape

    # Initialize the pruning mask M and block quantization errors E to all zeros

    M = torch.zeros(d_row, d_col, dtype=torch.bool).to(device=device)
    E = torch.zeros(d_row, B, dtype=torch.float64).to(device=device)

    # previous mask, check where weights are 0 (don't want to update/select these weights)
    prev_mask = (W != 0)
    # add up tot number of weights already masked
    prev_num_masked = torch.sum(~prev_mask)
    # print(f"num weights masked: {prev_num_masked}")
    # print(f"total weights: {W.numel}")
    # update proportion to prune, accounting for not pruning masked weights
    # print(f"original p: {p}")
    p *= (torch.numel(W) - prev_num_masked)/torch.numel(W)
    # p = 1 - (1-p) * (torch.numel(W) - prev_num_masked)/torch.numel(W)
    # print(f"new p: {p}")

    # only need to calculate w_square and h_square once
    # Loop over blocks of columns of W (as specified by B)

    for i in range(0, d_col, B):

        # Loop over columns within a block
        for j in range(i, min(i + B, d_col)):

            # If j is a multiple of Bs, prune a portion of the weights
            if j % Bs == 0:
                # Get the mask for the largest (1 - p)% of weights based on squared value and inverse hessian
                # ASTERISK: prune_values is matrix of w^2/H^(-1)_cc

                # Finding respective sections of hessian and weights matrix
                w_square_section = torch.square(W[:, j:j + Bs])
                h_square_section = torch.square(H_inv[j:j + Bs, j:j + Bs]).diag() # 1 dimensional vector

                prev_mask_section = prev_mask[:, j:j + Bs]

                # getting the prune values matrix from W and H^-1 sections
                prune_values = w_square_section / h_square_section.unsqueeze(0)

                # set prune_values of already-pruned weights to 0 (to not select them)
                prune_values = prune_values * prev_mask_section

                num_el_prune = int(p * prune_values.numel())

                cutoff_value = torch.topk(prune_values.flatten(), num_el_prune, largest=True).values[-1]

                #getting the final mask
                mask = prune_values > cutoff_value

                # print(f"num zeros in new mask: {prop_zeros(mask)}")
                #masking
                M[:, j:j + Bs] = mask * prev_mask_section
                # print(f"num zeros in new mask with old mask: {prop_zeros(M[:, j:j + Bs])}")

            # Calculate the pruning error for this column
            E[:, j - i] = W[:, j] / H_inv[j, j]

            # Freeze the weights that are not pruned by multiplying by the pruning mask
            E[:, j - i] = (~M[:, j]) * E[:, j - i]

            # Also freeze previously masked weights
            # E[:, j - i] = prev_mask[: j] * E[:, j - i]

            W[:, j:i + B] -= torch.ger(E[:, j - i], H_inv[j, j:i + B])
            
        # Update all remaining weights
        W[:, i + B:] -= torch.matmul(E, H_inv[i:i + B, i + B:])

    # print(f"Proportion of Mask 0s: {prop_zeros(M)}")
    # return mask
    return M

In [6]:
from input_prehooks import get_feature_storage_name
import gc
import torch
from tqdm import tqdm
from torch.nn.utils import prune
from inverse_hessian import calc_inverse_hessian
# import calculate_mask
# import iterative_calculate_mask

opt_blacklist = ['model.decoder.embed_tokens', 'model.decoder.embed_positions']

#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Using calibration data (inputs to each intermediate weight layer)
# Iterate through named parameters, calculate inverse hessian and calculate mask

def get_module_name(param_name):
    if param_name[-5:] == ".bias":
        return param_name[:-5], "bias"
    elif param_name[-7:] == ".weight":
        return param_name[:-7], "weight"
    else:
        return None, None


def sparsegpt_prune(model, feature_hessians, 
EPSILON, SPARSENESS, B, Bs, module_blacklist=opt_blacklist, iterative=True):
    module_dict = {}
    for n, m in model.named_modules():
        module_dict[n] = m
    # print(module_dict.keys())
    
    param_names = []
    param_dict = {}
    for n, m in model.named_parameters():
        param_names.append(n)
        param_dict[n] = m
    # print(parameter_list)

    model.eval()
    with torch.no_grad():
        # for name in tqdm(param_names):
        for param_name in tqdm(param_names, total=len(param_names)):
            module_name, param_type = get_module_name(param_name)

            # skip bias, embed, etc parameters
            if module_name in module_blacklist or module_name is None \
                or param_type is None or param_type!="weight":
                continue

            if len(param_dict[param_name].shape) < 2:
                continue

            param = param_dict[param_name]

            #print(f"Doing layer {name}")
            # get layer input from features, key is get_feature_storage_name(module_name)
            # get_feature_storage_name(module_name) stores k_proj, v_proj, q_proj together
            # since they are the same input
            
            layer_hessian = feature_hessians[get_feature_storage_name(module_name)].to(device=device)

            # calculate inverse hessian
            # check if input is flattened e.g. from 8,512,768 to 4096,768
            inv_hess = calc_inverse_hessian(layer_hessian, epsilon=EPSILON)

            # calculate mask
            # if iterative:
            #     mask = iterative_calculate_mask.calculate_mask(W=param, H_inv=inv_hess, p=SPARSENESS, B=B, Bs=Bs)
            # else:
            #     mask = calculate_mask.calculate_mask(W=param, H_inv=inv_hess, p=SPARSENESS, B=B, Bs=Bs)
            mask = calculate_mask(W=param, H_inv=inv_hess, p=SPARSENESS, B=B, Bs=Bs)

            # get module from lookup dictionary by module name
            module = module_dict[module_name]
            # apply mask
            prune.custom_from_mask(module=module, name=param_type, mask=mask)
            prune.remove(module=module, name=param_type)
            gc.collect()
            torch.cuda.empty_cache()           
    pruned_model_name = f'opt-350m-test-{SPARSENESS}'

    torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')

In [7]:
# initial prune
model = OPTForCausalLM.from_pretrained(model_name, output_attentions=True,
                                           output_hidden_states=True).to(device=device) # type: ignore

SPARSENESS=.8

feature_hessians = {}
#put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
split_model_calibration(model)
torch.cuda.empty_cache()

sparsegpt_prune(model=model, feature_hessians=feature_hessians, # type: ignore
EPSILON=EPSILON, SPARSENESS=SPARSENESS, B=B, Bs=Bs)
# print(f"Proportion of 0s: {get_prop_zeros(model)}")


5it [00:05,  1.04s/it]                       
100%|██████████| 196/196 [01:15<00:00,  2.60it/s]


In [8]:
# Test Iterative Pruning
# from SparseGPT_pruning import sparsegpt_prune

def get_prop_zeros(model):
    return torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) / (torch.numel(model.get_decoder().layers[0].self_attn.k_proj.weight == 0))

# model = OPTForCausalLM.from_pretrained(model_name, output_attentions=True,
#                                            output_hidden_states=True).to(device=device) # type: ignore

SPARSENESS=.8
import iterative_calculate_mask
for i in range(1,5):
    
    #storage_dir = f'tmp/{model_name}-{SPARSENESS}'
    
    # First, put in forward hooks
    # Don't store inputs, instead store hessians (less data)
    # Only store hessians once, as all models take the same hessians
    print('------------')
    if i == 0:
        feature_hessians = {}
        #put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
        put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
        split_model_calibration(model)
        torch.cuda.empty_cache()
    print('-----------')
    sparsegpt_prune(model=model, feature_hessians=feature_hessians, # type: ignore
    EPSILON=EPSILON, SPARSENESS=SPARSENESS, B=B, Bs=Bs)
    print(f"Proportion of 0s: {get_prop_zeros(model)}")
    

------------
-----------


100%|██████████| 196/196 [01:12<00:00,  2.71it/s]


Proportion of 0s: 0.3619791567325592
------------
-----------


100%|██████████| 196/196 [01:15<00:00,  2.61it/s]


Proportion of 0s: 0.490234375
------------
-----------


100%|██████████| 196/196 [01:16<00:00,  2.56it/s]


Proportion of 0s: 0.5931006669998169
------------
-----------


100%|██████████| 196/196 [01:11<00:00,  2.72it/s]


Proportion of 0s: 0.67578125


In [None]:
from input_prehooks import get_feature_storage_name
import gc
from SparseGPT_pruning import sparsegpt_prune
layer_blacklist = ['model.decoder.embed_tokens.weight', 'model.decoder.embed_tokens.bias', 'model.decoder.embed_positions.weight']

# Using calibration data (inputs to each intermediate weight layer)
# Iterate through named parameters, calculate inverse hessian and calculate mask

SPARSENESS_LIST = [0.5]#0.1, 0.2, 0.3, 0.5, 0.7, 0.9
for i, SPARSENESS in enumerate(SPARSENESS_LIST):
    
    # Load model with pre-trained head
    model = OPTForCausalLM.from_pretrained(model_name, output_attentions=True,
                                           output_hidden_states=True).to(device=device) # type: ignore
    
    #storage_dir = f'tmp/{model_name}-{SPARSENESS}'
    
    # First, put in forward hooks
    # Don't store inputs, instead store hessians (less data)
    # Only store hessians once, as all models take the same hessians
    print('------------')
    if i == 0:
        feature_hessians = {}
        #put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
        put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
        split_model_calibration(model)
        torch.cuda.empty_cache()
    print('-----------')
    
    sparsegpt_prune(model=model, feature_hessians=feature_hessians, # type: ignore
    EPSILON=EPSILON, SPARSENESS=SPARSENESS, B=B, Bs=Bs)

    pruned_model_name = f'{model_size}-test-{SPARSENESS}'

    torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')

In [None]:
from save_pruned_model import load_unmasked_model, load_masked_model

loaded_model = OPTForCausalLM.from_pretrained(f'facebook/{model_size}', output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore
# loaded_model = torch.nn.DataParallel(loaded_model, device_ids=[0,1,2,3])
load_unmasked_model(loaded_model, f'pruned_models/{model_size}-test-0.5.pt')

loaded_model_2 = OPTForCausalLM.from_pretrained(f'facebook/{model_size}', output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore
load_masked_model(loaded_model_2, f'pruned_models/{model_size}-test-0.5.pt')

def get_prop_zeros(model):
    return torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) / (torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) + torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight != 0))

print(get_prop_zeros(loaded_model))
print(get_prop_zeros(model))
print(get_prop_zeros(loaded_model_2))
# loaded_model.eval()
# _ = loaded_model(torch.randint(high=20, size=(1,10)))

In [None]:
prune_list = []
layer_blacklist = ['', 'module','module.model','module.model.decoder',
                   'module.model.decoder.embed_tokens',
                   'module.model.decoder.embed_tokens',
                   'module.model.decoder.embed_positions']
for name, module in model.named_modules():
    # skip the embed layer or skip norms which have 1 dimension
    if name in layer_blacklist or 'norm' in name or not isinstance(module, torch.nn.Module):
        continue
    prune_list.append((module, 'weight'))

# prune to 0s
class ThresholdPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    # default threshold is 0, prunes weights that are already 0 (for training)
    def __init__(self, threshold=1e-8):
        self.threshold = threshold

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) >= self.threshold
    
prune.global_unstructured(prune_list,
                          pruning_method=ThresholdPruning)