In [None]:
import os
import time

import torch
torch.set_num_threads(8)
from torch import nn
import torch.nn.functional as F
import transformers

from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig

from quant import quantize, Quantizer
from gptq import GPTQ
from modelutils import find_layers
from datautils import get_loaders

from tqdm import tqdm, trange


def skip(*args, **kwargs):
    pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip


def initialize_layerless_llama(checkpoint_path):
    config = LlamaConfig.from_pretrained("decapoda-research/llama-7b-hf")
    config.num_hidden_layers=0
    
    model = LlamaForCausalLM(config)
    model.load_state_dict(torch.load(os.path.join(checkpoint_path, "pytorch_model-00033-of-00033.bin")))
    model.seqlen = 2048
    
    return model.to(torch.float16)


def load_and_dispatch_a_layer(layer_idx, checkpoint_path, model: LlamaForCausalLM):
    config = transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf")

    layer = LlamaDecoderLayer(config)
    layer_state_dict = {name[len(f"model.layers.{layer_idx}."):]: tensor for name, tensor in torch.load(os.path.join(checkpoint_path, f"pytorch_model-{layer_idx+1:05}-of-00033.bin")).items()}
    layer.load_state_dict(layer_state_dict, strict=False)
    del layer_state_dict    
    
    model.model.layers.append(layer.to(torch.float16))


def get_scale_and_zero(x, maxq):
    tmp = torch.zeros(x.shape[0], device=x.device)
    xmin = torch.minimum(x.min(1)[0], tmp)
    xmax = torch.maximum(x.max(1)[0], tmp)

    shape = x.shape
    tmp = (xmin == 0) & (xmax == 0)
    xmin[tmp] = -1
    xmax[tmp] = +1

    scale = (xmax - xmin) / maxq
    zero = torch.round(-xmin / scale)

    shape = [-1] + [1] * (len(shape) - 1)
    scale = scale.reshape(shape)
    zero = zero.reshape(shape)
    return scale.to(x.dtype), zero.to(x.dtype)


def custom_quantize(x, bits: int):
    x = x.clone().detach()
    maxq = torch.tensor(2 ** bits - 1)
    scale, zero = get_scale_and_zero(x, maxq)
    
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return q.to(torch.uint8), scale, zero


class QuantizedLinear(nn.Module):
    def __init__(self, q, scale, zero, bias):
        super().__init__()
        self.q = nn.Parameter(q, requires_grad=False)
        self.scale = nn.Parameter(scale, requires_grad=False)
        self.zero = nn.Parameter(zero, requires_grad=False)
        
        if bias is not None:
            self.bias = nn.Parameter(bias.data.clone().detach())
        else:
            self.bias = None
    
    def forward(self, input):
        return F.linear(input, self.scale * self.q - self.scale * self.zero, self.bias)


In [None]:
def gptq(weight: torch.Tensor, bits: int, hessian: torch.Tensor, blocksize:int=128, percdamp:float=.01):
#     tick = time.time()
    original_dtype = weight.dtype
    weight = weight.float()
    columns = weight.shape[1]

    # Find maxq, scale, zero such that
    # For all elements w in weight:
    # 0 < (x + zero) / scale < maxq
    maxq = torch.tensor(2 ** bits - 1)
    scale, zero = get_scale_and_zero(weight, maxq)

    dead = torch.diag(hessian) == 0
    hessian[dead, dead] = 1
    weight[:, dead] = 0

    # decrasing activation size
    perm = torch.argsort(torch.diag(hessian), descending=True)
    weight = weight[:, perm]
    hessian = hessian[perm, :][:, perm]
    invperm = torch.argsort(perm)

    # Allocate the quantized tensor
    quantized_weight = torch.zeros(weight.shape, dtype=torch.uint8, device=weight.device)

    # Get Cholesky inverse damped Hessian
    damp = percdamp * torch.mean(torch.diag(hessian))
    diag = torch.arange(columns, device=weight.device)
    hessian[diag, diag] += damp
    hessian = torch.linalg.cholesky(hessian)
    hessian = torch.cholesky_inverse(hessian)
    hessian = torch.linalg.cholesky(hessian, upper=True)
    inverse_hessian = hessian

    # Iterate over blocks
    for block_start in range(0, columns, blocksize):
        block_end = min(block_start + blocksize, columns)
        num_columns_in_block = block_end - block_start

        # Get the next block
        block = weight[:, block_start:block_end]
        quantized_block = torch.zeros(block.shape, dtype=torch.uint8, device=block.device)
        block_error = torch.zeros_like(block)
        inverse_block_hessian = inverse_hessian[block_start:block_end, block_start:block_end]

        # Iterate over it's rows
        for i in range(num_columns_in_block):
            column = block[:, i]
            inverse_hessian_diag_value = inverse_block_hessian[i, i]
            
            # Quantize the row
            quantized_column = torch.clamp(torch.round(column.unsqueeze(1) / scale) + zero, 0, maxq).flatten()
            quantized_block[:, i] = quantized_column.to(torch.uint8)
            dequantized_column = scale.flatten() * quantized_column - scale.flatten() * zero.flatten()
            block[:, i] = dequantized_column
            
            # Update the remaining columns within the block
            column_error = (column - dequantized_column) / inverse_hessian_diag_value
            block[:, i:] -= column_error.unsqueeze(1).matmul(inverse_block_hessian[i, i:].unsqueeze(0))
            block_error[:, i] = column_error
        quantized_weight[:, block_start:block_end] = quantized_block
        
        # Update the remaining columns outside the block
        weight[:, block_end:] -= block_error.matmul(inverse_hessian[block_start:block_end, block_end:])

    torch.cuda.synchronize()
