In [None]:
import wandb
wandb.login()
wandb.init(project="ICLR", name = 'Perplexity Test')

In [None]:
from tqdm import tqdm
from save_pruned_model import load_into_model
import torch
from torch.nn.utils import prune
from transformers import AutoTokenizer, OPTForCausalLM
from datasets import load_dataset
import gc

#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_name = "opt-125m"
token_length=1024

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(f'facebook/{model_name}', padding_side='left')

#test_set = load_dataset('wikitext', 'wikitext-2-v1', split='test[:10%]')
test_set = load_dataset('wikitext', 'wikitext-2-v1', split='test')
tokenized_test = tokenizer(test_set['text'])

flattened_input_ids = [item for sublist in tokenized_test.input_ids for item in sublist]
flattened_input_ids = flattened_input_ids[:(len(flattened_input_ids) - (len(flattened_input_ids) % token_length))]
flattened_input_ids = torch.Tensor(flattened_input_ids).reshape(-1, token_length).type(torch.LongTensor).to(device=device)

flattened_masks = [item for sublist in tokenized_test.attention_mask for item in sublist]
flattened_masks = flattened_masks[:(len(flattened_masks) - (len(flattened_masks) % token_length))]
flattened_masks = torch.Tensor(flattened_masks).reshape(-1, token_length).type(torch.LongTensor).to(device=device)

test_dict = {'input_ids': flattened_input_ids, 'attention_mask': flattened_masks}

from torch.utils.data import Dataset, DataLoader, SequentialSampler

# Create dataset
class WikiSet(Dataset):
    '''
        Dataset that concatenates the CC and MLO view right or left images of the breast tissue 
        into a single 1024x1024 image.
    '''
    
    def __init__(self, test_dict):
        self.test_dict = test_dict
        
    def __len__(self):
        return len(self.test_dict['input_ids'])
    
    def __getitem__(self, idx):
        input_ids = self.test_dict['input_ids'][idx]
        attention_mask = self.test_dict['attention_mask'][idx]
        
        return input_ids, attention_mask
    
wikiset = WikiSet(test_dict)
sampler = SequentialSampler(wikiset)
loader = DataLoader(wikiset, batch_size=16, sampler=sampler)
loop = tqdm(loader)

In [None]:
for SPARSITY in [0.2, 0.4, 0.6, 0.8, 1]:
    loaded_model = OPTForCausalLM.from_pretrained(f'facebook/{model_name}', output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore
    
    if SPARSITY != 1:
        load_into_model(loaded_model, f'pruned_models/{model_name}-{SPARSITY}.pt')
    loaded_model = torch.nn.DataParallel(loaded_model, device_ids=[0,1,2,3])
    loaded_model.eval()
    perp_sparse = 0
    for batch_idx, (input_ids, attention_mask) in enumerate(loop):
        batch_dict = {'input_ids': input_ids,'attention_mask':attention_mask}
        with torch.no_grad():
            sparse_output = loaded_model(**batch_dict, labels=batch_dict['input_ids'])
            perp_sparse += torch.exp(sparse_output.loss)
    perp_sparse /= len(wikiset)
    wandb.log({"sparse_perplexity": perp_sparse, 'sparsity': SPARSITY})
    del loaded_model
    gc.collect()
    torch.cuda.empty_cache()