In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../LLM-Pruner/")

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import itertools
import pandas as pd
import os

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

## Model and Tokenizer Setup

In [5]:
model_id = "microsoft/phi-1_5"
model_revision = "349cf8b5e81fd5f791d1740da5de1313a0419bbd" # latest as of feb 1st

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

In [7]:
vocab = tokenizer.get_vocab()
len(vocab)

50295

In [8]:
# tokenizer.decode(token_info.get_prefixes(top_tokens[1000][0], 9, 10)[0])

In [9]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    revision=model_revision,
    trust_remote_code=True,
    # be careful with this?
    torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
).cuda()

In [10]:
model

PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2048)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-23): 24 x PhiDecoderLayer(
        (self_attn): PhiAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (dense): Linear(in_features=2048, out_features=2048, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): Linear(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear(in_features=8192, out_features=2048, bias=True)
        )
        (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (final_layernorm): LayerNorm((2048,), e

## Dataset

In [11]:
from LLMPruner.datasets.example_samples import get_examples

In [12]:
# examples = get_examples("bookcorpus", tokenizer, n_samples=10)
# pd.to_pickle(examples.cpu(), "./examples.pkl")
examples = pd.read_pickle("./examples.pkl")

In [13]:
examples

tensor([[ 8182,   479,   672,  ...,   287, 10905,   306],
        [12284,   284,   766,  ...,  1257,   618,   673],
        [ 1058,   257,   380,  ...,   613,  6051,   837],
        ...,
        [  988,   290, 32627,  ...,   329,   262,  9546],
        [ 4868,   584,  3835,  ...,   279,  2518,    78],
        [ 7091,   714,  5368,  ...,   674, 15876,  1267]])

In [14]:
# examples = examples[:, :9]

In [15]:
examples = examples.cuda()

In [16]:
tokenizer.decode(examples[3])

', escaped 2/46, dunblane john ferguson jnr, glengyles, ardno, escaped 2/46, dunblane peter king, glengyles, ardno, died 22/5/46 john livingston, glengyles, ardno, freed 21/8/47, ardnamurchan donald mgrigor, glengyles, ardno, escaped grigor mgrigor, glengyles, ardno, freed 47, argyll piper james mgregor, glengyles, carlisle, transported - john m'

## Model importance computation

In [17]:
def get_mlps(model):
    layers = model.get_submodule("model").get_submodule("layers")
    return [layer.get_submodule("mlp") for layer in layers]

In [18]:
def custom_loss(logits, labels, model):
    # Returns crossentropy loss per token, w/o reduction
    
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    orig_shape = shift_labels.shape
    # Flatten the tokens
    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
    shift_logits = shift_logits.view(-1, model.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels).view(orig_shape)
    return loss

In [19]:
def compute_acc_grad_2(model, examples, mlps):
    # Computes the second term of the taylor expansion
    params = [list(mlp.parameters()) for mlp in mlps]
    params = itertools.chain.from_iterable(params) # flatten list
    res = model(examples, labels=examples)
    losses_tens = custom_loss(res.logits, examples, model)
    losses = [loss.mean() for loss in losses_tens]
    # import pdb; pdb.set_trace()
    for example_loss in losses:
        example_loss.backward(retain_graph=True)
        for param in params: # for all the weights
            num_examples = examples.shape[0]
            with torch.no_grad():
                grad = param.grad.detach()
                sq_grad = grad * grad / num_examples
                if hasattr(param, "acc_grad"):
                    param.acc_grad += sq_grad
                else:
                    param.acc_grad = sq_grad
        model.zero_grad()
        del example_loss
        torch.cuda.empty_cache()
    return losses_tens

In [20]:
def compute_acc_grad(model, examples, mlps):
    # Computes the second term of the taylor expansion
    params = [list(mlp.parameters()) for mlp in mlps]
    params = itertools.chain.from_iterable(params) # flatten list
    
    for example in examples:
        example = example.unsqueeze(0)
        # import pdb; pdb.set_trace()
        loss = model(example, labels=example).loss
        # import pdb; pdb.set_trace()
        loss.backward()
        for param in params: # for all the weights
            num_examples = examples.shape[0]
            with torch.no_grad():
                grad = param.grad.detach()
                sq_grad = grad * grad / num_examples
                if hasattr(param, "acc_grad"):
                    param.acc_grad += sq_grad
                else:
                    param.acc_grad = sq_grad
        model.zero_grad()
        del loss
        torch.cuda.empty_cache()

In [21]:
@torch.no_grad()
def compute_mlp_importance(mlps):
    # computes importance for hidden layers of MLP
    importances = []
    for mlp in mlps:
        fc1 = mlp.fc1
        fc2 = mlp.fc2
        # compute importance of inputs to hidden
        salience_w1 = fc1.weight * fc1.weight.grad
        salience_w2 = fc2.weight * fc2.weight.grad
        
        salience_w1 = salience_w1 - 0.5 * fc1.weight * fc1.weight.acc_grad * fc1.weight
        salience_w2 = salience_w2 - 0.5 * fc2.weight * fc2.weight.acc_grad * fc2.weight

        importance_w1_component =  salience_w1.abs().sum(dim=1)
        importance_w2_component =  salience_w2.abs().sum(dim=0)

        # analogous to group reduction?
        importance = importance_w1_component + importance_w2_component
        importances.append(importance.detach().cpu())
    return importances

In [22]:
def compute_importances(model, examples, use_fast=False, few_mlps=True):
    # start = time()
    # returns importance for mlp examples
    mlps = get_mlps(model)
    if few_mlps:
        mlps = [mlps[0], mlps[len(mlps)//2], mlps[-1]]
    # torch.cuda.synchronize()
    # start_2 = time()
    if use_fast:
        loss = compute_acc_grad_2(model, examples, mlps) # estimates hessian term of importance
    else:
        compute_acc_grad(model, examples, mlps) # estimates hessian term of importance
    torch.cuda.synchronize()
    # end_2 = time()
    # print(f"compute acc grad: {end_2 - start_2}")
    # import pdb; pdb.set_trace()
    if use_fast:
        # print("using fast")
        loss = loss.mean()
        loss.backward()
    else:
        loss = model(examples, labels=examples).loss
        loss.backward()
    importances = compute_mlp_importance(mlps)
    # torch.cuda.synchronize()
    # end = time()
    # print(f"compute ttl: {end - start}")
    return importances

In [23]:
importances = compute_importances(model, examples, few_mlps=False)

In [25]:
pd.to_pickle(importances, "imps_llm_prunner_style.pkl")

## Scratch code
feel free to ignore

In [25]:
importances = _

In [29]:
importances[0]

tensor([0.0172, 0.0293, 0.0263,  ..., 0.0334, 0.0138, 0.0241], device='cuda:0',
       dtype=torch.float16)

In [27]:
mlps = get_mlps(model)

In [61]:
examples.shape

torch.Size([10, 128])

In [17]:
mlps = get_mlps(model)

In [20]:
mlp = mlps[0]

In [21]:
fc1 = mlp.get_submodule("fc1")
fc2 = mlp.get_submodule("fc2")

In [38]:
fc1

Linear(in_features=2048, out_features=8192, bias=True)

In [None]:
def compute_importance(mlp):
    pass

In [78]:
modules = list(model.modules())

In [101]:
x = model.get_submodule("model").get_submodule("layers")[0].get_submodule("mlp").get_submodule("fc1")

In [113]:
list(x.parameters())[0].grad.shape

torch.Size([8192, 2048])