In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from calculate_mask import calculate_mask
from inverse_hessian import calc_inverse_hessian
from input_prehooks import put_input_hooks
import gc

In [2]:
# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train'])), total=calibration_size):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                with torch.no_grad():
                    encoded_input = tokenizer(batch_sentences, return_tensors="pt",
                                              padding="max_length", max_length=token_length,
                                              truncation=True).to(device=device)
                    model(**encoded_input, labels=encoded_input.input_ids)
                    torch.cuda.empty_cache()
                    batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break

# Sparsify and Finetune Model

In [3]:
calibration_size=128
token_length=1024
calibration_batch_size=1

EPSILON = 1e-8
B = 128
Bs = 128

#hyperparam test, remove later
EPOCH_COUNT = 10

#set device
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
model_name = "opt-2.7b"
#load tokenizer
tokenizer = AutoTokenizer.from_pretrained(f'facebook/{model_name}')
#Load dataset
dataset = load_dataset('c4', 'en', streaming=True)

In [None]:
from testing_module import finetune_model
from iterative_prune_finetune import iterative_sparsegpt_prune_tune
SPARSITIES = [0.2,0.3,0.5,0.7,0.9,1]#0.1, 0.2,0.3,0.5,0.7,0.9,1
#encode tokens

    
for i, SPARSITY in enumerate(tqdm(SPARSITIES, total=len(SPARSITIES))):
    # Load model with pre-trained head
    model = OPTForCausalLM.from_pretrained(f'facebook/{model_name}', output_attentions=True,
                                           output_hidden_states=True).to(device=device) # type: ignore
    model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])
    
    if i == 0:
        feature_hessians = {}
        #put_input_hooks(model=model, features=feature_hessians, storage_dir=storage_dir, offload_freq=10000, feature_storage_device='cpu')
        put_input_hooks(model=model, features=feature_hessians, feature_storage_device='cpu')
        split_model_calibration(model)
    torch.cuda.empty_cache()
    iterative_sparsegpt_prune_tune(model=model, model_size=model_name,
                                   sparseness_sequence=[SPARSITY],
                                   feature_hessians=feature_hessians,
                                   EPSILON=EPSILON, B=B, Bs=Bs,
                                   tokenizer=tokenizer,
                                   EPOCH_COUNT=EPOCH_COUNT)

  0%|          | 0/6 [00:00<?, ?it/s]
  0%|          | 0/128 [00:00<?, ?it/s][A

  2%|▏         | 2/128 [00:04<05:21,  2.55s/it][A
  2%|▏         | 3/128 [00:09<07:31,  3.61s/it][A
  3%|▎         | 4/128 [00:14<08:28,  4.10s/it][A
  4%|▍         | 5/128 [00:18<08:58,  4.38s/it][A
  5%|▍         | 6/128 [00:23<09:14,  4.55s/it][A
  5%|▌         | 7/128 [00:28<09:22,  4.65s/it][A
  6%|▋         | 8/128 [00:33<09:25,  4.72s/it][A
  7%|▋         | 9/128 [00:38<09:26,  4.76s/it][A
  8%|▊         | 10/128 [00:43<09:25,  4.79s/it][A
  9%|▊         | 11/128 [00:48<09:23,  4.82s/it][A
  9%|▉         | 12/128 [00:53<09:20,  4.83s/it][A
 10%|█         | 13/128 [00:57<09:16,  4.84s/it][A
 11%|█         | 14/128 [01:02<09:12,  4.85s/it][A
 12%|█▏        | 15/128 [01:07<09:08,  4.85s/it][A
 12%|█▎        | 16/128 [01:12<09:03,  4.86s/it][A
 13%|█▎        | 17/128 [01:17<09:00,  4.87s/it][A
 14%|█▍        | 18/128 [01:22<08:55,  4.87s/it][A
 15%|█▍        | 19/128 [01:27<08:50,  4.8

# Testing model

import wandb
wandb.login()

import numpy as np
from testing_module import test_model

model_name = "opt-350m"
token_length=1024
stride = 512
wandb.init(project="ICLR", 
           name = f'{model_name} Wikitext Test', 
           config={'token_length': token_length,
                 'model_name': model_name,
                 'stride': stride,
                 'fine_tuned': 'not finetuned'})
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(f'facebook/{model_name}', 
                                          padding_side='left', 
                                          use_fast=False)
# Load dataset
test_set = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
encodings = tokenizer("\n\n".join(test_set['text']), return_tensors='pt')

seq_len = encodings.input_ids.size(1)
SPARSITIES = [0.2, 0.3, 0.5, 0.7, 0.9, 1]#, 0.4, 0.6, 0.8, 1

for SPARSITY in SPARSITIES:
    test_model(model_name, encodings, token_length, seq_len, stride, wandb, SPARSITY, is_finetuned=False)
    
### NOW DO FINETUNED
wandb.init(project="ICLR", 
           name = f'{model_name} Wikitext Test', 
           config={'token_length': token_length,
                 'model_name': model_name,
                 'stride': stride,
                 'fine_tuned': 'finetuned'})
for SPARSITY in SPARSITIES:
    test_model(model_name, encodings, token_length, seq_len, stride, wandb, SPARSITY, is_finetuned=True)

# Iteratively Prune

from SparseGPT_pruning import sparsegpt_prune

sparsegpt_prune()