In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from utils.prune_utils import sparsegpt_prune

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#Load dataset
dataset = load_dataset('c4', "en", streaming=True)

In [3]:
# Calibrate model (get inputs to each layer with calibration data)

calibration_size=4
token_length=512
calibration_batch_size=2

EPSILON = 0.01
B = 4
Bs = 2

# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model, tokenizer):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train'])), total=calibration_size):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                with torch.no_grad():
                    encoded_input = tokenizer(batch_sentences, return_tensors="pt",
                                              padding="max_length", max_length=token_length,
                                              truncation=True).to(device=device)
                    model(**encoded_input, labels=encoded_input.input_ids)
                    batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break

In [4]:
from collections import OrderedDict
from typing import Dict, Callable
import torch

def remove_all_hooks(model: torch.nn.Module) -> None:
    for name, child in model._modules.items():
        if child is not None:
            if hasattr(child, "_forward_hooks"):
                child._forward_hooks: Dict[int, Callable] = OrderedDict()
            elif hasattr(child, "_forward_pre_hooks"):
                child._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
            elif hasattr(child, "_backward_hooks"):
                child._backward_hooks: Dict[int, Callable] = OrderedDict()
            remove_all_hooks(child)

In [5]:
from utils.iterative_prune_finetune import iterative_sparsegpt_prune_tune
from utils.prehook_utils import put_input_hooks,remove_all_hooks,  put_backward_hooks, check_whitelist
from utils.prune_utils import sparsegpt_prune, mask_lowest
from utils.finetune_utils import finetune_model_inplace
from utils.save_utils import mask_from_pruned, unmask_model
from fsdp_finetune import fsdp_finetune

def get_prop_zeros(model):
    return torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) / (torch.numel(model.get_decoder().layers[0].self_attn.k_proj.weight))

for model_size in ['opt-125m', 'opt-350m', 'opt-1.3b']:
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(f'facebook/{model_size}', padding_side='left')
    SPARSITIES = [1, 0.9, 0.7, 0.5, 0.3, 0.2]#0.1, 0.2,0.3,0.5,0.7,0.9,1
    model = OPTForCausalLM.from_pretrained(f'facebook/{model_size}', 
                                           output_attentions=True, 
                                           output_hidden_states=True).to(device=device)
    # model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])
    for SPARSENESS in SPARSITIES:
        print(f"sparsity {SPARSENESS}")
        feature_hessians = {}
        put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
        all_hooks = put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
        split_model_calibration(model,tokenizer)
        
        # for n, m in model.named_modules():
        #     m.register_forward_pre_hook(None)

        # for handle in all_hooks:
        #     handle.remove()
        remove_all_hooks(model=model)

        sparsegpt_prune(model=model, model_name=model_size, 
                        feature_hessians=feature_hessians, 
                        EPSILON=EPSILON, SPARSENESS=SPARSENESS, B=B, Bs=Bs, save_model=False)
        
        print(f"proportion of zeros: {get_prop_zeros(model)}")
        
        # mask_lowest(model=model, amount=1-SPARSENESS)
        del feature_hessians

        print(f"Memory allocated after pruning: {torch.cuda.memory_allocated()}")

        torch.cuda.empty_cache()
        
        back_hooks = put_backward_hooks(model=model)
        # for name, param in model.named_parameters():
        #     if 'weight' in name and check_whitelist(name):
        #         mask = param != 0
        #         print(f"prop nonzeros: {torch.sum(mask) / torch.numel(param)}")
        #         def hook(grad, mask=mask):
        #             return grad * mask.float()
        #         param.register_hook(hook)
        # mask_from_pruned(model=model)

        #finetune_model_inplace(model=model, tokenizer=tokenizer, 
        #                       SPARSITY=SPARSENESS, device=device, EPOCH_COUNT=1, max_step=1000)
        #unmask_model(model=model)
        config = {"model": model, "lr": 2e-5, "num_epochs": 1,
                  "seed": 1, "batch_size": 16,
                  'model_name': model_size,
                  'sparsity': SPARSENESS,"train_steps": 100,
                  'max_step': 100, 'save_model': False}
        fsdp_finetune(config)
        # fine_tune(model=model, EPOCH_COUNT=1, tokenizer=tokenizer)

        # for hook in back_hooks:
        #     hook.remove()

        remove_all_hooks(model=model)

        # unmask_model(model=model)

        print(f"proportion of zeros: {get_prop_zeros(model)}")

        pruned_model_name = f'{model_size}-finetuned-{SPARSENESS}'
        torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}-iterative.pt')
        print(f"Memory allocated after fine tuning: {torch.cuda.memory_allocated()}")
        
    # Prune using the sparseGPT method, saves as pruned_models/{model_name}-{SPARSENESS}.pt WITHOUT mask
    #iterative_sparsegpt_prune_tune(model, model_size, SPARSITIES, feature_hessians, EPSILON, B, Bs, tokenizer, EPOCH_COUNT=10)

    del model
    torch.cuda.empty_cache()

