In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from trl import SFTTrainer
import itertools
import pandas as pd
import os
from dataset_preprocessing import TokenInfo
import torch
from tqdm import tqdm
import itertools

import os
from os import listdir

model_id = "microsoft/phi-1_5"
model_revision = "349cf8b5e81fd5f791d1740da5de1313a0419bbd" # latest as of feb 1st

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    revision=model_revision,
    trust_remote_code=True,
    # be careful with this?
    # torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
)

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()


In [3]:
def get_all_modules(model):
    """ Returns all submodules """
    layers = model.get_submodule("model").get_submodule("layers")

    mlps = [layer.get_submodule("mlp") for layer in layers]
    fc1s = [mlp.fc1 for mlp in mlps]
    fc2s = [mlp.fc2 for mlp in mlps]

    attns = [layer.get_submodule("self_attn") for layer in layers]
    q_projs = [attn.q_proj for attn in attns]
    k_projs = [attn.k_proj for attn in attns]
    v_projs = [attn.v_proj for attn in attns]
    denses = [attn.dense for attn in attns]

    return [q_projs, k_projs, v_projs, denses, fc1s, fc2s]

In [4]:
def custom_loss(logits, labels, model):
    """ Returns crossentropy loss per token, w/o reduction """
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    orig_shape = shift_labels.shape
    # Flatten the tokens
    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
    shift_logits = shift_logits.view(-1, model.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels).view(orig_shape)
    return loss

In [5]:
def compute_acc_grad(model, examples, modules):
    """ Computes squared gradient term in delta loss approximation.
    Then it stores it in the param.acc_grad attribute."""
    params_q = [list(module.parameters()) for module in modules[0]]
    params_q = itertools.chain.from_iterable(params_q)
    params_k = [list(module.parameters()) for module in modules[1]]
    params_k = itertools.chain.from_iterable(params_k)
    params_v = [list(module.parameters()) for module in modules[2]]
    params_v = itertools.chain.from_iterable(params_v)
    params_d = [list(module.parameters()) for module in modules[3]]
    params_d = itertools.chain.from_iterable(params_d)
    params_fc1s = [list(module.parameters()) for module in modules[4]]
    params_fc1s = itertools.chain.from_iterable(params_fc1s)
    params_fc2s = [list(module.parameters()) for module in modules[5]]
    params_fc2s = itertools.chain.from_iterable(params_fc1s)
                
    all_params = params_q + params_k + params_v + params_d + params_fc1s + params_fc2s
                
    res = model(examples, labels=examples)
    losses_tens = custom_loss(res.logits, examples, model)
    losses = [loss.mean() for loss in losses_tens]
                        
    # import pdb; pdb.set_trace()
    for example_loss in losses:
        example_loss.backward(retain_graph=True)
        for param in all_params: # for all the weights
            num_examples = examples.shape[0]
            with torch.no_grad():
                grad = param.grad.detach()
                sq_grad = grad * grad / num_examples
                if hasattr(param, "acc_grad"):
                    param.acc_grad += sq_grad
                else:
                    param.acc_grad = sq_grad
        model.zero_grad()
        del example_loss
        torch.cuda.empty_cache()
    return losses_tens

In [6]:
@torch.no_grad()
def compute_paired_importance(modules1, modules2, dim1=1, dim2=0):
    """Shared function for mlp/q/k/v/d. Given modules with gradients and squared gradients stored,
    approximated the importances as the delta of the loss using taylor
    expansion."""
    importances = []

    pairs = list(zip(modules1, modules2))
        
    for pair in pairs:

        m1 = pair[0]
        m2 = pair[1]

        salience_w1 = m1.weight * m1.weight.grad
        salience_w2 = m1.weight * m2.weight.grad
        
        salience_w1 = salience_w1 - 0.5 * m1.weight * m1.weight.acc_grad * m1.weight
        salience_w2 = salience_w2 - 0.5 * m2.weight * m2.weight.acc_grad * m2.weight

        importance_w1_component = salience_w1.abs().sum(dim=dim1)
        importance_w2_component = salience_w2.abs().sum(dim=dim2)

        importance = importance_w1_component + importance_w2_component
        importances.append(importance.detach().cpu())
    return importances

In [7]:
def get_input(storage, key):
    """
    Get input into mlp layer in forward pass
    """
    def hook(module, input, output):
        # Assuming the layer takes a single Tensor as input, store it
        storage[key] = input[0].detach()
    return hook


In [8]:
def get_sample_input_trailing_token(mlp_input_dict):
    """
    Get single embedding(inputs into mlps) for each layer
    for trailing token in sequence
    """

    # alwyas take first sequence, is randomly sampled beforehand, so no need to random sample again
    # importances are based on the entire batch, but we only take one sample input
    return [inputs[0, -1, :].cpu() for layer, inputs in mlp_input_dict.items()]

