In [5]:
# %env CUDA_VISIBLE_DEVICES=6
# %env TRANSFORMERS_CACHE=/mnt/LLM/hub
# %env OMP_NUM_THREADS=16

import os
import sys
sys.path.insert(0, '..')

# 设置CUDA_VISIBLE_DEVICES环境变量
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import time
import random
from tqdm.auto import trange
# import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from typing import Optional,Union,List

torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
from src.kmeans import find_nearest_cluster, fit_faiss_kmeans, fit_kmeans, fit_kmeans_1d
from src.utils import ellipsis, maybe_script
from src.aq import QuantizedWeight


  from .autonotebook import tqdm as notebook_tqdm


In [72]:
# quantized_weight = QuantizedWeight(
#     XTX=XTX, reference_weight=reference_weight, num_codebooks=num_codebooks,
#     nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits, 
#     out_group_size=out_group_size, in_group_size=in_group_size,
#     verbose=True, max_iter=init_max_iter,   # faster init, not tested
# )
num_codebooks = 1
nbits_per_codebook = 8
out_group_size = 1
in_group_size = 4
# batch_size = 16384
# beam_size = 1
# beam_search_epochs = 100
# sparsity_regularizer = 0
# print_frequency = 10
scale_nbits = 0    # 0 means no scales, 16 means no compression;
# codebook_values_nbits = 16  # less than 16 means we quantize codebooks as well
init_max_iter = 100


In [7]:
import torch
def fit_kmeans(
    data: torch.Tensor,
    k: int,
    max_iter: int = 1000,
    check_every: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-08,
    greedy_init: bool = False,
    block_size_vals: int = 2**30,
    devices: Optional[List[torch.device]] = None,
):
    """
    :param data: [nsamples, dim]
    :param k: number of centroids
    :param max_iter: run at most this many iterations
    :param check_every: check for convergence (allclose(new_centroids, old_centroids)) once in this many steps
    :param rtol: early stopping relative tolerance for centroids
    :param atol: early stopping absolute tolerance for centroids
    :param greedy_init: if True, init by greedily selecting the point that is farthest from any cluster
        if False (default), initialize with random points using pytorch global RNG
    :param block_size_vals: how many dot products to compute at a time
    :param devices: if specified, run kmeans in data-parallel mode across these devices
    :return: (clusters float[k, dim], data_indices int[nsamples], reconstructed_data: float[nsamples, dim])
    """
    if devices is None:
        devices = [data.device]

    if greedy_init:
        clusters = _kmeans_greedy_init(data, k)
    else:
        clusters = data[torch.randperm(data.shape[0])[:k], :]  # [k, dim]

    block_size = block_size_vals // k
    shard_size = (len(data) - 1) // len(devices) + 1
    data = [
        data[gi * shard_size : (gi + 1) * shard_size].to(devices[gi], non_blocking=True) for gi in range(len(devices))
    ]
    nearest_indices = [torch.empty(len(data[gi]), dtype=torch.int64, device=devices[gi]) for gi in range(len(devices))]
    clusters = [clusters.to(device, non_blocking=True) for device in devices]

    for i in range(max_iter):
        for block_start in range(0, shard_size, block_size):
            for gi in range(len(devices)):
                nearest_indices[gi][block_start : block_start + block_size] = torch.addmm(
                    torch.bmm(clusters[gi][:, None, :], clusters[gi][:, :, None]).flatten(),
                    data[gi][block_start : block_start + block_size],
                    clusters[gi].T,
                    beta=-0.5,
                ).argmax(1)
            # note: the above formula equals to - 0.5 || data[:, None, :] - clusters[None, :, :] || ^ 2 + const

        if len(devices) == 1:
            new_clusters = [
                clusters[0]
                .clone()
                .index_reduce_(dim=0, index=nearest_indices[0], source=data[0], reduce="mean", include_self=False)
            ]
        else:
            cluster_sums = [
                torch.zeros_like(clusters[gi])
                .index_add(dim=0, index=nearest_indices[gi], source=data[gi])
                .to(devices[0], non_blocking=True)
                for gi in range(len(devices))
            ]
            cluster_counts = [
                torch.bincount(nearest_indices[gi], minlength=k).to(devices[0], non_blocking=True)
                for gi in range(len(devices))
            ]
            for gi in range(1, len(devices)):
                cluster_sums[0] += cluster_sums[gi]
                cluster_counts[0] += cluster_counts[gi]

            new_clusters = [cluster_sums[0] / cluster_counts[0].unsqueeze(1).clamp_min(1)]
            new_clusters[0] += (cluster_counts[0].unsqueeze(1) == 0) * clusters[0]
            for gi in range(1, len(devices)):
                new_clusters.append(new_clusters[0].to(devices[gi], non_blocking=True))

        if i % check_every == 0:
            if torch.allclose(new_clusters[0], clusters[0], rtol=rtol, atol=atol):
                break
        clusters = new_clusters
    for block_start in range(0, shard_size, block_size):
        for gi in range(len(devices)):
            nearest_indices[gi][block_start : block_start + block_size] = torch.addmm(
                torch.bmm(clusters[gi][:, None, :], clusters[gi][:, :, None]).flatten(),
                data[gi][block_start : block_start + block_size],
                clusters[gi].T,
                beta=-0.5,
            ).argmax(1)

    clusters = clusters[0]
    nearest_indices = torch.cat([nearest_indices[gi].to(devices[0]) for gi in range(len(devices))], dim=0)
    reconstructed_data = clusters[nearest_indices]
    return clusters, nearest_indices, reconstructed_data
x = torch.randn(4096*4096//16//4, 4, device='cuda')
print(x.shape)
clusters, nearest_indices, reconstructed_data = fit_kmeans(x,256,1000)
reconstructed_data.shape

torch.Size([262144, 4])


AttributeError: 'tuple' object has no attribute 'shape'

In [76]:
device = torch.device('cuda:0') 
x = torch.load("/home/quant/test.pth",map_location=device)
reference_weight = x
num_codebooks = 1
nbits_per_codebook = 8
out_group_size = 1
in_group_size = 4
batch_size = 16384
beam_size = 1
beam_search_epochs = 100
sparsity_regularizer = 0
print_frequency = 10
scale_nbits = 0    # 0 means no scales, 16 means no compression;
# codebook_values_nbits = 16  # less than 16 means we quantize codebooks as well
init_max_iter = 500
quantized_weight_list = []
slip_x = x.split(x.shape[0]//16, dim=0)
for tensor in slip_x:
    XTX =torch.eye(tensor.shape[-1])
    quantized_weight = QuantizedWeight(
        XTX = XTX,
        reference_weight=tensor, num_codebooks=num_codebooks, 
        nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits, 
        out_group_size=out_group_size, in_group_size=in_group_size, 
        verbose=True, max_iter=init_max_iter,  
    )
    quantized_weight_list.append(quantized_weight)

initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.85it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.89it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.85it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.87it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.81it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.65it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.81it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.87it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.92it/s]
initializing with kmeans: 100%|██████████| 1/1 [00:00<00:00,  1.

In [74]:
# run.log({"Avg_bits": quantized_weight.estimate_nbits_per_parameter()})
# print("AVG bits:", quantized_weight.estimate_nbits_per_parameter())
# print(quantized_weight.parameters())
param_list = [param for model in quantized_weight_list for param in model.parameters()]

opt = torch.optim.Adam(param_list, lr=1e-4, betas=(0.0, 0.95), amsgrad=True)

In [75]:
print_frequency = 20
for epoch in range(1000):
    XTX = torch.eye(4096).cuda()
    start = time.perf_counter()
    now_weight = torch.cat([cnt_weight() for cnt_weight in quantized_weight_list],dim = 0)
    # print(now_weight.shape)
    delta_weight = (now_weight - reference_weight).double()
    loss = (delta_weight).flatten() @ delta_weight.flatten() / len(delta_weight)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    # if epoch % print_frequency == 0:
    print(f"loss={loss.item():.10f}\t",
            f"time_on_epoch {epoch} = {time.perf_counter() - start}")
    # if (epoch + 1) % beam_search_epochs == 0:
    #     for quantized_weight in quantized_weight_list:
    #         quantized_weight.beam_search_update_codes_(
    #             XTX, reference_weight, beam_size=beam_size, sparsity_regularizer=sparsity_regularizer,
    #             dim_rng=random.Random(), verbose=True)

    #         if sparsity_regularizer != 0:
    #             sparsity_rate = ((quantized_weight.codes == 0).sum() / quantized_weight.codes.numel()).item()
    #             print(f"Sparsity rate {sparsity_rate:.5f}")

loss=0.0735402556	 time_on_epoch 0 = 0.011769775301218033
loss=nan	 time_on_epoch 1 = 0.011287960223853588
loss=nan	 time_on_epoch 2 = 0.011282090097665787
loss=nan	 time_on_epoch 3 = 0.011256150901317596
loss=nan	 time_on_epoch 4 = 0.011364913545548916
loss=nan	 time_on_epoch 5 = 0.01133693102747202
loss=nan	 time_on_epoch 6 = 0.01136601623147726
loss=nan	 time_on_epoch 7 = 0.011247285641729832
loss=nan	 time_on_epoch 8 = 0.009272407740354538
loss=nan	 time_on_epoch 9 = 0.009311560541391373
loss=nan	 time_on_epoch 10 = 0.009423838928341866
loss=nan	 time_on_epoch 11 = 0.009344280697405338
loss=nan	 time_on_epoch 12 = 0.009240148589015007
loss=nan	 time_on_epoch 13 = 0.009282426908612251
loss=nan	 time_on_epoch 14 = 0.009432234801352024
loss=nan	 time_on_epoch 15 = 0.009226513095200062
loss=nan	 time_on_epoch 16 = 0.00909573957324028
loss=nan	 time_on_epoch 17 = 0.00908541027456522
loss=nan	 time_on_epoch 18 = 0.009235299192368984
loss=nan	 time_on_epoch 19 = 0.009088948369026184
loss=

KeyboardInterrupt: 

In [29]:
quantized_weight.codes.shape

torch.Size([16, 4096, 64, 1])