In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from trl import SFTTrainer
import itertools
import pandas as pd
import os
from dataset_preprocessing import TokenInfo
import torch
from tqdm import tqdm
import itertools

import os
from os import listdir

model_id = "microsoft/phi-1_5"
model_revision = "349cf8b5e81fd5f791d1740da5de1313a0419bbd" # latest as of feb 1st

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    revision=model_revision,
    trust_remote_code=True,
    # be careful with this?
    # torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
)

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()


In [2]:
def get_modules(model):
    layers = model.get_submodule("model").get_submodule("layers")

    mlps = [layer.get_submodule("mlp") for layer in layers]
    fc1s = [mlp.fc1 for mlp in mlps]
    fc2s = [mlp.fc2 for mlp in mlps]

    attns = [layer.get_submodule("self_attn") for layer in layers]
    q_projs = [attn.q_proj for attn in attns]
    k_projs = [attn.k_proj for attn in attns]
    v_projs = [attn.v_proj for attn in attns]
    denses = [attn.dense for attn in attns]

    return [fc1s, fc2s, q_projs, k_projs, v_projs, denses]



In [3]:
modules = get_modules(model)

In [4]:
def custom_loss(logits, labels, model):
    """ Returns crossentropy loss per token, w/o reduction """
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    orig_shape = shift_labels.shape
    # Flatten the tokens
    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
    shift_logits = shift_logits.view(-1, model.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels).view(orig_shape)
    return loss

In [6]:
def compute_acc_grad(model, examples, modules):
    """ Computes squared gradient term in delta loss approximation.
    Then it stores it in the param.acc_grad attribute."""
    params_mlp1s = [list(mlp.parameters()) for mlp in modules[0]]
    params_mlp1s = itertools.chain.from_iterable(params) # flatten list
    params_mlp2s = [list(mlp.parameters()) for mlp in modules[1]]
    params_mlp2s = itertools.chain.from_iterable(params)
    params_q = [list(mlp.parameters()) for mlp in modules[2]]
    params_q = itertools.chain.from_iterable(params_q)
    params_k = [list(mlp.parameters()) for mlp in modules[3]]
    params_k = itertools.chain.from_iterable(params_k)
    params_v = [list(mlp.parameters()) for mlp in modules[4]]
    params_v = itertools.chain.from_iterable(params_v)
    params_d = [list(mlp.parameters()) for mlp in modules[5]]
    params_d = itertools.chain.from_iterable(params_d)
                
    all_params = params_mlp1s + params_mlp2s + params_q + params_k + params_v + params_d
                
    res = model(examples, labels=examples)
    losses_tens = custom_loss(res.logits, examples, model)
    losses = [loss.mean() for loss in losses_tens]
                        
    # import pdb; pdb.set_trace()
    for example_loss in losses:
        example_loss.backward(retain_graph=True)
        for param in all_params: # for all the weights
            num_examples = examples.shape[0]
            with torch.no_grad():
                grad = param.grad.detach()
                sq_grad = grad * grad / num_examples
                if hasattr(param, "acc_grad"):
                    param.acc_grad += sq_grad
                else:
                    param.acc_grad = sq_grad
        model.zero_grad()
        del example_loss
        torch.cuda.empty_cache()
    return losses_tens

In [7]:
@torch.no_grad()
def compute_paired_importance(modules1, modules2, dim1=1, dim2=0):
    """Shared function for mlp/q/k/v/d. Given modules with gradients and squared gradients stored,
    approximated the importances as the delta of the loss using taylor
    expansion."""
    importances = []

    pairs = list(zip(modules1, modules2))
        
    for pair in pairs:

        m1 = pair[0]
        m2 = pair[2]

        salience_w1 = m1.weight * m1.weight.grad
        salience_w2 = fc2.weight * fc2.weight.grad
        
        salience_w1 = salience_w1 - 0.5 * m1.weight * m1.weight.acc_grad * m1.weight
        salience_w2 = salience_w2 - 0.5 * m2.weight * m2.weight.acc_grad * m2.weight

        importance_w1_component =  salience_w1.abs().sum(dim=dim1)
        importance_w2_component =  salience_w2.abs().sum(dim=dim2)

        importance = importance_w1_component + importance_w2_component
        importances.append(importance.detach().cpu())
    return importances

In [8]:
def get_input(storage, key):
    """
    Get input into mlp layer in forward pass
    """
    def hook(module, input, output):
        # Assuming the layer takes a single Tensor as input, store it
        storage[key] = input[0].detach()
    return hook


In [9]:
def get_sample_input_trailing_token(mlp_input_dict):
    """
    Get single embedding(inputs into mlps) for each layer
    for trailing token in sequence
    """

    # alwyas take first sequence, is randomly sampled beforehand, so no need to random sample again
    # importances are based on the entire batch, but we only take one sample input
    return [inputs[0, -1, :].cpu() for layer, inputs in mlp_input_dict.items()]

In [10]:
def compute_delta_loss_importances(model, examples):
    """Computes and returns impotances of every hidden neuron in the model's
    mlps. Here we define importance as the change of the loss if we were to
    set the inbound and outbound weights of a neuron to 0."""
    
    mlp1s, mlp2s, qs, ks, vs, ds = get_modules(model)
    
    """mlps = mlps if idxs is None else [mlps[i] for i in idxs]
    qs = qs if idxs is None else [qs[i] for i in idxs]
    ks = ks if idxs is None else [ks[i] for i in idxs]
    vs = vs if idxs is None else [vs[i] for i in idxs]
    ds = ds if idxs is None else [ds[i] for i in idxs]"""

    mlp_input_dict = {}
    hooks = []

    for i, mlp in enumerate(mlps):
        fc1 = mlp.fc1
        hook = fc1.register_forward_hook(get_input(mlp_input_dict, i))
        hooks.append(hook)
        
    # compute and store first derivative squared
    loss = compute_acc_grad(model.cuda(), examples.cuda(), [mlp1s, mlps2s, qs, ks, vs, ds])
    torch.cuda.synchronize()
    
    # compute and store first derivative
    loss = loss.mean()
    loss.backward()
    
    # Once first derivative and second derivative squared are stored,
    # compute the importances.
    importances_mlp = compute_paired_importance(mlp1s, mlp2s, dim1=1, dim2=0)
    importances_qk = compute_paired_importance(qs, ks, dim1=1, dim2=1)
    importances_vd = compute_paired_importance(vs, ds, dim1=1, dim2=0)

    all_imps = [importances_mlp, importances_qk, importances_vd]

    # cleanup
    for hook in hooks:
        hook.remove()

    sample_inputs_trailing_token = get_sample_input_trailing_token(mlp_input_dict)
    
    return all_imps, sample_inputs_trailing_token

In [None]:
all_imps, sample_inputs_trailing_token = compute_delta_loss_importances(model, examples)