In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from calculate_mask import calculate_mask
from inverse_hessian import calc_inverse_hessian
from input_prehooks import put_input_hooks
from testing_module import calculate_perp

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_size = "opt-125m"

model_name = f"facebook/{model_size}"
# model_name = "facebook/opt-1.3b"

#Load dataset
dataset = load_dataset('wikitext', "wikitext-2-raw-v1", streaming=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

In [3]:
# Calibrate model (get inputs to each layer with calibration data)

calibration_size=4
token_length=512
calibration_batch_size=2

EPSILON = 1e-8
B = 4
Bs = 2

# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train'])), total=calibration_size):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                with torch.no_grad():
                    encoded_input = tokenizer(batch_sentences, return_tensors="pt",
                                              padding="max_length", max_length=token_length,
                                              truncation=True).to(device=device)
                    model(**encoded_input, labels=encoded_input.input_ids)
                    batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break

# Sparsify Model

In [4]:
# SparseGPT fine tune loop
from SparseGPT_pruning import sparsegpt_prune
from trainingv2 import fine_tune
from save_pruned_model import mask_from_pruned, unmask_model

# first, calibrate model
model = OPTForCausalLM.from_pretrained(model_name, output_attentions=True,
                                        output_hidden_states=True).to(device=device) # type: ignore
feature_hessians = {}
#put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
split_model_calibration(model)
torch.cuda.empty_cache()


# set up data
def encode_tok(examples):
    return tokenizer(examples['text'], truncation=True, padding='max_length')
#stream c4, training split
training_data = load_dataset('wikitext', "wikitext-2-raw-v1", split='train', streaming=True)
#load tokenizer
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
#IMPORTANT: process data while streaming -> remove unnecessary columns in batches
training_data = training_data.map(encode_tok, batched=True, remove_columns=["text"])
#set data to tensor mode
training_data = training_data.with_format("torch")

def get_prop_zeros(model):
    return torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) / (torch.numel(model.get_decoder().layers[0].self_attn.k_proj.weight))

# now, fine tune loop
# sparseness is defined as proportion of nonzeros (opposite of intuitive)
# sparseness_sequence = [.9, .8, .7, .6, .5, .4, .3, .2]
sparseness_sequence = [.9, .5, .4]

for sparseness_index in range(len(sparseness_sequence)):

    if sparseness_index == 0:
        sparseness_dif = sparseness_sequence[sparseness_index]
    else:
        sparseness_dif = sparseness_sequence[sparseness_index] / sparseness_sequence[sparseness_index -1]
    
    sparsegpt_prune(model=model, feature_hessians=feature_hessians, SPARSENESS=sparseness_dif, EPSILON=EPSILON, B=B, Bs=Bs)
    print(f"After pruning, Model has {get_prop_zeros(model)}")

    # activate masks
    mask_from_pruned(model=model)

    # fine_tune(model=model, training_data=training_data, EPOCH_COUNT=1, tokenizer=tokenizer)
    
    print(f"After fine-tuning, Model has {get_prop_zeros(model)}")

    # deactivate masks
    unmask_model(model=model)
    for n, m in model.named_buffers():
        print(n)

    pruned_model_name = f'{model_size}-test-{sparseness_sequence[sparseness_index]}'
    torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')

5it [00:04,  1.23it/s]                       
100%|██████████| 196/196 [00:54<00:00,  3.58it/s]


After pruning, Model has 0.1009114608168602
After fine-tuning, Model has 0.1009114608168602
model.decoder.layers.0.self_attn.k_proj.weight_mask
model.decoder.layers.0.self_attn.v_proj.weight_mask
model.decoder.layers.0.self_attn.q_proj.weight_mask
model.decoder.layers.0.self_attn.out_proj.weight_mask
model.decoder.layers.0.fc1.weight_mask
model.decoder.layers.0.fc2.weight_mask
model.decoder.layers.1.self_attn.k_proj.weight_mask
model.decoder.layers.1.self_attn.v_proj.weight_mask
model.decoder.layers.1.self_attn.q_proj.weight_mask
model.decoder.layers.1.self_attn.out_proj.weight_mask
model.decoder.layers.1.fc1.weight_mask
model.decoder.layers.1.fc2.weight_mask
model.decoder.layers.2.self_attn.k_proj.weight_mask
model.decoder.layers.2.self_attn.v_proj.weight_mask
model.decoder.layers.2.self_attn.q_proj.weight_mask
model.decoder.layers.2.self_attn.out_proj.weight_mask
model.decoder.layers.2.fc1.weight_mask
model.decoder.layers.2.fc2.weight_mask
model.decoder.layers.3.self_attn.k_proj.weig

100%|██████████| 196/196 [00:00<00:00, 2102515.56it/s]


After pruning, Model has 0.1009114608168602
After fine-tuning, Model has 0.1009114608168602
model.decoder.layers.0.self_attn.k_proj.weight_mask
model.decoder.layers.0.self_attn.v_proj.weight_mask
model.decoder.layers.0.self_attn.q_proj.weight_mask
model.decoder.layers.0.self_attn.out_proj.weight_mask
model.decoder.layers.0.fc1.weight_mask
model.decoder.layers.0.fc2.weight_mask
model.decoder.layers.1.self_attn.k_proj.weight_mask
model.decoder.layers.1.self_attn.v_proj.weight_mask
model.decoder.layers.1.self_attn.q_proj.weight_mask
model.decoder.layers.1.self_attn.out_proj.weight_mask
model.decoder.layers.1.fc1.weight_mask
model.decoder.layers.1.fc2.weight_mask
model.decoder.layers.2.self_attn.k_proj.weight_mask
model.decoder.layers.2.self_attn.v_proj.weight_mask
model.decoder.layers.2.self_attn.q_proj.weight_mask
model.decoder.layers.2.self_attn.out_proj.weight_mask
model.decoder.layers.2.fc1.weight_mask
model.decoder.layers.2.fc2.weight_mask
model.decoder.layers.3.self_attn.k_proj.weig