#     print('Quantization time %.2f' % (time.time() - tick))
    return weight[:, invperm].to(original_dtype), quantized_weight[:, invperm], scale.to(original_dtype), zero.to(original_dtype)

In [None]:
def replace_submodule(module, submodule_path, new_submodule):
    submodule_names = submodule_path.split(".")
    for submodule in submodule_names[:-1]:
        module = getattr(module, submodule)
    setattr(module, submodule_names[-1], new_submodule)


@torch.no_grad()
def llama_nearest(checkpoint_path, model, bits: int, device):
    print('Starting ...')
    
    # Load and quantize all the layers
    layers = model.model.layers
    for i in trange(32):
        load_and_dispatch_a_layer(i, checkpoint_path, model)

        layer = layers[i]
        linear_submodules = find_layers(layer)

        sequential_groups = [
            ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
            ['self_attn.o_proj'],
            ['mlp.up_proj', 'mlp.gate_proj'],
            ['mlp.down_proj']
        ]
       
        for group_names in sequential_groups:
            current_linears_to_quantize = {n: linear_submodules[n] for n in group_names}
            for name, linear in current_linears_to_quantize.items():                
                q, scale, zero = custom_quantize(linear.weight.data.to(device), bits=bits)
                replace_submodule(layer, name, QuantizedLinear(q.cpu(), scale, zero, linear.bias))

        layers[i] = layer
        del layer
        torch.cuda.empty_cache()


In [None]:
class Catcher(nn.Module):
    def __init__(self, inps_dest):
        super().__init__()
        self.i = 0
        self.attention_mask = None
        self.position_ids = None
        self.inps_dest = inps_dest

    def forward(self, inp, **kwargs):
        self.inps_dest[self.i] = inp
        self.attention_mask = kwargs['attention_mask']
        self.position_ids = kwargs['position_ids']
        self.i += 1
        raise ValueError
    
    def get_the_catch(self):
        return self.attention_mask, self.position_ids


def get_accumulate_input_fn(name, hessians, num_samples):
    def tmp(_, inp, out):
        inp = inp[0].data # ... x hidden_size
        inp = inp.reshape((-1, inp.shape[-1])) # inputs x hidden_size
        inp = inp.t().float() # hidden_size x inputs
        num_samples[name] += 1
        if hessians[name] is None:
            hessians[name] = inp.matmul(inp.t())
        else:
            hessians[name] += inp.matmul(inp.t())
    return tmp


