# Installing the Dependencies

In [None]:
%%capture
!pip install transformers
!pip install sentencepiece
!pip install datasets

# Quantizing Matrices Row-Wise

In [416]:
import os
import math
import time
import random
from tqdm import tqdm, trange

import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F

import transformers
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers import AutoTokenizer
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


### 1: Basic Quantization

**Mapping the values to the allowed range**

Quantization is the process of mapping input values from a large set to output values in a smaller set. For instance, if we consider 4-bit
quantization, our values are represented by $4$ bits, meaning we can represent values between 0 and $2^4-1=15$.

 * To produce the quantized representation, we need to be able to map the matrix values to and from this range.
 * For reasons that become important later, we will perform this mapping independently for each matrix row.
 * We will parametrize the mapping like this: $out = \frac{in}{scale} + zero$, where $scale$ and $zero$ are row-wise constants.
 * For a matrix of size `(m, k)` ($m$ rows, $k$ columns) we will aggregate those parameters into two vectors `scale` and `zero` of size `(m, 1)`.

**Task 1.1:** Complete the function below to perform this mapping:

In [248]:
def get_scale_and_zero(x: Tensor, max_abs: float) -> tuple[Tensor, Tensor]:
    """ Given a tensor x of shape (m, k) and max_abs > 0 produce tensors scale and zero of shape (m, 1) 
        such that 0 < x / scale + zero < max_abs"""
    xmin = x.min(-1)[0]
    xmax = x.max(-1)[0]

    xmin[xmin == xmax] = -1
    xmax[xmin == xmax] = +1

    scale = (xmax - xmin) / max_abs
    zero = -xmin / scale

    scale = scale.unsqueeze(-1)
    zero = zero.unsqueeze(-1)
    return scale.to(x.dtype), zero.to(x.dtype)


In [249]:
# Testing your code

x = torch.arange(512 * 1024).reshape(512, 1024).float()
scale, zero = get_scale_and_zero(x, 15)

assert scale.shape == (512, 1), "scale is wrong shape"
assert zero.shape == (512, 1), "zero is wrong shape"
assert torch.all(scale * 15 <= 1023.1), "Scale can't be that large. The resulting interval is too wide"
assert torch.all(scale * 15 >= 1022.9), "Scale shouldn't be that small. The resulting interval is too narrow"
assert torch.all(-0.001 <  x / scale + zero) and torch.all(x / scale + zero < 15 + 0.001)
print("All tests passed!")

All tests passed!


**Quantization**

Having mapped the values into the allowed range, we can simply round them to obtain the quantized matrix. Complete the functions below to perform row-wise quantization. Note that:
 * `measure_and_quantize` takes the matrix `x` and `bits` - the number of bits to quantize to. Calculate the allowed quantized values range yourself (hint: look a few cells above).
 * Use `get_scale_and_zero` to obtain the layer-wise quantization constants, then use them to map the matrix values to the required range.
 * `torch.clamp(...)` the quantized values to ensure that they are in the range.
 * The function returns the quantized matrix, as well as the quantization constants, because we'll need them to dequantiza the matrix.

**Task 1.2:** Complete the function below to perform quantization:

In [256]:
def quantize(x: Tensor, scale: Tensor, zero: Tensor, max_abs: float) -> Tensor:
    """Given a tensor x quantize it, producing tensors quantized_x"""
    quantized_x = torch.round(x / scale + zero) 
    quantized_x = torch.clamp(quantized_x, 0, max_abs)
    return quantized_x.to(torch.uint8)


def dequantize(x: Tensor, scale: Tensor, zero: Tensor) -> Tensor:
    return scale * x - scale * zero


def measure_and_quantize(x: Tensor, bits: float) -> tuple[Tensor, Tensor, Tensor]:
    max_abs = 2 ** bits - 1
    scale, zero = get_scale_and_zero(x, max_abs)
    x_quantized = quantize(x, scale, zero, max_abs)
    return x_quantized.to(torch.uint8), scale, zero
    

# We cast the quantized matrix to uint8, but the values themselves must be in the uint<bits> range
# This is because torch lacks support for lower bit integers
# Obviously, we require bits <= 8

In [257]:
# Testing your code

x = torch.arange(512 * 1024).reshape(512, 1024).float()
scale, zero = get_scale_and_zero(x, 15)
quantized_x, scale, zero = measure_and_quantize(x, 4)

