In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from utils.prune_utils import sparsegpt_prune

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_size = "opt-125m"

model_name = f"facebook/{model_size}"

#Load dataset
dataset = load_dataset('wikitext', "wikitext-2-raw-v1", streaming=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

In [3]:
# Calibrate model (get inputs to each layer with calibration data)

calibration_size=4
token_length=1024
calibration_batch_size=2

EPSILON = 0.01
B = 4
Bs = 2

# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train'])), total=calibration_size):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                with torch.no_grad():
                    encoded_input = tokenizer(batch_sentences, return_tensors="pt",
                                              padding="max_length", max_length=token_length,
                                              truncation=True).to(device=device)
                    model(**encoded_input, labels=encoded_input.input_ids)
                    batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break

# Sparsify Model

In [4]:
# from iterative_prune_finetune import iterative_sparsegpt_prune_tune
from utils.prehook_utils import put_input_hooks,remove_all_hooks
from utils.prune_utils import sparsegpt_prune
from utils.finetune_utils import finetune_model_inplace
from utils.save_utils import unmask_model
from utils.save_utils import mask_from_pruned, unmask_model
import gc

def get_prop_zeros(model):
    return torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) / torch.numel(model.get_decoder().layers[0].self_attn.k_proj.weight)

SPARSITIES = [0.9, 0.7, 0.5, 0.3, 0.2]#0.1, 0.2,0.3,0.5,0.7,0.9,1
for SPARSENESS in SPARSITIES:
    model = OPTForCausalLM.from_pretrained(f'facebook/{model_size}', 
                                       output_attentions=True, 
                                       output_hidden_states=True).to(device=device)
    # model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])

    feature_hessians = {}
    #put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
    all_hooks = put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
    split_model_calibration(model)
    for hook in all_hooks:
        hook.remove()
    print(f"memory allocated before pruning: {torch.cuda.memory_allocated()}")
    sparsegpt_prune(model=model, model_name=model_size, 
                    feature_hessians=feature_hessians, 
                    EPSILON=EPSILON, SPARSENESS=SPARSENESS, B=B, Bs=Bs, save_model=False)
    feature_hessians.clear()
    gc.collect()
    torch.cuda.empty_cache()
    print(f"memory allocated before fine tuning: {torch.cuda.memory_allocated()}")
    mask_from_pruned(model=model)
    finetune_model_inplace(model=model, tokenizer=tokenizer, 
                           SPARSITY=SPARSENESS, device=device, EPOCH_COUNT=10)
    unmask_model(model=model)
    torch.cuda.empty_cache()
    #unmask_model(model=model)
    pruned_model_name = f'{model_size}-finetuned-{SPARSENESS}'
    torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}-iterative.pt')
    print(f"memory allocated: {torch.cuda.memory_allocated()}")
    print(f"proportion zeros: {get_prop_zeros(model)}")
# Prune using the sparseGPT method, saves as pruned_models/{model_name}-{SPARSENESS}.pt WITHOUT mask
#iterative_sparsegpt_prune_tune(model, model_size, SPARSITIES, feature_hessians, EPSILON, B, Bs, tokenizer, EPOCH_COUNT=10)


5it [00:05,  1.09s/it]                       


memory allocated before pruning: 502228992


100%|██████████| 196/196 [00:50<00:00,  3.85it/s]


memory allocated before fine tuning: 502228992


  0%|          | 0/10 [00:00<?, ?it/s]Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


max memory during batch: 1825327104
max memory during batch: 3722486784
max memory during batch: 3184641536
max memory during batch: 3411198976


 10%|█         | 1/10 [00:06<00:57,  6.39s/it]

max memory during batch: 3370840576
memory allocated after epoch: 2729281024
max memory during batch: 3334405120
max memory during batch: 3728124928
max memory during batch: 3185484288
max memory during batch: 3412511744
max memory during batch: 3371279872


 20%|██        | 2/10 [00:10<00:42,  5.29s/it]

memory allocated after epoch: 2726964736
max memory during batch: 3335440384
max memory during batch: 3730997248
max memory during batch: 3184476672
max memory during batch: 3407294464
max memory during batch: 3370164736


 30%|███       | 3/10 [00:15<00:34,  4.96s/it]

memory allocated after epoch: 2729281024
max memory during batch: 3333839872
max memory during batch: 3732876800
max memory during batch: 3187583488
max memory during batch: 3414560256
max memory during batch: 3370270208


 40%|████      | 4/10 [00:20<00:28,  4.82s/it]

memory allocated after epoch: 2727875072
max memory during batch: 3332915200
max memory during batch: 3729937408
max memory during batch: 3186563584
max memory during batch: 3409532416
max memory during batch: 3368755712


 50%|█████     | 5/10 [00:24<00:23,  4.72s/it]

memory allocated after epoch: 2726668800
max memory during batch: 3334323200
max memory during batch: 3731910144
max memory during batch: 3185325568
max memory during batch: 3409117184


 60%|██████    | 6/10 [00:29<00:18,  4.66s/it]

max memory during batch: 3369355776
memory allocated after epoch: 2728188416
max memory during batch: 3332850688
max memory during batch: 3731148800
max memory during batch: 3184067072
max memory during batch: 3411408896
max memory during batch: 3367992832


 70%|███████   | 7/10 [00:33<00:13,  4.62s/it]

memory allocated after epoch: 2727124480
max memory during batch: 3332811776
max memory during batch: 3732649984
max memory during batch: 3184017920
max memory during batch: 3411425280
max memory during batch: 3368806912


 80%|████████  | 8/10 [00:38<00:09,  4.60s/it]

