In [1]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33maaquib111[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
from tqdm import tqdm
from save_pruned_model import load_into_model
import torch
from torch.nn.utils import prune
from transformers import AutoTokenizer, OPTForCausalLM
from datasets import load_dataset
import gc

# Constants
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = "opt-125m"
token_length=4096
stride = 4096

wandb.init(project="ICLR", name = f'{model_name} Wikitext Test', config={'token_length': token_length,
                                                             'model_name': model_name,
                                                             'stride': stride})
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(f'facebook/{model_name}', 
                                          padding_side='left', 
                                          use_fast=False)
# Load dataset
test_set = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
encodings = tokenizer("\n\n".join(test_set['text']), return_tensors='pt')

Found cached dataset wikitext (/gs/gsfs0/users/asyed/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
  0%|                                                    | 0/19 [00:00<?, ?it/s]

tokenized_test = tokenizer(test_set['text'])

flattened_input_ids = [item for sublist in tokenized_test.input_ids for item in sublist]
flattened_input_ids = flattened_input_ids[:(len(flattened_input_ids) - (len(flattened_input_ids) % token_length))]
flattened_input_ids = torch.Tensor(flattened_input_ids).reshape(-1, token_length).type(torch.LongTensor).to(device=device)

flattened_masks = [item for sublist in tokenized_test.attention_mask for item in sublist]
flattened_masks = flattened_masks[:(len(flattened_masks) - (len(flattened_masks) % token_length))]
flattened_masks = torch.Tensor(flattened_masks).reshape(-1, token_length).type(torch.LongTensor).to(device=device)

test_dict = {'input_ids': flattened_input_ids, 'attention_mask': flattened_masks}

from torch.utils.data import Dataset, DataLoader, SequentialSampler

# Create dataset
class WikiSet(Dataset):
    '''
        Dataset that concatenates the CC and MLO view right or left images of the breast tissue 
        into a single 1024x1024 image.
    '''
    
    def __init__(self, test_dict):
        self.test_dict = test_dict
        
    def __len__(self):
        return len(self.test_dict['input_ids'])
    
    def __getitem__(self, idx):
        input_ids = self.test_dict['input_ids'][idx]
        attention_mask = self.test_dict['attention_mask'][idx]
        
        return input_ids, attention_mask
    
wikiset = WikiSet(test_dict)
sampler = SequentialSampler(wikiset)
loader = DataLoader(wikiset, batch_size=batch_size, sampler=sampler)
loop = tqdm(loader)

In [None]:
import numpy as np

seq_len = encodings.input_ids.size(1)
SPARSITIES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9, 1]#, 0.4, 0.6, 0.8, 1

for SPARSITY in SPARSITIES:
    loaded_model = OPTForCausalLM.from_pretrained(f'facebook/{model_name}', output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore
    
    if SPARSITY != 1:
        load_into_model(loaded_model, f'pruned_models/{model_name}-{SPARSITY}.pt')
    loaded_model = torch.nn.DataParallel(loaded_model, device_ids=[0,1,2,3])
    loaded_model.eval()
    _ = loaded_model(torch.randint(high=20, size=(1,10)))
    
    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + token_length, seq_len)
        trg_len = end_loc - prev_end_loc
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device=device)
        target_ids = input_ids.clone()
        target_ids[:,:-trg_len] = -100
        
        with torch.no_grad():
            outputs = loaded_model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs.loss * trg_len
            
        nlls.append(neg_log_likelihood)
        
        prev_end_loc = end_loc
        if end_loc == seq_len:
            break
            
    ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
    wandb.log({"perplexity": ppl, 'density': SPARSITY})
    
    del loaded_model
    gc.collect()
    torch.cuda.empty_cache()