In [None]:
import torch
torch.set_default_dtype(torch.float16)
sparse12_model_path = 'checkpoints/sparse_checkpoints/sparse_llama12_alpaca_sd.pth' # SparseGPT model with 1:2 sparsity pattern applied to 7B llama with alpaca as a fine-tuning dataset.
device = torch.device('cuda:0')
model = torch.load(sparse12_model_path).to(device)

In [None]:
def bit_packing(mask):
    flat_mask = mask.to(torch.uint8).flatten()
    reshaped_mask = flat_mask.view(-1, 8)
    packed_mask = torch.zeros(reshaped_mask.shape[0], dtype=torch.uint8, device=mask.device)
    for i in range(8):
        packed_mask += reshaped_mask[:, i] << i
    packed_mask = packed_mask.view(mask.shape[0], mask.shape[1]//8)
    return packed_mask

def compress_weights(param):
    W = param.clone().detach()
    sparsity_mask = (W != 0)
    V = W.masked_select(sparsity_mask).view(W.shape[0], -1)
    assert V.shape[1]==W.shape[1]//2, "Incorrect sparsity pattern"
    return V.to(param.dtype), bit_packing(sparsity_mask)

cur_model = model
new_state_dict = {}
old_size, new_size = 0, 0
for name, p in cur_model.named_parameters():
    if any(x in name for x in ["wq", "wk", "wv", "wo", "w1", "w2", "w3", "query_w", "key_w", "value_w", "attn_w"]) and len(p.shape) > 1:
        V, sparsity_mask = compress_weights(p)
        new_state_dict[name] = V
        new_name_ls = name.split('.')
        new_name_ls[-2] = new_name_ls[-2]+'_mask'
        new_name = '.'.join(new_name_ls[:-1])
        new_state_dict[new_name] = sparsity_mask
    else:
        new_state_dict[name] = p

# compute the size of a state_dict
old_size = sum(p.numel() for p in cur_model.parameters())
new_size = sum(p.numel() for p in new_state_dict.values())
print('Compressed model is {:.2f}x of the original model'.format(new_size/old_size*100))
torch.save(new_state_dict, 'checkpoints/sparse_checkpoints/compressed_llama_alpaca.pth')

In [None]:
# At the time of inference, use the following two functions to decompress the weights of each layer independently.
def bit_unpacking(packed_mask):
    rows, cols = packed_mask.shape
    packed_mask = packed_mask.view(-1)
    unpacked_mask = torch.zeros(packed_mask.shape[0]*8, dtype=torch.uint8, device=packed_mask.device)
    for i in range(8):
        unpacked_mask[i::8] = (packed_mask >> i).bitwise_and_(1)
    unpacked_mask = unpacked_mask.view(rows, cols*8)
    return unpacked_mask

def decompress_weights(V, sparsity_mask, x):
    sparsity_mask = bit_unpacking(sparsity_mask)
    rows, cols = sparsity_mask.shape
    x = x.view(-1, cols//2, 2).transpose(1, 2)
    P = V.unsqueeze(0) * x.unsqueeze(2)
    Q = P * sparsity_mask.view(-1, 2).t().view(1, 2, rows, cols//2)
    b = torch.sum(Q, dim=(1, 3))
    torch.cuda.empty_cache()
    return b

In [None]:
# An example of an internal layer transformation at inference time: The first class is the implementation for the standard (non-compressed) weights. The second class is the implementation for the same layer when processing the compressed weights.
import torch.nn as nn
import torch.nn.functional as F

class ProjLayerSiluMatMul(nn.Module):
    def __init__(
        self,
        in_feature_size: int,
        hidden_feature_size: int,
        device: torch.device = None,
    ):
        super().__init__()
        self.hidden_feature_size = hidden_feature_size
        self.in_feature_size = in_feature_size

        self.w1 = nn.Linear(
            in_feature_size, hidden_feature_size, bias=False, device=device
        )
        self.w2 = nn.Linear(
            hidden_feature_size, in_feature_size, bias=False, device=device
        )
        self.w3 = nn.Linear(
            in_feature_size, hidden_feature_size, bias=False, device=device
        )

    def forward(self, x):
        
        w1x = self.w1(x)
        return self.w2(w1x * F.sigmoid(w1x) * self.w3(x))
    
class ProjLayerSiluMatMul_Compressed(nn.Module):
    def __init__(
        self,
        in_feature_size: int,
        hidden_feature_size: int,
        device: torch.device = None,
    ):
        super().__init__()
        self.hidden_feature_size = hidden_feature_size
        self.in_feature_size = in_feature_size

        self.w1 = nn.Linear(
            in_feature_size//2, hidden_feature_size, bias=False, device=device
        )
        self.w2 = nn.Linear(
            hidden_feature_size//2, in_feature_size, bias=False, device=device
        )
        self.w3 = nn.Linear(
            in_feature_size//2, hidden_feature_size, bias=False, device=device
        )
        self.w1_mask = nn.Parameter(torch.empty((hidden_feature_size, in_feature_size//8), dtype=torch.uint8, device=device), requires_grad=False)
        self.w2_mask = nn.Parameter(torch.empty((in_feature_size, hidden_feature_size//8), dtype=torch.uint8, device=device), requires_grad=False)
        self.w3_mask = nn.Parameter(torch.empty((hidden_feature_size, in_feature_size//8), dtype=torch.uint8, device=device), requires_grad=False)

    def forward(self, x):
        w1x = decompress_weights(self.w1.weight, self.w1_mask, x)
        w3x = decompress_weights(self.w3.weight, self.w3_mask, x)
        w2x = decompress_weights(self.w2.weight, self.w2_mask, w1x * F.sigmoid(w1x) * w3x)
        return w2x
