In [1]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33maaquib111[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
from tqdm import tqdm
from save_pruned_model import load_unmasked_model, load_masked_model
import torch
from torch.nn.utils import prune
from transformers import AutoTokenizer, OPTForCausalLM
from datasets import load_dataset
import gc

# Constants
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = "opt-125m"
token_length=1024
stride = 512

wandb.init(project="ICLR", name = f'{model_name} Finetuned Wikitext Test', config={'token_length': token_length,
                                                             'model_name': model_name,
                                                             'stride': stride,
                                                             'fine_tuned': 'finetuned'})
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(f'facebook/{model_name}', 
                                          padding_side='left', 
                                          use_fast=False)
# Load dataset
test_set = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
encodings = tokenizer("\n\n".join(test_set['text']), return_tensors='pt')

Found cached dataset wikitext (/gs/gsfs0/users/asyed/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


In [3]:
import numpy as np

seq_len = encodings.input_ids.size(1)
SPARSITIES = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9, 1]#, 0.4, 0.6, 0.8, 1

for SPARSITY in SPARSITIES:
    loaded_model = OPTForCausalLM.from_pretrained(f'facebook/{model_name}',
                                                  output_attentions=True,
                                                  output_hidden_states=True).to(device=device) # type: ignore
    loaded_model = torch.nn.DataParallel(loaded_model, device_ids=[0,1,2,3])


    load_unmasked_model(loaded_model, 
                            f'pruned_models/{model_name}-{SPARSITY}-finetuned.pt')
    loaded_model.eval()
    _ = loaded_model(torch.randint(high=20, size=(1,10)))
    
    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + token_length, seq_len)
        trg_len = end_loc - prev_end_loc
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device=device)
        target_ids = input_ids.clone()
        target_ids[:,:-trg_len] = -100
        
        with torch.no_grad():
            outputs = loaded_model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs.loss * trg_len
            
        nlls.append(neg_log_likelihood)
        
        prev_end_loc = end_loc
        if end_loc == seq_len:
            break
            
    ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
    wandb.log({"perplexity": ppl, 'density': SPARSITY})
    
    del loaded_model
    gc.collect()
    torch.cuda.empty_cache()