assert quantized_x.shape == x.shape, "Shape of quantized_x is incorrect"
assert scale.shape == (512, 1), "Shape of scale is incorrect"
assert zero.shape == (512, 1), "Shape of zero is incorrect"
assert torch.all(quantized_x >= 0) and torch.all(quantized_x <= 15) and torch.any(quantized_x == 15), "wrong quantized_x values range"
assert torch.allclose(x, dequantize(quantized_x, scale, zero), atol=50), "Dequantized values are too far from the original values"
print("All tests passed!")

All tests passed!


**Using the quantized matrix**

To actually use the matrix, we'll have to map it's values back into their original form. 

In [258]:
class QuantizedLinear(nn.Module):
    def __init__(self, quantized_weight, scale, zero, bias):
        super().__init__()
        self.quantized_weight = nn.Parameter(quantized_weight, requires_grad=False)
        self.scale = nn.Parameter(scale, requires_grad=False)
        self.zero = nn.Parameter(zero, requires_grad=False)
        self.bias = nn.Parameter(bias.data.clone().detach()) if bias is not None else None

    def forward(self, input):
        return F.linear(input, dequantize(self.quantized_weight, self.scale, self.zero), self.bias)
    

This class will be used as a replacement for `nn.Linear`. It holds the quantized weight and only dequantizes it during it's forward passes.

### Task 2: GPTQ

GPTQ is the State Of The Art quantization algorithm for post-trainig DL model quantization. It works by sequentially quantizing the model's linear layer weights.

Although in outputs results similar to what one would get with Round To Nearest quantization, it makes a key observations to boost it's end performance:
 * It is layer input aware, meaning int optimizes the quantized matrix to show best perfromance on inputs typically encountered in that layer.
More formally, the problem can be formulated as:
$$
W_q = argmin_{\widehat{W}}\|XW^T - X\widehat{W}^T\|_2^2
$$
, where
 * $X$ is the input matrix of shape `(..., IN)`.
 * $XW^T$ is the unquantized output of shape `(..., OUT)`. We think of the norm above as taking a sum over those (...) dimensions.
 * $W$ is the unquantized weight of shape `(OUT, IN)`.
 * $\widehat{W}$ is the quantized weight taken from some quantization grid.

One can notice that the expression above is independent with regard to the rows of $W$ and $\widehat{W}$, meaning we can solve it for each row in parallel. This is the reason why we're working with row-wise quantization in the first place. Notice that the quantization grid only depends on min/max values withing the row and not the quantization process, so we can think of it as fixed.

and the dimension of the optimization problem is `IN`, which is too much to solve exactly. The algorithm proposes to solve it iteratively.

Less us consider a vector of full precision weights $F$ and corresponding sent of inputs $X_F$. The corresponding objective is quadratic with Hessian
$$
H_F = 2X_FX_F^T.
$$
The algorithm can be described like this:
 * Do the following steps until $F$ is fully quantized:
    1. Sample one element from $F$ randomly. Denote it by $F_i$.
    2. Quantize the coordinate by prjecting in onto the quantization grid $Q_i = quant(F_i)$.
    3. Update all of the remaining weights $F_: = F_: - \frac{F_i - quant(F_i)}{\left[H_F^{-1}\right]_{ii}}\cdot\left[H_F^{-1}\right]_{i,:}$.
    4. Exclude $i$ from $F$.

It uses the inverse Hessian to slightly tune the remaining unquantized weights to mitigate the quantization error.

