In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

import sys
sys.path.append("/home/msst/repo/Quantization")
import qlib

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BITS = 2
GROUP_SIZE = 64

In [3]:
# FP
path_to_model = "/media/msst/ssd_storage1/ml/llm/pretrained_models/Llama2-7B"
tokenizer = AutoTokenizer.from_pretrained(path_to_model)
# model = AutoModelForCausalLM.from_pretrained(
model = qlib._modeling.modeling_llama.LlamaForCausalLM.from_pretrained(
    path_to_model,
    device_map="cpu",
    dtype="auto"
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 85.13it/s]


In [4]:
class QuantLinear(nn.Module):
    def __init__(self, weight_shape, group_size):
        super().__init__()
        self.weight_shape = weight_shape
        self.group_size = group_size
        self.scale_size = [weight_shape[0], weight_shape[1] // group_size]
        self.scale = nn.Parameter(torch.empty(self.scale_size))
        self.offset = nn.Parameter(torch.empty(self.scale_size))
        self.register_buffer(
            "compressed_weight",
            torch.empty(
                weight_shape,
                dtype=torch.uint8,
                requires_grad=False
            )
        )

    def reshape_weight_for_scaling(self, w):
        return w.reshape(
            self.weight_shape[0], self.weight_shape[1] // self.group_size, -1
        )


    @torch.compile()
    def reconstruct_weight(self):
        w = self.reshape_weight_for_scaling(self.compressed_weight)
        w = w * self.scale[..., None] - self.offset[..., None]
        w = w.reshape(self.weight_shape)
        return w


    def forward(self, x):
        w = self.reconstruct_weight()
        return torch.nn.functional.linear(x, w.to(x.dtype))

In [5]:
# model = AutoModelForCausalLM.from_pretrained(
qmodel = qlib._modeling.modeling_llama.LlamaForCausalLM.from_pretrained(
    path_to_model,
    device_map="cpu",
    dtype="auto"
)

def wrap_model(current_module, prefix=''):
    for module_name, module in current_module.named_children():
        full_name = f"{prefix}.{module_name}" if prefix else module_name
        
        if "proj" in module_name:
            weight_shape = module.weight.data.shape
            setattr(current_module, module_name, QuantLinear(weight_shape, GROUP_SIZE))
        else:
            wrap_model(module, full_name)

wrap_model(qmodel)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 109.44it/s]


In [6]:
class MinMaxInitializer:
    def __init__(self):
        pass

    @torch.no_grad()
    def __call__(self, x_grouped, negative_clip, positive_clip):
        x_min = x_grouped.min(axis=-1)[0].unsqueeze(-1).float()
        x_max = x_grouped.max(axis=-1)[0].unsqueeze(-1).float()

        offset = (x_max * negative_clip - x_min * positive_clip) / (positive_clip - negative_clip)
        scale = (x_max + offset) / positive_clip
        scale = torch.abs(scale)

        scale = scale.reshape(x_grouped.shape[0], x_grouped.shape[1])
        offset = offset.reshape(x_grouped.shape[0], x_grouped.shape[1])

        return scale.contiguous(), offset.contiguous()

initializer = MinMaxInitializer()

@torch.no_grad()
def configure_single_layer(qlayer, layer, bits, C=None):
    max_int_val = 2**bits - 1

    orig_weight = layer.weight.data
    if C is not None:
        orig_weight = C.float() @ orig_weight.float()

    orig_weight_reshaped = qlayer.reshape_weight_for_scaling(orig_weight)
    scale, offset = initializer(orig_weight_reshaped, negative_clip=0, positive_clip=max_int_val)
    
    quant_weight = (orig_weight_reshaped + offset[..., None]) / scale[..., None]
    quant_weight = quant_weight.reshape_as(orig_weight)

    quant_weight = torch.clamp(torch.round(quant_weight), 0, max_int_val).to(torch.uint8)

    qlayer.compressed_weight.copy_(quant_weight)
    qlayer.scale.copy_(scale)
    qlayer.offset.copy_(offset)


