In [1]:
import torch
import matplotlib.pyplot as plt
import os
import struct
from collections import  defaultdict

from transformers import T5ForConditionalGeneration

In [2]:
def get_exponent(model, nbits):
    exponent = {}
    for idx, (name, param) in enumerate(model.named_parameters()):
        # assert param.dtype is torch.float16
        if param.ndim == 2 and param.shape[0] != 1 and param.shape[1] != 1:
            r, c = param.shape
            tensor_data = list(
                map(lambda x: int(
                        "{}".format(
                                bin(
                                    int.from_bytes(
                                    struct.pack('>e', x), byteorder='big'
                                    )
                                )[3: 3+nbits]      # torch.float16 [3: 8] torch.bfloat16[3: 11]
                            ), 
                            base=2
                    ),
                    param.abs().neg().reshape(-1).tolist()
                )
            )
            tensor = torch.tensor(tensor_data, dtype=torch.int8).reshape(r, c)
            exponent[name] = tensor
    return exponent

In [3]:
### Get the exponent ###
models_hub = {
    "t5": {
        "path": "/home/styaeng/project/delta-compress/pretrained_model/t5",       ### 这里要写成下载后的模型权重文件所在的路径
        "hdlr": T5ForConditionalGeneration.from_pretrained
    }
}
t5_model = models_hub['t5']['hdlr'](models_hub['t5']['path'])
t5_exponent = get_exponent(t5_model, nbits=5)

In [4]:
'''
2 bits
    [0 | 1 | 2 | 3]
3 bits
    [0 | 1 | 2 | 3 | 4 | 5 | 6 | 7]
4 bits
    [0-15]
5 bits
    [0]
'''

pattern = [
    [0, 1, 2, 3],
    [0, 1, 2, 3, 4, 5, 6, 7],
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
    [0],
]

In [11]:
### Get the percentage in tile granularity ###
tile_gran_pattern = []
PageSize = 4 * 1024 * 8 # bits
FP16 = 16   # bits
for k, v in t5_exponent.items():
    tile_pattern = defaultdict(dict)
    row, col = v.shape
    if col * FP16 <= PageSize:
        tile_pattern[k] = "LTP"     # a.k.a of "Less Than PageSize"
    else:
        for row_idx in range(row):
            tile = v[row_idx]
            for shift in range(3, -1, -1):
                tile_gran_pattern[k][row_idx] = defaultdict(dict)
                idx = 3 - shift
                pat = pattern[idx]
                tmp = tile >> shift
                for elem in pat:
                    tile_pattern[k][row_idx][shift][elem] = torch.count_nonzero(tmp == elem).item()
    tile_gran_pattern.append(tile_pattern)

In [28]:
tile_gran_pattern[0]

defaultdict(dict, {'shared.weight': 'LTP'})

In [46]:
### Get the percentage in matrix granularity ###
mat_gran_pattern = []
PageSize = 4 * 1024 * 8 # bits
FP16 = 16   # bits

for k, v in t5_exponent.items():
    row, col = v.shape
    higher_cur = 0
    lower_cur = higher_cur + 1
    while lower_cur <= row:
        """
        if tile[higher_cur]_compressed_size < PageSize:
            lower_cur = lower_cur + 1
        if tile[]
        """ 

torch.Size([1, 512])
torch.Size([2, 512])


In [48]:
v.shape

torch.Size([512, 2048])

In [53]:
v[1:511].shape

torch.Size([510, 2048])