In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from utils.prune_utils import sparsegpt_prune

In [2]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_size = "opt-350m"

model_name = f"facebook/{model_size}"

#Load dataset
dataset = load_dataset('wikitext', "wikitext-2-raw-v1", streaming=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

In [3]:
# Calibrate model (get inputs to each layer with calibration data)

calibration_size=128
token_length=2048
calibration_batch_size=2

EPSILON = 0.01
B = 128
Bs = 128

# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train'])), total=calibration_size):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                with torch.no_grad():
                    encoded_input = tokenizer(batch_sentences, return_tensors="pt",
                                              padding="max_length", max_length=token_length,
                                              truncation=True).to(device=device)
                    model(**encoded_input, labels=encoded_input.input_ids)
                    batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break

# Sparsify Model

In [4]:
from iterative_prune_finetune import iterative_sparsegpt_prune_tune
from utils.prehook_utils import put_input_hooks,remove_all_hooks
from utils.prune_utils import sparsegpt_prune
from utils.finetune_utils import finetune_model_inplace
from utils.save_utils import unmask_model

SPARSITIES = [0.9, 0.7, 0.5, 0.3, 0.2]#0.1, 0.2,0.3,0.5,0.7,0.9,1
for SPARSENESS in SPARSITIES:
    model = OPTForCausalLM.from_pretrained(f'facebook/{model_size}', 
                                       output_attentions=True, 
                                       output_hidden_states=True).to(device=device)
    model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])

    feature_hessians = {}
    #put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
    all_hooks = put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
    split_model_calibration(model)
    for hook in all_hooks:
        hook.remove()
    sparsegpt_prune(model=model, model_name=model_size, 
                    feature_hessians=feature_hessians, 
                    EPSILON=EPSILON, SPARSENESS=SPARSENESS, B=B, Bs=Bs, save_model=False)
    torch.cuda.empty_cache()
    finetune_model_inplace(model=model, tokenizer=tokenizer, 
                           SPARSITY=SPARSENESS, device=device, EPOCH_COUNT=10)
    #unmask_model(model=model)
    pruned_model_name = f'{model_size}-finetuned-{SPARSENESS}'
    torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}-iterative.pt')
# Prune using the sparseGPT method, saves as pruned_models/{model_name}-{SPARSENESS}.pt WITHOUT mask
#iterative_sparsegpt_prune_tune(model, model_size, SPARSITIES, feature_hessians, EPSILON, B, Bs, tokenizer, EPOCH_COUNT=10)

del model
del feature_hessians

129it [03:37,  1.69s/it]                         
100%|██████████| 388/388 [00:43<00:00,  8.94it/s]
  0%|          | 0/10 [00:00<?, ?it/s]Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 10/10 [01:28<00:00,  8.88s/it]
129it [03:38,  1.69s/it]                         
100%|██████████| 388/388 [00:44<00:00,  8.63it/s]
100%|██████████| 10/10 [01:14<00:00,  7.43s/it]
129it [03:36,  1.68s/it]                         
100%|██████████| 388/388 [00:45<00:00,  8.61it/s]
100%|██████████| 10/10 [01:14<00:00,  7.45s/it]
129it [03:48,  1.77s/it]                         
100%|██████████| 388/388 [00:46<00:00,  8.38it/s]
100%|██████████| 10/10 [01:14<00:00,  7.45s/it]
129it [03:23,  1.58s/it]                         
100%|██████████| 388/388 [00:42<00:00,  9.18it