@torch.no_grad()
def init_quant_model(qmodel, model, bits):
    for qmodule_name, qmodule in qmodel.named_modules():
        if isinstance(qmodule, QuantLinear):
            orig_module = model.get_submodule(qmodule_name)
            configure_single_layer(qmodule, orig_module, bits)

            err = torch.mean(((orig_module.weight.data.cpu() - qmodule.reconstruct_weight().cpu()) / (orig_module.weight.data.cpu().std() + 1e-8))**2)
            print(err, qmodule_name)


# init_quant_model(qmodel, model, bits=BITS)

In [7]:
path = f"/home/msst/repo/Quantization/nb/adaptive_rounding/init_w{BITS}gs{GROUP_SIZE}.pth"
# torch.save(qmodel.state_dict(), path)
qmodel.load_state_dict(torch.load(path))

<All keys matched successfully>

In [8]:
from transformers.masking_utils import create_causal_mask


@torch.no_grad()
def prepare_hessian(activations):
    hidden_size = activations[0].shape[-1]
    H = torch.zeros(hidden_size, hidden_size).cuda()
    for act in activations:
        act = act.cuda().view(-1, act.shape[-1]).float() / hidden_size ** 0.5
        H += act.T @ act
    return H


@torch.no_grad()
def prepare_hessian_q(activations, activations_q):
    hidden_size = activations[0].shape[-1]
    H_q = torch.zeros(hidden_size, hidden_size).cuda()
    for act_id in range(len(activations)):
        act = activations[act_id].cuda().view(-1, hidden_size).float() / hidden_size ** 0.5
        act_q = activations_q[act_id].cuda().view(-1, hidden_size).float() / hidden_size ** 0.5
        H_q += act.T @ act_q
    return H_q


def prepare_C(H, Hq):
    I = torch.eye(H.shape[0]).cuda() * 0 #* H.std()
    C = torch.linalg.inv(H + I) @ Hq
    return C


def hessian_loss(layer_q, layer_fp, H, C):
    w_q = layer_q.reconstruct_weight()
    if C is not None:
        w_q = w_q @ C.T
    w = layer_fp.weight
    delta_w = layer_q.reconstruct_weight() - layer_fp.weight
    return torch.trace(delta_w @ H @ delta_w.T)


def optimize_quant_params(
        layer_q,
        layer_fp,
        bits,
        H,
        C=None,
    ):
    trainable_params = [layer_q.scale, layer_q.offset]    
    optim = torch.optim.Adam(trainable_params, lr=1e-3)
    n_steps = 100
    
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        for i in range(n_steps):
            optim.zero_grad()
            loss = hessian_loss(layer_q, layer_fp, H, C)
            if i == 0:
                init_loss =  loss.item()
            loss.backward()
            optim.step()

        print(f"{init_loss} -> {loss}")


def hessian_loss_ste(layer_q, layer_fp, H, bits):
    max_int_val = 2**bits - 1

    latent_weight_reshaped = layer_q.reshape_weight_for_scaling(layer_fp.weight + layer_q.weight_addition)
    latent_weight_scaled = (latent_weight_reshaped + layer_q.offset[..., None]) / layer_q.scale[..., None]

    quant_weight = torch.clamp(torch.round(latent_weight_scaled), 0, max_int_val).to(torch.uint8)
    quant_weight_ste = quant_weight + latent_weight_scaled

    layer_q.compressed_weight.copy_(quant_weight.reshape_as(layer_fp.weight))

    weight_reco = quant_weight_ste * layer_q.scale[..., None] - layer_q.offset[..., None]
    weight_reco = weight_reco.reshape_as(layer_fp.weight)

    delta_w = weight_reco - layer_fp.weight

    C = 1e-6
    return torch.trace(delta_w @ H @ delta_w.T) + C * torch.sum(layer_q.weight_addition ** 2)


def optimize_quant_params_ste(
        layer_q,
        layer_fp,
        bits,
        H
    ):
    layer_q.weight_addition = nn.Parameter(torch.zeros_like(layer_fp.weight.data).float())

    # trainable_params = [layer_q.scale, layer_q.offset, layer_q.latent_weight]        
    # trainable_params = [layer_q.scale, layer_q.offset]
    trainable_params = [layer_q.weight_addition]        
    optim = torch.optim.Adam(trainable_params, lr=1e-4)
    n_steps = 100

    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        for i in range(n_steps):
            optim.zero_grad()

            loss = hessian_loss_ste(layer_q, layer_fp, H, bits)
            if i == 0:
                init_loss =  loss.item()
            loss.backward()
            optim.step()

        print(f"{init_loss} -> {loss}")

    del layer_q.weight_addition