There are a few more ideas that make this algorithm much faster:
 1. We can represent the random order of quantization (sampling of $i$) by permuting the row in advance, and then iterating over the row element in order.
 $$
   F_{i:} = F_{i:} - \frac{F_{i} - quant(F_{i})}{\left[H_F^{-1}\right]_{ii}}\cdot\left[H_F^{-1}\right]_{i,i:}
 $$
 2. The problem is row-wise independent, meaning that we can the same permutation each row and perform those operations in a vector fashion for all the rows at the same time.
 $$
   F_{:,i:} = F_{:,i:} - \frac{F_{:,i} - quant(F_{:,i})}{\left[H_F^{-1}\right]_{ii}}\odot\left[H_F^{-1}\right]_{i,i:}\leftarrow\text{\textbf{ you'll have to code this}}
 $$ 
 
 3. We dont' actually need to recompute the inverse Hessian. At $i$-th step we only need its $t$-th row, and we can use fancy math to precompute the matrix containing all of those rows in advance.
 $$
  H^{-1} = Cholesky(H^{-1})^T    
 $$

 4. We don't need to tune all the remaining unquantized values right away. We can only apply the updates for the closest elements right away and accumulate all the other updates to apply them only once in a while. 
 
    We'll do this in block of fixed size, applying the updates inside of those blocks and updating the weights outside only when we're done with the block. To accumulate those updates, we'll collect the scaled quantization error
    $$
      Err_{:,i} =\frac{F_{:,i} - quant(F_{:,i})}{\left[H_F^{-1}\right]_{ii}}\text{ for all }i\text{ in block}.
    $$

**GPTQ within blocks**

Implement GPTQ within the block. Iterate over the columns in ordered vector fashion, quantizing them one by one and updating all the remaining colums within the block.

Return the quantized weight as well as the matrix of quantization errors that we'll need to tune the unquantized weights outside of the block.

**Task 2.1:** Implement GPTQ block

In [398]:
def gptq_block(block_weight: Tensor, block_hessian_inverse: Tensor, scale: Tensor, zero: Tensor, max_abs: float) -> tuple[Tensor, Tensor]:
    """NOTE: This function is allowed to alter the block_weight as we won't need those weights anymore

    Args:
        block_weight (Tensor): weight to quantize of shape (OUT, BLOCK_SIZE)
        block_hessian_inverse (Tensor): Cholesky inverse Hessian. Upper triangular of shape (BLOCK_SIZE, BLOCK_SIZE)
        scale (Tensor): row-wise quantization constats of shape (OUT, 1)
        zero (Tensor): row-wise quantization constats of shape (OUT, 1)
        max_abs (float): quantized values must lie in [0, max_abs]

    Returns:
        tuple[Tensor, Tensor]: quantized weight and scaled quantization error
    """
    quantized_block_weight = torch.zeros(block_weight.shape, dtype=torch.uint8, device=block_weight.device)
    scaled_block_error = torch.zeros_like(block_weight)
    
    # Interate over the block's columns
    for i in range(block_weight.shape[1]):
        # Get the column and the corresponding inverse Hessian
        column_weight = block_weight[:, [i]]
        column_hessian_inverse = block_hessian_inverse[i, i]

        # Quantize the column weight
        quantized_column_weight = quantize(column_weight, scale, zero, max_abs)
        quantized_block_weight[:, [i]] = quantized_column_weight
        dequantized_column_weight = dequantize(quantized_column_weight, scale, zero)

        # Update all the following columns within the block
        scaled_column_error = (column_weight - dequantized_column_weight) /  column_hessian_inverse
        block_weight[:, i:] -= scaled_column_error.matmul(block_hessian_inverse[[i], i:])
        scaled_block_error[:, [i]] = scaled_column_error
    
    return quantized_block_weight, scaled_block_error


In [399]:
# TODO: how the fuck do I test it?

Now we can implement the full algorithm: 
 * Get row-wise quantization constants.
 * Randomly permute the weight columns. Think about how you'd have to shuffle the Hessian as well.
 * Process the Hessian to obtain the precomputed inverse Hessian.
 * Iterate over the columns in blocks:
    * Get the next block and quantize it.
    * Tune all the following blocks to mitigate the quantization error.
      $$
         F_{:,block\_end:} = F_{:,block\_end:} - Err_{:,block\_start:block\_end}\odot\left[H_F^{-1}\right]_{block\_start:block\_end,block\_end:}
      $$

In [400]:
def gptq(weight: torch.Tensor, bits: int, hessian: torch.Tensor, blocksize:int=128, percdamp:float=.01):
    dtype = weight.dtype
    weight = weight.clone().detach()
    weight = weight.float()
    num_columns = weight.shape[1]
    
    # Identify and patch always-zero input coordinates
    dead = torch.diag(hessian) == 0
    hessian[dead, dead] = 1
    weight[:, dead] = 0

    # Get row-wise quantization constants
    maxq = torch.tensor(2 ** bits - 1)
    scale, zero = get_scale_and_zero(weight, maxq)


    # Randomly permute the weight columns
    perm = torch.randperm(hessian.shape[0])
    weight = weight[:, perm]
    hessian = hessian[perm, :][:, perm]
    invperm = torch.argsort(perm)

    # Process the Hessian to obtain the precomputed inverse Hessian
    damp = percdamp * torch.mean(torch.diag(hessian))
    diag = torch.arange(num_columns, device=weight.device)
    hessian[diag, diag] += damp
    hessian = torch.linalg.cholesky(hessian)
    hessian = torch.cholesky_inverse(hessian)
    hessian = torch.linalg.cholesky(hessian, upper=True)
    hessian_inverse = hessian

    # Iterate over the columns in blocks
    quantized_weight = torch.zeros(weight.shape, dtype=torch.uint8, device=weight.device)
    for block_start in range(0, num_columns, blocksize):
        block_end = min(block_start + blocksize, num_columns)

        # Get the next block and quantize it
        quantized_block_weight, block_error = gptq_block(
            weight[:, block_start:block_end].clone(),
            hessian_inverse[block_start:block_end, block_start:block_end],
            scale,
            zero,
            maxq,
        )

        # Tune all the following blocks to mitigate the quantization error
        quantized_weight[:, block_start:block_end] = quantized_block_weight
        weight[:, block_end:] -= block_error.matmul(hessian_inverse[block_start:block_end, block_end:])

    return quantized_weight[:, invperm], scale.to(dtype), zero.to(dtype)

In [401]:
# TODO: how the fuck do I test it?

# LLM Quantization

## Preparations

Run all the cells in this subsection to download and prepare the model and the data

### Download and convert the model

Run the code below to download the model checkpoint and repack so that we could load the layer one by one.
 * Each layer $i \in [0, 31]$ is saved in a separate file `"./model/layer_{i}.bin"`
 * Everything outside of those layers (embeddings, lm_head, etc.) is saved in `"./model/non_layers.bin"`

In [None]:
!mkdir model

In [None]:
from huggingface_hub import snapshot_download

snapshot_download(repo_id="yahma/llama-7b-hf", local_dir="./model")

def repack_llama(path):
    non_layers = {}
    first_state_dict = torch.load(os.path.join(path, "pytorch_model-00001-of-00002.bin"))
    non_layers["model.embed_tokens.weight"] = first_state_dict["model.embed_tokens.weight"]
    for i in trange(24):
        layer = {key: value for key, value in first_state_dict.items() if f"layers.{i}." in key}
        torch.save(layer, os.path.join(path, f"layer_{i}.bin"))
        for key in layer:
            del first_state_dict[key]
    del first_state_dict

    second_state_dict = torch.load(os.path.join(path, "pytorch_model-00002-of-00002.bin"))
    non_layers["lm_head.weight"] = second_state_dict["lm_head.weight"]
    non_layers["model.norm.weight"] = second_state_dict["model.norm.weight"]
    for i in trange(24, 32):
        layer = {key: value for key, value in second_state_dict.items() if f"layers.{i}." in key}
        torch.save(layer, os.path.join(path, f"layer_{i}.bin"))
        for key in layer:
            del second_state_dict[key]
    del second_state_dict

    torch.save(non_layers, os.path.join(path, f"non_layers.bin"))

repack_llama("./model")

### Dispatching the model

To properly quantize the model we'll need two functions.
 1. `initialize_layerless_llama` creates a llama model without any layers, but correct weights otherwise
 2. `load_and_dispatch_a_layer` loads a layer insterts it into the model after the last layer

In [435]:
def skip(*args, **kwargs):
    pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip


def initialize_layerless_llama(checkpoint_path):        
    config = LlamaConfig.from_pretrained("yahma/llama-7b-hf")
    config.num_hidden_layers=0

    model = LlamaForCausalLM(config)
    model.load_state_dict(torch.load(os.path.join(checkpoint_path, "non_layers.bin")))
    model.seqlen = 2048

    return model.to(torch.float16)


def load_and_dispatch_a_layer(layer_idx, checkpoint_path, model: LlamaForCausalLM):
    if checkpoint_path == "TEST":
        linear = nn.Linear(16, 16)
        linear.weight.data = torch.arange(16 * 16).reshape(16, 16).float()
        model.model.layers.append(nn.ModuleDict({"submodule": linear}))
        return
        
    config = transformers.AutoConfig.from_pretrained("yahma/llama-7b-hf")

    layer = LlamaDecoderLayer(config)
    layer_state_dict = torch.load(os.path.join(checkpoint_path, f"layer_{layer_idx}.bin"))
    layer_state_dict = {name[len(f"model.layers.{layer_idx}."):]: tensor for name, tensor in layer_state_dict.items()}
    layer.load_state_dict(layer_state_dict, strict=False)
    del layer_state_dict

    model.model.layers.append(layer.to(torch.float16))


Calling `initialize_layerless_llama` and then calling `load_and_dispatch_a_layer` for each layer in order would fully load the model, but we'll also quantize the layes as we go.

### Getting the data

In [436]:
def get_wikitext2(model, seed, seqlen, nsamples=128):
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
    trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc

## RTN Quantization for LLaMA

**Auxiliary functions:**
 * `find_layers` takes a module and returns a dictionary containing all of it's *Linear* submodules with their path-names as the keys.
 * `replace_submodule` takes a module, a path-name and a submodule and replaces the module's submodule at path-name with the new submodule.

In [437]:
def find_layers(module: nn.Module, name='') -> dict[str, nn.Module]:
    if type(module) == nn.Linear:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, name=name + '.' + name1 if name != '' else name1
        ))
    return res


