In [1]:
import torch
import matplotlib.pyplot as plt
import os
import struct
from collections import  defaultdict

from transformers import T5ForConditionalGeneration

In [2]:
def get_exponent(model, nbits):
    exponent = {}
    for idx, (name, param) in enumerate(model.named_parameters()):
        # assert param.dtype is torch.float16
        if param.ndim == 2 and param.shape[0] != 1 and param.shape[1] != 1:
            r, c = param.shape
            tensor_data = list(
                map(lambda x: int(
                        "{}".format(
                                bin(
                                    int.from_bytes(
                                    struct.pack('>e', x), byteorder='big'
                                    )
                                )[3: 3+nbits]      # torch.float16 [3: 8] torch.bfloat16[3: 11]
                            ), 
                            base=2
                    ),
                    param.abs().neg().reshape(-1).tolist()
                )
            )
            tensor = torch.tensor(tensor_data, dtype=torch.int8).reshape(r, c)
            exponent[name] = tensor
    return exponent

In [3]:
### Get the exponent ###
models_hub = {
    "t5": {
        "path": "/home/styaeng/project/delta-compress/pretrained_model/t5",       ### 这里要写成下载后的模型权重文件所在的路径
        "hdlr": T5ForConditionalGeneration.from_pretrained
    }
}
t5_model = models_hub['t5']['hdlr'](models_hub['t5']['path'])
t5_exponent = get_exponent(t5_model, nbits=5)

In [4]:
'''
2 bits
    [0 | 1 | 2 | 3]
3 bits
    [0 | 1 | 2 | 3 | 4 | 5 | 6 | 7]
4 bits
    [0-15]
5 bits
    [0]
'''

pattern = [
    [0, 1, 2, 3],
    [0, 1, 2, 3, 4, 5, 6, 7],
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
    [0],
]

In [5]:
### Get the percentage in tile granularity ###
tile_gran_pattern = []
PageSize = 4 * 1024 * 8 # bits
FP16 = 16   # bits
for k, v in t5_exponent.items():
    tile_pattern = defaultdict(dict)
    row, col = v.shape
    if col * FP16 <= PageSize:
        tile_pattern[k] = "LTP"     # a.k.a of "Less Than PageSize"
    else:
        for row_idx in range(row):
            tile = v[row_idx]
            for shift in range(3, -1, -1):
                tile_pattern[k][row_idx] = defaultdict(dict)
                idx = 3 - shift
                pat = pattern[idx]
                tmp = tile >> shift
                for elem in pat:
                    tile_pattern[k][row_idx][shift][elem] = torch.count_nonzero(tmp == elem).item()
    tile_gran_pattern.append(tile_pattern)

In [50]:
tile_pattern = {}
for shift in range(3, -1, -1):
    idx = 3 - shift
    pat = pattern[idx]
    tmp = tile >> shift
    tile_pattern[idx+2] = {}
    for elem in pat:
        tile_pattern[idx+2][elem] = torch.count_nonzero(tmp == elem).item()


In [132]:
row, col = v.shape
factor = 5

higher_cur = 0
lower_cur = higher_cur + factor * PageSize // (col * 16)

tile = v[higher_cur: lower_cur]

tile_pattern = {}
for shift in range(3, -1, -1):
    idx = 3 - shift
    pat = pattern[idx]
    tmp = tile >> shift
    tile_pattern[idx+2] = {}
    for elem in pat:
        tile_pattern[idx+2][elem] = torch.count_nonzero(tmp == elem).item()



In [133]:
# for bit in range(2, 6):
#     for count in tile_pattern[bit].keys():
#         print(tile_pattern[bit][count], end=" ")
#     print()
bit = 4
values = torch.tensor(list(tile_pattern[bit].values())).sort()[0]
values, tile.numel()

print(f"original space occupied: {(tile.numel() * FP16 + PageSize - 1)// PageSize}")
### if bit == 2 ###
values = torch.sort(values, descending=True)[0]
### GROUP 1 ###
wasted_grp1 = factor *PageSize - (values[0] * (FP16 - bit) + bit + (values.sum() - values[0]) * FP16)
### GROUP 2 ###
wasted_grp2 = factor *PageSize - (values[0] * (FP16 - bit) + bit + values[1] * (FP16 - bit) + bit + (values.sum() - values[0] - values[1]) * FP16)
### GROUP 2 ###
wasted_grp3 = factor *PageSize - (values[0] * (FP16 - bit) + bit + \
                                values[1] * (FP16 - bit) + bit + \
                                values[2] * (FP16 - bit) + bit + \
                                (values.sum() - values[0] - values[1] - values[2]) * FP16)
### GROUP 4 ###
wasted_grp4 = factor *PageSize - (values[0] * (FP16 - bit) + bit + \
                        values[1] * (FP16 - bit) + bit + \
                        values[2] * (FP16 - bit) + bit + \
                        values[3] * (FP16 - bit) + bit + \
                        (values.sum() - values[0] - values[1] - values[2] - values[3]) * FP16)

print(wasted_grp1, wasted_grp2, wasted_grp3, wasted_grp4)

original space occupied: 5
tensor(20756) tensor(30712) tensor(38044) tensor(39924)


In [141]:
wasted_bits = factor * PageSize - values.sum() * FP16
wasted_bits

tensor(0)

In [165]:
for bit in range(2, 5):
    values = torch.tensor(list(tile_pattern[bit].values())).sort()[0]
    values = torch.sort(values, descending=True)[0]
    wasted_bits = factor * PageSize - values.sum() * FP16
    for grp_count, v in enumerate(values):
        wasted_bits -= (v * (FP16 - bit) + bit - v * FP16)
        # print(f"bit={bit}, group_counts={grp_count+1}, wasted_bits={wasted_bits} bits")
        if wasted_bits >= PageSize:
            print(f"Compress the tile successfully and free 1 page storage space!\
            \nWith bits={bit}, group_counts={grp_count}, wasted_bits={wasted_bits} ({(wasted_bits / PageSize):.2f} page)!")
            break

Compress the tile successfully and free 1 page storage space!            
With bits=4, group_counts=2, wasted_bits=38044 (1.16 page)!


In [46]:
### Get the percentage in matrix granularity ###
mat_gran_pattern = []
PageSize = 4 * 1024 * 8 # bits
FP16 = 16   # bits

for k, v in t5_exponent.items():
    row, col = v.shape
    higher_cur = 0
    lower_cur = higher_cur + PageSize // (col * FP16)
    while lower_cur <= row:
        """
        # 优化点:
            1. 没必要一行一行地进行计算，在初始tile存储占用还没有到PageSize的时候是没有必要计算压缩后的存储收益的；
        if tile[higher_cur]_compressed_size < PageSize:
            lower_cur = lower_cur + 1
        if tile[]
        """ 
        tile = v[higher_cur: lower_cur]
        

torch.Size([1, 512])
torch.Size([2, 512])