memory allocated after epoch: 2726323712
max memory during batch: 3332570112
max memory during batch: 3730759168
max memory during batch: 3185777152
max memory during batch: 3408865280
max memory during batch: 3369061888


 90%|█████████ | 9/10 [00:42<00:04,  4.60s/it]

memory allocated after epoch: 2728477184
max memory during batch: 3336590336
max memory during batch: 3730469376
max memory during batch: 3188997632
max memory during batch: 3412296704


100%|██████████| 10/10 [00:47<00:00,  4.74s/it]

max memory during batch: 3369910784
memory allocated after epoch: 2726430208





memory allocated: 1003693056
proportion zeros: 0.1009114608168602


5it [00:07,  1.43s/it]                       


memory allocated before pruning: 502993920


100%|██████████| 196/196 [00:51<00:00,  3.78it/s]


memory allocated before fine tuning: 502228992


  0%|          | 0/10 [00:00<?, ?it/s]

max memory during batch: 1827162112
max memory during batch: 3722540032
max memory during batch: 3186822656
max memory during batch: 3409554432
max memory during batch: 3369180672


 10%|█         | 1/10 [00:04<00:44,  4.93s/it]

memory allocated after epoch: 2727946752
max memory during batch: 3333452800
max memory during batch: 3732768768
max memory during batch: 3186133504
max memory during batch: 3410338304


 20%|██        | 2/10 [00:09<00:38,  4.76s/it]

max memory during batch: 3372814336
memory allocated after epoch: 2729612288
max memory during batch: 3336914432
max memory during batch: 3732587008
max memory during batch: 3187034112
max memory during batch: 3412694528


 30%|███       | 3/10 [00:14<00:32,  4.68s/it]

max memory during batch: 3372651008
memory allocated after epoch: 2729210368
max memory during batch: 3335597568
max memory during batch: 3729841152
max memory during batch: 3187325440
max memory during batch: 3414098944
max memory during batch: 3373518336


 40%|████      | 4/10 [00:18<00:28,  4.68s/it]

memory allocated after epoch: 2730931712
max memory during batch: 3336486912
max memory during batch: 3734632448
max memory during batch: 3187015168
max memory during batch: 3415277568
max memory during batch: 3374604288


 50%|█████     | 5/10 [00:23<00:23,  4.63s/it]

memory allocated after epoch: 2730892288
max memory during batch: 3336123392
max memory during batch: 3730751488
max memory during batch: 3188776448
max memory during batch: 3413292032
max memory during batch: 3369784832


 60%|██████    | 6/10 [00:28<00:18,  4.64s/it]

memory allocated after epoch: 2729885184
max memory during batch: 3334389760
max memory during batch: 3730408448
max memory during batch: 3187119104
max memory during batch: 3414403584


 70%|███████   | 7/10 [00:32<00:13,  4.66s/it]

max memory during batch: 3370396160
memory allocated after epoch: 2728561152
max memory during batch: 3338322944
max memory during batch: 3729020416
max memory during batch: 3186503168
max memory during batch: 3412751872
max memory during batch: 3371387392


 80%|████████  | 8/10 [00:37<00:09,  4.66s/it]

memory allocated after epoch: 2730749440
max memory during batch: 3336848384
max memory during batch: 3731456000
max memory during batch: 3186087424
max memory during batch: 3409763328
max memory during batch: 3371624960


 90%|█████████ | 9/10 [00:42<00:04,  4.65s/it]

memory allocated after epoch: 2730001920
max memory during batch: 3334942720
max memory during batch: 3730339840
max memory during batch: 3186465792
max memory during batch: 3411646464
max memory during batch: 3369982464


100%|██████████| 10/10 [00:46<00:00,  4.66s/it]

memory allocated after epoch: 2728095232





memory allocated: 1005854720
proportion zeros: 0.30078125


5it [00:07,  1.44s/it]                       


memory allocated before pruning: 503236608


 38%|███▊      | 75/196 [00:18<00:27,  4.33it/s]

In [6]:
for k in feature_hessians.keys():
    print(k)

model.decoder.project_in
model.decoder.layers.0.self_attn.in_proj
model.decoder.layers.0.self_attn.out_proj
model.decoder.layers.0.fc1
model.decoder.layers.0.fc2
model.decoder.layers.1.self_attn.in_proj
model.decoder.layers.1.self_attn.out_proj
model.decoder.layers.1.fc1
model.decoder.layers.1.fc2
model.decoder.layers.2.self_attn.in_proj
model.decoder.layers.2.self_attn.out_proj
model.decoder.layers.2.fc1
model.decoder.layers.2.fc2
model.decoder.layers.3.self_attn.in_proj
model.decoder.layers.3.self_attn.out_proj
model.decoder.layers.3.fc1
model.decoder.layers.3.fc2
model.decoder.layers.4.self_attn.in_proj
model.decoder.layers.4.self_attn.out_proj
model.decoder.layers.4.fc1
model.decoder.layers.4.fc2
model.decoder.layers.5.self_attn.in_proj
model.decoder.layers.5.self_attn.out_proj
model.decoder.layers.5.fc1
model.decoder.layers.5.fc2
model.decoder.layers.6.self_attn.in_proj
model.decoder.layers.6.self_attn.out_proj
model.decoder.layers.6.fc1
model.decoder.layers.6.fc2
model.decoder.la