In [1]:
import torch
import matplotlib.pyplot as plt
import os
import struct
from collections import  defaultdict

from transformers import T5ForConditionalGeneration

In [2]:
def get_exponent(model, nbits):
    exponent = {}
    for idx, (name, param) in enumerate(model.named_parameters()):
        # assert param.dtype is torch.float16
        if param.ndim == 2 and param.shape[0] != 1 and param.shape[1] != 1:
            r, c = param.shape
            tensor_data = list(
                map(lambda x: int(
                        "{}".format(
                                bin(
                                    int.from_bytes(
                                    struct.pack('>e', x), byteorder='big'
                                    )
                                )[3: 3+nbits]      # torch.float16 [3: 8] torch.bfloat16[3: 11]
                            ), 
                            base=2
                    ),
                    param.abs().neg().reshape(-1).tolist()
                )
            )
            tensor = torch.tensor(tensor_data, dtype=torch.int8).reshape(r, c)
            exponent[name] = tensor
    return exponent

In [3]:
### Get the exponent ###
models_hub = {
    "t5": {
        "path": "/home/styaeng/project/delta-compress/pretrained_model/t5",       ### 这里要写成下载后的模型权重文件所在的路径
        "hdlr": T5ForConditionalGeneration.from_pretrained
    }
}
t5_model = models_hub['t5']['hdlr'](models_hub['t5']['path'])
t5_exponent = get_exponent(t5_model, nbits=5)

In [4]:
'''
2 bits
    [0 | 1 | 2 | 3]
3 bits
    [0 | 1 | 2 | 3 | 4 | 5 | 6 | 7]
4 bits
    [0-15]
5 bits
    [0]
'''

pattern = [
    [0, 1, 2, 3],
    [0, 1, 2, 3, 4, 5, 6, 7],
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
    [0],
]

In [5]:
### Get the percentage in matrix granularity ###
tensor_pattern = []
for k, v in t5_exponent.items():
    # tensor_self_pattern["name"] = k
    tensor_self_pattern = defaultdict(dict)
    tensor_self_pattern['name'] = k
    row, col = v.shape
    tensor_self_pattern['size'] = v.numel()
    t = v
    for shift in range(3, -1, -1):
        idx = 3 - shift
        pat = pattern[idx]
        tmp = t >> shift
        tensor_self_pattern[f'{idx+2}_bits'] = {}
        for elem in pat:
            tensor_self_pattern[f"{idx+2}_bits"][elem] = torch.count_nonzero(tmp == elem)
        # tensor_self_pattern['compress_num'][value] = torch.count_nonzero(t == value) / (row * col)
    tensor_pattern.append(tensor_self_pattern)

In [None]:
bitwid = 5
compress_bit = 2
tensor_id = 1
PageSize = 4 * 1024 * 8

compression_ratio_average = 0
for tensor_id in range(0, len(tensor_pattern)):
    final_compression_ratio = -1
    final_bitwise = -1
    for compress_bit in range(2, 6):
        compress_cnt = torch.tensor(list(tensor_pattern[tensor_id][f'{compress_bit}_bits'].values())).max().item()
        before_compress_bit = tensor_pattern[tensor_id]['size'] * 5

        compressed_part = compress_cnt * (bitwid - compress_bit) + compress_bit
        uncompressed_part = (tensor_pattern[tensor_id]['size'] - compress_cnt) * 5

        compressed_page_count = (compressed_part + PageSize - 1) // PageSize + (uncompressed_part + PageSize - 1) // PageSize
        uncompressed_page_count = (before_compress_bit + PageSize - 1) // PageSize
        compression_ratio = (uncompressed_page_count - compressed_page_count) / uncompressed_page_count

        if final_compression_ratio < compression_ratio:
            final_compression_ratio = compression_ratio
            final_bitwise = compress_bit
    compression_ratio_average += final_compression_ratio / len(tensor_pattern)
    print(f"{tensor_pattern[tensor_id]['name']}", \
            "\t", uncompressed_page_count, \
            "\t", compressed_page_count, \
            "\t", final_bitwise, \
            "\t", f"{final_compression_ratio * 100:.2f}%")
print(f"compression_ratio = {compression_ratio_average * 100 :.2f}%")

In [7]:
# ### Get the percentage in tile granularity ###
tensor_pattern_tile_gran = []
for k, v in t5_exponent.items():
    tensor_self_pattern = defaultdict(dict)
    tensor_self_pattern["name"] = k
    row, col = v.shape
    t = v
    for row_id in range(t.shape[0]):
        tensor_self_pattern[f"row_{row_id}"]['size'] = t[row_id].numel()
        for shift in range(3, -1, -1):
            idx = 3 - shift
            pat = pattern[idx]
            tmp = t[row_id] >> shift
            tensor_self_pattern[f"row_{row_id}"][f'{idx+2}_bits'] = {}
            for elem in pat:
                tensor_self_pattern[f"row_{row_id}"][f"{idx+2}_bits"][elem] = torch.count_nonzero(tmp == elem)
    tensor_pattern_tile_gran.append(tensor_self_pattern)

In [31]:
for i, (k, v) in enumerate(t5_exponent.items()):
    assert tensor_pattern_tile_gran[i]['name'] == k
    for row_id in range(v.shape[0]):
        assert tensor_pattern_tile_gran[i][f'row_{row_id}']['size'] == v.shape[1]
        for bit_count in range(2, 5):
            assert torch.tensor(
                list(
                    tensor_pattern_tile_gran[i][f'row_{row_id}'][f'{bit_count}_bits'].values()
                    )
                ).sum() == v.shape[1]

