In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from calculate_mask import calculate_mask
from inverse_hessian import calc_inverse_hessian
from input_prehooks import put_input_hooks
import gc

In [2]:
model = OPTForCausalLM.from_pretrained(f'facebook/opt-125m', output_attentions=True,
                                           output_hidden_states=True) # type: ignore
model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])

In [3]:
!nvidia-smi

Sat Feb 25 00:47:45 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  On   | 00000000:01:00.0 Off |                    0 |
| N/A   27C    P0    35W / 250W |      2MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-PCI...  On   | 00000000:41:00.0 Off |                    0 |
| N/A   28C    P0    37W / 250W |      2MiB / 40960MiB |      0%      Default |
|       

# Functions

In [4]:
# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train'])), total=calibration_size):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                with torch.no_grad():
                    encoded_input = tokenizer(batch_sentences, return_tensors="pt",
                                              padding="max_length", max_length=token_length,
                                              truncation=True).to(device=device)
                    model(**encoded_input, labels=encoded_input.input_ids)
                    del encoded_input
                    torch.cuda.empty_cache()
                    batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break

In [6]:
# SparseGPT fine tune loop
import torch

def get_prop_zeros(model):
    return torch.sum(model.module.get_decoder().layers[0].self_attn.k_proj.weight == 0) / (torch.numel(model.module.get_decoder().layers[0].self_attn.k_proj.weight))

# now, fine tune loop
# sparseness is defined as proportion of nonzeros (opposite of intuitive)
# sparseness_sequence = [.9, .8, .7, .6, .5, .4, .3, .2]
# sparseness_sequence = [.9, .5, .4]

# model is model to iteratively tune and prune, will do so in place
# model_size is for naming the save files (like opt-125m)
# sparseness sequence is sequence of sparsities (.8 sparseness = 20% proportion of zeros)
# training_data and tokenizer are for fine_tuning (should already be preprocessed with torch.format and stuff)
def iterative_sparsegpt_prune_tune(model, model_size, sparseness_sequence, feature_hessians, EPSILON, B, Bs, tokenizer, EPOCH_COUNT):
    # for sparseness_index in range(len(sparseness_sequence)):
    for sparseness in sparseness_sequence:
        sparsegpt_prune(model=model, model_name=model_size, feature_hessians=feature_hessians, SPARSENESS=sparseness, EPSILON=EPSILON, B=B, Bs=Bs)
        #del model
        #torch.cuda.empty_cache()
        #fine_tune(model_name=model_name, EPOCH_COUNT=EPOCH_COUNT, tokenizer=tokenizer)

        # deactivate masks
        #unmask_model(model=model)
        #torch.cuda.empty_cache()
        #print(f"proportion of zeros: {get_prop_zeros(model)}")

        #pruned_model_name = f'{model_size}-finetuned-{sparseness}'
        #torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')
        #del model
def iterative_cerebras_prune_tune(model, model_size, sparseness_sequence, training_data, tokenizer, EPOCH_COUNT):
    for sparseness in sparseness_sequence:

        mask_lowest(model=model, amount=1-sparseness)

        # activate masks
        mask_from_pruned(model=model)

        fine_tune(model=model, training_data=training_data, EPOCH_COUNT=EPOCH_COUNT, tokenizer=tokenizer)

        # deactivate masks
        unmask_model(model=model)

        print(f"proportion of zeros: {get_prop_zeros(model)}")

        pruned_model_name = f'{model_size}-cerebras-tune-and-prune-{sparseness}'
        torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')

In [7]:
# LOAD PRUNED MODEL

import torch
from torch.nn.utils import prune
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'


# function to get module name from parameter name
def get_module_name(param_name):
    if param_name[-5:] == ".bias":
        return param_name[:-5], "bias"
    elif param_name[-7:] == ".weight":
        return param_name[:-7], "weight"
    
    elif param_name[-10:] == ".bias_orig":
        return param_name[:-10], "bias"
    elif param_name[-12:] == ".weight_orig":
        return param_name[:-12], "weight"
    else:
        return None, None

# load model without masks
def load_unmasked_model(existing_model, state_dict_path):
    existing_model.load_state_dict(torch.load(state_dict_path))

# prune 0s to a mask, to make training easier (ostensibly)
class ZeroPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    # default threshold is 0, prunes weights that are already 0 (for training)
    def __init__(self):
        pass

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) != 0

