In [1]:
import os

import torch
torch.set_num_threads(8)
from torch import nn
import torch.nn.functional as F
import transformers

from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig

from quant import quantize, Quantizer
from gptq import GPTQ
from modelutils import find_layers
from datautils import get_loaders

from tqdm import tqdm, trange


def skip(*args, **kwargs):
    pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip


def initialize_layerless_llama(checkpoint_path):
    config = LlamaConfig.from_pretrained("decapoda-research/llama-7b-hf")
    config.num_hidden_layers=0
    
    model = LlamaForCausalLM(config)
    model.load_state_dict(torch.load(os.path.join(checkpoint_path, "pytorch_model-00033-of-00033.bin")))
    model.seqlen = 2048
    
    return model


def load_and_dispatch_a_layer(layer_idx, checkpoint_path, model: LlamaForCausalLM):
    config = transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf")

    layer = LlamaDecoderLayer(config)
    layer_state_dict = {name[len(f"model.layers.{layer_idx}."):]: tensor for name, tensor in torch.load(os.path.join(checkpoint_path, f"pytorch_model-{layer_idx+1:05}-of-00033.bin")).items()}
    layer.load_state_dict(layer_state_dict, strict=False)
    del layer_state_dict    
    
    model.model.layers.append(layer)


def get_scale_and_zero(x, maxq):
    tmp = torch.zeros(x.shape[0], device=x.device)
    xmin = torch.minimum(x.min(1)[0], tmp)
    xmax = torch.maximum(x.max(1)[0], tmp)

    shape = x.shape
    tmp = (xmin == 0) & (xmax == 0)
    xmin[tmp] = -1
    xmax[tmp] = +1

    scale = (xmax - xmin) / maxq
    zero = torch.round(-xmin / scale)

    shape = [-1] + [1] * (len(shape) - 1)
    scale = scale.reshape(shape)
    zero = zero.reshape(shape)
    return scale.to(x.dtype), zero.to(x.dtype)


def custom_quantize(x, bits: int):
    x = x.clone().detach()
    maxq = torch.tensor(2 ** bits - 1)
    scale, zero = get_scale_and_zero(x, maxq)
    
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return q.to(torch.uint8), scale, zero


class QuantizedLinear(nn.Module):
    def __init__(self, q, scale, zero, bias):
        super().__init__()
        self.q = nn.Parameter(q, requires_grad=False)
        self.scale = nn.Parameter(scale, requires_grad=False)
        self.zero = nn.Parameter(zero, requires_grad=False)
        
        if bias is not None:
            self.bias = nn.Parameter(bias.data.clone().detach())
        else:
            self.bias = None
    
    def forward(self, input):
        return F.linear(input, self.scale * self.q - self.scale * self.zero, self.bias)


  from .autonotebook import tqdm as notebook_tqdm


CUDA extension not installed.


In [None]:
def gptq(x: torch.Tensor, bits: int, hessian: torch.Tensor, blocksize:int=128, percdamp:float=.01):
    W = x.clone().detach()
    W = W.float()
    columns = W.shape[1]

    maxq = torch.tensor(2 ** bits - 1)
    scale, zero = get_scale_and_zero(W, maxq)

    H = hessian
    dead = torch.diag(H) == 0
    H[dead, dead] = 1
    W[:, dead] = 0

    # decrasing activation size
    perm = torch.argsort(torch.diag(H), descending=True)
    W = W[:, perm]
    H = H[perm, :][:, perm]
    invperm = torch.argsort(perm)

    Losses = torch.zeros_like(W)
    Q = torch.zeros(W.shape, dtype=torch.uint8, device=W.device)

    damp = percdamp * torch.mean(torch.diag(H))
    diag = torch.arange(columns, device=W.device)
    H[diag, diag] += damp
    H = torch.linalg.cholesky(H)
    H = torch.cholesky_inverse(H)
    H = torch.linalg.cholesky(H, upper=True)
    Hinv = H

    for i1 in range(0, columns, blocksize):
        i2 = min(i1 + blocksize, columns)
        count = i2 - i1

        W1 = W[:, i1:i2].clone()
        Q1 = torch.zeros(W1.shape, dtype=torch.uint8, device=W1.device)
        Err1 = torch.zeros_like(W1)
        Losses1 = torch.zeros_like(W1)
        Hinv1 = Hinv[i1:i2, i1:i2]

        for i in range(count):
            w = W1[:, i]
            d = Hinv1[i, i]
            
            q = torch.clamp(torch.round(w.unsqueeze(1) / scale) + zero, 0, maxq).flatten()
            Q1[:, i] = q.to(torch.uint8)
            q = scale.flatten() * q - scale.flatten() * zero.flatten()
            
            Losses1[:, i] = (w - q) ** 2 / d ** 2

            err1 = (w - q) / d
            W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
            Err1[:, i] = err1

        Q[:, i1:i2] = Q1
        Losses[:, i1:i2] = Losses1 / 2

        W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

    torch.cuda.synchronize()
    return Q[:, invperm], scale, zero