def init_quant_block_hessian(
        block_q,
        block_fp,
        bits,
        activations,
        causal_mask,
        position_embeddings,
        with_opt=True,
        ):

    ##### Attention #####

    # Copy activations for the residual stream
    residual_activations = [x.clone() for x in activations]

    # Collect activations after input_layernorm
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            act = block_fp.input_layernorm(act)
            activations[act_id] = act.cpu()

    # Initialize q,k,v-projs
    H = prepare_hessian(activations)
    block_q_attn = block_q.self_attn
    block_fp_attn = block_fp.self_attn
    for layer_name in ["q_proj", "k_proj", "v_proj"]:
        layer_q = getattr(block_q_attn, layer_name)
        layer_fp = getattr(block_fp_attn, layer_name)
        configure_single_layer(layer_q, layer_fp, bits)
        if with_opt:
            optimize_quant_params(layer_q, layer_fp, bits, H)

    # Collect attention-out activations
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            act = block_fp_attn.compute_attention(
                hidden_states=act, 
                position_embeddings=position_embeddings,
                attention_mask=causal_mask
            )[0]
            activations[act_id] = act.cpu()

    # Initialize o_proj
    layer_q = block_q.self_attn.o_proj
    layer_fp = block_fp.self_attn.o_proj
    H = prepare_hessian(activations)
    configure_single_layer(layer_q, layer_fp, bits)
    if with_opt:
        optimize_quant_params(layer_q, layer_fp, bits, H)

    # Collect self_attn outs
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            res_act = residual_activations[act_id].cuda()
            act = block_fp.self_attn.o_proj(act)
            activations[act_id] = (act + res_act).cpu()

    ##### MLP #####

    # Copy activations for the residual stream
    residual_activations = [x.clone() for x in activations]

    # Collect activations after post_attention_layernorm
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            act = block_fp.post_attention_layernorm(act)
            activations[act_id] = act.cpu()

    # Initialize gate_proj and up_proj
    H = prepare_hessian(activations)
    block_q_mlp = block_q.mlp
    block_fp_mlp = block_fp.mlp
    for layer_name in ["gate_proj", "up_proj"]:
        layer_q = getattr(block_q_mlp, layer_name)
        layer_fp = getattr(block_fp_mlp, layer_name)
        configure_single_layer(layer_q, layer_fp, bits)
        if with_opt:
            optimize_quant_params(layer_q, layer_fp, bits, H)

    # Collect internal mlp activations
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            act = block_fp_mlp.act_fn(block_fp.mlp.gate_proj(act)) * block_fp.mlp.up_proj(act)
            activations[act_id] = act.cpu()

    # Initialize down_proj
    layer_q = block_q.mlp.down_proj
    layer_fp = block_fp.mlp.down_proj
    H = prepare_hessian(activations)
    configure_single_layer(layer_q, layer_fp, bits)
    if with_opt:
        optimize_quant_params(layer_q, layer_fp, bits, H)

    # Collect mlp outs
    with torch.no_grad():
        for act_id, act in enumerate(activations):
            act = act.cuda()
            res_act = residual_activations[act_id].cuda()
            act = block_fp.mlp.down_proj(act)
            activations[act_id] = (act + res_act).cpu()