# apply pytorch mask in place of 0 weights to make backpropagation easier for training
default_opt_blacklist = ['model.decoder.embed_tokens', 'model.decoder.embed_positions']
def mask_from_pruned(model, module_blacklist=default_opt_blacklist):
    module_dict = {}
    for n, m in model.named_modules():
        module_dict[n] = m
    # print(module_dict.keys())
    
    parameter_list = []
    param_dict = {}
    for n, m in model.named_parameters():
        parameter_list.append(n)
        param_dict[n] = m
    # print(parameter_list)

    for n in parameter_list:
        module_name, param_type = get_module_name(n)

        # skip bias, embed, etc parameters
        if module_name in module_blacklist or module_name is None \
            or param_type is None or param_type!="weight":
            continue

        if len(param_dict[n].shape) < 2:
            continue

        ZeroPruning.apply(module=module_dict[module_name], name=param_type)
# unmask model with 0s in place
def unmask_model(model, module_blacklist=default_opt_blacklist):
    module_dict = {}
    for n, m in model.named_modules():
        module_dict[n] = m
    # print(module_dict.keys())
    
    parameter_list = []
    param_dict = {}
    for n, m in model.named_parameters():
        parameter_list.append(n)
        param_dict[n] = m
    # print(parameter_list)

    for n in parameter_list:
        module_name, param_type = get_module_name(n)

        # skip bias, embed, etc parameters
        if module_name in module_blacklist or module_name is None \
            or param_type is None or param_type!="weight":
            continue

        if len(param_dict[n].shape) < 2:
            continue
            
        prune.remove(module=module_dict[module_name], name=param_type)
        torch.cuda.clear_cache()

# load model with masks
def load_masked_model(existing_model, state_dict_path):

    # first load like normal
    load_unmasked_model(existing_model, state_dict_path)
    
    # then reapply the (previously removed) masks
    mask_from_pruned(model=existing_model)

    # prune.global_unstructured(
    #     existing_model.parameters(), pruning_method=ThresholdPruning, threshold=0
    # )

In [8]:
# manage imports
import torch
from transformers import DataCollatorForLanguageModeling,AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import gc

device = 'cuda' if torch.cuda.is_available() else 'cpu' 

def test_model(model_name, encodings, token_length, seq_len, stride, wandb, SPARSITY, is_finetuned=False, device=device):
    loaded_model = OPTForCausalLM.from_pretrained(f'facebook/{model_name}',
                                                  output_attentions=True,
                                                  output_hidden_states=True).to(device=device) # type: ignore
    loaded_model = torch.nn.DataParallel(loaded_model, device_ids=[0,1,2,3])
    load_unmasked_model(loaded_model, f'pruned_models/{model_name}-{SPARSITY}.pt')
    if is_finetuned:
        load_unmasked_model(loaded_model, 
                            f'pruned_models/{model_name}-finetuned-{SPARSITY}.pt')
    else:
        if SPARSITY != 1:
            load_unmasked_model(loaded_model, 
                            f'pruned_models/{model_name}-{SPARSITY}.pt')
    loaded_model.eval()
    _ = loaded_model(torch.randint(high=20, size=(1,10)))
    
    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + token_length, seq_len)
        trg_len = end_loc - prev_end_loc
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device=device)
        target_ids = input_ids.clone()
        target_ids[:,:-trg_len] = -100
        
        with torch.no_grad():
            outputs = loaded_model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs.loss * trg_len
            
        nlls.append(neg_log_likelihood)
        
        prev_end_loc = end_loc
        if end_loc == seq_len:
            break
            
    ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
    wandb.log({"perplexity": ppl, 'density': SPARSITY})
    
    del loaded_model
    gc.collect()
    torch.cuda.empty_cache()