In [2]:
def replace_submodule(module, submodule_path, new_submodule):
    submodule_names = submodule_path.split(".")
    for submodule in submodule_names[:-1]:
        module = getattr(module, submodule)
    setattr(module, submodule_names[-1], new_submodule)


@torch.no_grad()
def llama_nearest(checkpoint_path, model, bits: int, device):
    print('Starting ...')
    
    # Load and quantize all the layers
    layers = model.model.layers
    for i in trange(32):
        load_and_dispatch_a_layer(i, checkpoint_path, model)

        layer = layers[i]
        linear_submodules = find_layers(layer)

        sequential_groups = [
            ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
            ['self_attn.o_proj'],
            ['mlp.up_proj', 'mlp.gate_proj'],
            ['mlp.down_proj']
        ]
       
        for group_names in sequential_groups:
            current_linears_to_quantize = {n: linear_submodules[n] for n in group_names}
            for name, linear in current_linears_to_quantize.items():                
                q, scale, zero = custom_quantize(linear.weight.data.to(device), bits=bits)
                replace_submodule(layer, name, QuantizedLinear(q.cpu(), scale, zero, linear.bias))

        layers[i] = layer
        del layer
        torch.cuda.empty_cache()


In [1]:
class Catcher(nn.Module):
    def __init__(self, inps_dest):
        super().__init__()
        self.i = 0
        self.attention_mask = None
        self.position_ids = None
        self.inps_dest = inps_dest

    def forward(self, inp, **kwargs):
        self.inps_dest[self.i] = inp
        self.attention_mask = kwargs['attention_mask']
        self.position_ids = kwargs['position_ids']
        self.i += 1
        raise ValueError
    
    def get_the_catch(self):
        return self.attention_mask, self.position_ids


def get_accumulate_input_fn(name, hessians, num_samples):
    def tmp(_, inp, out):
        inp = inp[0].data # ... x hidden_size
        inp = inp.reshape((-1, inp.shape[-1])) # inputs x hidden_size
        inp = inp.t().float() # hidden_size x inputs
        num_samples[name] += 1
        if hessians[name] is None:
            hessians[name] = inp.matmul(inp.t())
        else:
            hessians[name] += inp.matmul(inp.t())
    return tmp