def init_quant_block_hessian_2(
        block_q,
        block_fp,
        bits,
        activations,
        activations_q,
        causal_mask,
        position_embeddings,
        with_opt=True,
        ):

    ##### Attention #####

    # Copy activations for the residual stream
    residual_activations = [x.clone() for x in activations]
    residual_activations_q = [x.clone() for x in activations_q]

    # Collect activations after input_layernorm
    with torch.no_grad():
        for act_id in range(len(activations)):
            act = block_fp.input_layernorm(activations[act_id].cuda())            
            activations[act_id] = act.cpu()
            
            act_q = block_q.input_layernorm(activations_q[act_id].cuda())            
            activations_q[act_id] = act_q.cpu()

    # Initialize q,k,v-projs
    H = prepare_hessian(activations)
    Hq = prepare_hessian_q(activations, activations_q)
    C = prepare_C(H, Hq)

    block_q_attn = block_q.self_attn
    block_fp_attn = block_fp.self_attn
    for layer_name in ["q_proj", "k_proj", "v_proj"]:
        layer_q = getattr(block_q_attn, layer_name)
        layer_fp = getattr(block_fp_attn, layer_name)
        configure_single_layer(layer_q, layer_fp, bits)
        if with_opt:
            optimize_quant_params(layer_q, layer_fp, bits, H, C)

    # Collect attention-out activations
    with torch.no_grad():
        for act_id in range(len(activations)):
            act = block_fp_attn.compute_attention(
                hidden_states=act.cuda(), 
                position_embeddings=position_embeddings,
                attention_mask=causal_mask
            )[0]
            activations[act_id] = act.cpu()
            
            act_q = block_q_attn.compute_attention(
                hidden_states=act_q.cuda(), 
                position_embeddings=position_embeddings,
                attention_mask=causal_mask
            )[0]
            activations_q[act_id] = act_q.cpu()

    # Initialize o_proj
    layer_q = block_q.self_attn.o_proj
    layer_fp = block_fp.self_attn.o_proj
    
    H = prepare_hessian(activations)
    Hq = prepare_hessian_q(activations, activations_q)
    C = prepare_C(H, Hq)

    configure_single_layer(layer_q, layer_fp, bits)
    if with_opt:
        optimize_quant_params(layer_q, layer_fp, bits, H, C)

    # Collect self_attn outs
    with torch.no_grad():
        for act_id in range(len(activations)):
            act = activations[act_id].cuda()
            res_act = residual_activations[act_id].cuda()
            activations[act_id] = (block_fp.self_attn.o_proj(act) + res_act).cpu()

            act_q = activations_q[act_id].cuda()
            res_act_q = residual_activations_q[act_id].cuda()
            activations_q[act_id] = (block_q.self_attn.o_proj(act_q) + res_act_q).cpu()


    ##### MLP #####

    # Copy activations for the residual stream
    residual_activations = [x.clone() for x in activations]
    residual_activations_q = [x.clone() for x in activations_q]

    # Collect activations after post_attention_layernorm
    with torch.no_grad():
        for act_id in range(len(activations)):
            act = block_fp.post_attention_layernorm(activations[act_id].cuda())
            activations[act_id] = act.cpu()

            act_q = block_q.post_attention_layernorm(activations_q[act_id].cuda())
            activations_q[act_id] = act_q.cpu()

    # Initialize gate_proj and up_proj
    H = prepare_hessian(activations)
    Hq = prepare_hessian_q(activations, activations_q)
    C = prepare_C(H, Hq)

    block_q_mlp = block_q.mlp
    block_fp_mlp = block_fp.mlp
    for layer_name in ["gate_proj", "up_proj"]:
        layer_q = getattr(block_q_mlp, layer_name)
        layer_fp = getattr(block_fp_mlp, layer_name)
        configure_single_layer(layer_q, layer_fp, bits)
        if with_opt:
            optimize_quant_params(layer_q, layer_fp, bits, H, C)

    # Collect internal mlp activations
    with torch.no_grad():
        for act_id in range(len(activations)):
            act = activations[act_id].cuda()
            act = block_fp_mlp.act_fn(block_fp.mlp.gate_proj(act)) * block_fp.mlp.up_proj(act)
            activations[act_id] = act.cpu()

            act_q = activations_q[act_id].cuda()
            act_q = block_q_mlp.act_fn(block_q.mlp.gate_proj(act_q)) * block_q.mlp.up_proj(act_q)
            activations_q[act_id] = act_q.cpu()

    # Initialize down_proj
    layer_q = block_q.mlp.down_proj
    layer_fp = block_fp.mlp.down_proj
    
    H = prepare_hessian(activations)
    Hq = prepare_hessian_q(activations, activations_q)
    C = prepare_C(H, Hq)

    configure_single_layer(layer_q, layer_fp, bits)
    if with_opt:
        optimize_quant_params(layer_q, layer_fp, bits, H, C)

    # Collect mlp outs
    with torch.no_grad():
        for act_id in range(len(activations)):
            act = activations[act_id].cuda()
            res_act = residual_activations[act_id].cuda()
            activations[act_id] = (block_fp.mlp.down_proj(act) + res_act).cpu()

            act_q = activations_q[act_id].cuda()
            res_act_q = residual_activations_q[act_id].cuda()
            activations_q[act_id] = (block_q.mlp.down_proj(act) + res_act_q).cpu()


