In [None]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from calculate_mask import calculate_mask
from inverse_hessian import inverse_hessian
from input_prehooks import put_input_hooks
from testing_module import calculate_perp

In [2]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_name = "facebook/opt-125m"

#Load dataset
dataset = load_dataset('c4', 'en', streaming=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

# Load model with pre-trained head
model = OPTForCausalLM.from_pretrained(model_name, output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore

# Load generator
generator = pipeline('text-generation', model=model_name)

In [3]:
# Calibrate model (get inputs to each layer with calibration data)

calibration_size=4
token_length=512
calibrate_on_cpu = False
calibration_batch_size=2

# First, put in forward hooks
features = {}
put_input_hooks(model=model, features=features, feature_storage_device='cpu')


# run model on batches of calibration data, then concatenate inputs
def split_model_calibration(model):
    batch_sentences = []
    for i, data in tqdm(enumerate(iter(dataset['train']))):
        if i < calibration_size + 1:
            if len(batch_sentences) >= calibration_batch_size:
                encoded_input = tokenizer(batch_sentences, return_tensors="pt", padding="max_length", max_length=token_length, truncation=True)
                model(**encoded_input, labels=encoded_input.input_ids)
                batch_sentences = []
            batch_sentences.append(data['text'])
        else:
            break

split_model_calibration(model)


model
model.decoder
model.decoder.embed_tokens
model.decoder.embed_positions
model.decoder.final_layer_norm
model.decoder.layers
model.decoder.layers.0
model.decoder.layers.0.self_attn
model.decoder.layers.0.self_attn.k_proj
model.decoder.layers.0.self_attn.v_proj
model.decoder.layers.0.self_attn.q_proj
model.decoder.layers.0.self_attn.out_proj
model.decoder.layers.0.activation_fn
model.decoder.layers.0.self_attn_layer_norm
model.decoder.layers.0.fc1
model.decoder.layers.0.fc2
model.decoder.layers.0.final_layer_norm
model.decoder.layers.1
model.decoder.layers.1.self_attn
model.decoder.layers.1.self_attn.k_proj
model.decoder.layers.1.self_attn.v_proj
model.decoder.layers.1.self_attn.q_proj
model.decoder.layers.1.self_attn.out_proj
model.decoder.layers.1.activation_fn
model.decoder.layers.1.self_attn_layer_norm
model.decoder.layers.1.fc1
model.decoder.layers.1.fc2
model.decoder.layers.1.final_layer_norm
model.decoder.layers.2
model.decoder.layers.2.self_attn
model.decoder.layers.2.self_

5it [00:02,  2.01it/s]


In [8]:
for k in features.keys():
    print(f"{k}, shape is {features[k].dtype}")


model
model.decoder
model.decoder.embed_tokens
model.decoder.embed_positions
model.decoder.final_layer_norm
model.decoder.layers
model.decoder.layers.0
model.decoder.layers.0.self_attn
model.decoder.layers.0.self_attn.k_proj
model.decoder.layers.0.self_attn.v_proj
model.decoder.layers.0.self_attn.q_proj
model.decoder.layers.0.self_attn.out_proj
model.decoder.layers.0.activation_fn
model.decoder.layers.0.self_attn_layer_norm
model.decoder.layers.0.fc1
model.decoder.layers.0.fc2
model.decoder.layers.0.final_layer_norm
model.decoder.layers.1
model.decoder.layers.1.self_attn
model.decoder.layers.1.self_attn.k_proj
model.decoder.layers.1.self_attn.v_proj
model.decoder.layers.1.self_attn.q_proj
model.decoder.layers.1.self_attn.out_proj
model.decoder.layers.1.activation_fn
model.decoder.layers.1.self_attn_layer_norm
model.decoder.layers.1.fc1
model.decoder.layers.1.fc2
model.decoder.layers.1.final_layer_norm
model.decoder.layers.2
model.decoder.layers.2.self_attn
model.decoder.layers.2.self_

In [5]:
# print(features["model.decoder.layers.1.self_attn.out_proj"])
# print(features["model.decoder.layers.1.self_attn_layer_norm"])
# print(torch.equal(features["model.decoder.layers.5.self_attn.k_proj"], features["model.decoder.layers.5.self_attn.q_proj"]))
# print(torch.equal(features["model.decoder.layers.1.final_layer_norm"], features["model.decoder.layers.1.fc1"]))
# print(features["model.decoder.layers.1.final_layer_norm"])
# print(features["model.decoder.layers.1.fc1"])

KeyError: 'model.decoder.layers.1.self_attn.out_proj'

# Sparsify Model

In [None]:
# make a dictionary to access module by name
module_lookup_dict = {}
for module_name, module_iter in model.named_modules():
    module_lookup_dict[module_name] = module_iter
EPSILON = 1e-8
SPARSENESS = .5
B = 2
Bs = 2

# function to get module name from parameter name
def get_module_name(param_name):
    if param_name[-5:] == ".bias":
        return param_name[:-5], "bias"
    elif param_name[-7:] == ".weight":
        return param_name[:-7], "weight"
    else:
        return None, None

In [None]:
from input_prehooks import get_feature_storage_name

layer_blacklist = ['model.decoder.embed_tokens.weight', 'model.decoder.embed_tokens.bias',
'model.decoder.embed_positions.weight']

# Using calibration data (inputs to each intermediate weight layer)
# Iterate through named parameters, calculate inverse hessian and calculate mask

# without this
param_lookup_dict = {}
param_names = []
for name, param in model.named_parameters():
    param_names.append(name)
    param_lookup_dict[name] = param

with torch.no_grad():
    # for name in tqdm(param_names):
    for name in param_names:
        param = param_lookup_dict[name]

        # skip the embed layer
        if name in layer_blacklist:
            continue
        
        # skip norms which have 1 dimension
        if len(param.shape) < 2:
            continue

        module_name, param_type = get_module_name(name)

        # apply to weight layers
        if param_type == "weight":
            print(f"Doing layer {name}")
            # get layer input from features, key is get_feature_storage_name(module_name)
            # get_feature_storage_name(module_name) stores k_proj, v_proj, q_proj together
            # since they are the same input
            layer_input = features[get_feature_storage_name(module_name)].to(device=device)
            
            # calculate inverse hessian
            # check if input is flattened e.g. from 8,512,768 to 4096,768
            if len(layer_input.shape) == 2:
                inv_hess = inverse_hessian(torch.transpose(layer_input, 0, 1), epsilon=EPSILON, 
                flattened=True)

            else:
                inv_hess = inverse_hessian(torch.transpose(layer_input, 1, 2), epsilon=EPSILON,
                flattened=False)

            # calculate mask
            mask = calculate_mask(W=param, H_inv=inv_hess, p=SPARSENESS, B=B, Bs=Bs)
            
            # get module from lookup dictionary by module name
            module = module_lookup_dict[module_name]
            # apply mask
            prune.custom_from_mask(module=module, name=param_type, mask=mask)

# Save Pruned Model

In [None]:
# SAVE PRUNED MODEL
pruned_model_name = f'opt-125m-{SPARSENESS}'
# torch.save(model,'pruned_models/' + pruned_model_name)
# model.save_pretrained(save_directory = 'pruned_models/' + pruned_model_name)

torch.save(model.state_dict(), f'pruned_models/{pruned_model_name}.pt')

In [None]:
# LOAD SAVED MODEL

from save_pruned_model import load_into_model
import torch
from torch.nn.utils import prune
from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset
device = 'cuda' if torch.cuda.is_available() else 'cpu'

loaded_model = OPTForCausalLM.from_pretrained('facebook/opt-125m', output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore

load_into_model(loaded_model, 'pruned_models/opt-125m.pt')

In [None]:
model_name = "facebook/opt-125m"

def get_prop_zeros(model):
    return torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) / (torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight == 0) + torch.sum(model.get_decoder().layers[0].self_attn.k_proj.weight != 0))

print(get_prop_zeros(loaded_model))
print(get_prop_zeros(model))

# Sample Testing

In [None]:
'''# REGULAR OUTPUT
dense_model = OPTForCausalLM.from_pretrained("facebook/opt-125m", output_attentions=True, output_hidden_states=True).to(device=device)
encoded_test_input = tokenizer('What did you just say to me? I will have you know', return_tensors="pt",
                                                                                    padding="max_length", 
                                                                                    max_length=token_length, 
                                                                                    truncation=True)
#print(encoded_test_input)
print('DENSE MODEL:')
with torch.no_grad():
    generated_ids = dense_model.generate(**encoded_test_input, max_new_tokens=30, num_beams=5, do_sample=True)
dense_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(f'\tOutput: {dense_output}')

print('SPARSE MODEL: ')
with torch.no_grad():
    generated_ids = model.generate(**encoded_test_input, max_new_tokens=30, num_beams=5, do_sample=True)
sparse_output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f'\tOutput: {sparse_output}')''''

In [None]:
dense_model = OPTForCausalLM.from_pretrained("facebook/opt-125m", output_attentions=True, output_hidden_states=True).to(device=device)
encoded_test_input = tokenizer('What did you just say to me? I will have you know', return_tensors="pt",
                                                                                    padding="max_length", 
                                                                                    max_length=token_length, 
                                                                                    truncation=True)
print(torch.exp(dense_model(**encoded_test_input, labels = encoded_test_input.input_ids).loss))
print(torch.exp(model(**encoded_test_input, labels = encoded_test_input.input_ids).loss))
print(torch.exp(loaded_model(**encoded_test_input, labels = encoded_test_input.input_ids).loss))

# Perplexity Testing

In [None]:
test_set = load_dataset('wikitext', 'wikitext-2-v1', split='test[:10%]')
tokenized_test = tokenizer(test_set['text'])

flattened_input_ids = [item for sublist in tokenized_test.input_ids for item in sublist]
flattened_input_ids = flattened_input_ids[:(len(flattened_input_ids) - (len(flattened_input_ids) % token_length))]
flattened_input_ids = torch.Tensor(flattened_input_ids).reshape(-1, token_length).int()

flattened_masks = [item for sublist in tokenized_test.attention_mask for item in sublist]
flattened_masks = flattened_masks[:(len(flattened_masks) - (len(flattened_masks) % token_length))]
flattened_masks = torch.Tensor(flattened_masks).reshape(-1, token_length).int()

test_dict = {'input_ids': flattened_input_ids, 'attention_mask': flattened_masks}

In [None]:
model.eval()
dense_model.eval()
output = model(**test_dict, labels=test_dict['input_ids'])
output2 = dense_model(**test_dict, labels=test_dict['input_ids'])

In [None]:
print(torch.exp(output.loss))
print(torch.exp(output2.loss))