def finetune_model(model_name, tokenizer, SPARSITY, device=device, EPOCH_COUNT=10):
    #encode tokens
    def encode_tok(examples):
        return tokenizer(examples['text'], truncation=True, padding='max_length')

    #stream c4, training split
    training_data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train', streaming=True)
    #IMPORTANT: process data while streaming -> remove unnecessary columns in batches
    training_data = training_data.map(encode_tok, batched=True, remove_columns=["text"])

    #set data to tensor mode
    training_data = training_data.with_format("torch")

    #dataloader from dataloader (mlm=False when training without mask)
    reformatted_data = DataLoader(training_data, collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=False))
    
    loaded_model = OPTForCausalLM.from_pretrained(f'facebook/{model_name}',
                                                  output_attentions=True,
                                                  output_hidden_states=True).to(device=device) # type: ignore
    loaded_model = torch.nn.DataParallel(loaded_model, device_ids=[0,1,2,3])# activate masks
    
    if SPARSITY != 1:
        load_masked_model(loaded_model, f'pruned_models/{model_name}-{SPARSITY}.pt')
    loaded_model.eval()
    _ = loaded_model(torch.randint(high=20, size=(1,10)))
    mask_from_pruned(model=loaded_model)
    #training loop
    loaded_model.train().to(device)
    t_optim = torch.optim.AdamW(params=loaded_model.parameters(), lr=1e-5)
    for epoch in tqdm(range(EPOCH_COUNT)):
        training_data.set_epoch(epoch)
        for i, batch in enumerate(reformatted_data):
            if i == 5:
                break
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = loaded_model(**batch)
            loss = outputs[0]
            loss.backward()
            t_optim.step()
            t_optim.zero_grad()
    unmask_model(loaded_model)
    torch.save(loaded_model.state_dict(), f'pruned_models/{model_name}-{SPARSITY}-finetuned.pt')



In [9]:
from input_prehooks import get_feature_storage_name
import gc
import torch
from tqdm import tqdm
from torch.nn.utils import prune
import calculate_mask
import iterative_calculate_mask

opt_blacklist = ['module.model.decoder.embed_tokens', 'module.model.decoder.embed_positions']

#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Using calibration data (inputs to each intermediate weight layer)
# Iterate through named parameters, calculate inverse hessian and calculate mask

def get_module_name(param_name):
    if param_name[-5:] == ".bias":
        return param_name[:-5], "bias"
    elif param_name[-7:] == ".weight":
        return param_name[:-7], "weight"
    else:
        return None, None


def sparsegpt_prune(model, model_name, feature_hessians, 
EPSILON, SPARSENESS, B, Bs, module_blacklist=opt_blacklist, iterative=True):
    module_dict = {}
    for n, m in model.named_modules():
        module_dict[n] = m
    # print(module_dict.keys())
    
    param_names = []
    param_dict = {}
    for n, m in model.named_parameters():
        param_names.append(n)
        param_dict[n] = m
    # print(parameter_list)

    model.eval()
    with torch.no_grad():
        # for name in tqdm(param_names):
        for param_name in tqdm(param_names, total=len(param_names)):
            module_name, param_type = get_module_name(param_name)

            # skip bias, embed, etc parameters
            if module_name in module_blacklist or module_name is None \
                or param_type is None or param_type!="weight":
                continue

            if len(param_dict[param_name].shape) < 2:
                continue

            param = param_dict[param_name]

            #print(f"Doing layer {name}")
            # get layer input from features, key is get_feature_storage_name(module_name)
            # get_feature_storage_name(module_name) stores k_proj, v_proj, q_proj together
            # since they are the same input
            
            layer_hessian = feature_hessians[get_feature_storage_name(module_name)].to(device=device)

            # calculate inverse hessian
            # check if input is flattened e.g. from 8,512,768 to 4096,768
            inv_hess = calc_inverse_hessian(layer_hessian, epsilon=EPSILON)

            # calculate mask
            if iterative:
                mask = iterative_calculate_mask.calculate_mask(W=param, H_inv=inv_hess, p=SPARSENESS, B=B, Bs=Bs)
            else:
                mask = calculate_mask.calculate_mask(W=param, H_inv=inv_hess, p=SPARSENESS, B=B, Bs=Bs)

            # get module from lookup dictionary by module name
            module = module_dict[module_name]
            # apply mask
            prune.custom_from_mask(module=module, name=param_type, mask=mask)
            prune.remove(module=module, name=param_type)
            gc.collect()
            torch.cuda.empty_cache()  

    pruned_model_name = f'{model_name}-{SPARSENESS}'

    torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')

# Sparsify and Finetune Model

In [10]:
calibration_size=128
token_length=1024
calibration_batch_size=1

EPSILON = 1e-8
B = 128 
Bs = 128

#hyperparam test, remove later
EPOCH_COUNT = 10

#set device
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
model_name = "opt-1.3b"
#load tokenizer
tokenizer = AutoTokenizer.from_pretrained(f'facebook/{model_name}')
#Load dataset
dataset = load_dataset('c4', 'en', streaming=True)

