In [115]:
import torch
from torch.nn.utils import prune

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

# from calculate_mask import calculate_mask
# from inverse_hessian import inverse_hessian
from input_prehooks import put_input_hooks

In [121]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'


#Load dataset
dataset = load_dataset('c4', 'en', streaming=True)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
# Load model with pre-trained head
model = OPTForCausalLM.from_pretrained("facebook/opt-125m", output_attentions=True, output_hidden_states=True)
# Load genrator
generator = pipeline('text-generation', model="facebook/opt-125m")
# Create calibration data
calibration_data = []
for i, data in enumerate(iter(dataset['train'])):
    if i > 7:
        break
    tokenized = tokenizer.encode(data['text'], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
    calibration_data.append(tokenized)
#calibration_data = torch.transpose(torch.squeeze(torch.stack(calibration_data)),0,1).to(device=device)
calibration_data = torch.squeeze(torch.stack(calibration_data)).to(device=device)
calibration_data.double()

tensor([[2.0000e+00, 4.8290e+04, 7.1300e+03,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00],
        [2.0000e+00, 4.8763e+04, 1.1000e+01,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00],
        [2.0000e+00, 5.9700e+02, 1.4189e+04,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00],
        ...,
        [2.0000e+00, 3.8700e+02, 9.2980e+03,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00],
        [2.0000e+00, 1.0000e+02, 8.0200e+02,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00],
        [2.0000e+00, 1.3300e+02, 4.0660e+03,  ..., 1.0000e+00, 1.0000e+00,
         1.0000e+00]], dtype=torch.float64)

In [122]:
calibration_data.shape

torch.Size([8, 512])

In [123]:
# First, put in forward hooks
features = {}
put_input_hooks(model=model, features=features)

# Run calibration data through model at first to calculate features dictionary with
# input tensors to each intermediate layer
model(calibration_data)

# function to get module name from parameter name
def get_module_name(param_name):
    if param_name[-5:] == ".bias":
        return param_name[:-5], "bias"
    elif param_name[-7:] == ".weight":
        return param_name[:-7], "weight"
    else:
        return None, None

In [124]:
# make a dictionary to access module by name
# model_param_lookup_dict = {}
# for param_name, param_iter in model.named_parameters():
#     model_param_lookup_dict[param_name] = param_iter
    
model_lookup_dict = {}
for module_name, module_iter in model.named_modules():
    model_lookup_dict[module_name] = module_iter

for k in features.keys():
    # print(k)
    try:
        # print(f"for {k}")
        # print(f"{k} shape: {features[k][0].shape}")
        # print(f"weight shape: {model_lookup_dict[k].weight.shape}")
        # if 512 in model_lookup_dict[k].weight.shape:
        print(k)
        print(f"input shape: {features[k][0].shape}")
        print(f"weight shape: {model_lookup_dict[k].weight.shape}")
    except:
        continue
    # except:
        # continue



input shape: torch.Size([8, 512])
model.decoder
model.decoder.embed_tokens
input shape: torch.Size([8, 512])
weight shape: torch.Size([50272, 768])
model.decoder.embed_positions
input shape: torch.Size([8, 512])
weight shape: torch.Size([2050, 768])
model.decoder.layers.0
input shape: torch.Size([8, 512, 768])
model.decoder.layers.0.self_attn_layer_norm
input shape: torch.Size([8, 512, 768])
weight shape: torch.Size([768])
model.decoder.layers.0.self_attn
model.decoder.layers.0.self_attn.q_proj
input shape: torch.Size([8, 512, 768])
weight shape: torch.Size([768, 768])
model.decoder.layers.0.self_attn.k_proj
input shape: torch.Size([8, 512, 768])
weight shape: torch.Size([768, 768])
model.decoder.layers.0.self_attn.v_proj
input shape: torch.Size([8, 512, 768])
weight shape: torch.Size([768, 768])
model.decoder.layers.0.self_attn.out_proj
input shape: torch.Size([8, 512, 768])
weight shape: torch.Size([768, 768])
model.decoder.layers.0.final_layer_norm
input shape: torch.Size([4096, 76

In [132]:
def calculate_mask(
    W,
    H_inv,
    p,
    B,
    Bs,
    ):

    # Get the number of rows and columns in W

    (d_row, d_col) = W.shape

    # Initialize the pruning mask M and block quantization errors E to all zeros

    M = torch.zeros(d_row, d_col, dtype=torch.bool)
    E = torch.zeros(d_row, B, dtype=torch.float64)

    # only need to calculate w_square and h_square once
    # Loop over blocks of columns of W (as specified by B)

    for i in range(0, d_col, B):

        # Loop over columns within a block

        for j in range(i, min(i + B, d_col)):

            # If j is a multiple of Bs, prune a portion of the weights

            if j % Bs == 0:

                # Get the mask for the largest (1 - p)% of weights based on squared value and inverse hessian

                # ASTERISK: prune_values is matrix of w^2/H^(-1)_cc

                # Finding respective sections of hessian and weights matrix
                w_square_section = torch.square(W[:, j:j + Bs])
                h_square_section = torch.square(H_inv[j:j + Bs, j:j
                        + Bs]).diag()  # 1 dimensional vector

                # getting the prune values matrix from W and H^-1 sections
                prune_values = w_square_section \
                    / h_square_section.unsqueeze(0)

                #calulating cutoff for the weights
                cutoff_value = torch.kthvalue(prune_values, int((1 - p)
                        * d_row), dim=0)[0]

                #getting the final mask
                mask = prune_values > cutoff_value

                #masking
                M[:, j:j + Bs] = mask

            # Calculate the pruning error for this column

            E[:, j - i] = W[:, j] / H_inv[j, j]

            # Freeze the weights that are not pruned by multiplying by the pruning mask
            # Invert mask (~M equivalent to 1 - M < might be -(M + 1))

            E[:, j - i] = ~M[:, j] * E[:, j - i]

            # Update the weights in this block based on the pruning error and inverse hessian information

            W[:, j:i + B] -= torch.ger(E[:, j - i], H_inv[j, j:i + B])

        # Update all remaining weights

        # print(f"this weight shape: {W[:, i + B:].shape}")
        # print(f"e shape: {E.shape}")
        # print(f"Hessian shape: {H_inv[i:i + B, i + B:].shape}")
        W[:, i + B:] -= torch.matmul(E, H_inv[i:i + B, i + B:])

    # return mask

    return M


def inverse_hessian(X, epsilon=0.01):
    """
    Calculate the inverse of a positive-definite matrix using the Cholesky decomposition.
    Args:
    - X (torch.Tensor): dxn tensor
    - epsilon (float): small constant to prevent Hessian from being singular
    Returns:
    - torch.Tensor: inverted matrix
    """
    X = X.double()
    print(f"input shape: {X.shape}")
    X_T = torch.transpose(X, 1, 2)
    identity = torch.eye(X.shape[1], dtype=torch.float64)
    # print(f"shape of x @ x_t: {torch.sum(X @ X_T, dim=0).shape}")
    H = 2 * (torch.sum(X @ X_T, dim=0) + (epsilon * identity))
    # print(torch.linalg.eig(H)[0])
    print(f"H SHAPE: {H.shape}")
    # print(f"num zeros in hessian: {torch.sum(H == 0)}")
    # print(f"Determinant is {torch.linalg.det(H)}")
    # print(f"Hessian Diagonal is {H.diag()}")
    H_inv = torch.inverse(H)
    
    # H_inv = torch.cholesky(H_inv).T
    H_inv = torch.lu(H_inv)[0].T
    
    return H_inv


# Re-load model with pre-trained head
model = OPTForCausalLM.from_pretrained("facebook/opt-125m", output_attentions=True, output_hidden_states=True)

# make a dictionary to access module by name
model_lookup_dict = {}
for module_name, module_iter in model.named_modules():
    model_lookup_dict[module_name] = module_iter

EPSILON = 1e-8
SPARSENESS = .2
B = 32
Bs = 16

layer_blacklist = ['model.decoder.embed_tokens.weight', 'model.decoder.embed_tokens.bias',
'model.decoder.embed_positions.weight', 'model.decoder.final_layer_norm.weight',
'model.decoder.final_layer_norm.bias']

# Using calibration data (inputs to each intermediate weight layer)
# Iterate through named parameters, calculate inverse hessian and calculate mask
with torch.no_grad():
    for name, param in model.named_parameters():

        # skip the embed layer
        if name in layer_blacklist:
            continue
        
        # skip norms which have 1 dimension
        if len(param.shape) < 2:
            continue

        module_name, param_type = get_module_name(name)

        # apply to weight and bias layers
        if param_type == "weight" or param_type == "bias":
            # input to parameter
            layer_input = features[module_name][0]
            print(name)
            print(f"layer input shape: {layer_input.shape}")
            # print(f"weight shape: {param.shape}")
            
            # calculate inverse hessian
            inv_hess = inverse_hessian(torch.transpose(layer_input, 1, 2), epsilon=EPSILON)
            # inv_hess = inverse_hessian(layer_input, epsilon=EPSILON)
            # print(f"hessian shape: {inv_hess.shape}")

            # calculate mask
            mask = calculate_mask(W=param, H_inv=inv_hess, p=SPARSENESS, B=B, Bs=Bs)
            
            # get module from lookup dictionary by module name
            module = model_lookup_dict[module_name]
            # apply mask
            prune.custom_from_mask(module=module, name=param_type, mask=mask)
        # break

model.decoder.layers.0.self_attn.k_proj.weight
layer input shape: torch.Size([8, 512, 768])
input shape: torch.Size([8, 768, 512])
shape of x @ x_t: torch.Size([768, 768])


RuntimeError: The size of tensor a (768) must match the size of tensor b (8) at non-singleton dimension 1

In [None]:
for n, m in model.named_parameters():
    print(n)

model.decoder.embed_tokens.weight
model.decoder.embed_positions.weight_orig
model.decoder.final_layer_norm.weight
model.decoder.final_layer_norm.bias
model.decoder.layers.0.self_attn.k_proj.weight
model.decoder.layers.0.self_attn.k_proj.bias
model.decoder.layers.0.self_attn.v_proj.weight
model.decoder.layers.0.self_attn.v_proj.bias
model.decoder.layers.0.self_attn.q_proj.weight
model.decoder.layers.0.self_attn.q_proj.bias
model.decoder.layers.0.self_attn.out_proj.weight
model.decoder.layers.0.self_attn.out_proj.bias
model.decoder.layers.0.self_attn_layer_norm.weight
model.decoder.layers.0.self_attn_layer_norm.bias
model.decoder.layers.0.fc1.weight
model.decoder.layers.0.fc1.bias
model.decoder.layers.0.fc2.weight
model.decoder.layers.0.fc2.bias
model.decoder.layers.0.final_layer_norm.weight
model.decoder.layers.0.final_layer_norm.bias
model.decoder.layers.1.self_attn.k_proj.weight
model.decoder.layers.1.self_attn.k_proj.bias
model.decoder.layers.1.self_attn.v_proj.weight
model.decoder.l