In [1]:
import torch
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

'''
  Calculates inverse hessian matrix given input X
  X is a dxn matrix, where d is num_features and n is num_examples in dataset
'''
def inverse_hessian(X, l):
  X = X.to(device)
  H_inv = torch.inverse(2 * X @ X.t() + (l * torch.eye(X.shape[0])))
  return H_inv.to(device)

'''
W: Layer matrix
H_inv: Inverse Hessian
p: percent pruning
B: batch-update blocksize
Bs: adaptive mask selection blocksize
'''
def prune_model(W, H_inv, p, B, Bs):
    # Matrices to device
    W = W.to(device)

    d_row, d_col = W.shape
    M = torch.ones(d_row, d_col, dtype=torch.float32, device=device) # 0/1 pruning mask
    E = torch.zeros(d_row, B, dtype=torch.float32, device=device) # block quantization errors
    
    H_inv_T = torch.transpose(torch.cholesky(H_inv), 0, 1) # Hessian inverse information; upper triangular
    for i in range(0, d_col, B):
        for j in range(i, min(i + B, d_col)):
            if j % Bs == 0:
                # mask of (1 - p)% weights wc with largest w^2c / [H_inv]^2cc
                block = W[:, j:j+Bs].clone().detach().to("cpu").numpy()
                H_inv_block = H_inv_T[j:j+Bs, j:j+Bs].clone().detach().to("cpu").numpy()
                weights_squared = np.square(block)
                H_inv_squared = np.square(H_inv_block)
                weights_normalized = weights_squared / H_inv_squared
                k = int(np.floor((1 - p) * Bs))

                # Calculate prune mask
                prune_index = np.argpartition(weights_normalized, -k, axis=1)[:, :k]
                prune_mask = np.zeros(block.shape, dtype=np.bool)
                prune_mask[np.arange(d_row)[:, None], prune_index] = 1
                M[:, j:j+Bs] = torch.from_numpy(prune_mask).to(device).float()

                
            # pruning error
            E[:, j-i] = W[:, j] / H_inv[j, j]
            # freeze weights that are not pruned
            E[:, j-i] *= (1 - M[:, j]).float()
            # update weights in block
            W[:, j:i+B] -= torch.mm(E[:, j-i][:, None], H_inv_T[j, j:i+B][None, :])
        # update all remaining weights
        W[:, i+B:] -= torch.mm(E, H_inv_T[i+B:, i+B:])

    W *= M.float() # set pruned weights to 0
    return W

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
H = torch.tensor([[2, 0.5], [0.5, 3]], dtype=torch.float32)
print(inverse_hessian(H, 0))

In [None]:
%%capture
from transformers import AutoTokenizer, AutoModel, pipeline

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")

# Load model with pre-trained head
model = AutoModel.from_pretrained("facebook/opt-125m", output_attentions=True, output_hidden_states=True)

generator = pipeline('text-generation', model="facebook/opt-125m")