In [11]:
#from testing_module import finetune_model
#from trainingv2 import fine_tune
SPARSITIES = [0.2,0.3,0.5,0.7,0.9,1]#0.1, 0.2,0.3,0.5,0.7,0.9,1
#encode tokens

    
for i, SPARSITY in enumerate(tqdm(SPARSITIES, total=len(SPARSITIES))):
    # Load model with pre-trained head
    '''model = OPTForCausalLM.from_pretrained(f'facebook/{model_name}', output_attentions=True,
                                           output_hidden_states=True).to(device=device) # type: ignore
    model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])
    !nvidia-smi
    
    if i == 0:
        feature_hessians = {}
        #put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
        put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
        split_model_calibration(model)
    torch.cuda.empty_cache()
    iterative_sparsegpt_prune_tune(model=model, model_size=model_name,
                                   sparseness_sequence=[SPARSITY],
                                   feature_hessians=feature_hessians,
                                   EPSILON=EPSILON, B=B, Bs=Bs,
                                   tokenizer=tokenizer,
                                   EPOCH_COUNT=EPOCH_COUNT)'''
    finetune_model(model_name=model_name, tokenizer=tokenizer, SPARSITY=SPARSITY)

  0%|          | 0/6 [00:00<?, ?it/s]
  0%|          | 0/10 [00:00<?, ?it/s][AAsking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
  0%|          | 0/10 [00:01<?, ?it/s]
  0%|          | 0/6 [00:27<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 39.41 GiB total capacity; 38.02 GiB already allocated; 32.50 MiB free; 38.08 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

# Testing model

In [9]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33maaquib111[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
import numpy as np
from testing_module import test_model

model_name = "opt-350m"
token_length=1024
stride = 512
wandb.init(project="ICLR", 
           name = f'{model_name} Wikitext Test', 
           config={'token_length': token_length,
                 'model_name': model_name,
                 'stride': stride,
                 'fine_tuned': 'not finetuned'})
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(f'facebook/{model_name}', 
                                          padding_side='left', 
                                          use_fast=False)
# Load dataset
test_set = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
encodings = tokenizer("\n\n".join(test_set['text']), return_tensors='pt')

seq_len = encodings.input_ids.size(1)
SPARSITIES = [0.2, 0.3, 0.5, 0.7, 0.9, 1]#, 0.4, 0.6, 0.8, 1

for SPARSITY in SPARSITIES:
    test_model(model_name, encodings, token_length, seq_len, stride, wandb, SPARSITY, is_finetuned=False)
    
### NOW DO FINETUNED
wandb.init(project="ICLR", 
           name = f'{model_name} Wikitext Test', 
           config={'token_length': token_length,
                 'model_name': model_name,
                 'stride': stride,
                 'fine_tuned': 'finetuned'})
for SPARSITY in SPARSITIES:
    test_model(model_name, encodings, token_length, seq_len, stride, wandb, SPARSITY, is_finetuned=True)

Found cached dataset wikitext (/gs/gsfs0/users/asyed/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
100%|█████████▉| 560/562 [00:48<00:00, 11.55it/s]
100%|█████████▉| 560/562 [00:48<00:00, 11.53it/s]
100%|█████████▉| 560/562 [00:48<00:00, 11.46it/s]
100%|█████████▉| 560/562 [00:49<00:00, 11.40it/s]
100%|█████████▉| 560/562 [00:49<00:00, 11.38it/s]
100%|█████████▉| 560/562 [00:48<00:00, 11.47it/s]


VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
density,▁▂▄▅▇█
perplexity,▃█▅▁▁▁

0,1
density,1.0
perplexity,20.87292


100%|█████████▉| 560/562 [00:48<00:00, 11.52it/s]
100%|█████████▉| 560/562 [00:48<00:00, 11.51it/s]
100%|█████████▉| 560/562 [00:48<00:00, 11.44it/s]
100%|█████████▉| 560/562 [00:49<00:00, 11.42it/s]
100%|█████████▉| 560/562 [00:49<00:00, 11.42it/s]
100%|█████████▉| 560/562 [00:49<00:00, 11.37it/s]


# Iteratively Prune

In [9]:
from SparseGPT_pruning import sparsegpt_prune

In [None]:
sparsegpt_prune()