In [97]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Parameter 
import torch.nn.functional as F

# import tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

# test generate
inputs = tokenizer("Hello, world is ", return_tensors="pt")
outputs = model.generate(**inputs, max_length=68)
print(tokenizer.decode(outputs[0]))


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hello, world is  going to be a lot more interesting than it was before.
I'm not sure if I'm going to be able to do this, but I'm going to be able to do it.
I'm going to be able to do it.
I'm going to be able to do it.



In [98]:
inputs

{'input_ids': tensor([[15496,    11,   995,   318,   220]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [99]:
for name, mod in model.named_modules():
    print(name, mod.__class__.__name__)

 GPT2LMHeadModel
transformer GPT2Model
transformer.wte Embedding
transformer.wpe Embedding
transformer.drop Dropout
transformer.h ModuleList
transformer.h.0 GPT2Block
transformer.h.0.ln_1 LayerNorm
transformer.h.0.attn GPT2Attention
transformer.h.0.attn.c_attn Conv1D
transformer.h.0.attn.c_proj Conv1D
transformer.h.0.attn.attn_dropout Dropout
transformer.h.0.attn.resid_dropout Dropout
transformer.h.0.ln_2 LayerNorm
transformer.h.0.mlp GPT2MLP
transformer.h.0.mlp.c_fc Conv1D
transformer.h.0.mlp.c_proj Conv1D
transformer.h.0.mlp.act NewGELUActivation
transformer.h.0.mlp.dropout Dropout
transformer.h.1 GPT2Block
transformer.h.1.ln_1 LayerNorm
transformer.h.1.attn GPT2Attention
transformer.h.1.attn.c_attn Conv1D
transformer.h.1.attn.c_proj Conv1D
transformer.h.1.attn.attn_dropout Dropout
transformer.h.1.attn.resid_dropout Dropout
transformer.h.1.ln_2 LayerNorm
transformer.h.1.mlp GPT2MLP
transformer.h.1.mlp.c_fc Conv1D
transformer.h.1.mlp.c_proj Conv1D
transformer.h.1.mlp.act NewGELUActiva

In [100]:
import yaml

with open("../configs/config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

print(cfg["default_w_bits"])       
print(cfg["per_layer_bits"].keys())


8
dict_keys(['transformer.h.0.attn.c_attn', 'transformer.h.0.attn.c_proj', 'transformer.h.0.mlp.c_fc', 'transformer.h.0.mlp.c_proj', 'transformer.h.1.attn.c_attn', 'transformer.h.1.attn.c_proj', 'transformer.h.1.mlp.c_fc', 'transformer.h.1.mlp.c_proj', 'transformer.h.2.attn.c_attn', 'transformer.h.2.attn.c_proj', 'transformer.h.2.mlp.c_fc', 'transformer.h.2.mlp.c_proj', 'transformer.h.3.attn.c_attn', 'transformer.h.3.attn.c_proj', 'transformer.h.3.mlp.c_fc', 'transformer.h.3.mlp.c_proj', 'transformer.h.4.attn.c_attn', 'transformer.h.4.attn.c_proj', 'transformer.h.4.mlp.c_fc', 'transformer.h.4.mlp.c_proj', 'transformer.h.5.attn.c_attn', 'transformer.h.5.attn.c_proj', 'transformer.h.5.mlp.c_fc', 'transformer.h.5.mlp.c_proj', 'transformer.h.6.attn.c_attn', 'transformer.h.6.attn.c_proj', 'transformer.h.6.mlp.c_fc', 'transformer.h.6.mlp.c_proj', 'transformer.h.7.attn.c_attn', 'transformer.h.7.attn.c_proj', 'transformer.h.7.mlp.c_fc', 'transformer.h.7.mlp.c_proj', 'transformer.h.8.attn.c_att

In [101]:
import re
import torch.nn as nn

def want_quant(name, mod, cfg):
    # skip embedding / norm / lm_head
    # only Linear or Conv1D
    if name == "lm_head": 
        return False
    if isinstance(mod, nn.Linear) or mod.__class__.__name__ == "Conv1D":
        return True
    return False

In [102]:
list(model.named_modules())[10][1].weight.shape
a,b = list(model.named_modules())[9][1].weight.shape
a,b

(768, 2304)

In [103]:
class QuantLinear(nn.Module):
    r"""Quantized version of nn.Linear.

    This layer works just like nn.Linear, but the weights are stored
    in int8 format to save memory and improve efficiency.

    Instead of a single global scale, each output channel has its own
    scale factor. This makes the quantization more accurate because
    different output channels can have very different weight ranges.

    Input shape:  (*, in_features)
    Output shape: (*, out_features)
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("qweight",
            torch.empty(out_features, in_features, dtype=torch.int8, device=device))
        self.register_buffer("w_scale",
            torch.ones(out_features, dtype=torch.float32, device=device))
        self.register_buffer("w_zp",
            torch.zeros(out_features, dtype=torch.int32, device=device))
        if bias:
            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter("bias", None)


    def forward(self, input: Tensor) -> Tensor:
        if torch.any(self.w_zp != 0):
            # non-zero： (q - zp) * scale
            W = (self.qweight.int() - self.w_zp.view(-1, 1)).float() * self.w_scale.view(-1, 1)
        else:
            # zero w zp： q * scale
            W = self.qweight.float() * self.w_scale.view(-1, 1)
        return F.linear(input, W, self.bias)

    def extra_repr(self) -> str:
        """
        Return the extra representation of the quant module.
        """
        return (f"in_features={self.in_features}, out_features={self.out_features}, "
                f"bias={self.bias is not None}, dtype=int8, per_channel=True")
                
    @staticmethod
    def get_bits_for_layer(name: str, cfg: dict) -> int:
        return cfg["per_layer_bits"][name]

    def quantize_from_float(self, weight: torch.Tensor, bits: int = 8):
        # support 2-8 bits
        qmin, qmax = -(2**(bits-1)), 2**(bits-1) - 1 
        w_max_abs = weight.abs().max(dim=1, keepdim=True)[0]
        w_max_abs = torch.clamp(w_max_abs, min=1e-8)
        scale = w_max_abs / qmax
        qweight = torch.clamp(torch.round(weight / scale), qmin, qmax).to(torch.int8)
        zero_point = torch.zeros(weight.size(0), dtype=torch.int32, device=weight.device)
        self.qweight.copy_(qweight)
        self.w_scale.copy_(scale.squeeze())
        self.w_zp.copy_(zero_point)

    @classmethod
    def from_linear(cls, base: nn.Linear, name: str, cfg: dict):
        bits = cls.get_bits_for_layer(name, cfg)
        q = cls(base.in_features, base.out_features,
                bias=(base.bias is not None),
                device=base.weight.device, dtype=base.weight.dtype)
        with torch.no_grad():
            q.quantize_from_float(base.weight, bits=bits)
            if base.bias is not None:
                q.bias.copy_(base.bias)
        return q


In [104]:
print(QuantLinear.get_bits_for_layer("transformer.h.0.attn.c_attn", cfg))   


8


In [105]:
def replace_with_quant(model, cfg):
    name_to_module = dict(model.named_modules())
    for name, mod in list(name_to_module.items()):
        print(name, mod.__class__.__name__)
        if not want_quant(name, mod, cfg):
            continue

        # parent module location
        if '.' in name:
            parent_name, child_name = name.rsplit('.', 1)
            parent = name_to_module[parent_name]
        else:
            parent, child_name = model, name

        # convert Conv1D to Linear
        if mod.__class__.__name__ == "Conv1D":
                in_f, out_f = mod.weight.shape          # Conv1D weight is (out, in)
                base = nn.Linear(in_f, out_f, bias=(mod.bias is not None))
                base.to(mod.weight.device, dtype=mod.weight.dtype)
                with torch.no_grad():
                    base.weight.copy_(mod.weight.T)       # copy 
                    if mod.bias is not None:
                        base.bias.copy_(mod.bias)
        else:
            base = mod

        # construct quantization wrapper
        qcfg = cfg.copy()
        qmod = QuantLinear.from_linear(base, name,cfg=qcfg)

        # replace the original layer with setattr
        setattr(parent, child_name, qmod)

replace_with_quant(model, cfg)


 GPT2LMHeadModel
transformer GPT2Model
transformer.wte Embedding
transformer.wpe Embedding
transformer.drop Dropout
transformer.h ModuleList
transformer.h.0 GPT2Block
transformer.h.0.ln_1 LayerNorm
transformer.h.0.attn GPT2Attention
transformer.h.0.attn.c_attn Conv1D
transformer.h.0.attn.c_proj Conv1D
transformer.h.0.attn.attn_dropout Dropout
transformer.h.0.attn.resid_dropout Dropout
transformer.h.0.ln_2 LayerNorm
transformer.h.0.mlp GPT2MLP
transformer.h.0.mlp.c_fc Conv1D
transformer.h.0.mlp.c_proj Conv1D
transformer.h.0.mlp.act NewGELUActivation
transformer.h.0.mlp.dropout Dropout
transformer.h.1 GPT2Block
transformer.h.1.ln_1 LayerNorm
transformer.h.1.attn GPT2Attention
transformer.h.1.attn.c_attn Conv1D
transformer.h.1.attn.c_proj Conv1D
transformer.h.1.attn.attn_dropout Dropout
transformer.h.1.attn.resid_dropout Dropout
transformer.h.1.ln_2 LayerNorm
transformer.h.1.mlp GPT2MLP
transformer.h.1.mlp.c_fc Conv1D
transformer.h.1.mlp.c_proj Conv1D
transformer.h.1.mlp.act NewGELUActiva

In [106]:

inputs = tokenizer("Hello, world is ", return_tensors="pt")
with torch.no_grad():
    outputs = model.generate(**inputs, max_length=68)

print(tokenizer.decode(outputs[0]))



Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hello, world is  going to be a lot more interesting than it was before.
I'm not sure if I'm going to be able to do this, but I'm going to be able to do it.
I'm going to be able to do it.
I'm going to be able to do it.