@torch.no_grad()
def llama_gptq(checkpoint_path, model, dataloader, bits: int, device, n_samples=128):
    print('Starting ...')

    use_cache = model.config.use_cache
    model.config.use_cache = False

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype
    )
    outs = torch.zeros_like(inps)

    # Collect the first layer inputs
    catcher = Catcher(inps)
    model.model.layers = nn.ModuleList((catcher,))
    for batch in dataloader:
        try:
            model(batch[0])
        except ValueError:
            pass
    attention_mask, position_ids = catcher.get_the_catch()
    model.model.layers = nn.ModuleList()
    
    # Load and quantize all the layers
    layers = model.model.layers
    for i in trange(32):
        load_and_dispatch_a_layer(i, checkpoint_path, model)
        layer = layers[i].to(device)
        linear_submodules = find_layers(layer)
        
        sequential_groups = [
            ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
            ['self_attn.o_proj'],
            ['mlp.up_proj', 'mlp.gate_proj'],
            ['mlp.down_proj']
        ]
       
        for group_names in sequential_groups:
            current_linears_to_quantize = {name: linear_submodules[name] for name in group_names}

            hessians = {name: None for name in current_linears_to_quantize}
            num_samples = {name: 0 for name in current_linears_to_quantize}
            handles = [linear.register_forward_hook(get_accumulate_input_fn(name, hessians, num_samples)) for name, linear in current_linears_to_quantize.items()]
            for j in range(n_samples):
                outs[j] = layer(inps[j].unsqueeze(0).to(device), attention_mask=attention_mask, position_ids=position_ids)[0].cpu()
            for h in handles:
                h.remove()

            for name, linear in current_linears_to_quantize.items():
                q, scale, zero = gptq(linear.weight.data, bits, 2 * hessians[name] / num_samples[name])
                replace_submodule(layer, name, QuantizedLinear(q, scale, zero, linear.bias))

        for j in range(n_samples):
            outs[j] = layer(inps[j].unsqueeze(0).to(device), attention_mask=attention_mask, position_ids=position_ids)[0].cpu()

        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache

NameError: name 'nn' is not defined

In [None]:
@torch.no_grad()
def llama_eval(model, testenc, device):
    print('Evaluating ...')
    
    testenc = testenc.input_ids[...,:testenc.input_ids.shape[1]//8]
    n_samples = testenc.numel() // model.seqlen

    use_cache = model.config.use_cache
    model.config.use_cache = False

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype
    )
    outs = torch.zeros_like(inps)

    # Collect the first layer inputs
    catcher = Catcher(inps)
    model.model.layers = nn.ModuleList((catcher,))
    for i in range(n_samples):
        batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)]
        try:
            model(batch)
        except ValueError:
            pass
    attention_mask, position_ids = catcher.get_the_catch()
    model.model.layers = nn.ModuleList()
    
    # Forward pass through the layers
    layers = model.model.layers
    for i in trange(len(layers)):
        layer = layers[i].to(device)
        
        for j in range(n_samples):
            outs[j] = layer(inps[j].unsqueeze(0).to(device), attention_mask=attention_mask, position_ids=position_ids)[0].cpu()
        
        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()
        
        inps, outs = outs, inps

    # Calculate PPL
    testenc = testenc.to(device)
    nlls = []
    for i in range(n_samples):
        hidden_states = inps[i].unsqueeze(0)
        if model.model.norm is not None:
            model.model.norm = model.model.norm.to(device)
            hidden_states = model.model.norm(hidden_states.to(device))
            model.model.norm = model.model.norm.cpu()
        
        model.lm_head = model.lm_head.to(device)
        lm_logits = model.lm_head(hidden_states)
        model.lm_head = model.lm_head.cpu()
        
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = testenc[
            :, (i * model.seqlen):((i + 1) * model.seqlen)
        ][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * model.seqlen
        nlls.append(neg_log_likelihood)
    ppl = torch.exp(torch.stack(nlls).sum() / (n_samples * model.seqlen))
    print(ppl.item())

    model.config.use_cache = use_cache


In [4]:
DEVICE = "cuda:0"
MODEL = "../model/"
SEED = 0
BITS = 4

In [5]:
model = initialize_layerless_llama(MODEL)
model.eval()

Starting ...
Ready.


100%|██████████| 32/32 [00:34<00:00,  1.09s/it]


In [6]:
dataloader, testloader = get_loaders(
    "wikitext2", seed=SEED, model=MODEL, seqlen=model.seqlen
)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


wikitext2
Evaluating ...


  0%|          | 0/32 [28:52<?, ?it/s]


RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/tmp/ipykernel_562573/2934076896.py", line 133, in forward
    def forward(self, x):
        output = x @ (self.scale * self.q - self.scale * self.zero)
                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        if self.bias is not None:
            output += self.bias
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2048x4096 and 11008x4096)


In [None]:
# llama_gptq(MODEL, model, dataloader, BITS, DEVICE)
llama_nearest(MODEL, model, BITS, DEVICE)

In [None]:
llama_eval(model, testloader, DEVICE)