In [10]:
def compute_delta_loss_importances(model, examples):
    """Computes and returns impotances of every hidden neuron in the model's
    mlps. Here we define importance as the change of the loss if we were to
    set the inbound and outbound weights of a neuron to 0."""
    
    qs, ks, vs, ds, fc1s, fc2s = get_all_modules(model)

    mlp_input_dict = {}
    hooks = []

    for i, fc1 in enumerate(fc1s):
        hook = fc1.register_forward_hook(get_input(mlp_input_dict, i))
        hooks.append(hook)
        
    # compute and store first derivative squared
    loss = compute_acc_grad(model.cuda(), examples.cuda(), [qs, ks, vs, ds, fc1s, fc2s])
    torch.cuda.synchronize()
    
    # compute and store first derivative
    loss = loss.mean()
    loss.backward()
    
    # Once first derivative and second derivative squared are stored,
    # compute the importances.
    importances_qk = compute_paired_importance(qs, ks, dim1=1, dim2=1)
    importances_vd = compute_paired_importance(vs, ds, dim1=1, dim2=0)
    importances_mlp = compute_paired_importance(fc1s, fc2s, dim1=1, dim2=0)

    # all_imps = torch.cat([importances_qk, importances_vd, importances_mlp])

    layer_to_imps = {}
    layers = model.get_submodule("model").get_submodule("layers")
    for i, layer in enumerate(layers):
        layer_to_imps[layer] = [importances_qk[i,:], importances_vd[i,:], importances_mlp[i,:]]
        

    # cleanup
    for hook in hooks:
        hook.remove()

    sample_inputs_trailing_token = get_sample_input_trailing_token(mlp_input_dict)
    
    return layer_to_imps, sample_inputs_trailing_token

In [25]:
from dataset_preprocessing import fetch_preprocessed_data
from tqdm import tqdm

fetch_preprocessed_data(data="tiny-textbooks")

Resolving data files: 100%|██████████| 42/42 [00:00<00:00, 250941.26it/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (2524 > 2048). Running this sequence through the model will result in indexing errors


KeyboardInterrupt: 

In [16]:
from dataset_preprocessing import TokenInfo

token_info = TokenInfo()

...Loading dataset...


FileNotFoundError: [Errno 2] No such file or directory: './/dataset_tokenized.pkl'

In [None]:
# Examples not defined

all_imps, sample_inputs_trailing_token = compute_delta_loss_importances(model, examples)

In [11]:
def create_pruned_weights(shape0, shape1, dtype, weightdata, biasdata=None):
    layer_pruned = torch.nn.Linear(
        shape0,
        shape1,
        dtype=dtype
    )
    
    with torch.no_grad():    
        layer_pruned.weight.data = weightdata
        if biasdata != None:
            layer_pruned.bias.data = biasdata

    return layer_pruned

In [12]:
@torch.no_grad()
def full_pruning(layer_to_imps, prune_ratio):
    """ Given a dictionary of layer -> [3x importance] tensor (for query-key, value-dense, and mlp), prunes
    the full model.
    """

    all_imps = []

    for importances in layer_to_imps.values():
        concatenated = importances[0] + importances[1] + importances[2]
        all_imps.extend(concatenated)
    
    num_prune_cells = int(len(all_imps) * prune_ratio)
    
    # Choose which node-indexes to prune, mark those indexes with '0'
    _, indices_to_replace = torch.topk(all_imps, num_prune_cells, largest=False)
    mask = torch.ones_like(all_imps, dtype=torch.bool)
    mask[indices_to_replace] = False
    
    # Make a new dict with indexes with smallest values zeroed out
    split_size_layer = len(concatenated)
    split1 = len(list(layer_to_imps.values())[0][0])
    split2 = len(list(layer_to_imps.values())[0][1])

    layer_split_mask = torch.split(mask, split_size_layer)
    layer_module_split_mask = [[layer[:split1], layer[split1:split2], layer[split2:]] for layer in layer_split_mask]
    mask_dict = {key: value for key, value in zip(importances.keys(), layer_module_split_mask)}

    # Prune each mlp
    for layer, mask_list in mask_dict.items():
        keep_idx_qk = torch.arange(mask_list[0].shape[0], dtype=torch.long)[mask_list[0]]
        keep_idx_vd = torch.arange(mask_list[1].shape[0], dtype=torch.long)[mask_list[1]]
        keep_idx_mlp = torch.arange(mask_list[2].shape[0], dtype=torch.long)[mask_list[2]]

        attn = layer.get_submodule("self_attn")
        mlp = layer.get_submodule("mlp")
        query = attn.q_proj
        key = attn.k_proj
        value = attn.v_proj
        dense = attn.dense
        fc1 = mlp.fc1
        fc2 = mlp.fc2

        query_pruned = create_pruned_weights(
            query.weight.shape[1],
            keep_idx_qk.shape[0],
            query.weight.dtype,
            torch.clone(query.weight[keep_idx_qk])
            )
        
        key_pruned = create_pruned_weights(
            key.weight.shape[1],
            keep_idx_qk.shape[0],
            query.weight.dtype,
            torch.clone(query.weight[keep_idx_qk])
            )
        
        value_pruned = create_pruned_weights(
            value.weight.shape[1],
            keep_idx_vd.shape[0],
            value.weight.dtype,
            torch.clone(query.weight[keep_idx_vd])
            )
        
        dense_pruned = create_pruned_weights(
            keep_idx_vd.shape[0],
            dense.weight.shape[0],
            dense.weight.dtype,
            torch.clone(dense.weight[:, keep_idx_vd])
            )
        
        fc1_pruned = create_pruned_weights(
            fc1.weight.shape[1],
            keep_idx_mlp.shape[0],
            fc1.weight.dtype,
            torch.clone(fc1.weight[keep_idx_mlp]),
            torch.clone(fc1.bias[keep_idx_mlp])
        )

        fc2_pruned = create_pruned_weights(
            keep_idx_mlp.shape[0],
            fc2.weight.shape[0],
            fc2.weight.dtype,
            torch.clone(fc2.weight[:, keep_idx_mlp]),
            torch.clone(fc2.bias[:, keep_idx_mlp])
        )

        attn.q_proj = query_pruned
        attn.k_proj = key_pruned
        attn.v_proj = value_pruned
        attn.dense = dense_pruned
        mlp.fc1 = fc1_pruned
        mlp.fc2 = fc2_pruned