100%|██████████| 196/196 [00:00<00:00, 2086506.56it/s]


After pruning, Model has 0.1009114608168602
After fine-tuning, Model has 0.1009114608168602
model.decoder.layers.0.self_attn.k_proj.weight_mask
model.decoder.layers.0.self_attn.v_proj.weight_mask
model.decoder.layers.0.self_attn.q_proj.weight_mask
model.decoder.layers.0.self_attn.out_proj.weight_mask
model.decoder.layers.0.fc1.weight_mask
model.decoder.layers.0.fc2.weight_mask
model.decoder.layers.1.self_attn.k_proj.weight_mask
model.decoder.layers.1.self_attn.v_proj.weight_mask
model.decoder.layers.1.self_attn.q_proj.weight_mask
model.decoder.layers.1.self_attn.out_proj.weight_mask
model.decoder.layers.1.fc1.weight_mask
model.decoder.layers.1.fc2.weight_mask
model.decoder.layers.2.self_attn.k_proj.weight_mask
model.decoder.layers.2.self_attn.v_proj.weight_mask
model.decoder.layers.2.self_attn.q_proj.weight_mask
model.decoder.layers.2.self_attn.out_proj.weight_mask
model.decoder.layers.2.fc1.weight_mask
model.decoder.layers.2.fc2.weight_mask
model.decoder.layers.3.self_attn.k_proj.weig

In [None]:
from input_prehooks import get_feature_storage_name
import gc
from SparseGPT_pruning import sparsegpt_prune
layer_blacklist = ['model.decoder.embed_tokens.weight', 'model.decoder.embed_tokens.bias', 'model.decoder.embed_positions.weight']

# Using calibration data (inputs to each intermediate weight layer)
# Iterate through named parameters, calculate inverse hessian and calculate mask

SPARSENESS_LIST = [0.5]#0.1, 0.2, 0.3, 0.5, 0.7, 0.9
for i, SPARSENESS in enumerate(SPARSENESS_LIST):
    
    # Load model with pre-trained head
    model = OPTForCausalLM.from_pretrained(model_name, output_attentions=True,
                                           output_hidden_states=True).to(device=device) # type: ignore
    
    #storage_dir = f'tmp/{model_name}-{SPARSENESS}'
    
    # First, put in forward hooks
    # Don't store inputs, instead store hessians (less data)
    # Only store hessians once, as all models take the same hessians
    print('------------')
    if i == 0:
        feature_hessians = {}
        #put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
        put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
        split_model_calibration(model)
        torch.cuda.empty_cache()
    print('-----------')
    
    sparsegpt_prune(model=model, feature_hessians=feature_hessians, # type: ignore
    EPSILON=EPSILON, SPARSENESS=SPARSENESS, B=B, Bs=Bs)

    pruned_model_name = f'{model_size}-test-{SPARSENESS}'
    # make a dictionary to access module by name
    module_lookup_dict = {}
    for module_name, module_iter in model.named_modules():
        module_lookup_dict[module_name] = module_iter

    # without this
    param_lookup_dict = {}
    param_names = []
    for name, param in model.named_parameters():
        param_names.append(name)
        param_lookup_dict[name] = param
    
    model.eval()
    #model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])
    with torch.no_grad():
        # for name in tqdm(param_names):

    torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')

In [None]:
from save_pruned_model import load_unmasked_model, load_masked_model

loaded_model = OPTForCausalLM.from_pretrained(f'facebook/{model_size}', output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore
# loaded_model = torch.nn.DataParallel(loaded_model, device_ids=[0,1,2,3])
load_unmasked_model(loaded_model, f'pruned_models/{model_size}-test-0.5.pt')

loaded_model_2 = OPTForCausalLM.from_pretrained(f'facebook/{model_size}', output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore
load_masked_model(loaded_model_2, f'pruned_models/{model_size}-test-0.5.pt')

def get_prop_zeros(model):
    return torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) / (torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) + torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight != 0))

print(get_prop_zeros(loaded_model))
print(get_prop_zeros(model))
print(get_prop_zeros(loaded_model_2))
# loaded_model.eval()
# _ = loaded_model(torch.randint(high=20, size=(1,10)))

In [None]:
prune_list = []
layer_blacklist = ['', 'module','module.model','module.model.decoder',
                   'module.model.decoder.embed_tokens',
                   'module.model.decoder.embed_tokens',
                   'module.model.decoder.embed_positions']
for name, module in model.named_modules():
    # skip the embed layer or skip norms which have 1 dimension
    if name in layer_blacklist or 'norm' in name or not isinstance(module, torch.nn.Module):
        continue
    prune_list.append((module, 'weight'))

# prune to 0s
class ThresholdPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    # default threshold is 0, prunes weights that are already 0 (for training)
    def __init__(self, threshold=1e-8):
        self.threshold = threshold

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) >= self.threshold
    
prune.global_unstructured(prune_list,
                          pruning_method=ThresholdPruning)