In [None]:
%env CUDA_VISIBLE_DEVICES 6

In [None]:
import functools
import transformers

MODEL_PATH = 'meta-llama/Llama-2-7b-hf'
MODEL_SEQLEN = 4096


@functools.cache
def get_model():
    return transformers.AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype='auto',
    )

In [None]:
import os
import sys
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedWeight, QuantizedLinear

import torch
import quiptools_cuda
from matmul_had import get_hadK
from fast_hadamard_transform import hadamard_transform

torch.set_num_threads(16)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False

In [None]:
def matmul_hadU_cuda(X, hadK, K):
    n = X.shape[-1]
    if K == 1:
        return hadamard_transform(X.contiguous(), 1/(n**0.5))

    input = X.float().view(-1, K, n // K)
    input = hadamard_transform(input.contiguous(), 1/(n**0.5))
    input = hadK.to(input.device).to(input.dtype) @ input
    return input.to(X.device).to(X.dtype).reshape(X.shape)


def vec_to_tuple(inp):
    return tuple(v.item() for v in inp)


class HadamardWrapper(nn.Module):
    def __init__(self, SU, SV, inner, device='cuda'):
        super().__init__()
        
        self.out_dim, self.in_dim = len(SV), len(SU)

        SU = SU.detach().clone().to(device).float()
        SU.requires_grad = True

        SV = SV.detach().clone().to(device).float()
        SV.requires_grad = True
        
        self.SU = torch.nn.Parameter(SU, requires_grad=True)
        self.register_buffer(SV, requires_grad=True)
        self.inner = inner

    def forward(self, x):
        out_dim, in_dim = self.out_dim, self.in_dim
        had_left_T, K_left = get_hadK(in_dim)
        if had_left_T is not None:
            had_left_T = had_left_T.T.contiguous()
            assert had_left_T.requires_grad == False
        had_right, K_right = get_hadK(out_dim)
        if had_right is not None:
            assert had_right.requires_grad == False
        input_shape = x.shape
        assert input_shape[-1] == in_dim
        x = x.view(-1, in_dim)
        # x = x.to(torch.float32)
        x = x * self.SU
        x = matmul_hadU_cuda(x, had_left_T, K_left) / 32
        # x = x.to(torch.float16)
        x = self.inner(x)
        # x = x.to(torch.float32)
        x = matmul_hadU_cuda(x, had_right, K_right)
        x = x * self.SV * 32
        # x = x.to(torch.float16)
        x = x.reshape(tuple(input_shape[:-1]) + (out_dim,))
        return x


def replace_submodule(module, submodule_path, new_submodule):
    submodule_names = submodule_path.split(".")
    for submodule in submodule_names[:-1]:
        module = getattr(module, submodule)
    setattr(module, submodule_names[-1], new_submodule)

In [None]:
grid_tensor = torch.load('./e8p_grid.pt', map_location='cpu', weights_only=True)
assert sorted([vec_to_tuple(x) for x in grid_tensor]) == [vec_to_tuple(x) for x in grid_tensor]

grid_hashed = (grid_tensor * 4).round()
assert grid_hashed.abs().max() < 127
grid_hashed = grid_hashed.to(torch.int8).view(torch.int64).view(65536)
idx_by_hash = {h.item(): idx for idx, h in enumerate(grid_hashed)}

In [None]:
import grid

packed_abs_grid = grid.get_packed_abs_grid()

In [None]:
def get_codes(SU, SV, Qidxs):
    in_dim, out_dim = len(SU), len(SV)
    
    W = quiptools_cuda.decompress_packed_e8p(
        Qidxs.view(out_dim // 16, in_dim // 64, 8, 4).cuda(),
        packed_abs_grid.cuda(),
    ).cpu()
    
    W_hashed = W.cpu().reshape(out_dim * in_dim // 8, 8) * 4
    assert W_hashed.abs().max() < 127
    W_hashed = W_hashed.to(torch.int8).view(torch.int64).view(out_dim * in_dim // 8)
    
    W_codes = torch.tensor([
        idx_by_hash[h.item()] for h in W_hashed
    ]).reshape(out_dim, in_dim // 8)

    return W_codes


def get_quantized_weight(SU, SV, codes):  
    in_dim, out_dim = len(SU), len(SV)
    
    quantized_weight = QuantizedWeight(
        reference_weight=torch.ones((out_dim, in_dim), dtype=torch.float16).cuda(), num_codebooks=1,
        nbits_per_codebook=16, scale_nbits=0, 
        out_group_size=1, in_group_size=8,
        verbose=False, max_iter=0,
    )
    
    quantized_weight.scales.data = torch.ones_like(quantized_weight.scales.data)
    quantized_weight.scales.requires_grad = False
    quantized_weight.codebooks.data = grid_tensor.reshape(quantized_weight.codebooks.shape).detach().clone().cuda()
    quantized_weight.codes.data = codes.clone().reshape(quantized_weight.codes.shape).cuda().to(torch.int32)
    return quantized_weight

In [None]:
from safetensors import safe_open

tensors = {}
with safe_open("model.safetensors", framework="pt", device="cpu") as f:
   for key in f.keys():
       tensors[key] = f.get_tensor(key)

In [None]:
import quiptools_cuda

def load_and_convert(layer):
    Qidxs = tensors[f'{layer}.Qidxs']
    SU = tensors[f'{layer}.SU'].float()
    SV = tensors[f'{layer}.SV'].float() * tensors[f'{layer}.Wscale'].float()
    fuse_scales = tensors.get(f'{layer}.fuse_scales', None)
    if fuse_scales is not None:
        fuse_scales = fuse_scales.float()
    codebook_id = tensors[f'{layer}.codebook_id']
    assert codebook_id.item() == 7
    return SU, SV, get_codes(SU, SV, Qidxs), fuse_scales

In [None]:
layer_idx = 0

In [None]:
model = get_model().cuda()

import copy

model_changed = copy.deepcopy(model)

In [None]:
import tqdm

for layer_idx in tqdm.tqdm(range(32)):
    SU, SV, codes, fuse_scales = load_and_convert(f'model.layers.{layer_idx}.mlp.down_proj')
    assert fuse_scales is None
    down_linear = HadamardWrapper(SU, SV, QuantizedLinear(get_quantized_weight(SU, SV, codes), bias=None))
    
    # ---
    
    SU, SV, codes, fuse_scales = load_and_convert(f'model.layers.{layer_idx}.mlp.upgate_proj')
    
    up_out_dim = get_model().model.layers[layer_idx].mlp.up_proj.weight.shape[0]
    gate_out_dim = get_model().model.layers[layer_idx].mlp.gate_proj.weight.shape[0]
    
    up_scale, gate_scale = fuse_scales
    
    scales = torch.cat([
        up_scale * torch.ones((up_out_dim,), dtype=torch.float32),
        gate_scale * torch.ones((gate_out_dim,), dtype=torch.float32),
    ], dim=0)
    SV = SV * scales
    
    upgate_linear = HadamardWrapper(SU, SV, QuantizedLinear(get_quantized_weight(SU, SV, codes), bias=None))
    
    # ---
    
    SU, SV, codes, fuse_scales = load_and_convert(f'model.layers.{layer_idx}.self_attn.o_proj')
    assert fuse_scales is None
    o_linear = HadamardWrapper(SU, SV, QuantizedLinear(get_quantized_weight(SU, SV, codes), bias=None))
    
    # ---
    
    SU, SV, codes, fuse_scales = load_and_convert(f'model.layers.{layer_idx}.self_attn.qkv_proj')
    
    q_out_dim = get_model().model.layers[layer_idx].self_attn.q_proj.weight.shape[0]
    k_out_dim = get_model().model.layers[layer_idx].self_attn.k_proj.weight.shape[0]
    v_out_dim = get_model().model.layers[layer_idx].self_attn.v_proj.weight.shape[0]
    
    q_scale, k_scale, v_scale = fuse_scales
    
    scales = torch.cat([
        q_scale * torch.ones((q_out_dim,), dtype=torch.float32),
        k_scale * torch.ones((k_out_dim,), dtype=torch.float32),
        v_scale * torch.ones((v_out_dim,), dtype=torch.float32)
    ], dim=0)
    SV = SV * scales
    
    qkv_linear = HadamardWrapper(SU, SV, QuantizedLinear(get_quantized_weight(SU, SV, codes), bias=None))
    
    del model_changed.model.layers[layer_idx].self_attn.q_proj
    del model_changed.model.layers[layer_idx].self_attn.k_proj
    del model_changed.model.layers[layer_idx].self_attn.v_proj
    del model_changed.model.layers[layer_idx].self_attn.o_proj
    
    model_changed.model.layers[layer_idx].self_attn.qkv_proj = qkv_linear
    model_changed.model.layers[layer_idx].self_attn.o_proj = o_linear
    
    del model_changed.model.layers[layer_idx].mlp.gate_proj
    del model_changed.model.layers[layer_idx].mlp.up_proj
    del model_changed.model.layers[layer_idx].mlp.down_proj
    
    model_changed.model.layers[layer_idx].mlp.upgate_proj = upgate_linear
    model_changed.model.layers[layer_idx].mlp.down_proj = down_linear

In [None]:
torch.save(model_changed, 'model.pt')

In [None]:
!mkdir ./quip-sharp-model-aqlm-format

In [None]:
for layer_idx, layer in enumerate(model_changed.model.layers):
    torch.save(layer, f'./quip-sharp-model-aqlm-format/{layer_idx}.pth')

In [None]:
torch.save(
    {k: v for k, v in model_changed.state_dict().items() if 'model.layers' not in k},
    './quip-sharp-model-aqlm-format/not_quantized_weights.pt',
)

In [None]:
torch.save(
    dict(),
    '/mnt/ar_home/galqiwi/tmp/quip-sharp-model/args.pt',
)