def init_quant_model_hessian(model_q, model_fp, bits, dataloader):
    embed_tokens = model_fp.get_decoder().embed_tokens.cuda()
    embed_tokens_device = embed_tokens.weight.device

    _batch = next(iter(dataloader))
    _inputs_embeds = embed_tokens(_batch.to(embed_tokens_device))

    cache_position = torch.arange(_inputs_embeds.shape[1], device=_inputs_embeds.device)
    position_ids = cache_position.unsqueeze(0)
    causal_mask = create_causal_mask(
        config=model_fp.config,
        input_embeds=_inputs_embeds,
        attention_mask=None,
        cache_position=cache_position,
        past_key_values=None,
        position_ids=position_ids,
    )

    position_embeddings = model_fp.get_decoder().rotary_emb(_inputs_embeds, position_ids)

    # Prepare activations
    activations = []
    with torch.no_grad():
        for batch in dataloader:
            activations.append(embed_tokens(batch.to(embed_tokens_device)).cpu())
    activations_q = [a.clone() for a in activations]

    for decoder_layer_id in tqdm(range(len(model_q.get_decoder().layers))):
        # if decoder_layer_id > 2:
        #     break

        block_q = model_q.get_decoder().layers[decoder_layer_id].cuda()
        block_fp = model_fp.get_decoder().layers[decoder_layer_id].cuda()

        # init_quant_block_hessian(
        #     block_q,
        #     block_fp,
        #     bits,
        #     activations,
        #     causal_mask,
        #     position_embeddings,
        #     with_opt=True
        # )

        init_quant_block_hessian_2(
            block_q,
            block_fp,
            bits,
            activations,
            activations_q,
            causal_mask,
            position_embeddings,
            with_opt=True
        )

        block_q = block_q.cpu()
        block_fp = block_fp.cpu()


config = {
    "dataset_name" : "wiki",
    "split": "train[:10000]",
    # "dataset_name" : "slim_pajama",
    # "split": "train[:5000]",
    "seq_length": 4096,
    "n_seq" : 64, #128,
    "batch_size": 8,
    "random_seed": 'no_rand'
}
dataloader = qlib.QATDataset(
    config=config,
    tokenizer=tokenizer
).get_dataloader()
print(len(dataloader))
init_quant_model_hessian(qmodel, model, BITS, dataloader)

8


  0%|          | 0/32 [00:00<?, ?it/s]

220.8610076904297 -> 16.006084442138672
270.2167053222656 -> 17.31378936767578
158.06936645507812 -> 5.9089813232421875
0.3702833652496338 -> 0.03462900221347809
571.605224609375 -> 234.9871826171875
557.256103515625 -> 226.6083984375
5.793052673339844 -> 1.3856124877929688


  3%|▎         | 1/32 [00:52<27:15, 52.77s/it]

2806.46923828125 -> 742.6286010742188
2754.2802734375 -> 750.1689453125
506.09649658203125 -> 89.96592712402344
2.8629579544067383 -> 0.28932833671569824
1511.875 -> 805.513671875
1342.026611328125 -> 710.8385009765625
2970.34033203125 -> 8.761390686035156


  6%|▋         | 2/32 [01:43<25:43, 51.45s/it]

5471.4736328125 -> 2247.2216796875
5874.99609375 -> 2489.27880859375
1697.52685546875 -> 627.808837890625
142.06756591796875 -> 5.227794647216797
3214.8486328125 -> 1136.830810546875
2784.61083984375 -> 985.773193359375
24.076446533203125 -> 8.934795379638672


  9%|▉         | 3/32 [02:33<24:37, 50.95s/it]

11847.787109375 -> 4335.6171875
12579.689453125 -> 4693.1611328125
3501.369140625 -> 1208.107666015625
649.212646484375 -> 9.239688873291016
5669.966796875 -> 1388.126953125
4854.236328125 -> 1169.9423828125
43.729705810546875 -> 12.924774169921875


 12%|█▎        | 4/32 [03:23<23:38, 50.64s/it]

