In [1]:
from utils import quantise_tensor, dequantise_tensor, compute_quantisation_mse,get_tensor_memory_size
import torch

t = torch.rand((10, 100))
chunk_size = 256
t_q, scales, locations = quantise_tensor(t, chunk_size)
t_approx = dequantise_tensor(t_q, scales, locations, chunk_size)

mse = compute_quantisation_mse(t, t_approx)

print(f'Required memory: {get_tensor_memory_size(t)} vs {get_tensor_memory_size(t_q) + get_tensor_memory_size(scales) + get_tensor_memory_size(locations)}')

MSE: 0.07002327591180801
Norm: 18.12198829650879
RelativeMSE: 0.3864%
MSE: 0.07093057036399841
Norm: 17.9595890045166
RelativeMSE: 0.39495%
Required memory: 4000 vs 1032


In [2]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("google-bert/bert-base-uncased")

Some weights of the model checkpoint at google-bert/bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
model = model.bert.encoder.layer[0]

In [4]:
from tqdm import tqdm
import torch

def quantise_model(model, chunk_size):
    parameter_mapping = {}

    d = model.state_dict()
    for parameter_name, p in tqdm(model.named_parameters()):
        d[parameter_name], scales, locations = quantise_tensor(p, chunk_size)
        parameter_mapping[parameter_name] = {'scales':scales, 'locations': locations, 'chunk_size': chunk_size}

    model.load_state_dict(d)

    return model, parameter_mapping

model, parameter_mapping = quantise_model(model, chunk_size=4)
# print(model.state_dict())
leaf_modules = []
for module_name, m in model.named_modules():
    if len(list(m.named_modules())) == 1:
        leaf_modules.append(module_name)

0it [00:00, ?it/s]

16it [00:58,  3.65s/it]


In [5]:
def hook_factory(leaf_module_name):

    def dequantise_hook(module, args):
        d = module.state_dict()
        module.quantised_state = {}
        parameter_names = [p for p, _ in module.named_parameters()]
        for parameter_name in parameter_names:
            global_parameter_name = f'{leaf_module_name}.{parameter_name}'
            p = module.get_parameter(parameter_name)
            module.quantised_state[parameter_name] = p.clone()
            p_approx = dequantise_tensor(p, parameter_mapping[global_parameter_name]['scales'], parameter_mapping[global_parameter_name]['locations'],parameter_mapping[global_parameter_name]['chunk_size'])
            d[parameter_name] = p_approx
        
        module.load_state_dict(d)


    def cleanup_hook(module, args, output):
        d = module.state_dict()
        for parameter_name, p in module.quantised_state.items():
            d[parameter_name] = p

        module.load_state_dict(d)

    return dequantise_hook, cleanup_hook

for leaf_module in leaf_modules:
    dequantise_hook, cleanup_hook = hook_factory(leaf_module)
    model.get_submodule(leaf_module).register_forward_pre_hook(dequantise_hook)
    model.get_submodule(leaf_module).register_forward_hook(cleanup_hook)




In [6]:
model.forward(torch.ones(1,10,768))

(tensor([[[-0.0471, -0.0135, -0.2095,  ...,  0.1135, -0.6417, -0.0570],
          [-0.0471, -0.0135, -0.2095,  ...,  0.1135, -0.6417, -0.0570],
          [-0.0471, -0.0135, -0.2095,  ...,  0.1135, -0.6417, -0.0570],
          ...,
          [-0.0471, -0.0135, -0.2095,  ...,  0.1135, -0.6417, -0.0570],
          [-0.0471, -0.0135, -0.2095,  ...,  0.1135, -0.6417, -0.0570],
          [-0.0471, -0.0135, -0.2095,  ...,  0.1135, -0.6417, -0.0570]]],
        grad_fn=<NativeLayerNormBackward0>),)