In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from utils.prune_utils import sparsegpt_prune

In [None]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#Load dataset
dataset = load_dataset('c4', "en", streaming=True)

In [None]:
# Calibrate model (get inputs to each layer with calibration data)

calibration_size=128
token_length=2048
calibration_batch_size=2

EPSILON = 0.01
B = 128
Bs = 128

In [None]:
from iterative_prune_finetune import iterative_sparsegpt_prune_tune
from utils.prehook_utils import put_input_hooks,remove_all_hooks
from utils.prune_utils import sparsegpt_prune
from utils.finetune_utils import finetune_model_inplace
from utils.save_utils import unmask_model, load_masked_model_single
from utils.fsdp_finetune import fsdp_finetune

for model_size in ['opt-1.3b']:
    SPARSITIES = [1, 0.9, 0.7, 0.5, 0.3, 0.2]#0.1, 0.2,0.3,0.5,0.7,0.9,1
    
    for SPARSENESS in SPARSITIES:
        model = OPTForCausalLM.from_pretrained(f'facebook/{model_size}', 
                                           output_attentions=True, 
                                           output_hidden_states=True)

        load_masked_model_single(model, f'pruned_models/{model_size}-{SPARSENESS}.pt')
        torch.cuda.empty_cache()
        #finetune_model_inplace(model=model, tokenizer=tokenizer, 
        #                       SPARSITY=SPARSENESS, device=device, EPOCH_COUNT=1, max_step=1000)
        #unmask_model(model=model)
        config = {"model": model, "lr": 2e-5, "num_epochs": 1,
                  "seed": 1, "batch_size": 16,
                  'model_name': model_size,
                  'sparsity': SPARSENESS,"train_steps": 10000,
                  'max_step': 10000, 'save_model': True}
        fsdp_finetune(config)
        del model
        torch.cuda.empty_cache()