In [1]:
import os
import torch
from tqdm import tqdm, trange
import json
from huggingface_hub import snapshot_download

snapshot_download(repo_id="yahma/llama-7b-hf", local_dir="./model")

def repack_llama(path):
    non_layers = {}
    first_state_dict = torch.load(os.path.join(path, "pytorch_model-00001-of-00002.bin"))
    non_layers["model.embed_tokens.weight"] = first_state_dict["model.embed_tokens.weight"]
    for i in trange(24):
        layer = {key: value for key, value in first_state_dict.items() if f"layers.{i}." in key}
        torch.save(layer, os.path.join(path, f"layer_{i}.bin"))
    del first_state_dict
    
    second_state_dict = torch.load(os.path.join(path, "pytorch_model-00002-of-00002.bin"))
    non_layers["lm_head.weight"] = second_state_dict["lm_head.weight"]
    non_layers["model.norm.weight"] = second_state_dict["model.norm.weight"]
    for i in trange(24, 32):
        layer = {key: value for key, value in second_state_dict.items() if f"layers.{i}." in key}
        torch.save(layer, os.path.join(path, f"layer_{i}.bin"))
    del second_state_dict
    
    torch.save(non_layers, os.path.join(path, f"non_layers.bin"))

repack_llama("./model")

  from .autonotebook import tqdm as notebook_tqdm
Downloading (…)27ca2ada75/README.md: 100%|██████████| 8.84k/8.84k [00:00<00:00, 3.62MB/s]