def replace_submodule(module, submodule_path, new_submodule):
    submodule_names = submodule_path.split(".")
    for submodule in submodule_names[:-1]:
        module = getattr(module, submodule)
    setattr(module, submodule_names[-1], new_submodule)

**Not quantizing the model**

First, take a look at the function below. It uses the functions above to load the layers one by one and iterate over their `Linear` submodules replacing them. You'll need to quantize those submodules and create the `QuantizedLinear` ones to replace the original ones with.

**Task:** implement RTN quantization for LLaMA

In [454]:
@torch.no_grad()
def llama_nearest(checkpoint_path, model, bits: int, device):
    print('Starting ...')
    # Load and quantize all the layers
    layers = model.model.layers
    for i in trange(32):
        load_and_dispatch_a_layer(i, checkpoint_path, model)
        layer = layers[i]
        
        linear_submodules = find_layers(layer)
        for name, linear in linear_submodules.items():
            q, scale, zero = measure_and_quantize(linear.weight.data.to(device), bits=bits)
            quantized_linear = QuantizedLinear(q.cpu(), scale, zero, linear.bias)
            
            replace_submodule(layer, name, quantized_linear)

        layers[i] = layer
        torch.cuda.empty_cache()


In [455]:
# Testing your code

model = nn.ModuleDict({"model": nn.ModuleDict({"layers": nn.ModuleList([])})})
llama_nearest("TEST", model, 4, "cpu")

