In [1]:
import gc
import os
import time
import torch
import numpy as np
from tqdm import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from instruct_pipeline import InstructionTextGenerationPipeline
os.environ["CUDA_MODULE_LOADING"] = "LAZY"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
pretrained_dir = 'checkpoints/sparse_dolly_12/' # 1:2 sparsity
model_name = 'databricks/dolly-v2-12b'

model = AutoModelForCausalLM.from_pretrained(pretrained_dir, torch_dtype=torch.half).to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
generate_text = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer)

text = "Explain to me the difference between nuclear fission and fusion."
with torch.no_grad():
    res = generate_text(text)
print('The model output is: ', res[0]["generated_text"])

The model output is:  Fusion is the process of using nuclear fission to create heat and light. Nuclear fission is the process of using nuclear fission to create heat and light. Nuclear fusion is the process of using nuclear fission to create heat and light. Nuclear fusion is the process of using nuclear fission to create heat and light.
I hope this helps!


In [3]:
def verify_permutation(name, param):
    W = param.clone().detach().cpu().numpy()
    rows, cols = W.shape
    sparsity_mask = (W != 0)

    x = np.random.randn(cols, 1)
    a = W @ x

    V = W[sparsity_mask.astype(bool)].reshape(rows, cols // 2)
    P = V[np.newaxis, ...]*x.reshape(cols// 2, 2).T[:, np.newaxis, :]
    Q = P * sparsity_mask.reshape(-1, 2).T.reshape(2, rows, cols // 2)
    b = np.sum(Q, axis=(0, 2))
    assert np.allclose(a.flatten(), b) == True, f"Failed for {name}"

def verify_dolly_sparsity(cur_model):
    for name, param in cur_model.named_parameters():
        if ('dense' in name or 'query_key_value' in name) and len(param.shape)>1:
            verify_permutation(name, param)
verify_dolly_sparsity(model)

In [4]:
def verify_permutation_torch(name, param):
    W = param.clone().detach().to(torch.float64)
    rows, cols = W.shape
    sparsity_mask = (W != 0)

    x = torch.randn(W.shape[1], dtype=torch.float64, device=device)
    a = W @ x

    V = W[sparsity_mask.to(bool)].reshape(rows, cols // 2)
    P = V[np.newaxis, ...]*x.reshape(cols// 2, 2).T[:, np.newaxis, :]
    Q = P * sparsity_mask.reshape(-1, 2).T.reshape(2, rows, cols // 2)
    b = torch.sum(Q, dim=(0, 2))
    assert torch.allclose(a.flatten(), b) == True, f"Failed for {name}"

def verify_dolly_sparsity_torch(cur_model):
    for name, param in cur_model.named_parameters():
        if ('dense' in name or 'query_key_value' in name) and len(param.shape)>1:
            verify_permutation_torch(name, param)
verify_dolly_sparsity_torch(model)

In [5]:
def compress_weights(param):
    W = param.clone().detach()
    sparsity_mask = (W != 0)
    V = W.masked_select(sparsity_mask).view(W.shape[0], -1)
    assert V.shape[1]==W.shape[1]//2, "Incorrect sparsity pattern"
    return V, sparsity_mask.bool()

def decompress_weights(V, x, sparsity_mask):
    rows, cols = sparsity_mask.shape
    x = x.reshape(cols // 2, 2).t()
    P = V.unsqueeze(0) * x.unsqueeze(1)
    Q = P * sparsity_mask.reshape(-1, 2).t().reshape(2, rows, cols // 2)
    b = torch.sum(Q, dim=(0, 2))
    return b

def verify_compression(cur_model):
    for name, p in cur_model.named_parameters():
        if ('dense' in name or 'query_key_value' in name) and len(p.shape) > 1:
            V, sparsity_mask = compress_weights(p)
            x = torch.randn(p.shape[1], dtype=torch.half, device=device)
            b = decompress_weights(V, x, sparsity_mask)
            a = (p.to(torch.half)@x).flatten()
            assert torch.allclose(b, a, atol=0.0079), f"Failed for {name}"
    print("Compression verified!")
verify_compression(model)

Compression verified!


In [6]:
def compute_nonzero_mean(cur_model):
    non_zero_sum = 0.0
    non_zero_count = 0
    all_non_zero_mean = []
    for param in cur_model.parameters():
        data = param.data
        non_zero_values = data[data != 0]
        non_zero_sum = torch.sum(non_zero_values)
        non_zero_count = non_zero_values.numel()
        non_zero_mean = non_zero_sum / non_zero_count
        all_non_zero_mean.append(non_zero_mean.item())
    return np.mean(all_non_zero_mean), np.std(all_non_zero_mean)
print(compute_nonzero_mean(model))

(0.1910080316416714, 0.43680537607459835)


In [7]:
def compress_weights(param):
    W = param.clone().detach()
    sparsity_mask = (W != 0)
    V = W.masked_select(sparsity_mask).view(W.shape[0], -1)
    assert V.shape[1]==W.shape[1]//2, "Incorrect sparsity pattern"
    # indexes = torch.nonzero(sparsity_mask, as_tuple=False).to(torch.int32)
    return V.to(param.dtype), sparsity_mask.to(torch.uint8)

def compress_model(cur_model):
    new_state_dict = {}
    old_size, new_size = 0, 0
    for name, p in cur_model.named_parameters():
        old_size += (16 * p.nelement())
        if ('dense' in name or 'query_key_value' in name) and len(p.shape) > 1:
            V, sparsity_mask = compress_weights(p)
            new_size += (16 * V.nelement() + 1 * sparsity_mask.nelement())
            new_state_dict[name] = V.cpu()
            new_state_dict[name + '_mask'] = sparsity_mask.cpu()
        else:
            new_state_dict[name] = p.cpu()
            new_size += (16 * p.nelement())
    return new_state_dict, old_size, new_size

new_state_dict, old_size, new_size = compress_model(model)
print(old_size, new_size, new_size/old_size)

# Right now, a bool tensor uses 1 byte per element. If used 1 bit per element, we can compress the pytorch model to ~58% of its original size.

189470310400 110197964800 0.581610726067613
