In [23]:
%%capture
import torch
from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = load_dataset('c4', 'en', streaming=True)
model_text = "facebook/opt-125m"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_text)

# Load model with pre-trained head
model = OPTForCausalLM.from_pretrained(model_text, output_attentions=True, output_hidden_states=True)

generator = pipeline('text-generation', model=model_text)

calibration_data = []
for i, data in enumerate(iter(dataset['train'])):
    if i > 128:
        break
    tokenized = tokenizer.encode(data['text'], return_tensors="pt", padding="max_length", truncation=True, max_length=768)
    calibration_data.append(tokenized)
calibration_data = torch.squeeze(torch.stack(calibration_data)).to(device=device)

In [24]:
import torch
import numpy as np

def inverse_hessian(X, epsilon=0.01):
    """
    Calculate the inverse of a positive-definite matrix using the Cholesky decomposition.
    Args:
    - X (torch.Tensor): dxn tensor
    - epsilon (float): small constant to prevent Hessian from being singular
    Returns:
    - torch.Tensor: inverted matrix
    """
    X = X.float()
    X_T = torch.transpose(X, 0, 1)
    identity = torch.eye(X.shape[0], dtype=torch.float32)
    H_inv = torch.inverse(2 * (X @ X_T + epsilon * identity))
    #H_inv = torch.cholesky(H_inv).T
    H_inv = torch.lu(H_inv)[0].T
    
    return H_inv

# W is weights matrix for one layer
# H_inv is inverse hessian for one layer
# p is proportion of weights to 0
# B is lazy block size, low B helps to reduce memory use
# Bs is inverse of how often to make masks (e.g. when Bs is 4, make new masks with 20% sparseness every 4 columns)
def calculate_mask(W, H_inv, p, B, Bs):
    # Get the number of rows and columns in W
    d_row, d_col = W.shape
    
    # Initialize the pruning mask M and block quantization errors E to all zeros
    M = torch.zeros(d_row, d_col, dtype=torch.bool)
    E = torch.zeros(d_row, B)

    # only need to calculate w_square and h_square once
    # w_square = torch.square(W)
    # h_square = torch.square(H_inv)

    # Loop over blocks of columns of W
    for i in range(0, d_col, B):
        # Loop over columns within a block
        for j in range(i, min(i + B, d_col)):
            # If j is a multiple of Bs, prune a portion of the weights
            if j % Bs == 0:
                # Get the mask for the largest (1 - p)% of weights based on squared value and inverse hessian

                # prune_values is matrix of w^2/H^(-1)_cc
                
                w_square_section = torch.square(W[:, j:j+Bs])
                h_square_section = torch.square(H_inv[j:j+Bs, j:j+Bs]).diag() # 1 dimensional vector

                # print("weights squared and h_inv:")
                # print(w_square_section)
                # print(h_square_section)

                prune_values = w_square_section / h_square_section.unsqueeze(0)
                # print("prune values: ")
                # print(prune_values)

                cutoff_value = torch.kthvalue(prune_values, int((1 - p) * d_row), dim=0)[0]
                # print("cutoff value: ")
                # print(cutoff_value)
    
                # print("mask: ")
                mask = prune_values > cutoff_value
            
                M[:, j:j+Bs] = mask

            # Calculate the pruning error for this column
            E[:, j-i] = W[:, j] / H_inv[j, j]
            # Freeze the weights that are not pruned by multiplying by the pruning mask
            # Invert mask (~M equivalent to 1 - M)
            E[:, j-i] = ~M[:, j] * E[:, j-i]
            # Update the weights in this block based on the pruning error and inverse hessian information
            W[:, j:i+B] -= torch.ger(E[:, j-i], H_inv[j, j:i+B])
        # Update all remaining weights
        W[:, i+B:] -= torch.matmul(E, H_inv[i:i+B, i+B:])
    
    # return mask
    return M

In [25]:

X = torch.randn(512, 512, dtype=torch.float32)
lmbda = 0.1
#print(torch.transpose(calibration_data,0,1).shape)
H_inv = inverse_hessian(torch.transpose(calibration_data,0,1), lmbda)
#H_inv = inverse_hessian(X, lmbda)
print(H_inv.shape)

torch.Size([768, 768])


In [16]:
H_inv.isnan().sum()

tensor(0)

# Testing model

In [26]:
model

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0): OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05,

# Prune Model

In [27]:
inv_hess = inverse_hessian(torch.transpose(calibration_data,0,1), 0.2)
print(inv_hess.shape)
for name, param in model.named_parameters():
    print(name)
    param = calculate_mask(param, inv_hess, 0.5, 32, 32)
    break

torch.Size([768, 768])
model.decoder.embed_tokens.weight
torch.Size([50272, 768])


RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

In [56]:
input1 = tokenizer("Hello, my dog is cute", return_tensors="pt", padding="max_length", truncation=True)
input2 = tokenizer("What the fuck did you just fucking say about me, you little bitch?", return_tensors="pt", padding="max_length", truncation=True)
output = model.generate(input1.input_ids, max_length=100, num_return_sequences=1, temperature=0.5, top_p=0.95)
tokenizer.decode(output[0], skip_special_tokens=True)

"Hello, my dog is cute and I love her. I'm a little nervous about her because she's a little bit shy and I'm not sure if she's going to be able to handle it. I'm not sure if she's going to be able to handle it. I'm not sure if she's going to be able to handle it. I'm not sure if she's going to be able to handle it. I'm not sure if she's going to be able to handle"

In [119]:
output.shape

torch.Size([1, 100])