10924.50390625 -> 2956.976806640625
11276.0791015625 -> 3084.78564453125
3292.0634765625 -> 828.7548828125
858.423828125 -> 9.749809265136719
7359.904296875 -> 1426.61376953125
5979.35546875 -> 1132.34130859375
79.19644165039062 -> 13.4339599609375


 16%|█▌        | 5/32 [04:14<22:43, 50.48s/it]

11384.982421875 -> 2455.312744140625
12361.57421875 -> 2681.96044921875
3615.353515625 -> 715.705810546875
1402.65478515625 -> 13.901622772216797
8288.8447265625 -> 1329.260986328125
6643.712890625 -> 1042.3768310546875
71.44436645507812 -> 11.808341979980469


 19%|█▉        | 6/32 [05:04<21:48, 50.33s/it]

15521.015625 -> 2752.78271484375
15990.3837890625 -> 2813.870361328125
4686.244140625 -> 760.693359375
1144.681640625 -> 11.400745391845703
9771.666015625 -> 1471.75
7529.671875 -> 1105.85595703125
85.7706298828125 -> 12.789131164550781


 22%|██▏       | 7/32 [05:54<20:55, 50.21s/it]

16341.0390625 -> 2730.373779296875
16529.10546875 -> 2723.862060546875
5180.05859375 -> 783.30126953125
1229.0068359375 -> 12.0103759765625
10990.603515625 -> 1559.81005859375
8555.4326171875 -> 1183.5902099609375
97.76409912109375 -> 13.63885498046875


 25%|██▌       | 8/32 [06:44<20:04, 50.19s/it]

16469.703125 -> 2562.67724609375
16481.388671875 -> 2558.045166015625
5285.39453125 -> 751.55810546875
1547.0048828125 -> 15.452003479003906
11029.259765625 -> 1458.03125
9076.654296875 -> 1176.7666015625
115.91778564453125 -> 14.345466613769531


 28%|██▊       | 9/32 [07:34<19:12, 50.12s/it]

17014.119140625 -> 2457.2744140625
17623.171875 -> 2579.177978515625
5697.400390625 -> 757.6943359375
1881.4873046875 -> 16.124359130859375
11402.88671875 -> 1403.650146484375
9688.4482421875 -> 1171.519287109375
126.17559814453125 -> 14.72317886352539


 31%|███▏      | 10/32 [08:25<18:28, 50.37s/it]

17100.52734375 -> 2307.962890625
18084.234375 -> 2482.820068359375
5681.75390625 -> 705.1201171875
1785.65771484375 -> 16.426197052001953
11669.6103515625 -> 1378.6876220703125
10200.80078125 -> 1180.32666015625
148.6103515625 -> 15.987907409667969


 34%|███▍      | 11/32 [09:16<17:42, 50.59s/it]

18698.392578125 -> 2483.412109375
18449.69140625 -> 2449.148193359375
7440.83984375 -> 909.1630859375
3432.060546875 -> 28.09123992919922
12391.56640625 -> 1309.3408203125
11052.015625 -> 1150.602783203125
163.071044921875 -> 15.855094909667969


 38%|███▊      | 12/32 [10:07<16:54, 50.74s/it]

19516.361328125 -> 2326.371337890625
20551.869140625 -> 2487.814697265625
7211.8828125 -> 798.41357421875
3238.806640625 -> 23.90423583984375
12879.123046875 -> 1285.11083984375
11936.880859375 -> 1169.932373046875
166.0516357421875 -> 16.424331665039062


 41%|████      | 13/32 [10:58<16:06, 50.87s/it]

19559.25390625 -> 2212.381103515625
20295.169921875 -> 2290.927978515625
7935.2734375 -> 825.95068359375
4940.703125 -> 31.362037658691406
13569.0478515625 -> 1197.6319580078125
12871.7001953125 -> 1118.5137939453125
195.8963623046875 -> 17.583099365234375


 44%|████▍     | 14/32 [11:50<15:21, 51.21s/it]