RuntimeError: Error(s) in loading state_dict for DataParallel:
	Missing key(s) in state_dict: "module.model.decoder.embed_tokens.weight", "module.model.decoder.embed_positions.weight", "module.model.decoder.final_layer_norm.weight", "module.model.decoder.layers.0.self_attn.k_proj.weight", "module.model.decoder.layers.0.self_attn.v_proj.weight", "module.model.decoder.layers.0.self_attn.q_proj.weight", "module.model.decoder.layers.0.self_attn.out_proj.weight", "module.model.decoder.layers.0.self_attn_layer_norm.weight", "module.model.decoder.layers.0.fc1.weight", "module.model.decoder.layers.0.fc2.weight", "module.model.decoder.layers.0.final_layer_norm.weight", "module.model.decoder.layers.1.self_attn.k_proj.weight", "module.model.decoder.layers.1.self_attn.v_proj.weight", "module.model.decoder.layers.1.self_attn.q_proj.weight", "module.model.decoder.layers.1.self_attn.out_proj.weight", "module.model.decoder.layers.1.self_attn_layer_norm.weight", "module.model.decoder.layers.1.fc1.weight", "module.model.decoder.layers.1.fc2.weight", "module.model.decoder.layers.1.final_layer_norm.weight", "module.model.decoder.layers.2.self_attn.k_proj.weight", "module.model.decoder.layers.2.self_attn.v_proj.weight", "module.model.decoder.layers.2.self_attn.q_proj.weight", "module.model.decoder.layers.2.self_attn.out_proj.weight", "module.model.decoder.layers.2.self_attn_layer_norm.weight", "module.model.decoder.layers.2.fc1.weight", "module.model.decoder.layers.2.fc2.weight", "module.model.decoder.layers.2.final_layer_norm.weight", "module.model.decoder.layers.3.self_attn.k_proj.weight", "module.model.decoder.layers.3.self_attn.v_proj.weight", "module.model.decoder.layers.3.self_attn.q_proj.weight", "module.model.decoder.layers.3.self_attn.out_proj.weight", "module.model.decoder.layers.3.self_attn_layer_norm.weight", "module.model.decoder.layers.3.fc1.weight", "module.model.decoder.layers.3.fc2.weight", "module.model.decoder.layers.3.final_layer_norm.weight", "module.model.decoder.layers.4.self_attn.k_proj.weight", "module.model.decoder.layers.4.self_attn.v_proj.weight", "module.model.decoder.layers.4.self_attn.q_proj.weight", "module.model.decoder.layers.4.self_attn.out_proj.weight", "module.model.decoder.layers.4.self_attn_layer_norm.weight", "module.model.decoder.layers.4.fc1.weight", "module.model.decoder.layers.4.fc2.weight", "module.model.decoder.layers.4.final_layer_norm.weight", "module.model.decoder.layers.5.self_attn.k_proj.weight", "module.model.decoder.layers.5.self_attn.v_proj.weight", "module.model.decoder.layers.5.self_attn.q_proj.weight", "module.model.decoder.layers.5.self_attn.out_proj.weight", "module.model.decoder.layers.5.self_attn_layer_norm.weight", "module.model.decoder.layers.5.fc1.weight", "module.model.decoder.layers.5.fc2.weight", "module.model.decoder.layers.5.final_layer_norm.weight", "module.model.decoder.layers.6.self_attn.k_proj.weight", "module.model.decoder.layers.6.self_attn.v_proj.weight", "module.model.decoder.layers.6.self_attn.q_proj.weight", "module.model.decoder.layers.6.self_attn.out_proj.weight", "module.model.decoder.layers.6.self_attn_layer_norm.weight", "module.model.decoder.layers.6.fc1.weight", "module.model.decoder.layers.6.fc2.weight", "module.model.decoder.layers.6.final_layer_norm.weight", "module.model.decoder.layers.7.self_attn.k_proj.weight", "module.model.decoder.layers.7.self_attn.v_proj.weight", "module.model.decoder.layers.7.self_attn.q_proj.weight", "module.model.decoder.layers.7.self_attn.out_proj.weight", "module.model.decoder.layers.7.self_attn_layer_norm.weight", "module.model.decoder.layers.7.fc1.weight", "module.model.decoder.layers.7.fc2.weight", "module.model.decoder.layers.7.final_layer_norm.weight", "module.model.decoder.layers.8.self_attn.k_proj.weight", "module.model.decoder.layers.8.self_attn.v_proj.weight", "module.model.decoder.layers.8.self_attn.q_proj.weight", "module.model.decoder.layers.8.self_attn.out_proj.weight", "module.model.decoder.layers.8.self_attn_layer_norm.weight", "module.model.decoder.layers.8.fc1.weight", "module.model.decoder.layers.8.fc2.weight", "module.model.decoder.layers.8.final_layer_norm.weight", "module.model.decoder.layers.9.self_attn.k_proj.weight", "module.model.decoder.layers.9.self_attn.v_proj.weight", "module.model.decoder.layers.9.self_attn.q_proj.weight", "module.model.decoder.layers.9.self_attn.out_proj.weight", "module.model.decoder.layers.9.self_attn_layer_norm.weight", "module.model.decoder.layers.9.fc1.weight", "module.model.decoder.layers.9.fc2.weight", "module.model.decoder.layers.9.final_layer_norm.weight", "module.model.decoder.layers.10.self_attn.k_proj.weight", "module.model.decoder.layers.10.self_attn.v_proj.weight", "module.model.decoder.layers.10.self_attn.q_proj.weight", "module.model.decoder.layers.10.self_attn.out_proj.weight", "module.model.decoder.layers.10.self_attn_layer_norm.weight", "module.model.decoder.layers.10.fc1.weight", "module.model.decoder.layers.10.fc2.weight", "module.model.decoder.layers.10.final_layer_norm.weight", "module.model.decoder.layers.11.self_attn.k_proj.weight", "module.model.decoder.layers.11.self_attn.v_proj.weight", "module.model.decoder.layers.11.self_attn.q_proj.weight", "module.model.decoder.layers.11.self_attn.out_proj.weight", "module.model.decoder.layers.11.self_attn_layer_norm.weight", "module.model.decoder.layers.11.fc1.weight", "module.model.decoder.layers.11.fc2.weight", "module.model.decoder.layers.11.final_layer_norm.weight". 
	Unexpected key(s) in state_dict: "module.model.decoder.embed_tokens.weight_orig", "module.model.decoder.embed_tokens.weight_mask", "module.model.decoder.embed_positions.weight_orig", "module.model.decoder.embed_positions.weight_mask", "module.model.decoder.final_layer_norm.weight_orig", "module.model.decoder.final_layer_norm.weight_mask", "module.model.decoder.layers.0.self_attn.k_proj.weight_orig", "module.model.decoder.layers.0.self_attn.k_proj.weight_mask", "module.model.decoder.layers.0.self_attn.v_proj.weight_orig", "module.model.decoder.layers.0.self_attn.v_proj.weight_mask", "module.model.decoder.layers.0.self_attn.q_proj.weight_orig", "module.model.decoder.layers.0.self_attn.q_proj.weight_mask", "module.model.decoder.layers.0.self_attn.out_proj.weight_orig", "module.model.decoder.layers.0.self_attn.out_proj.weight_mask", "module.model.decoder.layers.0.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.0.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.0.fc1.weight_orig", "module.model.decoder.layers.0.fc1.weight_mask", "module.model.decoder.layers.0.fc2.weight_orig", "module.model.decoder.layers.0.fc2.weight_mask", "module.model.decoder.layers.0.final_layer_norm.weight_orig", "module.model.decoder.layers.0.final_layer_norm.weight_mask", "module.model.decoder.layers.1.self_attn.k_proj.weight_orig", "module.model.decoder.layers.1.self_attn.k_proj.weight_mask", "module.model.decoder.layers.1.self_attn.v_proj.weight_orig", "module.model.decoder.layers.1.self_attn.v_proj.weight_mask", "module.model.decoder.layers.1.self_attn.q_proj.weight_orig", "module.model.decoder.layers.1.self_attn.q_proj.weight_mask", "module.model.decoder.layers.1.self_attn.out_proj.weight_orig", "module.model.decoder.layers.1.self_attn.out_proj.weight_mask", "module.model.decoder.layers.1.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.1.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.1.fc1.weight_orig", "module.model.decoder.layers.1.fc1.weight_mask", "module.model.decoder.layers.1.fc2.weight_orig", "module.model.decoder.layers.1.fc2.weight_mask", "module.model.decoder.layers.1.final_layer_norm.weight_orig", "module.model.decoder.layers.1.final_layer_norm.weight_mask", "module.model.decoder.layers.2.self_attn.k_proj.weight_orig", "module.model.decoder.layers.2.self_attn.k_proj.weight_mask", "module.model.decoder.layers.2.self_attn.v_proj.weight_orig", "module.model.decoder.layers.2.self_attn.v_proj.weight_mask", "module.model.decoder.layers.2.self_attn.q_proj.weight_orig", "module.model.decoder.layers.2.self_attn.q_proj.weight_mask", "module.model.decoder.layers.2.self_attn.out_proj.weight_orig", "module.model.decoder.layers.2.self_attn.out_proj.weight_mask", "module.model.decoder.layers.2.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.2.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.2.fc1.weight_orig", "module.model.decoder.layers.2.fc1.weight_mask", "module.model.decoder.layers.2.fc2.weight_orig", "module.model.decoder.layers.2.fc2.weight_mask", "module.model.decoder.layers.2.final_layer_norm.weight_orig", "module.model.decoder.layers.2.final_layer_norm.weight_mask", "module.model.decoder.layers.3.self_attn.k_proj.weight_orig", "module.model.decoder.layers.3.self_attn.k_proj.weight_mask", "module.model.decoder.layers.3.self_attn.v_proj.weight_orig", "module.model.decoder.layers.3.self_attn.v_proj.weight_mask", "module.model.decoder.layers.3.self_attn.q_proj.weight_orig", "module.model.decoder.layers.3.self_attn.q_proj.weight_mask", "module.model.decoder.layers.3.self_attn.out_proj.weight_orig", "module.model.decoder.layers.3.self_attn.out_proj.weight_mask", "module.model.decoder.layers.3.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.3.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.3.fc1.weight_orig", "module.model.decoder.layers.3.fc1.weight_mask", "module.model.decoder.layers.3.fc2.weight_orig", "module.model.decoder.layers.3.fc2.weight_mask", "module.model.decoder.layers.3.final_layer_norm.weight_orig", "module.model.decoder.layers.3.final_layer_norm.weight_mask", "module.model.decoder.layers.4.self_attn.k_proj.weight_orig", "module.model.decoder.layers.4.self_attn.k_proj.weight_mask", "module.model.decoder.layers.4.self_attn.v_proj.weight_orig", "module.model.decoder.layers.4.self_attn.v_proj.weight_mask", "module.model.decoder.layers.4.self_attn.q_proj.weight_orig", "module.model.decoder.layers.4.self_attn.q_proj.weight_mask", "module.model.decoder.layers.4.self_attn.out_proj.weight_orig", "module.model.decoder.layers.4.self_attn.out_proj.weight_mask", "module.model.decoder.layers.4.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.4.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.4.fc1.weight_orig", "module.model.decoder.layers.4.fc1.weight_mask", "module.model.decoder.layers.4.fc2.weight_orig", "module.model.decoder.layers.4.fc2.weight_mask", "module.model.decoder.layers.4.final_layer_norm.weight_orig", "module.model.decoder.layers.4.final_layer_norm.weight_mask", "module.model.decoder.layers.5.self_attn.k_proj.weight_orig", "module.model.decoder.layers.5.self_attn.k_proj.weight_mask", "module.model.decoder.layers.5.self_attn.v_proj.weight_orig", "module.model.decoder.layers.5.self_attn.v_proj.weight_mask", "module.model.decoder.layers.5.self_attn.q_proj.weight_orig", "module.model.decoder.layers.5.self_attn.q_proj.weight_mask", "module.model.decoder.layers.5.self_attn.out_proj.weight_orig", "module.model.decoder.layers.5.self_attn.out_proj.weight_mask", "module.model.decoder.layers.5.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.5.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.5.fc1.weight_orig", "module.model.decoder.layers.5.fc1.weight_mask", "module.model.decoder.layers.5.fc2.weight_orig", "module.model.decoder.layers.5.fc2.weight_mask", "module.model.decoder.layers.5.final_layer_norm.weight_orig", "module.model.decoder.layers.5.final_layer_norm.weight_mask", "module.model.decoder.layers.6.self_attn.k_proj.weight_orig", "module.model.decoder.layers.6.self_attn.k_proj.weight_mask", "module.model.decoder.layers.6.self_attn.v_proj.weight_orig", "module.model.decoder.layers.6.self_attn.v_proj.weight_mask", "module.model.decoder.layers.6.self_attn.q_proj.weight_orig", "module.model.decoder.layers.6.self_attn.q_proj.weight_mask", "module.model.decoder.layers.6.self_attn.out_proj.weight_orig", "module.model.decoder.layers.6.self_attn.out_proj.weight_mask", "module.model.decoder.layers.6.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.6.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.6.fc1.weight_orig", "module.model.decoder.layers.6.fc1.weight_mask", "module.model.decoder.layers.6.fc2.weight_orig", "module.model.decoder.layers.6.fc2.weight_mask", "module.model.decoder.layers.6.final_layer_norm.weight_orig", "module.model.decoder.layers.6.final_layer_norm.weight_mask", "module.model.decoder.layers.7.self_attn.k_proj.weight_orig", "module.model.decoder.layers.7.self_attn.k_proj.weight_mask", "module.model.decoder.layers.7.self_attn.v_proj.weight_orig", "module.model.decoder.layers.7.self_attn.v_proj.weight_mask", "module.model.decoder.layers.7.self_attn.q_proj.weight_orig", "module.model.decoder.layers.7.self_attn.q_proj.weight_mask", "module.model.decoder.layers.7.self_attn.out_proj.weight_orig", "module.model.decoder.layers.7.self_attn.out_proj.weight_mask", "module.model.decoder.layers.7.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.7.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.7.fc1.weight_orig", "module.model.decoder.layers.7.fc1.weight_mask", "module.model.decoder.layers.7.fc2.weight_orig", "module.model.decoder.layers.7.fc2.weight_mask", "module.model.decoder.layers.7.final_layer_norm.weight_orig", "module.model.decoder.layers.7.final_layer_norm.weight_mask", "module.model.decoder.layers.8.self_attn.k_proj.weight_orig", "module.model.decoder.layers.8.self_attn.k_proj.weight_mask", "module.model.decoder.layers.8.self_attn.v_proj.weight_orig", "module.model.decoder.layers.8.self_attn.v_proj.weight_mask", "module.model.decoder.layers.8.self_attn.q_proj.weight_orig", "module.model.decoder.layers.8.self_attn.q_proj.weight_mask", "module.model.decoder.layers.8.self_attn.out_proj.weight_orig", "module.model.decoder.layers.8.self_attn.out_proj.weight_mask", "module.model.decoder.layers.8.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.8.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.8.fc1.weight_orig", "module.model.decoder.layers.8.fc1.weight_mask", "module.model.decoder.layers.8.fc2.weight_orig", "module.model.decoder.layers.8.fc2.weight_mask", "module.model.decoder.layers.8.final_layer_norm.weight_orig", "module.model.decoder.layers.8.final_layer_norm.weight_mask", "module.model.decoder.layers.9.self_attn.k_proj.weight_orig", "module.model.decoder.layers.9.self_attn.k_proj.weight_mask", "module.model.decoder.layers.9.self_attn.v_proj.weight_orig", "module.model.decoder.layers.9.self_attn.v_proj.weight_mask", "module.model.decoder.layers.9.self_attn.q_proj.weight_orig", "module.model.decoder.layers.9.self_attn.q_proj.weight_mask", "module.model.decoder.layers.9.self_attn.out_proj.weight_orig", "module.model.decoder.layers.9.self_attn.out_proj.weight_mask", "module.model.decoder.layers.9.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.9.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.9.fc1.weight_orig", "module.model.decoder.layers.9.fc1.weight_mask", "module.model.decoder.layers.9.fc2.weight_orig", "module.model.decoder.layers.9.fc2.weight_mask", "module.model.decoder.layers.9.final_layer_norm.weight_orig", "module.model.decoder.layers.9.final_layer_norm.weight_mask", "module.model.decoder.layers.10.self_attn.k_proj.weight_orig", "module.model.decoder.layers.10.self_attn.k_proj.weight_mask", "module.model.decoder.layers.10.self_attn.v_proj.weight_orig", "module.model.decoder.layers.10.self_attn.v_proj.weight_mask", "module.model.decoder.layers.10.self_attn.q_proj.weight_orig", "module.model.decoder.layers.10.self_attn.q_proj.weight_mask", "module.model.decoder.layers.10.self_attn.out_proj.weight_orig", "module.model.decoder.layers.10.self_attn.out_proj.weight_mask", "module.model.decoder.layers.10.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.10.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.10.fc1.weight_orig", "module.model.decoder.layers.10.fc1.weight_mask", "module.model.decoder.layers.10.fc2.weight_orig", "module.model.decoder.layers.10.fc2.weight_mask", "module.model.decoder.layers.10.final_layer_norm.weight_orig", "module.model.decoder.layers.10.final_layer_norm.weight_mask", "module.model.decoder.layers.11.self_attn.k_proj.weight_orig", "module.model.decoder.layers.11.self_attn.k_proj.weight_mask", "module.model.decoder.layers.11.self_attn.v_proj.weight_orig", "module.model.decoder.layers.11.self_attn.v_proj.weight_mask", "module.model.decoder.layers.11.self_attn.q_proj.weight_orig", "module.model.decoder.layers.11.self_attn.q_proj.weight_mask", "module.model.decoder.layers.11.self_attn.out_proj.weight_orig", "module.model.decoder.layers.11.self_attn.out_proj.weight_mask", "module.model.decoder.layers.11.self_attn_layer_norm.weight_orig", "module.model.decoder.layers.11.self_attn_layer_norm.weight_mask", "module.model.decoder.layers.11.fc1.weight_orig", "module.model.decoder.layers.11.fc1.weight_mask", "module.model.decoder.layers.11.fc2.weight_orig", "module.model.decoder.layers.11.fc2.weight_mask", "module.model.decoder.layers.11.final_layer_norm.weight_orig", "module.model.decoder.layers.11.final_layer_norm.weight_mask". 