@torch.no_grad()
def llama_gptq(checkpoint_path, model, dataloader, bits: int, device, n_samples=128):
    print('Starting ...')

    use_cache = model.config.use_cache
    model.config.use_cache = False

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype, device=device
    )
    outs = torch.zeros_like(inps)

    # Collect the first layer inputs
    catcher = Catcher(inps)
    model.model.layers = nn.ModuleList((catcher,))
    for batch in dataloader:
        try:
            model(batch[0])
        except ValueError:
            pass
    attention_mask, position_ids = catcher.get_the_catch()
    model.model.layers = nn.ModuleList()
    
    # Load and quantize all the layers
    layers = model.model.layers
    attention_mask = attention_mask.to(device)
    position_ids = position_ids.to(device)
    for i in trange(32):
        load_and_dispatch_a_layer(i, checkpoint_path, model)
        layer = layers[i].to(device)
        linear_submodules = find_layers(layer)
        quantized_linear_submodules = {}
        
        sequential_groups = [
            ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
            ['self_attn.o_proj'],
            ['mlp.up_proj', 'mlp.gate_proj'],
            ['mlp.down_proj']
        ]
       
        for group_names in sequential_groups:
            current_linears_to_quantize = {name: linear_submodules[name] for name in group_names}

            hessians = {name: None for name in current_linears_to_quantize}
            num_samples = {name: 0 for name in current_linears_to_quantize}
            handles = [linear.register_forward_hook(get_accumulate_input_fn(name, hessians, num_samples)) for name, linear in current_linears_to_quantize.items()]
#             tick = time.time()
            for j in range(n_samples):
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
#             print('Forward pass time %.2f' % (time.time() - tick))
            for h in handles:
                h.remove()

            for name, linear in current_linears_to_quantize.items():
                dequantized_weight, quantized_weight, scale, zero = gptq(linear.weight.data, bits, 2 * hessians[name] / num_samples[name])
                linear.weight.data = dequantized_weight
                quantized_linear_submodules[name] = QuantizedLinear(quantized_weight, scale, zero, linear.bias)
        
        for name, quantized_linear_submodule in quantized_linear_submodules.items():
            replace_submodule(layer, name, quantized_linear_submodule)
        torch.cuda.empty_cache()

        for j in range(n_samples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]

        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache

In [None]:
@torch.no_grad()
def llama_eval(model, testenc, device):
    print('Evaluating ...')
    
    testenc = testenc.input_ids[...,:testenc.input_ids.shape[1]//4]
    n_samples = testenc.numel() // model.seqlen

    use_cache = model.config.use_cache
    model.config.use_cache = False

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype
    ).to(device)
    outs = torch.zeros_like(inps)

    # Collect the first layer inputs
    catcher = Catcher(inps)
    original_layers = model.model.layers
    model.model.layers = nn.ModuleList((catcher,))
    for i in range(n_samples):
        batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)]
        try:
            model(batch)
        except ValueError:
            pass
    attention_mask, position_ids = catcher.get_the_catch()
    model.model.layers = original_layers
    
    # Forward pass through the layers
    layers = model.model.layers
    attention_mask = attention_mask.to(device)
    position_ids = position_ids.to(device)
    for i in trange(len(layers)):
        layer = layers[i].to(device)
        
        for j in range(n_samples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        
        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()
        
        inps, outs = outs, inps

    # Calculate PPL
    testenc = testenc.to(device)
    nlls = []
    for i in range(n_samples):
        hidden_states = inps[i].unsqueeze(0)
        if model.model.norm is not None:
            model.model.norm = model.model.norm.to(device)
            hidden_states = model.model.norm(hidden_states)
            model.model.norm = model.model.norm.cpu()
        
        model.lm_head = model.lm_head.to(device)
        lm_logits = model.lm_head(hidden_states)
        model.lm_head = model.lm_head.cpu()
        
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = testenc[
            :, (i * model.seqlen):((i + 1) * model.seqlen)
        ][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * model.seqlen
        nlls.append(neg_log_likelihood)
    ppl = torch.exp(torch.stack(nlls).sum() / (n_samples * model.seqlen))
    print(ppl.item())

    model.config.use_cache = use_cache


In [None]:
DEVICE = "cuda:0"
MODEL = "../model/"
SEED = 0
BITS = 4

In [None]:
model = initialize_layerless_llama(MODEL)
model.eval()

In [None]:
dataloader, testloader = get_loaders(
    "wikitext2", seed=SEED, model=MODEL, seqlen=model.seqlen
)

In [None]:
llama_gptq(MODEL, model, dataloader, BITS, DEVICE)
# llama_nearest(MODEL, model, BITS, DEVICE)

In [None]:
llama_eval(model, testloader, DEVICE)

GPTQx8: 5.93

NEARESTx8: 5.93

GPTQx4: 6.18

NEARESTx4: 6.60