20364.998046875 -> 2039.711669921875
21315.78125 -> 2135.101806640625
8068.9921875 -> 733.38525390625
4177.5986328125 -> 21.258939743041992
14640.392578125 -> 1228.99462890625
13993.69140625 -> 1160.153564453125
204.26043701171875 -> 19.37676239013672


 47%|████▋     | 15/32 [12:40<14:25, 50.91s/it]

18993.72265625 -> 1810.953857421875
20283.794921875 -> 1951.0511474609375
8332.15625 -> 724.45361328125
6175.5859375 -> 28.501867294311523
15894.060546875 -> 1222.3604736328125
15165.75 -> 1152.9384765625
247.1981201171875 -> 22.501449584960938


 50%|█████     | 16/32 [13:31<13:32, 50.79s/it]

20036.75 -> 1763.216064453125
20949.505859375 -> 1867.35888671875
9619.2890625 -> 774.4091796875
10937.03125 -> 37.9432373046875
18312.52734375 -> 1168.333984375
17294.361328125 -> 1087.67529296875
332.896484375 -> 22.882461547851562


 53%|█████▎    | 17/32 [14:21<12:42, 50.81s/it]

20344.9296875 -> 1491.923583984375
21329.31640625 -> 1578.27099609375
9869.3359375 -> 663.63330078125
10867.7578125 -> 24.037832260131836
21006.328125 -> 1124.6241455078125
19100.66796875 -> 1009.219970703125
365.771240234375 -> 22.487564086914062


 56%|█████▋    | 18/32 [15:13<11:54, 51.04s/it]

21710.22265625 -> 1302.7371826171875
22536.431640625 -> 1367.510986328125
12120.8828125 -> 670.5576171875
22723.90625 -> 38.2092399597168
23807.05859375 -> 971.755615234375
21402.03515625 -> 852.85546875
472.50927734375 -> 21.730560302734375


 59%|█████▉    | 19/32 [16:05<11:07, 51.37s/it]

20980.580078125 -> 951.0987548828125
21921.119140625 -> 995.94140625
12025.26953125 -> 503.921875
26103.046875 -> 33.33537673950195
26000.9375 -> 836.9398193359375
22899.6875 -> 728.218505859375
546.87255859375 -> 19.435256958007812


 62%|██████▎   | 20/32 [16:58<10:20, 51.72s/it]

21817.52734375 -> 772.7890625
22411.013671875 -> 806.9041137695312
12655.40625 -> 416.0865478515625
34459.53125 -> 15.12813663482666
28130.76953125 -> 692.50830078125
24724.9140625 -> 596.1787109375
657.082763671875 -> 16.89883041381836


 66%|██████▌   | 21/32 [17:49<09:27, 51.58s/it]

23452.341796875 -> 635.1666259765625
23839.126953125 -> 650.615234375
15137.421875 -> 384.658447265625
55336.0625 -> -13.445680618286133
30928.73828125 -> 609.432373046875
26665.62890625 -> 515.1119384765625
767.855712890625 -> 15.301387786865234


 69%|██████▉   | 22/32 [18:40<08:34, 51.45s/it]

25539.44921875 -> 544.8697509765625
26104.306640625 -> 560.6552734375
15865.9453125 -> 318.125732421875
62410.375 -> -20.11383056640625
33891.34765625 -> 560.614013671875
28638.53125 -> 463.918701171875
944.241455078125 -> 14.788333892822266


 72%|███████▏  | 23/32 [19:32<07:44, 51.65s/it]

28693.119140625 -> 508.0504455566406
29195.3125 -> 520.4442749023438
19973.1796875 -> 338.010498046875
140265.3125 -> -30.648534774780273
36024.203125 -> 453.1529541015625
30902.08203125 -> 382.8240966796875
1145.337646484375 -> 13.142143249511719


 75%|███████▌  | 24/32 [20:23<06:51, 51.50s/it]

26355.951171875 -> 366.8974914550781
26654.15625 -> 374.9798583984375
19038.296875 -> 252.9283447265625
119058.5625 -> -61.53163528442383
38638.0703125 -> 414.4991455078125
33331.78125 -> 349.9017333984375
1322.242431640625 -> 12.711406707763672


 78%|███████▊  | 25/32 [21:15<05:59, 51.40s/it]