sparsity 1


5it [00:12,  2.42s/it]                       
100%|██████████| 196/196 [00:50<00:00,  3.87it/s]


proportion of zeros: 0.0006510416860692203
Memory allocated after pruning: 502228992
DistributedType.NO


100%|██████████| 1/1 [00:17<00:00, 17.89s/it]

Memory before entering the train : 613
Memory consumed at the end of the train (end-begin): 1757
Peak Memory consumed during the train (max-begin): 2491
Total Peak Memory consumed during the train (max): 3104





proportion of zeros: 0.0006510416860692203
Memory allocated after fine tuning: 2167001088
sparsity 0.9


5it [00:13,  2.79s/it]                       
100%|██████████| 196/196 [00:51<00:00,  3.82it/s]


proportion of zeros: 0.1009114608168602
Memory allocated after pruning: 1164890112
DistributedType.NO


100%|██████████| 1/1 [00:19<00:00, 19.07s/it]

Memory before entering the train : 1247
Memory consumed at the end of the train (end-begin): 1276
Peak Memory consumed during the train (max-begin): 2001
Total Peak Memory consumed during the train (max): 3248





proportion of zeros: 0.0006510416860692203
Memory allocated after fine tuning: 2327171072
sparsity 0.7


5it [00:14,  2.80s/it]                       
100%|██████████| 196/196 [00:49<00:00,  3.95it/s]


proportion of zeros: 0.30078125
Memory allocated after pruning: 1309849600
DistributedType.NO


100%|██████████| 1/1 [00:20<00:00, 20.02s/it]

Memory before entering the train : 1384
Memory consumed at the end of the train (end-begin): 1275
Peak Memory consumed during the train (max-begin): 2002
Total Peak Memory consumed during the train (max): 3386





proportion of zeros: 0.0982937291264534
Memory allocated after fine tuning: 2469509120
sparsity 0.5


5it [00:13,  2.77s/it]                       
100%|██████████| 196/196 [00:49<00:00,  3.96it/s]


proportion of zeros: 0.5006510615348816
Memory allocated after pruning: 1451413504
DistributedType.NO


100%|██████████| 1/1 [00:21<00:00, 21.09s/it]

Memory before entering the train : 1519
Memory consumed at the end of the train (end-begin): 1276
Peak Memory consumed during the train (max-begin): 2002
Total Peak Memory consumed during the train (max): 3521





proportion of zeros: 0.2971649169921875
Memory allocated after fine tuning: 2610810880
sparsity 0.3


5it [00:13,  2.79s/it]                       
100%|██████████| 196/196 [00:49<00:00,  3.96it/s]


proportion of zeros: 0.701171875
Memory allocated after pruning: 1594537984
DistributedType.NO


100%|██████████| 1/1 [00:22<00:00, 22.11s/it]

Memory before entering the train : 1655
Memory consumed at the end of the train (end-begin): 1273
Peak Memory consumed during the train (max-begin): 2004
Total Peak Memory consumed during the train (max): 3659





proportion of zeros: 0.498291015625
Memory allocated after fine tuning: 2753142784
sparsity 0.2


5it [00:13,  2.78s/it]                       
100%|██████████| 196/196 [00:49<00:00,  3.99it/s]


proportion of zeros: 0.80078125
Memory allocated after pruning: 1735315456
DistributedType.NO


100%|██████████| 1/1 [00:23<00:00, 23.18s/it]

Memory before entering the train : 1789
Memory consumed at the end of the train (end-begin): 1275
Peak Memory consumed during the train (max-begin): 2010
Total Peak Memory consumed during the train (max): 3799





proportion of zeros: 0.6992950439453125
Memory allocated after fine tuning: 2893664256
sparsity 1


5it [00:45,  9.12s/it]                       
100%|██████████| 388/388 [02:33<00:00,  2.53it/s]


proportion of zeros: 0.00048828125
Memory allocated after pruning: 3203231744
DistributedType.NO


  0%|          | 0/1 [00:03<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 8.00 GiB total capacity; 7.25 GiB already allocated; 0 bytes free; 7.27 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF