In [1]:
%load_ext autoreload
%autoreload 2

import torch
import os
import sys
sys.path.append("/home/msst/repo/Quantization")
from qlib import HaarWavelet, batch_gathering
from qlib.utils.pack_effective import unpack_bool_tensor
from qlib.vector_quantization.nn_search.faiss_nn_search import reassign

torch.backends.cuda.matmul.allow_tf32 = True
DEVICE = 'cuda'
DTYPE = torch.float32


model_name='Llama2-7b-hf'

In [2]:
def L2LossWithHess(w, wq, H=None):
    diff_squared = torch.pow(w - wq, 2)
    if H is None:
        return torch.mean(diff_squared)
    return torch.mean(H * diff_squared)

In [3]:
path_to_per_block_fp = f'/mnt/ssd_storage/ml/weights/vc_data/{model_name}/per_block_fp'
path_to_kmeans = f'/mnt/ssd_storage/ml/weights/vc_data/Llama2-7b-hf/kmeans/cb256_vecdim8_weightPERCOORD_scaleOUTL2_distMSE_blocksizeNone_iters10_abscoords_haar2'
path_to_hess = f'/mnt/ssd_storage/ml/weights/vc_data/{model_name}/xtx/{model_name}_xtx.pth'

xtx = torch.load(path_to_hess, weights_only=True)

In [4]:
block_names = sorted(
    [name for name in os.listdir(path_to_per_block_fp) if name.startswith('model.layers')],
    key=lambda name: int(name.split('.')[-1])
)


block_name = block_names[0]
block = torch.load(os.path.join(path_to_per_block_fp, block_name))
init_data = {}

for module_name, module in block.named_modules():
    if module.__class__ == torch.nn.Linear:
        weight = module.weight.data
        full_module_name = '.'.join([block_name, module_name])
        kmeans_data = torch.load(f'{path_to_kmeans}/{full_module_name}.pth')
        hess = xtx[full_module_name]
        init_data.update({
            full_module_name: {
                'kmeans_data' : kmeans_data,
                'weight' : weight,
                'hess' : hess,
            }
        })


  block = torch.load(os.path.join(path_to_per_block_fp, block_name))
  kmeans_data = torch.load(f'{path_to_kmeans}/{full_module_name}.pth')


In [20]:
class GDInit(torch.nn.Module):
    def __init__(self, codebook, indices, scales, haar_transform):
        super().__init__()
        self.indices = indices.to(torch.int)
        self.codebook = torch.nn.Parameter(codebook)
        self.scales = torch.nn.Parameter(scales)
        self.haar_transform = haar_transform

        self.vector_dim = self.codebook.shape[-1]
    
    def forward(self, w):
        orig_freq = self.haar_transform.forward(w / self.scales)
        orig_freq_signs = torch.sign(orig_freq)

        #####
        for freq_id in range(orig_freq.shape[0]):
            new_indices = torch.tensor(
                reassign(
                    torch.abs(orig_freq[freq_id].reshape(-1, self.vector_dim)), 
                    self.codebook.data[freq_id], 
                    reassine_params={"batch_size" : 2**10}
                ),
                dtype=self.indices.dtype, 
                device=self.indices.device
            )
            self.indices[freq_id] = new_indices
        #####

        q_freq = batch_gathering(
            self.codebook, 
            self.indices.to(torch.int32)).reshape_as(orig_freq_signs)
        q_freq *= orig_freq_signs
        wq = self.haar_transform.inverse(q_freq) * self.scales
        return wq

In [29]:
#data = init_data['model.layers.0.self_attn.q_proj']
data = init_data['model.layers.0.mlp.down_proj']

H = data['hess'].to(DTYPE).to(DEVICE)
w = data['weight'].to(DTYPE).to(DEVICE)
kmeans_data = data['kmeans_data']
codebook = kmeans_data['codebook'].to(DTYPE).to(DEVICE)
vector_dim = codebook.shape[-1]
indices = kmeans_data['indices'].to(DTYPE).to(DEVICE)
scales = kmeans_data['scales'].to(DTYPE).to(DEVICE)
haar_freq_shape = kmeans_data['signs']['shape']
haar_freq_signs = 2 * unpack_bool_tensor(kmeans_data['signs']['packed'].to(DEVICE), haar_freq_shape) - 1

In [37]:
gd_init = GDInit(
	codebook=codebook.clone(), 
	indices=indices.clone(), 
	scales=scales.clone(),
	haar_transform=HaarWavelet(level=data['kmeans_data']['metadata']['haar_decomposition_level']).to(DTYPE).to(DEVICE)
)

trainable_params = [
	#gd_init.codebook,
	gd_init.scales,
	#gd_init.haar_transform.forward_conv.weight
]

optimizer = torch.optim.Adam(params=trainable_params, lr=1e-1)

n_steps = 50

for i in range(n_steps):
	wq = gd_init.forward(w)
	loss = L2LossWithHess(w, wq) #, H**0.5)

	optimizer.zero_grad()
	loss.backward()
	#print(gd_init.scales.grad)
	optimizer.step()
	
	if i%5==0:
		with torch.no_grad():
			print(L2LossWithHess(w, wq))

  new_indices = torch.tensor(


tensor(2.7685e-05, device='cuda:0')
tensor(2.7681e-05, device='cuda:0')
tensor(2.7682e-05, device='cuda:0')
tensor(2.7686e-05, device='cuda:0')
tensor(2.7688e-05, device='cuda:0')
tensor(2.7689e-05, device='cuda:0')
tensor(2.7689e-05, device='cuda:0')
tensor(2.7689e-05, device='cuda:0')
tensor(2.7688e-05, device='cuda:0')
tensor(2.7688e-05, device='cuda:0')