Downloading (…)ada75/.gitattributes: 100%|██████████| 1.48k/1.48k [00:00<00:00, 6.28MB/s]
Fetching 10 files:  10%|█         | 1/10 [00:00<00:04,  1.84it/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 72.0/72.0 [00:00<00:00, 22.3kB/s]
Downloading (…)model.bin.index.json: 100%|██████████| 26.8k/26.8k [00:00<00:00, 19.6MB/s]
Downloading (…)neration_config.json: 100%|██████████| 137/137 [00:00<00:00, 588kB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 207/207 [00:00<00:00, 857kB/s]

[A

Downloading tokenizer.model: 100%|██████████| 500k/500k [00:00<00:00, 2.95MB/s]

[A

[A[A
[A

[A[A
[A
[A

[A[A
[A

[A[A
[A
[A

[A[A
[A

[A[A
[A

[A[A
[A

[A[A
[A

[A[A
[A

[A[A
[A

[A[A
[A

[A[A

[A[A
[A

[A[A
[A

[A[A

[A[A
[A
[A

[A[A
[A

[A[A
[A

[A[A
[A

[A[A
[A

[A[A
[A

[A[A


In [1]:
import os
import math
import time
import random
from tqdm import tqdm, trange

import torch
torch.set_num_threads(8)
from torch import nn
import torch.nn.functional as F

import transformers
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers import AutoTokenizer 
from datasets import load_dataset


def get_wikitext2(model, seed, seqlen, nsamples=128):
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
    trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


def skip(*args, **kwargs):
    pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip


def initialize_layerless_llama(checkpoint_path):
    config = LlamaConfig.from_pretrained("yahma/llama-7b-hf")
    config.num_hidden_layers=0
    
    model = LlamaForCausalLM(config)
    model.load_state_dict(torch.load(os.path.join(checkpoint_path, "non_layers.bin")))
    model.seqlen = 2048
    
    return model.to(torch.float16)


def load_and_dispatch_a_layer(layer_idx, checkpoint_path, model: LlamaForCausalLM):
    config = transformers.AutoConfig.from_pretrained("yahma/llama-7b-hf")

    layer = LlamaDecoderLayer(config)
    layer_state_dict = torch.load(os.path.join(checkpoint_path, f"layer_{layer_idx}.bin"))
    layer_state_dict = {name[len(f"model.layers.{layer_idx}."):]: tensor for name, tensor in layer_state_dict.items()}
    layer.load_state_dict(layer_state_dict, strict=False)
    del layer_state_dict    
    
    model.model.layers.append(layer.to(torch.float16))


def get_scale_and_zero(x, maxq):
    tmp = torch.zeros(x.shape[0], device=x.device)
    xmin = torch.minimum(x.min(1)[0], tmp)
    xmax = torch.maximum(x.max(1)[0], tmp)

    shape = x.shape
    tmp = (xmin == 0) & (xmax == 0)
    xmin[tmp] = -1
    xmax[tmp] = +1

    scale = (xmax - xmin) / maxq
    zero = torch.round(-xmin / scale)

    shape = [-1] + [1] * (len(shape) - 1)
    scale = scale.reshape(shape)
    zero = zero.reshape(shape)
    return scale.to(x.dtype), zero.to(x.dtype)


def custom_quantize(x, bits: int):
    x = x.clone().detach()
    maxq = torch.tensor(2 ** bits - 1)
    scale, zero = get_scale_and_zero(x, maxq)
    
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return q.to(torch.uint8), scale, zero


class QuantizedLinear(nn.Module):
    def __init__(self, q, scale, zero, bias):
        super().__init__()
        self.q = nn.Parameter(q, requires_grad=False)
        self.scale = nn.Parameter(scale, requires_grad=False)
        self.zero = nn.Parameter(zero, requires_grad=False)
        
        if bias is not None:
            self.bias = nn.Parameter(bias.data.clone().detach())
        else:
            self.bias = None
    
    def forward(self, input):
        return F.linear(input, self.scale * self.q - self.scale * self.zero, self.bias)


  from .autonotebook import tqdm as notebook_tqdm


In [531]:
def gptq(x: torch.Tensor, bits: int, hessian: torch.Tensor, blocksize:int=128, percdamp:float=.01):
    dtype = x.dtype
    W = x.clone().detach()
    W = W.float()
    columns = W.shape[1]

    maxq = torch.tensor(2 ** bits - 1)
    scale, zero = get_scale_and_zero(W, maxq)

    H = hessian
    dead = torch.diag(H) == 0
    H[dead, dead] = 1
    W[:, dead] = 0

    # decrasing activation size
    perm = torch.argsort(torch.diag(H), descending=True)
    W = W[:, perm]
    H = H[perm, :][:, perm]
    invperm = torch.argsort(perm)

    Losses = torch.zeros_like(W)
    Q = torch.zeros(W.shape, dtype=torch.uint8, device=W.device)

    damp = percdamp * torch.mean(torch.diag(H))
    diag = torch.arange(columns, device=W.device)
    H[diag, diag] += damp
    H = torch.linalg.cholesky(H)
    H = torch.cholesky_inverse(H)
    H = torch.linalg.cholesky(H, upper=True)
    Hinv = H

    for i1 in range(0, columns, blocksize):
        i2 = min(i1 + blocksize, columns)
        count = i2 - i1

        W1 = W[:, i1:i2].clone()
        Q1 = torch.zeros(W1.shape, dtype=torch.uint8, device=W1.device)
        Err1 = torch.zeros_like(W1)
        Losses1 = torch.zeros_like(W1)
        Hinv1 = Hinv[i1:i2, i1:i2]

        for i in range(count):
            w = W1[:, i]
            d = Hinv1[i, i]

            q = torch.clamp(torch.round(w.unsqueeze(1) / scale) + zero, 0, maxq).flatten()
            Q1[:, i] = q.to(torch.uint8)
            q = scale.flatten() * q - scale.flatten() * zero.flatten()

            Losses1[:, i] = (w - q) ** 2 / d ** 2

            err1 = (w - q) / d
            W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
            Err1[:, i] = err1

        Q[:, i1:i2] = Q1
        Losses[:, i1:i2] = Losses1 / 2

        W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

    torch.cuda.synchronize()
    return Q[:, invperm], scale.to(dtype), zero.to(dtype)

In [13]:
@torch.no_grad()
def llama_no_compression(checkpoint_path, model):
    print('Starting ...')
    # Load all the layers
    layers = model.model.layers
    for i in trange(32):
        load_and_dispatch_a_layer(i, checkpoint_path, model)
        
        
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res


def replace_submodule(module, submodule_path, new_submodule):
    submodule_names = submodule_path.split(".")
    for submodule in submodule_names[:-1]:
        module = getattr(module, submodule)
    setattr(module, submodule_names[-1], new_submodule)


@torch.no_grad()
def llama_nearest(checkpoint_path, model, bits: int, device):
    print('Starting ...')
    
    # Load and quantize all the layers
    layers = model.model.layers
    for i in trange(32):
        load_and_dispatch_a_layer(i, checkpoint_path, model)

        layer = layers[i]
        linear_submodules = find_layers(layer)

        sequential_groups = [
            ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
            ['self_attn.o_proj'],
            ['mlp.up_proj', 'mlp.gate_proj'],
            ['mlp.down_proj']
        ]
       
        for group_names in sequential_groups:
            current_linears_to_quantize = {n: linear_submodules[n] for n in group_names}
            for name, linear in current_linears_to_quantize.items():                
                q, scale, zero = custom_quantize(linear.weight.data.to(device), bits=bits)
                replace_submodule(layer, name, QuantizedLinear(q.cpu(), scale, zero, linear.bias))

        layers[i] = layer
        del layer
        torch.cuda.empty_cache()


In [14]:
def replace_submodule(module, submodule_path, new_submodule):
    submodule_names = submodule_path.split(".")
    for submodule in submodule_names[:-1]:
        module = getattr(module, submodule)
    setattr(module, submodule_names[-1], new_submodule)

        
@torch.no_grad()
def llama_gptq(checkpoint_path, model, dataloader, bits, device, n_samples=128):
    print('Starting ...')
    load_and_dispatch_a_layer(0, checkpoint_path, model)

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype
    )
    cache = {'i': 0, 'attention_mask': None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0])
        except ValueError:
            pass
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask'].to(device)
    position_ids = cache['position_ids'].to(device)

    print('Ready.')

    quantizers = {}
    for i in trange(32):
        if i != 0:
            load_and_dispatch_a_layer(i, checkpoint_path, model)
        layer = layers[i].to(device)
        linear_layers = find_layers(layer)
        
        sequential_groups = [
            ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
            ['self_attn.o_proj'],
            ['mlp.up_proj', 'mlp.gate_proj'],
            ['mlp.down_proj']
        ]

       
        for names in sequential_groups:
            subset = {name: linear_layers[name] for name in names}

            hessians = {name: None for name in subset}
            num_samples = {name: 0 for name in subset}
            def accumulate_input(name):
                def tmp(_, inp, out):
                    inp = inp[0].data # ... x hidden_size
                    inp = inp.reshape((-1, inp.shape[-1])) # inputs x hidden_size
                    inp = inp.t().float() # hidden_size x inputs
                    num_samples[name] += 1
                    if hessians[name] is None:
                        hessians[name] = inp.matmul(inp.t())
                    else:
                        hessians[name] += inp.matmul(inp.t())
                return tmp
            handles = []
            for name in subset:
                handles.append(subset[name].register_forward_hook(accumulate_input(name)))
            for j in range(n_samples):
                outs[j] = layer(inps[j].unsqueeze(0).to(device), attention_mask=attention_mask, position_ids=position_ids)[0].cpu()
            for h in handles:
                h.remove()

            for name in subset:
                x = subset[name].weight.data
                bias = subset[name].bias
                q, scale, zero = gptq(x, bits, 2 * hessians[name] / num_samples[name])
                replace_submodule(layer, name, QuantizedLinear(q, scale, zero, bias))

        for j in range(n_samples):
            outs[j] = layer(inps[j].unsqueeze(0).to(device), attention_mask=attention_mask, position_ids=position_ids)[0].cpu()

        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache
    
    return quantizers

In [16]:
class Catcher(nn.Module):
    def __init__(self, inps_dest):
        super().__init__()
        self.i = 0
        self.attention_mask = None
        self.position_ids = None
        self.inps_dest = inps_dest

    def forward(self, inp, **kwargs):
        self.inps_dest[self.i] = inp
        self.attention_mask = kwargs['attention_mask']
        self.position_ids = kwargs['position_ids']
        self.i += 1
        raise ValueError
    
    def get_the_catch(self):
        return self.attention_mask, self.position_ids

@torch.no_grad()
def llama_eval(model, testenc, device):
    print('Evaluating ...')
    
    input_ids = testenc.input_ids
    input_ids = input_ids[:, :(input_ids.shape[1] // model.seqlen) *  model.seqlen]
    input_ids = input_ids.reshape(input_ids.shape[1] // model.seqlen, model.seqlen)

    use_cache = model.config.use_cache
    model.config.use_cache = False
    
    total_nll = 0
    for batch in torch.tensor_split(input_ids, 4):
        n_samples = batch.shape[0]
        dtype = next(iter(model.parameters())).dtype
        inps = torch.zeros(
            (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype
        ).to(device)
        outs = torch.zeros_like(inps)

        # Collect the first layer inputs
        catcher = Catcher(inps)
        original_layers = model.model.layers
        model.model.layers = nn.ModuleList((catcher,))
        for sample in batch:
            try:
                model(sample.unsqueeze(0))
            except ValueError:
                pass
        attention_mask, position_ids = catcher.get_the_catch()
        model.model.layers = original_layers

        # Forward pass through the layers
        layers = model.model.layers
        attention_mask = attention_mask.to(device)
        position_ids = position_ids.to(device)
        for i in trange(len(layers)):
            layer = layers[i].to(device)

            for j in range(n_samples):
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]

            layers[i] = layer.cpu()
            del layer
            torch.cuda.empty_cache()

            inps, outs = outs, inps

        # Calculate PPL
        testenc = testenc.to(device)
        nlls = []
        for i in range(n_samples):
            hidden_states = inps[i].unsqueeze(0)
            if model.model.norm is not None:
                model.model.norm = model.model.norm.to(device)
                hidden_states = model.model.norm(hidden_states)
                model.model.norm = model.model.norm.cpu()

            model.lm_head = model.lm_head.to(device)
            lm_logits = model.lm_head(hidden_states)
            model.lm_head = model.lm_head.cpu()

            shift_logits = lm_logits[:, :-1, :]
            shift_labels = batch[i, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(device))
            total_nll += float(loss) * model.seqlen
    
    ppl = math.exp(total_nll / input_ids.numel())
    print(ppl)
    model.config.use_cache = use_cache


In [17]:
DEVICE = "cuda:0"
MODEL = "../model/"
SEED = 0
BITS = 4

In [18]:
model = initialize_layerless_llama(MODEL)
model.eval()

Downloading (…)lve/main/config.json: 100%|██████████| 472/472 [00:00<00:00, 977kB/s]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList()
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

In [19]:
dataloader, testloader = get_wikitext2(MODEL, SEED, model.seqlen)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [None]:
# llama_gptq(MODEL, model, dataloader, BITS, DEVICE)
# llama_nearest(MODEL, model, BITS, DEVICE)
# llama_no_compression(MODEL, model)

In [None]:
# llama_eval(model, testloader, DEVICE)

FP18: 5.67

GPTQx4: 5.94

NEARESTx4: 6.29

## BONUS: QUIK

In [1272]:
def quik(x: torch.Tensor, bits: int, hessian: torch.Tensor, blocksize:int=128, percdamp:float=.01, n_outliers=128):
    dtype = x.dtype
    W = x.clone().detach()
    W = W.float()

    H = hessian
    dead = torch.diag(H) == 0
    H[dead, dead] = 1
    W[:, dead] = 0
    # decrasing activation size
    perm = torch.argsort(torch.diag(H), descending=True)
    W = W[:, perm]
    H = H[perm, :][:, perm]

    # Process outliers
    outlier_weight = W[:,:n_outliers]
    W = W[:,n_outliers:]
    columns = W.shape[1]
    H = H[n_outliers:,:][:,n_outliers:]

    maxq = torch.tensor(2 ** bits - 1)
    scale, zero = get_scale_and_zero(W, maxq)

    Losses = torch.zeros_like(W)
    Q = torch.zeros(W.shape, dtype=torch.uint8, device=W.device)

    damp = percdamp * torch.mean(torch.diag(H))
    diag = torch.arange(columns, device=W.device)
    H[diag, diag] += damp
    H = torch.linalg.cholesky(H)
    H = torch.cholesky_inverse(H)
    H = torch.linalg.cholesky(H, upper=True)
    Hinv = H

    for i1 in range(0, columns, blocksize):
        i2 = min(i1 + blocksize, columns)
        count = i2 - i1

        W1 = W[:, i1:i2].clone()
        Q1 = torch.zeros(W1.shape, dtype=torch.uint8, device=W1.device)
        Err1 = torch.zeros_like(W1)
        Losses1 = torch.zeros_like(W1)
        Hinv1 = Hinv[i1:i2, i1:i2]

        for i in range(count):
            w = W1[:, i]
            d = Hinv1[i, i]

            q = torch.clamp(torch.round(w.unsqueeze(1) / scale) + zero, 0, maxq).flatten()
            Q1[:, i] = q.to(torch.uint8)
            q = scale.flatten() * q - scale.flatten() * zero.flatten()

            Losses1[:, i] = (w - q) ** 2 / d ** 2

            err1 = (w - q) / d
            W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
            Err1[:, i] = err1

        Q[:, i1:i2] = Q1
        Losses[:, i1:i2] = Losses1 / 2

        W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
    return Q, scale.to(dtype), zero.to(dtype), outlier_weight.to(dtype), perm


class QuikLinear(nn.Module):
    def __init__(self, quantized_weight, weight_scale, weight_zero, outlier_weight, bias, bits: int, perm):
        super().__init__()
        self.bits = bits
        self.perm = perm
        self.n_outliers = outlier_weight.shape[1]

        self.quantized_weight = nn.Parameter(quantized_weight, requires_grad=False)
        self.weight_scale = nn.Parameter(weight_scale, requires_grad=False)
        self.weight_zero = nn.Parameter(weight_zero, requires_grad=False)

        self.outlier_weight = nn.Parameter(outlier_weight, requires_grad=False)
        self.weights_reduced = (self.quantized_weight * self.weight_scale).sum(axis=1)

        if bias is not None:
            self.bias = nn.Parameter(bias.data.clone().detach())
        else:
            self.bias = None

    def forward(self, input):
        input = input[...,self.perm]
        input_quantized, input_scale, input_zero = custom_quantize(input[...,self.n_outliers:], self.bits)
        inputs_reduced = (input_quantized * input_scale).sum(axis=-1)

        quantized_result = F.linear(input_quantized.to(input.dtype), self.quantized_weight.to(input.dtype))
        quantized_result = quantized_result * self.weight_scale.T * input_scale

        quantized_result += (input_zero * input_scale) @ (self.weight_zero * self.weight_scale).T
        quantized_result -= (input_zero * input_scale) @ self.weights_reduced.unsqueeze(0)
        quantized_result -= inputs_reduced.unsqueeze(-1) @ (self.weight_zero * self.weight_scale).T

        outliers_result = F.linear(input[...,:self.n_outliers], self.outlier_weight, self.bias)

        return quantized_result + outliers_result


@torch.no_grad()
def llama_quik(checkpoint_path, model, dataloader, bits, device, n_samples=128):
    print('Starting ...')
    load_and_dispatch_a_layer(0, checkpoint_path, model)

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype, device=device
    )
    cache = {'i': 0, 'attention_mask': None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0])
        except ValueError:
            pass
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask'].to(device)
    position_ids = cache['position_ids'].to(device)

    print('Ready.')

    quantizers = {}
    for i in trange(32):
        if i != 0:
            load_and_dispatch_a_layer(i, checkpoint_path, model)
        layer = layers[i].to(device)
        linear_layers = find_layers(layer)

        sequential_groups = [
            ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
            ['self_attn.o_proj'],
            ['mlp.up_proj', 'mlp.gate_proj'],
            ['mlp.down_proj']
        ]


        for names in sequential_groups:
            subset = {name: linear_layers[name] for name in names}

            hessians = {name: None for name in subset}
            num_samples = {name: 0 for name in subset}
            def accumulate_input(name):
                def tmp(_, inp, out):
                    inp = inp[0].data # ... x hidden_size
                    inp = inp.reshape((-1, inp.shape[-1])) # inputs x hidden_size
                    inp = inp.t().float() # hidden_size x inputs
                    num_samples[name] += 1
                    if hessians[name] is None:
                        hessians[name] = inp.matmul(inp.t())
                    else:
                        hessians[name] += inp.matmul(inp.t())
                return tmp
            handles = []
            for name in subset:
                handles.append(subset[name].register_forward_hook(accumulate_input(name)))
            for j in range(n_samples):
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
            for h in handles:
                h.remove()

            for name in subset:
                x = subset[name].weight.data
                bias = subset[name].bias
                q, scale, zero, outlier_weight, perm = quik(x, bits, 2 * hessians[name] / num_samples[name])
                replace_submodule(layer, name, QuikLinear(q, scale, zero, outlier_weight, bias, bits, perm))

        for j in range(n_samples):
            outs[j] = layer(inps[j].unsqueeze(0).to(device), attention_mask=attention_mask, position_ids=position_ids)[0].cpu()

        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache


In [None]:
llama_quik(MODEL, model, dataloader, BITS, DEVICE)
llama_eval(model, testloader, DEVICE)

In [1273]:
N_OUTLIERS = 128

layer = nn.Linear(256, 512)
# layer.weight.data = torch.rand(size=(512, 256))
inputs = [torch.rand((1, 2, 256)) for _ in range(128)]
unquantized = layer.weight.data.clone()

In [1274]:
def get_accumulate_input_fn(name, hessians, num_samples):
    def tmp(_, inp, out):
        inp = inp[0].data # ... x hidden_size
        inp = inp.reshape((-1, inp.shape[-1])) # inputs x hidden_size
        inp = inp.t().float() # hidden_size x inputs
        num_samples[name] += 1
        if hessians[name] is None:
            hessians[name] = inp.matmul(inp.t())
        else:
            hessians[name] += inp.matmul(inp.t())
    return tmp

hessians = {"": None}
num_samples = {"": 0}
fn = get_accumulate_input_fn("", hessians, num_samples)
for input in inputs:
    fn(None, input, None)
    

In [1275]:
quantized_weight, scale, zero, outlier_weight, perm = quik(layer.weight.data, 4, 2 * hessians[""] / num_samples[""], n_outliers=N_OUTLIERS)
shit = QuikLinear(quantized_weight, scale, zero, outlier_weight, layer.bias.data.clone(), 8, perm)
torch.pow(sum(shit(input) for input in inputs) - sum(layer(input) for input in inputs), 2).mean() ** (1/2)

tensor(0.1627, grad_fn=<PowBackward0>)

In [1276]:
gptq_weight, gptq_scale, gptq_zero = gptq(layer.weight.data, 4, 2 * hessians[""] / num_samples[""])
piss = QuantizedLinear(gptq_weight, gptq_scale, gptq_zero, layer.bias.data.clone())
torch.pow(sum(piss(input) for input in inputs) - sum(layer(input) for input in inputs), 2).mean() **(1/2)

tensor(0.1708, grad_fn=<PowBackward0>)

In [1277]:
stupid_weight, stupid_scale, stupid_zero = custom_quantize(layer.weight.data, 4)
torch.pow(sum(F.linear(input, stupid_weight * stupid_scale - stupid_zero * stupid_scale) for input in inputs) - sum(layer(input) for input in inputs), 2).mean() ** (1/2)

tensor(5.2337, grad_fn=<PowBackward0>)