# Calibration Data, Model

In [16]:
%%capture
from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

dataset = load_dataset('c4', 'en', streaming=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")

# Load model with pre-trained head
model = OPTForCausalLM.from_pretrained("facebook/opt-125m", output_attentions=True, output_hidden_states=True)

generator = pipeline('text-generation', model="facebook/opt-125m")

calibration_data = []
for i, data in enumerate(iter(dataset['train'])):
    if i > 128:
        break
    tokenized = tokenizer.encode(data['text'], return_tensors="pt", padding="max_length", truncation=True, max_length=2048)
    calibration_data.append(tokenized)
calibration_data = torch.squeeze(torch.stack(calibration_data)).to(device=device)

In [11]:
import torch
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def inverse_hessian(X, epsilon=1e-8):
    """
    Calculate the inverse of a positive-definite matrix using the Cholesky decomposition.
    Args:
    - Hessian (torch.Tensor): positive-definite matrix to be inverted
    - epsilon (float): small constant to prevent Hessian from being singular
    Returns:
    - torch.Tensor: inverted matrix
    """

    hessian = torch.inverse((2 * (X @ torch.transpose(X, 0, 1))) + torch.eye(X.shape[0]) * epsilon)
    # Decompose the matrix into a upper triangular matrix
    inverse_hessian = torch.transpose(torch.cholesky(hessian, upper=True))
    return inverse_hessian


In [17]:
def sparse_gpt(W, H_inv, p, B, Bs):
    # Get the number of rows and columns in W
    d_row, d_col = W.shape
    
    # Initialize the pruning mask M and block quantization errors E to all zeros
    M = torch.zeros(d_row, d_col, dtype=torch.bool)
    E = torch.zeros(d_row, B)

    # Loop over blocks of columns of W
    for i in range(0, d_col, B):
        # Loop over columns within a block
        for j in range(i, min(i + B, d_col)):
            # If j is a multiple of Bs, prune a portion of the weights
            if j % Bs == 0:
                # Get the mask for the largest (1 - p)% of weights based on squared value and inverse hessian

                # prune_values is matrix of w^2/H^(-1)_cc
                print("weights squared and h_inv:")
                print(W[:, j:j+Bs]**2)
                print(H_inv[j:j+Bs, j:j+Bs])
                
                prune_values = W[:, j:j+Bs]**2 / H_inv[j:j+Bs, j:j+Bs].diag().unsqueeze(0)
                print("prune values: ")
                print(prune_values)

                cutoff_value = torch.kthvalue(prune_values, int((1 - p) * W.shape[0]), dim=0)[0]
                print("cutoff value: ")
                print(cutoff_value)
    
                print("mask: ")
                mask = prune_values > cutoff_value
            
                
                M[:, j:j+Bs] = mask
                # Calculate the pruning error for this block
                E[:, j-i] = W[:, j] / H_inv[j, j]
                # Freeze the weights that are not pruned by multiplying by the pruning mask
                E[:, j-i] = (1 - M[:, j]) * E[:, j-i]
                # Update the weights in this block based on the pruning error and inverse hessian information
                W[:, j:i+B] -= torch.ger(E[:, j-i], H_inv[j, j:i+B])
            else:
                # Calculate the pruning error for this column
                E[:, j-i] = W[:, j] / H_inv[j, j]
                # Freeze the weights that are not pruned by multiplying by the pruning mask
                E[:, j-i] = (1 - M[:, j]) * E[:, j-i]
                # Update the weights in this block based on the pruning error and inverse hessian information
                W[:, j:i+B] -= torch.ger(E[:, j-i], H_inv[j, j:i+B])
        # Update all remaining weights
        W[:, i+B:] -= torch.matmul(E, H_inv[i:i+B, i+B:])
    
    # Set pruned weights to zero
    W = W * M.float()
    
    return W

In [13]:
for name, param in model.named_parameters():
    print(name)
    print(param.shape)

NameError: name 'model' is not defined