In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../LLM-Pruner/")

In [39]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import itertools

## Model and Tokenizer Setup

In [4]:
model_id = "microsoft/phi-1_5"
model_revision = "349cf8b5e81fd5f791d1740da5de1313a0419bbd" # latest as of feb 1st

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

In [6]:
vocab = tokenizer.get_vocab()
len(vocab)

50295

In [7]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    revision=model_revision,
    trust_remote_code=True,
)

## Dataset

In [8]:
from LLMPruner.datasets.example_samples import get_examples

In [9]:
examples = get_examples("bookcorpus", tokenizer, n_samples=10)

In [10]:
examples

tensor([[ 1946,   262,  1772,  ...,  4490,  7541,   837],
        [ 3872,   837,   339,  ...,   340,   714,   307],
        [   82,  3656,   730,  ...,   329,   257,   845],
        ...,
        [  484,   537,   499,  ...,   502,  7263,   287],
        [  286,   683,   290,  ...,  7812,   465, 32870],
        [13033,   286,   262,  ...,   262, 40445,   837]])

In [11]:
res = model.forward(
    examples, 
    # This is ok, labels are shifted in forward function
    labels=examples
)

In [12]:
tokenizer.decode(examples[3])

" bucked in my hand as i emptied the clip, and from a great distance i heard someone screaming, but it wasn't him screaming, it was me screaming, me and everybody else who was left, if there was anybody left, all of us helpless, hopeless, stupid humans screaming, because we got it wrong, we got it all wrong, there was no alien swarm descending from the sky in their flying saucers or big metal walkers like something out of star wars or cute little wrinkly e.t.s who just wanted to pluck a couple of leaves, eat some reese's pieces, and go"

In [13]:
tokenizer.decode(res.logits[3].argmax(dim=-1))

"led off the seatbag I tried the bagboard and the the distance distance, could the say. and i was tooigh me,, it was me.. and screaming my else in was in in and you was a left, we of us were, helpless, and,,. and we were caught,, and got it wrong wrong, and was no one,, on the sky, the spaceship machineucers, their flying birdsers, in out of a Trek, something little alienly aliensw\n.c.'s wanted to beop our few of hairs from or a berriesind'ss crack of and go back"

## Model importance computation

In [14]:
# res = model.forward(
#     examples, 
#     # This is ok, labels are shifted in forward function
#     labels=examples
# )

In [15]:
# res.loss.backward()

In [16]:
def get_mlps(model):
    layers = model.get_submodule("model").get_submodule("layers")
    return [layer.get_submodule("mlp") for layer in layers]

In [None]:
def compute_acc_grad(model, examples, mlps):
    # Computes the second term of the taylor expansion
    params = [list(mlp.parameters()) for mlp in mlps]
    params = itertools.chain.from_iterable(params) # flatten list
    
    for example in examples:
        example = example.unsqueeze(0)
        loss = model(example, labels=examples)
        loss.backward()
        for param in params: # for all the weights
            num_examples = examples.shape[0]
            sq_grad = param.grad * param.grad / num_examples
            if hasattr(param, "acc_grad"):
                param.acc_grad += sq_grad
            else:
                param.acc_grad = sq_grad
        model.zero_grad()

In [58]:
@torch.no_grad()
def compute_mlp_importance(mlps):
    # computes importance for hidden layers of MLP
    importances = []
    for mlp in mlps:
        fc1 = mlp.fc1
        fc2 = mlp.fc2
        # compute importance of inputs to hidden
        salience_w1 = fc1.weight * fc1.weight.grad
        salience_w2 = fc2.weight * fc2.weight.grad
        
        salience_w1 = salience_w1 - 0.5 * fc1.weight * fc1.weight.acc_grad * fc1.weight
        salience_w2 = salience_w2 - 0.5 * fc2.weight * fc2.weight.acc_grad * fc2.weight

        importance_w1_component =  salience_w1.abs().sum(dim=1)
        importance_w2_component =  salience_w2.abs().sum(dim=0)

        # analogous to group reduction?
        importance = importance_w1_component + importance_w2_component
        importances.append(importance)
    return importances

In [59]:
def compute_importances(model, examples):
    # returns importance for mlp examples
    mlps = get_mlps(model)
    compute_acc_grad(model, examples, mlps) # estimates hessian term of importance
    loss = model(examples, labels=examples).loss
    loss.backward()
    importances = compute_importance(examples, labels=examples)
    return importances

In [61]:
examples.shape

torch.Size([10, 128])

In [17]:
mlps = get_mlps(model)

In [20]:
mlp = mlps[0]

In [21]:
fc1 = mlp.get_submodule("fc1")
fc2 = mlp.get_submodule("fc2")

In [38]:
fc1

Linear(in_features=2048, out_features=8192, bias=True)

In [None]:
def compute_importance(mlp):
    pass

In [78]:
modules = list(model.modules())

In [101]:
x = model.get_submodule("model").get_submodule("layers")[0].get_submodule("mlp").get_submodule("fc1")

In [113]:
list(x.parameters())[0].grad.shape

torch.Size([8192, 2048])