31968.38671875 -> 376.3715515136719
31903.12109375 -> 377.45013427734375
24034.265625 -> 273.262939453125
268874.75 -> 122.5596694946289
41490.9921875 -> 376.34515380859375
35833.921875 -> 318.4627685546875
1582.1494140625 -> 11.454811096191406


 81%|████████▏ | 26/32 [22:05<05:07, 51.21s/it]

29312.84765625 -> 281.92181396484375
29392.5625 -> 289.034423828125
23519.8046875 -> 221.40872192382812
336784.125 -> 278.508056640625
44568.890625 -> 330.1917724609375
38590.3671875 -> 280.58514404296875
1888.81494140625 -> 11.51028823852539


 84%|████████▍ | 27/32 [22:56<04:15, 51.07s/it]

33550.84375 -> 271.3127746582031
33743.01171875 -> 277.2298889160156
25171.5625 -> 195.06182861328125
376967.125 -> 365.7149963378906
47412.3984375 -> 306.5296325683594
41452.8125 -> 260.11358642578125
2260.80078125 -> 11.16352653503418


 88%|████████▊ | 28/32 [23:47<03:23, 50.93s/it]

32692.255859375 -> 227.25924682617188
32868.18359375 -> 230.88858032226562
27709.0703125 -> 187.49642944335938
636257.5 -> 930.4278564453125
49137.2421875 -> 263.0664978027344
44037.6953125 -> 231.2512969970703
2766.7646484375 -> 11.112064361572266


 91%|█████████ | 29/32 [24:37<02:32, 50.83s/it]

29108.828125 -> 175.52877807617188
29412.76953125 -> 179.634033203125
26255.0859375 -> 156.2301025390625
609797.5 -> 1212.5782470703125
52327.2890625 -> 248.03414916992188
47561.96875 -> 223.94113159179688
4589.9091796875 -> 13.34028434753418


 94%|█████████▍| 30/32 [25:28<01:41, 50.88s/it]

32930.0859375 -> 167.1314697265625
32648.41015625 -> 167.80740356445312
29722.390625 -> 151.41702270507812
922815.5 -> 2012.63427734375
55767.578125 -> 283.00445556640625
50153.7734375 -> 244.27688598632812
579165056.0 -> -360064.375


 97%|█████████▋| 31/32 [26:23<00:51, 51.08s/it]


_LinAlgError: linalg.inv: The diagonal element 2534 is zero, the inversion could not be completed because the input matrix is singular.

In [None]:
config = {
    "dataset_name" : "wiki",
    "split": "test",
    "seq_length": 2048,
    "batch_size": 1,
    "random_seed": 'no_rand'
}

dataloader = qlib.QATDataset(
    config=config,
    tokenizer=tokenizer
).get_dataloader()

qmodel = qmodel.cuda()
# qmodel = model.cuda()

with torch.no_grad():
    with torch.amp.autocast('cuda', dtype=torch.float16):
        res = qlib.evaluate(qmodel, dataloader, print_times=25)
        print(res)

  0%|          | 0/166 [00:00<?, ?it/s]W1228 18:25:51.958000 170414 site-packages/torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8)
W1228 18:25:51.958000 170414 site-packages/torch/_dynamo/convert_frame.py:1358] [0/8]    function: 'reconstruct_weight' (/tmp/ipykernel_170414/2410632724.py:24)
W1228 18:25:51.958000 170414 site-packages/torch/_dynamo/convert_frame.py:1358] [0/8]    last reason: 0/7: tensor 'self._buffers['compressed_weight']' size mismatch at index 0. expected 11008, actual 4096
W1228 18:25:51.958000 170414 site-packages/torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1228 18:25:51.958000 170414 site-packages/torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html
  4%|▍         | 7/166 [00:04<01:37,  1.62it/s]

112621.53783384823


  8%|▊         | 13/166 [00:07<01:32,  1.65it/s]

113769.37736770592


 11%|█▏        | 19/166 [00:10<01:28,  1.65it/s]

109991.26188648363


 15%|█▌        | 25/166 [00:13<01:25,  1.65it/s]

In [None]:
### w3gs128

# base:            8.76
# H (wiki):        6.17 
# H (slim-pajama): 6.31 

In [None]:
### w2gs64

# base:            9583
# H (wiki):        27.7
# H (slim-pajama): 36.6