bitwid = 5
compress_bit = 2
tensor_id = 1
PageSize = 4 * 1024 * 8

compression_ratio_average = 0
for tensor_id in range(0, len(tensor_pattern_tile_gran)):
    final_compression_ratio = -1
    final_pages = 0
    t = t5_exponent[tensor_pattern_tile_gran[tensor_id]['name']]
    row, col = t.shape
    origin_pages = 0
    compressed_tensor_bits = 0
    uncompressed_tensor_bits = 0
    bits_of_tensor = 0
    for row_id in range(row):
        compressed_part = 0
        uncompressed_part = 0
        max_compression_ratio = -1
        max_bits = -1
        final_bits = []
        before_compress_bits = tensor_pattern_tile_gran[tensor_id][f"row_{row_id}"]['size'] * 5
        for bit_idx in range(2, 6):
            compress_cnt = torch.tensor(
                list(
                    tensor_pattern_tile_gran[tensor_id][f"row_{row_id}"][f"{bit_idx}_bits"].values()
                    )
                ).max().item()
            compressed_bit = compress_cnt * (bitwid - bit_idx) + bit_idx
            uncompressed_bit = (tensor_pattern_tile_gran[tensor_id][f"row_{row_id}"]['size'] - compress_cnt) * 5
            compression_ratio = (before_compress_bits - compress_bit + uncompressed_bit) / before_compress_bits
            if max_compression_ratio < compression_ratio:
                max_bits = bit_idx
                final_bits = [compress_bit , uncompressed_bit]
        compressed_tensor_bits += final_bits[0]
        uncompressed_tensor_bits += final_bits[1]
        bits_of_tensor += tensor_pattern_tile_gran[tensor_id][f"row_{row_id}"]['size'] * 5
    final_pages = (compressed_tensor_bits + PageSize - 1) // PageSize + (compressed_tensor_bits + PageSize - 1) // PageSize
    origin_pages = (bits_of_tensor + PageSize - 1) // PageSize
    final_compression_ratio = (origin_pages - final_pages) / origin_pages
    compression_ratio_average += (final_compression_ratio / len(tensor_pattern_tile_gran))
    print(f"{tensor_pattern_tile_gran[tensor_id]['name']}", \
            "\t", origin_pages, \
            "\t", final_pages, \
            "\t", max_bits, \
            "\t", f"{final_compression_ratio * 100:.2f}%")
print(f"compression_ratio = {compression_ratio_average * 100:.2f}%")

shared.weight 	 2510 	 4 	 5 	 99.84%
encoder.block.0.layer.0.SelfAttention.q.weight 	 40 	 2 	 5 	 95.00%
encoder.block.0.layer.0.SelfAttention.k.weight 	 40 	 2 	 5 	 95.00%
encoder.block.0.layer.0.SelfAttention.v.weight 	 40 	 2 	 5 	 95.00%
encoder.block.0.layer.0.SelfAttention.o.weight 	 40 	 2 	 5 	 95.00%
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight 	 1 	 2 	 5 	 -100.00%
encoder.block.0.layer.1.DenseReluDense.wi.weight 	 160 	 2 	 5 	 98.75%
encoder.block.0.layer.1.DenseReluDense.wo.weight 	 160 	 2 	 5 	 98.75%
encoder.block.1.layer.0.SelfAttention.q.weight 	 40 	 2 	 5 	 95.00%
encoder.block.1.layer.0.SelfAttention.k.weight 	 40 	 2 	 5 	 95.00%
encoder.block.1.layer.0.SelfAttention.v.weight 	 40 	 2 	 5 	 95.00%
encoder.block.1.layer.0.SelfAttention.o.weight 	 40 	 2 	 5 	 95.00%
encoder.block.1.layer.1.DenseReluDense.wi.weight 	 160 	 2 	 5 	 98.75%
encoder.block.1.layer.1.DenseReluDense.wo.weight 	 160 	 2 	 5 	 98.75%
encoder.block.2.layer.0.SelfAt

In [47]:
print((t[0:512].numel() * 5 + PageSize - 1) // PageSize)

160


In [72]:
origin_bits = 0
after_compressed_bits = 0
for row in t[0:512]:
    max_count = torch.bincount(row).max().item()
    max_idx = torch.log2(torch.bincount(row).argmax())
    max_bit = torch.ceil(max_idx).to(torch.uint8).item()
    compressed_bits = max_count * (5 - 4) + 4
    uncompressed_bits = (row.shape[0] - max_count) * 5
    origin_bits += row.shape[0] * 5
    after_compressed_bits += (compressed_bits + uncompressed_bits)
    # print(compressed_bits, uncompressed_bits, original_bits)
    # original_pages = (origin_bits + PageSize - 1) // PageSize
    # after_compressed_pages = (compressed_bits + PageSize - 1) // PageSize \
    #     + (uncompressed_bits + PageSize - 1) // PageSize
    # print(after_compressed_pages, original_pages)
origin_pages = (original_bits + PageSize - 1) // PageSize
after_compressed_pages = (after_compressed_bits + PageSize - 1) // PageSize
print(origin_pages, after_compressed_pages)

161 123


In [73]:
max_bit

4