In [None]:
from huggingface_hub import snapshot_download

PATH = "../QuEST-800M-sparse-INT4"
snapshot_download(repo_id="ISTA-DASLab/QuEST-800M-sparse-INT4", local_dir=PATH)

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

summary.json:   0%|          | 0.00/1.87k [00:00<?, ?B/s]

main.pt:   0%|          | 0.00/4.11G [00:00<?, ?B/s]

'/nfs/scistore19/alistgrp/apanfero/QuEST/QuEST-800M-sparse-INT4'

In [2]:
import json

from optim.utils import load_checkpoint
from models.utils import get_model


class DotDict(dict):
    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(f"'DotDict' object has no attribute '{key}'")

    def __setattr__(self, key, value):
        self[key] = value


with open(f"{PATH}/summary.json", "r") as f:
    config = json.load(f)


In [8]:
import torch
from torch import nn
import torch.nn.functional as F

from fast_hadamard_transform import hadamard_transform

from models.quantization.base_linear import OPTIMAL_GAUSSIAN_SCALES, HalfHadamardFourEightTrustQuantizer, HalfHadamardTrustQuantizer


def quantize_pack_hadamard_four_eight(x: torch.Tensor, quantizer: HalfHadamardFourEightTrustQuantizer):
    x_had = hadamard_transform(x.reshape(-1, 128), scale=2 ** (-7/2)).reshape(x.shape)
    
    std = torch.sqrt(torch.mean(x_had**2, dim=-1, keepdim=True)) + 1e-8
    scale = OPTIMAL_GAUSSIAN_SCALES[quantizer.bits] * std

    step = 2 * scale / (quantizer.n_levels - 1)
    x_clip = torch.clamp(x_had, -scale, scale)
    xq = torch.round(x_clip / step + 1/2) * step - step / 2

    _, val_idx = x_had.reshape(-1, 4, 2).norm(p=quantizer.p, dim=-1).topk(k=2, dim=-1, largest=True)
    xq = xq.reshape(-1, 4, 2)
    xq_sparse = xq[
        torch.arange(xq.size(0)).repeat(2, 1).T,
        val_idx,
    ]
    xq_sparse = xq_sparse.reshape(x.shape[:-1] + (x.shape[-1] // 2,))

    xq_sparse = torch.round((xq_sparse + scale) / step)
    assert xq_sparse.min() >= 0 and xq_sparse.max() < quantizer.n_levels
    return xq_sparse, val_idx, scale, step
    # ^ note: xq_sparse is in rotated space!


def dequantize_four_eight(xq_sparse, val_idx, scale, step):
    weight = torch.zeros((xq_sparse.numel() // 4, 4, 2), dtype=torch.float32, device=xq_sparse.device)
    
    weight[
        torch.arange(weight.size(0)).repeat(2, 1).T,
        val_idx,
    ] = (xq_sparse.to(torch.float32) * step - scale).reshape(-1, 2, 2)
    
    return weight.reshape(xq_sparse.shape[:-1] + (xq_sparse.shape[-1] * 2,))


weight = torch.rand(2, 128).cuda()
quantizer = HadamardFourEightTrustQuantizer(bits=4)

ref = quantizer(weight)
xq_sparse, idx, scale, step = quantize_pack_hadamard_four_eight(weight, quantizer)
deq = dequantize_four_eight(xq_sparse, idx, scale, step)

torch.testing.assert_close(hadamard_transform(ref.reshape(-1, 128), scale=2 ** (-7/2)).reshape(ref.shape), deq, rtol=1e-3, atol=1e-3)

In [9]:
def quantize_pack_hadamard_dense(x: torch.Tensor, quantizer: HalfHadamardTrustQuantizer):
    assert quantizer.centered
    x_had = hadamard_transform(x.reshape(-1, 128), scale=2 ** (-7/2)).reshape(x.shape)
    
    std = torch.sqrt(torch.mean(x_had**2, dim=-1, keepdim=True)) + 1e-8
    scale = OPTIMAL_GAUSSIAN_SCALES[quantizer.bits] * std

    step = 2 * scale / (quantizer.n_levels - 1)
    x_clip = torch.clamp(x_had, -scale, scale)
    xq = torch.round((x_clip + scale) / step)

    assert xq.min() >= 0 and xq.max() < quantizer.n_levels
    return xq, scale, step
    # ^ note: xq is in rotated space!

def dequantize_dense(xq, scale, step):
    return xq * step - scale


quantizer = HadamardTrustQuantizer(bits=4)
ref = quantizer(weight)
xq, scale, step = quantize_pack_hadamard_dense(weight, quantizer)
deq = dequantize_dense(xq, scale, step)

torch.testing.assert_close(hadamard_transform(ref.reshape(-1, 128), scale=2 ** (-7/2)).reshape(ref.shape), deq, rtol=1e-3, atol=1e-3)

In [10]:
from models.quantization.base_linear import QuantizedLinear

class Linear4bit(nn.Module):
    def __init__(self, quantizer_linear):
        super().__init__()
        
        assert isinstance(quantizer_linear.weight_quantizer, HalfHadamardFourEightTrustQuantizer)
        assert isinstance(quantizer_linear.activation_quantizer, HalfHadamardTrustQuantizer)
        
        self.activation_quantizer = quantizer_linear.activation_quantizer
        
        wq = dequantize_four_eight(*quantize_pack_hadamard_four_eight(quantizer_linear.weight, quantizer_linear.weight_quantizer))
        self.register_buffer("wq", wq)
        self.bias = quantizer_linear.bias

    def forward(self, x):
        x = dequantize_dense(*quantize_pack_hadamard_dense(x, self.activation_quantizer))
        return F.linear(x, self.wq, self.bias)


def replace_linears(model):
    for name, module in model.named_children():
        if isinstance(module, QuantizedLinear):
            model._modules[name] = Linear4bit(module)
        else:
            replace_linears(module)
    return model

In [11]:
class PseudoDdp(nn.Module):
    def __init__(self, model):
        super().__init__()
        self._orig_mod = nn.ModuleDict({
            "module": model,
        })
        
class PseudoLoader:
    def load_state_dict(self, *args, **kwargs):
        pass

model = PseudoDdp(get_model(DotDict(config['args'])))
model.load_state_dict(torch.load(f"{PATH}/main.pt", map_location="cpu")["model"])
model = model.cuda()
model = model._orig_mod["module"]
model = replace_linears(model)


  model.load_state_dict(torch.load(f"{PATH}/main.pt", map_location="cpu")["model"])


In [12]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

In [13]:
def generate_text_greedily(model, tokenizer, prompt, max_length=50, device='cuda'):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(input_ids, get_logits=True)
            logits = outputs['logits'][:, -1, :]
        
        next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1)
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)
        
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

generated_text = generate_text_greedily(model, tokenizer, "Hi!", max_length=20)
print(generated_text)


Hi! I am a 20 year old student from the United States. I am currently studying at the


In [14]:
numel = 0
for name, param in model.named_buffers():
    numel += param.numel()
    print(name, param.numel())
    
print(numel/1e6)

transformer.h.0.attn.c_attn.wq 12582912
transformer.h.0.attn.c_proj.wq 4194304
transformer.h.0.mlp.w1.wq 11534336
transformer.h.0.mlp.w2.wq 11534336
transformer.h.0.mlp.c_proj.wq 11534336
transformer.h.1.attn.c_attn.wq 12582912
transformer.h.1.attn.c_proj.wq 4194304
transformer.h.1.mlp.w1.wq 11534336
transformer.h.1.mlp.w2.wq 11534336
transformer.h.1.mlp.c_proj.wq 11534336
transformer.h.2.attn.c_attn.wq 12582912
transformer.h.2.attn.c_proj.wq 4194304
transformer.h.2.mlp.w1.wq 11534336
transformer.h.2.mlp.w2.wq 11534336
transformer.h.2.mlp.c_proj.wq 11534336
transformer.h.3.attn.c_attn.wq 12582912
transformer.h.3.attn.c_proj.wq 4194304
transformer.h.3.mlp.w1.wq 11534336
transformer.h.3.mlp.w2.wq 11534336
transformer.h.3.mlp.c_proj.wq 11534336
transformer.h.4.attn.c_attn.wq 12582912
transformer.h.4.attn.c_proj.wq 4194304
transformer.h.4.mlp.w1.wq 11534336
transformer.h.4.mlp.w2.wq 11534336
transformer.h.4.mlp.c_proj.wq 11534336
transformer.h.5.attn.c_attn.wq 12582912
transformer.h.5.attn