assert len(model.model.layers) == 32, "You didn't load all the layers"
assert all(isinstance(layer.submodule, QuantizedLinear) for layer in model.model.layers), "Some Linears weren't properly replaced"
assert torch.all(model.model.layers[0].submodule.quantized_weight == torch.arange(16).unsqueeze(0).repeat(16, 1)), "Quantized weights are weird"
assert torch.all(model.model.layers[0].submodule.scale == 1), "Quantized scales are weird"
assert torch.all(model.model.layers[0].submodule.scale == 1), "Quantized scales are weird"

Starting ...


100%|██████████| 32/32 [00:00<00:00, 1848.35it/s]


In [None]:
@torch.no_grad()
def llama_gptq(checkpoint_path, model, dataloader, bits, device, n_samples=128):
    print('Starting ...')
    load_and_dispatch_a_layer(0, checkpoint_path, model)

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype, device=device
    )
    cache = {'i': 0, 'attention_mask': None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0])
        except ValueError:
            pass
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask'].to(device)
    position_ids = cache['position_ids'].to(device)

    print('Ready.')

    quantizers = {}
    for i in trange(32):
        if i != 0:
            load_and_dispatch_a_layer(i, checkpoint_path, model)
        layer = layers[i].to(device)
        linear_layers = find_layers(layer)

        sequential_groups = [
            ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
            ['self_attn.o_proj'],
            ['mlp.up_proj', 'mlp.gate_proj'],
            ['mlp.down_proj']
        ]


        for names in sequential_groups:
            subset = {name: linear_layers[name] for name in names}

            hessians = {name: None for name in subset}
            num_samples = {name: 0 for name in subset}
            def accumulate_input(name):
                def tmp(_, inp, out):
                    inp = inp[0].data # ... x hidden_size
                    inp = inp.reshape((-1, inp.shape[-1])) # inputs x hidden_size
                    inp = inp.t().float() # hidden_size x inputs
                    num_samples[name] += 1
                    if hessians[name] is None:
                        hessians[name] = inp.matmul(inp.t())
                    else:
                        hessians[name] += inp.matmul(inp.t())
                return tmp
            handles = []
            for name in subset:
                handles.append(subset[name].register_forward_hook(accumulate_input(name)))
            for j in range(n_samples):
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
            for h in handles:
                h.remove()

            for name in subset:
                x = subset[name].weight.data
                bias = subset[name].bias
                q, scale, zero = gptq(x, bits, 2 * hessians[name] / num_samples[name])
                replace_submodule(layer, name, QuantizedLinear(q, scale, zero, bias))

        for j in range(n_samples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]

        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache


### Evaluating

In [None]:
class Catcher(nn.Module):
    def __init__(self, inps_dest):
        super().__init__()
        self.i = 0
        self.attention_mask = None
        self.position_ids = None
        self.inps_dest = inps_dest

    def forward(self, inp, **kwargs):
        self.inps_dest[self.i] = inp
        self.attention_mask = kwargs['attention_mask']
        self.position_ids = kwargs['position_ids']
        self.i += 1
        raise ValueError

    def get_the_catch(self):
        return self.attention_mask, self.position_ids

@torch.no_grad()
def llama_eval(model, testenc, device):
    print('Evaluating ...')

    input_ids = testenc.input_ids
    input_ids = input_ids[:, :(input_ids.shape[1] // model.seqlen) *  model.seqlen]
    input_ids = input_ids.reshape(input_ids.shape[1] // model.seqlen, model.seqlen)

    use_cache = model.config.use_cache
    model.config.use_cache = False

    total_nll = 0
    for batch in torch.tensor_split(input_ids, 4):
        n_samples = batch.shape[0]
        dtype = next(iter(model.parameters())).dtype
        inps = torch.zeros(
            (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype
        ).to(device)
        outs = torch.zeros_like(inps)

        # Collect the first layer inputs
        catcher = Catcher(inps)
        original_layers = model.model.layers
        model.model.layers = nn.ModuleList((catcher,))
        for sample in batch:
            try:
                model(sample.unsqueeze(0))
            except ValueError:
                pass
        attention_mask, position_ids = catcher.get_the_catch()
        model.model.layers = original_layers

        # Forward pass through the layers
        layers = model.model.layers
        attention_mask = attention_mask.to(device)
        position_ids = position_ids.to(device)
        for i in trange(len(layers)):
            layer = layers[i].to(device)

            for j in range(n_samples):
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]

            layers[i] = layer.cpu()
            del layer
            torch.cuda.empty_cache()

            inps, outs = outs, inps

        # Calculate PPL
        testenc = testenc.to(device)
        nlls = []
        for i in range(n_samples):
            hidden_states = inps[i].unsqueeze(0)
            if model.model.norm is not None:
                model.model.norm = model.model.norm.to(device)
                hidden_states = model.model.norm(hidden_states)
                model.model.norm = model.model.norm.cpu()

            model.lm_head = model.lm_head.to(device)
            lm_logits = model.lm_head(hidden_states)
            model.lm_head = model.lm_head.cpu()

            shift_logits = lm_logits[:, :-1, :]
            shift_labels = batch[i, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(device))
            total_nll += float(loss) * model.seqlen

    ppl = math.exp(total_nll / input_ids.numel())
    print(ppl)
    model.config.use_cache = use_cache

### Running the whole thing

In [None]:
DEVICE = "cuda:0"
MODEL = "./model/"
SEED = 0
BITS = 4
N_SAMPLES = 128

In [None]:
model = initialize_layerless_llama(MODEL)
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList()
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

In [None]:
dataloader, testloader = get_wikitext2(MODEL, SEED, model.seqlen)

Downloading builder script:   0%|          | 0.00/8.48k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.62k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.72M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [None]:
llama_gptq(MODEL, model, dataloader, BITS, DEVICE)
# llama_nearest(MODEL, model, BITS, DEVICE)
# llama_no_compression(MODEL, model)

Starting ...
Ready.


100%|██████████| 32/32 [44:07<00:00, 82.74s/it]


In [None]:
llama_eval(model, testloader, DEVICE)

Evaluating ...


100%|██████████| 32/32 [01:38<00:00,  3.08s/it]
100%|██████████| 32/32 [01:38<00:00,  3.08s/it]
100%|██████████| 32/32 [01:36<00:00,  3.01s/it]
100%|██████████| 32/32 [01:35<00:00,  2.99s/it]


5.934689232829368


FP16: 5.67

GPTQx4: 5.94

NEARESTx4: 6.29