In [1]:
import torch
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import time
import random
from tqdm.auto import trange
# import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from typing import Optional,Union,List

torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
def find_nearest_cluster(data, clusters, block_size_vals: int = 2**30, devices: Optional[List[torch.device]] = None):
    """Find nearest clusters for each batch of data and return their indices"""
    if devices is None:
        devices = [data.device]
    block_size = block_size_vals // len(clusters)
    shard_size = (len(data) - 1) // len(devices) + 1
    data = [
        data[gi * shard_size : (gi + 1) * shard_size].to(devices[gi], non_blocking=True) for gi in range(len(devices))
    ]
    nearest_indices = [torch.empty(len(data[gi]), dtype=torch.int64, device=devices[gi]) for gi in range(len(devices))]
    clusters = [clusters.to(device, non_blocking=True) for device in devices]

    for block_start in range(0, shard_size, block_size):
        for gi in range(len(devices)):
            nearest_indices[gi][block_start : block_start + block_size] = torch.addmm(
                torch.bmm(clusters[gi][:, None, :], clusters[gi][:, :, None]).flatten(),
                data[gi][block_start : block_start + block_size],
                clusters[gi].T,
                beta=-0.5,
            ).argmax(1)
    clusters = clusters[0]
    nearest_indices = torch.cat([nearest_indices[gi].to(devices[0]) for gi in range(len(devices))], dim=0)
    reconstructed_data = clusters[nearest_indices]
    return nearest_indices, reconstructed_data
def fit_kmeans(
    data: torch.Tensor,
    k: int,
    max_iter: int = 1000,
    check_every: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-08,
    greedy_init: bool = False,
    block_size_vals: int = 2**30,
    devices: Optional[List[torch.device]] = None,
):
    """
    :param data: [nsamples, dim]
    :param k: number of centroids
    :param max_iter: run at most this many iterations
    :param check_every: check for convergence (allclose(new_centroids, old_centroids)) once in this many steps
    :param rtol: early stopping relative tolerance for centroids
    :param atol: early stopping absolute tolerance for centroids
    :param greedy_init: if True, init by greedily selecting the point that is farthest from any cluster
        if False (default), initialize with random points using pytorch global RNG
    :param block_size_vals: how many dot products to compute at a time
    :param devices: if specified, run kmeans in data-parallel mode across these devices
    :return: (clusters float[k, dim], data_indices int[nsamples], reconstructed_data: float[nsamples, dim])
    """
    if devices is None:
        devices = [data.device]

    if greedy_init:
        clusters = _kmeans_greedy_init(data, k)
    else:
        clusters = data[torch.randperm(data.shape[0])[:k], :]  # [k, dim]

    block_size = block_size_vals // k
    shard_size = (len(data) - 1) // len(devices) + 1
    data = [
        data[gi * shard_size : (gi + 1) * shard_size].to(devices[gi], non_blocking=True) for gi in range(len(devices))
    ]
    nearest_indices = [torch.empty(len(data[gi]), dtype=torch.int64, device=devices[gi]) for gi in range(len(devices))]
    clusters = [clusters.to(device, non_blocking=True) for device in devices]

    for i in range(max_iter):
        for block_start in range(0, shard_size, block_size):
            for gi in range(len(devices)):
                nearest_indices[gi][block_start : block_start + block_size] = torch.addmm(
                    torch.bmm(clusters[gi][:, None, :], clusters[gi][:, :, None]).flatten(),
                    data[gi][block_start : block_start + block_size],
                    clusters[gi].T,
                    beta=-0.5,
                ).argmax(1)
            # note: the above formula equals to - 0.5 || data[:, None, :] - clusters[None, :, :] || ^ 2 + const

        if len(devices) == 1:
            new_clusters = [
                clusters[0]
                .clone()
                .index_reduce_(dim=0, index=nearest_indices[0], source=data[0], reduce="mean", include_self=False)
            ]
        else:
            cluster_sums = [
                torch.zeros_like(clusters[gi])
                .index_add(dim=0, index=nearest_indices[gi], source=data[gi])
                .to(devices[0], non_blocking=True)
                for gi in range(len(devices))
            ]
            cluster_counts = [
                torch.bincount(nearest_indices[gi], minlength=k).to(devices[0], non_blocking=True)
                for gi in range(len(devices))
            ]
            for gi in range(1, len(devices)):
                cluster_sums[0] += cluster_sums[gi]
                cluster_counts[0] += cluster_counts[gi]

            new_clusters = [cluster_sums[0] / cluster_counts[0].unsqueeze(1).clamp_min(1)]
            new_clusters[0] += (cluster_counts[0].unsqueeze(1) == 0) * clusters[0]
            for gi in range(1, len(devices)):
                new_clusters.append(new_clusters[0].to(devices[gi], non_blocking=True))

        if i % check_every == 0:
            if torch.allclose(new_clusters[0], clusters[0], rtol=rtol, atol=atol):
                break
        clusters = new_clusters
    for block_start in range(0, shard_size, block_size):
        for gi in range(len(devices)):
            nearest_indices[gi][block_start : block_start + block_size] = torch.addmm(
                torch.bmm(clusters[gi][:, None, :], clusters[gi][:, :, None]).flatten(),
                data[gi][block_start : block_start + block_size],
                clusters[gi].T,
                beta=-0.5,
            ).argmax(1)

    clusters = clusters[0]
    nearest_indices = torch.cat([nearest_indices[gi].to(devices[0]) for gi in range(len(devices))], dim=0)
    reconstructed_data = clusters[nearest_indices]
    return clusters, nearest_indices, reconstructed_data
x = torch.randn(4096*4096//16//4, 4, device='cuda')
clusters, nearest_indices, reconstructed_data = fit_kmeans(x,256,1000)
reconstructed_data.shape

  from .autonotebook import tqdm as notebook_tqdm
  clusters[0]


torch.Size([262144, 4])

In [10]:
device = torch.device('cuda:0') 
x = torch.load("/home/quant/test.pth",map_location=device)

In [29]:
# 码本量化过程
# 1. 按列进行除以scale
# 2. 按列切分，因为权重矩阵是转置后的
# 3. kmeans初始化
org_weight = x
codebook_num_bits = 4
codebook_num = 2**(codebook_num_bits)
centroids_num = 256
# 计算每一行的二范数
scales = org_weight.norm(p=2, dim=1, keepdim=True)
# nn.Parameter(scales, requires_grad=True)
# 每一行除以其对应的范数
normalized_tensor = org_weight / scales
weight_list = org_weight.split(org_weight.shape[0]//codebook_num,dim = 0)
clusters_list = []
nearest_indices_list = []
reconstructed_data_list = []
for weight in weight_list:
    clusters, nearest_indices, reconstructed_data=fit_kmeans(weight.view(-1,4),k = centroids_num)
    clusters_list.append(clusters.unsqueeze(0))
    nearest_indices_list.append(nearest_indices.unsqueeze(0))
    reconstructed_data_list.append(reconstructed_data.view(weight.shape))
clusters_merge = torch.cat(clusters_list,dim = 0)
nearest_indices_merge = torch.cat(nearest_indices_list,dim = 0)
reconstructed_data_merge = torch.cat(reconstructed_data_list,dim = 0)


In [30]:
reconstructed_data_merge.shape

torch.Size([4096, 4096])

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# 假设我们有一个词汇表大小为10，嵌入维度为3
embedding_dim = 3
num_embeddings = 10

# 创建EmbeddingBag对象
embedding_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode='mean')

# 一个示例输入，假设0, 1, 2是词汇表中的索引
input = torch.LongTensor([1, 2, 3, 4, 5])

# offsets用于告诉EmbeddingBag每个“bag”（即句子或文档）的起始位置
# 在这个例子中，我们只有一个bag，它从索引0开始
offsets = torch.LongTensor([0])
offsets = torch.arange(0, codes.size(0), dtype=torch.long)
# 通过EmbeddingBag获取嵌入结果
output = F.embedding_bag(offsets,input)

print(output) 

ValueError: weight has to be a 2D Tensor, but got Tensor of dimension 1

In [27]:
import torch
import torch.nn.functional as F
num_embeddings = 10
embedding_dim =  5
embedding_matrix = torch.randn(num_embeddings, embedding_dim)
input_ids = torch.LongTensor([2, 1, 2])  # 示例输入
offsets = torch.LongTensor([0, 1,2])  # 每个序列的开始位置
output = F.embedding_bag(input_ids, embedding_matrix,offsets, mode='sum')
print(embedding_matrix)
print(output)

tensor([[-0.9365,  1.9752,  1.7521, -0.4886, -0.7203],
        [ 0.6167,  0.8720,  1.4587,  0.0234,  0.4352],
        [-0.1475,  0.0564, -0.3602,  0.8113,  0.5882],
        [ 0.1575, -1.2660,  1.6896, -0.3704, -0.1198],
        [-1.0532,  0.2494,  0.5768, -0.1848, -0.2209],
        [ 0.7181,  0.0438, -0.7713,  0.2479, -0.6923],
        [-1.6777, -1.3979,  0.4613, -0.1491, -0.4602],
        [ 0.9126, -0.0039, -1.0053,  0.0780,  1.3526],
        [-0.3347, -0.5506,  0.4579, -0.0217, -0.9277],
        [ 0.5397,  0.9202,  0.6600,  0.8914, -0.2042]])
tensor([[-0.1475,  0.0564, -0.3602,  0.8113,  0.5882],
        [ 0.6167,  0.8720,  1.4587,  0.0234,  0.4352],
        [-0.1475,  0.0564, -0.3602,  0.8113,  0.5882]])


In [26]:
embedding_matrix = torch.rand((2,2,2))
print(embedding_matrix)
embedding_matrix.flatten(0,1)

tensor([[[0.6272, 0.2958],
         [0.2129, 0.6498]],

        [[0.5710, 0.5451],
         [0.1002, 0.7316]]])


tensor([[0.6272, 0.2958],
        [0.2129, 0.6498],
        [0.5710, 0.5451],
        [0.1002, 0.7316]])

In [76]:
import torch.nn as nn
import math
import torch
import os
import sys
# os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import time
import random
import math
from tqdm.auto import trange
# import ipynbname  # pip install ipynbname

import torch.nn as nn
import torch.nn.functional as F
import transformers
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

from typing import Optional,Union,List
def fit_kmeans(
    data: torch.Tensor,
    k: int,
    max_iter: int = 100,
    check_every: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-08,
    greedy_init: bool = False,
    block_size_vals: int = 2**30,
    devices: Optional[List[torch.device]] = None,
):
    """
    :param data: [nsamples, dim]
    :param k: number of centroids
    :param max_iter: run at most this many iterations
    :param check_every: check for convergence (allclose(new_centroids, old_centroids)) once in this many steps
    :param rtol: early stopping relative tolerance for centroids
    :param atol: early stopping absolute tolerance for centroids
    :param greedy_init: if True, init by greedily selecting the point that is farthest from any cluster
        if False (default), initialize with random points using pytorch global RNG
    :param block_size_vals: how many dot products to compute at a time
    :param devices: if specified, run kmeans in data-parallel mode across these devices
    :return: (clusters float[k, dim], data_indices int[nsamples], reconstructed_data: float[nsamples, dim])
    """
    if devices is None:
        devices = [data.device]

    if greedy_init:
        clusters = _kmeans_greedy_init(data, k)
    else:
        clusters = data[torch.randperm(data.shape[0])[:k], :]  # [k, dim]

    block_size = block_size_vals // k
    shard_size = (len(data) - 1) // len(devices) + 1
    data = [
        data[gi * shard_size : (gi + 1) * shard_size].to(devices[gi], non_blocking=True) for gi in range(len(devices))
    ]
    nearest_indices = [torch.empty(len(data[gi]), dtype=torch.int64, device=devices[gi]) for gi in range(len(devices))]
    clusters = [clusters.to(device, non_blocking=True) for device in devices]

    for i in range(max_iter):
        for block_start in range(0, shard_size, block_size):
            for gi in range(len(devices)):
                nearest_indices[gi][block_start : block_start + block_size] = torch.addmm(
                    torch.bmm(clusters[gi][:, None, :], clusters[gi][:, :, None]).flatten(),
                    data[gi][block_start : block_start + block_size],
                    clusters[gi].T,
                    beta=-0.5,
                ).argmax(1)
            # note: the above formula equals to - 0.5 || data[:, None, :] - clusters[None, :, :] || ^ 2 + const

        if len(devices) == 1:
            new_clusters = [
                clusters[0]
                .clone()
                .index_reduce_(dim=0, index=nearest_indices[0], source=data[0], reduce="mean", include_self=False)
            ]
        else:
            cluster_sums = [
                torch.zeros_like(clusters[gi])
                .index_add(dim=0, index=nearest_indices[gi], source=data[gi])
                .to(devices[0], non_blocking=True)
                for gi in range(len(devices))
            ]
            cluster_counts = [
                torch.bincount(nearest_indices[gi], minlength=k).to(devices[0], non_blocking=True)
                for gi in range(len(devices))
            ]
            for gi in range(1, len(devices)):
                cluster_sums[0] += cluster_sums[gi]
                cluster_counts[0] += cluster_counts[gi]

            new_clusters = [cluster_sums[0] / cluster_counts[0].unsqueeze(1).clamp_min(1)]
            new_clusters[0] += (cluster_counts[0].unsqueeze(1) == 0) * clusters[0]
            for gi in range(1, len(devices)):
                new_clusters.append(new_clusters[0].to(devices[gi], non_blocking=True))

        if i % check_every == 0:
            if torch.allclose(new_clusters[0], clusters[0], rtol=rtol, atol=atol):
                break
        clusters = new_clusters
    for block_start in range(0, shard_size, block_size):
        for gi in range(len(devices)):
            nearest_indices[gi][block_start : block_start + block_size] = torch.addmm(
                torch.bmm(clusters[gi][:, None, :], clusters[gi][:, :, None]).flatten(),
                data[gi][block_start : block_start + block_size],
                clusters[gi].T,
                beta=-0.5,
            ).argmax(1)

    clusters = clusters[0]
    nearest_indices = torch.cat([nearest_indices[gi].to(devices[0]) for gi in range(len(devices))], dim=0)
    reconstructed_data = clusters[nearest_indices]
    return clusters, nearest_indices, reconstructed_data

def get_nearest_indices(
    S: torch.Tensor, #重要性
    W,
    shape, # 权重的原始形状
    centroids,
    devices: Optional[List[torch.device]] = None,
):
    if S is None:
        S = torch.zeros(shape[0]).to(W.device)
        S[0] = 1
        # S[0] = 1
    # if devices is None:
    #     devices = [data.device]
    # W  N*D
    # centroids n_centroids*D
    a1 = W.view(-1,centroids.shape[-1]).unsqueeze(1)
    # S为每一行的重要性权重，将其扩展成矩阵形式，方便计算
    s1 = S.repeat_interleave(shape[1]).view(a1.shape)
    b1 = centroids.unsqueeze(0)
    print(a1.shape)
    print(b1.shape)
    print(s1.shape)
    dist = ((a1-b1)**2*s1).sum(-1)

    # assignments = []
    assignments = dist.argmin(-1)
    return assignments
def quantize(org_weight,codebook_num = 2,centroids_num = 256,block_size = 64,centroid_len = 8):
    # 计算每一行的二范数
    # max_matrix = get_max(org_weight)
    reshspe_weight = org_weight.view(-1,block_size)
    scales = reshspe_weight.norm(p=2, dim=1, keepdim=True).float()
    # nn.Parameter(scales, requires_grad=True)
    # 每一行除以其对应的范数
    normalized_tensor = (reshspe_weight / scales)
    weight_list = normalized_tensor.split(normalized_tensor.shape[0]//codebook_num,dim = 0)
    print(len(weight_list))
    clusters_list = []
    nearest_indices_list = []
    reconstructed_data_list = []
    for weight in weight_list:
        clusters, nearest_indices, reconstructed_data=fit_kmeans(weight.view(-1,centroid_len),k = centroids_num,max_iter= 100)
        clusters_list.append(clusters.unsqueeze(0))
        nearest_indices_list.append(nearest_indices.unsqueeze(0))
        reconstructed_data_list.append(reconstructed_data.view(weight.shape))
        print(nearest_indices.max())
    clusters_merge = torch.cat(clusters_list,dim = 0)
    nearest_indices_merge = torch.cat(nearest_indices_list,dim = 0)
    reconstructed_data_merge = (torch.cat(reconstructed_data_list,dim = 0)*scales)
    reconstructed_data_merge = reconstructed_data_merge.view(org_weight.shape)
    print(reconstructed_data_merge.max())
    return clusters_merge,nearest_indices_merge,reconstructed_data_merge,scales
def col_wise_class(org_weight,class_num,max_iter = 100):
    clusters, nearest_indices, reconstructed_data=fit_kmeans(org_weight,k = class_num,max_iter= 500)
    return nearest_indices
class Quantization(nn.Module):
    
    def __init__(self,layer,codebook_num = 2,centroids_num = 256,bolck_size = 128,centroid_len = 4) -> None:
        super().__init__()
        self.layer = layer.float()
        self.dev = self.layer.weight.device 
        W = self.layer.weight.data.clone().cuda()
        self.rows = W.shape[0]
        self.columns = W.shape[1]
        self.H = torch.zeros((self.columns, self.columns), device=self.dev,dtype= torch.float64)
        self.nsamples = 0
        self.codebook_num = codebook_num
        self.centroids_num = centroids_num
        self.centroid_len = centroid_len
        clusters_merge,nearest_indices_merge,reconstructed_data_merge,scales \
            = quantize(W.float(),codebook_num=codebook_num,block_size=bolck_size,centroid_len=centroid_len)
        self.codebooks = nn.Parameter(clusters_merge,requires_grad=True)
        self.scales = nn.Parameter(scales,requires_grad=True)
        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.codes = nn.Parameter(nearest_indices_merge,requires_grad=False)
        self.reconstructed_data_merge = reconstructed_data_merge.to(self.dev)
        self.bolck_size=bolck_size
    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()
        self.H *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        self.scaler_row *= self.nsamples / (self.nsamples+tmp)
        inp = inp.type(torch.float32)
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2  / self.nsamples
        inp = math.sqrt(2 / self.nsamples) * inp.float()
        self.H += inp.matmul(inp.t())
        
    def differentiable_dequantize(self):
        codebook_num = self.codebook_num
        codes = self.codes.clone()
        for i in range(codebook_num):
            codes[i,:]+=self.centroids_num*i
        codebook_offsets = torch.arange(0,self.layer.weight.data.numel()//self.centroid_len).to(self.dev)
        reconstruct_weight = F.embedding_bag(codes.flatten(),self.codebooks.flatten(0,1),codebook_offsets,mode="sum")
        return (reconstruct_weight.view(-1,self.bolck_size)*self.scales).view((self.rows, self.columns))
    def update_index(self):
        reshspe_weight = self.layer.weight.data.clone().cuda().view(-1,self.bolck_size)
        # scales = reshspe_weight.norm(p=2, dim=1, keepdim=True).float()
        # nn.Parameter(scales, requires_grad=True)
        # 每一行除以其对应的范数
        normalized_tensor = (reshspe_weight / self.scales)
        weight_list = normalized_tensor.split(normalized_tensor.shape[0]//self.codebook_num,dim = 0)
        nearest_indices_list = []
        index = 0
        for weight in weight_list:
            nearest_indices=get_nearest_indices(S=None,W = weight.view(-1,self.centroid_len),shape = weight.shape,centroids=self.codebooks[index])
            nearest_indices_list.append(nearest_indices.unsqueeze(0))
            print("shape:",nearest_indices.shape)
            index +=1
        nearest_indices_merge = torch.cat(nearest_indices_list,dim = 0)
        print(nearest_indices_merge.shape)
        self.codes.data  =nearest_indices_merge
    def forward(self):
        weight = self.differentiable_dequantize()
        return weight.to(self.dev)
    def prune_wanda(self,sparsity_ratio = 0.2):
        W_metric = torch.abs(self.weight.data) * torch.sqrt(self.scaler_row.reshape((1,-1)))
        W_mask = (torch.zeros_like(W_metric) == 1)
        sort_res = torch.sort(W_metric, dim=-1, stable=True)
        indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity_ratio)]
        W_mask.scatter_(1, indices, True)
        self.layer.weight.data[W_mask] = 0

In [77]:
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
inp = torch.nn.Linear(1024,1024).cuda()

In [78]:
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
a  = Quantization(layer=inp,bolck_size=128,codebook_num=2)
cnt_codes = a.codes.clone()
print(cnt_codes)

2
tensor(255, device='cuda:0')
tensor(255, device='cuda:0')
tensor(0.0293, device='cuda:0')
tensor([[ 38, 192, 214,  ..., 114, 144, 183],
        [ 66, 221, 174,  ...,  87, 230,  48]], device='cuda:0')


In [79]:
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
a.update_index()
print((a.codes!=cnt_codes).sum())

torch.Size([131072, 1, 4])
torch.Size([1, 256, 4])
torch.Size([131072, 1, 4])
shape: torch.Size([131072])
torch.Size([131072, 1, 4])
torch.Size([1, 256, 4])
torch.Size([131072, 1, 4])
shape: torch.Size([131072])
torch.Size([2, 131072])
tensor(261121, device='cuda:0')


In [70]:
cnt_codes.shape

torch.Size([2, 131072])

In [91]:
print(a.codebooks)

Parameter containing:
tensor([[[ 0.1130,  0.1182, -0.0421, -0.1128],
         [-0.1178,  0.1117,  0.0520, -0.0526],
         [ 0.0144, -0.0559,  0.0112, -0.0201],
         ...,
         [-0.1221, -0.1143, -0.0078,  0.1190],
         [ 0.0592, -0.0468, -0.0631, -0.1185],
         [ 0.0466, -0.0425, -0.1206, -0.0551]],

        [[ 0.1216, -0.0103,  0.0480,  0.1091],
         [ 0.1179, -0.0255, -0.0401, -0.1157],
         [ 0.0006,  0.0678, -0.0354,  0.0707],
         ...,
         [-0.0338, -0.1109,  0.1213,  0.1103],
         [-0.1174,  0.1184, -0.0505,  0.1202],
         [-0.1136,  0.1188,  0.1160,  0.1093]]], device='cuda:0',
       requires_grad=True)


In [93]:
a.update_index()

torch.Size([2097152, 1, 4])
torch.Size([1, 256, 4])
torch.Size([2097152, 1, 4])
torch.Size([2097152, 1, 4])
torch.Size([1, 256, 4])
torch.Size([2097152, 1, 4])


In [94]:
a.codes-cnt_codes

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0')

In [92]:
codebook_num = a.codebook_num
codebook_flatten = a.codebooks.flatten(0,1)
codes = a.codes
print(a.codebooks.shape)
for i in range(1,codebook_num):
    codes[i,:]+=a.centroids_num
codebook_offsets = torch.arange(0,a.layer.weight.data.numel()//4)
print(codes)
print(codebook_offsets.shape)
reconstruct_weight = F.embedding_bag(codes.flatten(),a.codebooks.flatten(0,1),codebook_offsets,mode="sum")
print(reconstruct_weight.view(a.layer.weight.data.shape)*a.scales)

torch.Size([2, 256, 4])
Parameter containing:
tensor([[  2, 123, 203,  ..., 254, 114, 165],
        [435, 318, 304,  ..., 489, 418, 379]])
torch.Size([262144])
tensor([[ 0.0065,  0.0081,  0.0252,  ..., -0.0248,  0.0091, -0.0238],
        [-0.0232,  0.0039, -0.0068,  ...,  0.0117,  0.0238, -0.0078],
        [-0.0062, -0.0201, -0.0252,  ...,  0.0076, -0.0246, -0.0067],
        ...,
        [ 0.0053,  0.0101,  0.0072,  ...,  0.0139, -0.0224,  0.0253],
        [-0.0261,  0.0053, -0.0069,  ...,  0.0244, -0.0058, -0.0122],
        [ 0.0257,  0.0243,  0.0063,  ..., -0.0098,  0.0127,  0.0253]],
       grad_fn=<MulBackward0>)


In [4]:
import torch

# 假设的数据
N = 5
A = torch.tensor([0, 1, 1, 0, 0])  # 类别Tensor
B = torch.randn(N, 10)  # 假设的矩阵B

# 根据A中的类别对B进行重排
sorted_indices = torch.argsort(A)
B_sorted = torch.index_select(B, 0, sorted_indices)

# 复原到原始顺序
# 首先获取复原时的索引，即对sorted_indices进行再次排序的索引
restore_indices = torch.argsort(sorted_indices)
B_restored = torch.index_select(B_sorted, 0, restore_indices)

print("Original B:", B)
print("Sorted B:", B_sorted)
print("Restored B:", B_restored)

Original B: tensor([[-1.7957, -0.5694,  1.1150, -1.5595,  1.2649,  0.5069, -1.8813,  0.7021,
          1.8290,  0.5245],
        [-1.5341, -0.6747,  0.2805,  1.1678,  2.6085, -0.4850,  0.1465,  0.5565,
         -0.7984, -0.4044],
        [ 0.4942, -1.0950, -0.8076, -0.5437,  0.9076, -0.8676,  0.4262, -0.4891,
         -1.2001, -0.6039],
        [-0.9421,  0.8060, -0.5594,  0.2004,  0.2072,  0.0192, -1.1937,  1.2918,
         -0.1507, -0.2783],
        [-0.5567, -1.7914, -0.4568, -0.2659, -0.5638, -0.9680, -0.4975,  0.5015,
         -1.4597,  0.1531]])
Sorted B: tensor([[-1.7957, -0.5694,  1.1150, -1.5595,  1.2649,  0.5069, -1.8813,  0.7021,
          1.8290,  0.5245],
        [-0.9421,  0.8060, -0.5594,  0.2004,  0.2072,  0.0192, -1.1937,  1.2918,
         -0.1507, -0.2783],
        [-0.5567, -1.7914, -0.4568, -0.2659, -0.5638, -0.9680, -0.4975,  0.5015,
         -1.4597,  0.1531],
        [-1.5341, -0.6747,  0.2805,  1.1678,  2.6085, -0.4850,  0.1465,  0.5565,
         -0.7984, -0.404

In [155]:
import torch
l = torch.nn.Linear(1024,1024,bias=False)
x = torch.rand((2,1024))
res1 = l(x)

In [156]:
res1

tensor([[ 0.3014,  0.3794,  0.0551,  ..., -0.1906, -0.1463,  0.5322],
        [ 0.2058,  0.1400,  0.3062,  ..., -0.0755,  0.1639,  0.8108]],
       grad_fn=<MmBackward0>)

In [170]:
res2 = x@l.weight.data.t()
print(res2)

tensor([[ 0.3014,  0.2058],
        [ 0.3794,  0.1400],
        [ 0.0551,  0.3062],
        ...,
        [-0.1906, -0.0755],
        [-0.1463,  0.1639],
        [ 0.5322,  0.8108]])


In [165]:
res2

tensor([[ 0.0025, -0.0523],
        [ 0.1277,  0.2450],
        [ 0.1829, -0.1117],
        ...,
        [-0.0798, -0.0165],
        [ 0.0814,  0.2337],
        [-0.2363, -0.3942]])

In [69]:
a= torch.rand(5,6).cpu().double()*100
b= torch.rand(5,6).cpu().double()*666
c=torch.rand(6,7).cpu().double()

In [82]:
scales = a.norm(p=2, dim=1, keepdim=True).float().double()*10000

In [83]:
((a-b)*scales)@c

tensor([[-1.1561e+09, -1.2428e+09, -1.0333e+09, -1.3948e+09, -1.3496e+09,
         -1.4515e+09, -1.3786e+09],
        [-2.1541e+08, -1.7501e+08, -5.1016e+08, -2.0091e+08, -5.9879e+08,
         -6.4112e+08, -3.6177e+08],
        [-1.6216e+09, -8.0214e+08, -1.6077e+09, -7.1721e+08, -2.2656e+09,
         -2.1320e+09, -1.4620e+09],
        [-1.2804e+09, -9.6570e+08, -1.8153e+09, -1.5185e+09, -1.8825e+09,
         -1.7938e+09, -1.3736e+09],
        [-1.6828e+09, -9.2218e+08, -1.2531e+09, -1.0505e+09, -1.7726e+09,
         -1.5350e+09, -1.3210e+09]], dtype=torch.float64)

In [92]:
scales = a.norm(p=2, dim=1, keepdim=True).float().double()*10
(((a-b)@c)*scales) - (((a*scales-b*scales))@c)
# scale 的运算可以提到最外面

tensor([[-2.3283e-10,  0.0000e+00,  0.0000e+00, -2.3283e-10,  2.3283e-10,
         -2.3283e-10,  2.3283e-10],
        [ 2.9104e-11, -2.9104e-11,  5.8208e-11,  2.9104e-11,  0.0000e+00,
          0.0000e+00, -5.8208e-11],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -4.6566e-10,
          4.6566e-10, -2.3283e-10],
        [ 2.3283e-10,  1.1642e-10,  0.0000e+00, -2.3283e-10,  0.0000e+00,
          2.3283e-10, -2.3283e-10],
        [ 2.3283e-10,  1.1642e-10, -2.3283e-10,  0.0000e+00,  0.0000e+00,
         -2.3283e-10, -2.3283e-10]], dtype=torch.float64)

In [93]:
(((a-b)@c)*scales).norm(p=2)

tensor(7849233.9964, dtype=torch.float64)

In [95]:
(((a-b)@c)).norm(p=2)*scales

tensor([[7634229.0188],
        [5799491.6243],
        [9532574.3193],
        [7950872.9884],
        [6794302.0308]], dtype=torch.float64)

In [7]:
import torch
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
# 输入参数保持不变
M = 256
L_C = 4
L_V = 4096*4096

# 创建随机输入
C = torch.rand(M, L_C).cuda()
V = torch.rand(L_V).cuda()
inp = torch.rand(L_V).cuda()

# 重塑
num_segments = L_V // L_C
V_reshaped = V.view(num_segments, L_C)
inp_reshaped = inp.view(num_segments, L_C).cuda()

# 定义分批处理的大小
batch_size = 1024*1024  # 根据可用内存调整

# 初始化最佳匹配索引的列表
min_distance_indices = []

# 分批处理
for i in range(0, num_segments, batch_size):
    end_i = min(i + batch_size, num_segments)
    V_batch = V_reshaped[i:end_i].unsqueeze(1).expand(-1, M, -1)
    inp_batch = inp_reshaped[i:end_i].unsqueeze(1).expand(-1, M, -1)
    
    # 计算加权欧几里得距离
    weighted_diff = (C.unsqueeze(0).expand(end_i-i, -1, -1) - V_batch) ** 2 * inp_batch
    distances = torch.sqrt(weighted_diff.sum(dim=2))
    
    # 找到最佳匹配的索引
    batch_min_indices = torch.argmin(distances, dim=1)
    min_distance_indices.append(batch_min_indices)

# 合并结果
min_distance_indices1 = torch.cat(min_distance_indices)

print(f"每个子向量最佳匹配的码本向量索引: {min_distance_indices1}")

每个子向量最佳匹配的码本向量索引: tensor([ 99, 191, 234,  ..., 221, 153,  32], device='cuda:0')


In [8]:
def update_index(batch_size,important,)

OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 GiB. GPU 0 has a total capacty of 39.39 GiB of which 14.47 GiB is free. Including non-PyTorch memory, this process has 24.92 GiB memory in use. Of the allocated memory 21.31 GiB is allocated by PyTorch, and 3.13 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [13]:
min_distance_indices1-min_distance_indices2

tensor([0, 0, 0,  ..., 0, 0, 0])

In [32]:
lin = torch.nn.Linear(3,5,bias=False)

In [33]:
test = torch.rand(7,3)

In [34]:
lin(test) 

tensor([[ 0.2496, -0.2233,  0.3125,  0.1016, -0.2145],
        [ 0.2090, -0.4283,  0.2185,  0.1516, -0.2442],
        [ 0.1277, -0.4315,  0.1069,  0.1331, -0.2003],
        [ 0.4721, -0.5911,  0.5819,  0.2041, -0.4811],
        [ 0.2023, -0.2271,  0.2325,  0.1157, -0.1680],
        [ 0.2309, -0.3092,  0.2754,  0.1145, -0.2326],
        [ 0.1323, -0.4062,  0.1227,  0.1190, -0.2032]], grad_fn=<MmBackward0>)

In [38]:
test@lin.weight.data.T

tensor([[ 0.2496, -0.2233,  0.3125,  0.1016, -0.2145],
        [ 0.2090, -0.4283,  0.2185,  0.1516, -0.2442],
        [ 0.1277, -0.4315,  0.1069,  0.1331, -0.2003],
        [ 0.4721, -0.5911,  0.5819,  0.2041, -0.4811],
        [ 0.2023, -0.2271,  0.2325,  0.1157, -0.1680],
        [ 0.2309, -0.3092,  0.2754,  0.1145, -0.2326],
        [ 0.1323, -0.4062,  0.1227,  0.1190, -0.2032]])

In [2]:
import torch
a= torch.rand((2,6))
s= torch.ones(2)
s[1]=2
b= torch.rand((3,2))

In [8]:
(s.unsqueeze(0).expand(10, -1)).shape

torch.Size([10, 2])

In [33]:
# print(s1.shape)
a1 = a.view(-1,b.shape[1])
a1 = a1.unsqueeze(1)
s1 = s.repeat_interleave(a.shape[1]).view(a1.shape)
print(a1.shape)
b1 = b.unsqueeze(0)
print(b1.shape)
print(s1.shape)
print(((a1-b1)**2*s1).sum(-1))

torch.Size([6, 1, 2])
torch.Size([1, 3, 2])
torch.Size([6, 1, 2])
tensor([[0.1453, 0.7746, 0.1524],
        [0.0386, 0.5657, 0.0353],
        [0.4287, 0.1449, 0.2916],
        [1.6086, 0.8168, 1.2795],
        [1.1933, 0.3701, 0.8724],
        [0.0426, 1.7083, 0.1436]])


In [24]:
s2 = s1*2
print(((a1-b1)**2*s2))

tensor([[[3.9769e-01, 7.0685e-04, 3.0995e-01, 1.9730e-01],
         [7.3857e-03, 1.0877e+00, 1.7051e-03, 1.5028e-01],
         [1.3659e+00, 1.7461e-02, 1.0101e-01, 5.4490e-01],
         [8.7841e-01, 1.0223e-03, 4.7654e-01, 2.7275e-01],
         [1.6983e+00, 2.5771e-02, 6.7909e-02, 3.9137e-02],
         [2.8442e-01, 1.9815e-02, 1.9276e-04, 2.3848e-01]],

        [[1.3231e-02, 1.8867e-02, 1.7542e-01, 1.2636e-01],
         [1.8461e-01, 8.6892e-01, 8.7285e-01, 8.9370e-02],
         [4.2654e-01, 4.5669e-04, 1.6728e+00, 4.2179e-01],
         [1.7777e-01, 2.0375e-02, 8.1362e-02, 1.8796e-01],
         [6.2027e-01, 2.4766e-03, 5.1117e-01, 1.1907e-02],
         [3.1361e-04, 8.9980e-04, 9.7899e-01, 1.5970e-01]],

        [[3.8673e-01, 1.4961e+00, 7.6097e-01, 4.0581e-02],
         [5.9584e-03, 2.3602e-02, 1.2738e-01, 6.6549e-02],
         [1.3455e+00, 1.1330e+00, 4.8886e-06, 8.5630e-03],
         [8.6208e-01, 1.5093e+00, 1.0119e+00, 1.5222e-02],
         [1.6755e+00, 1.0733e+00, 3.3201e-01, 2.0053

In [30]:
H = torch.randn(4, 4)
H= H @ H.T
epsilon = 1e-6
A_reg = H + epsilon * torch.eye(H.size(0))
# H[diag, diag] += damp
print(H<0)
H = torch.linalg.cholesky(H)
print(H.shape)
H = torch.cholesky_inverse(H)
print(H.shape)
H = torch.linalg.cholesky(H, upper=True)
print(H.shape)

tensor([[False,  True,  True,  True],
        [ True, False,  True,  True],
        [ True,  True, False, False],
        [ True,  True, False, False]])
torch.Size([4, 4])
torch.Size([4, 4])
torch.Size([4, 4])


In [26]:
torch.eye(H.size(0))

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])

In [1]:
s1

NameError: name 's1' is not defined

In [4]:
from datasets import load_dataset
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
load_dataset("ivanzhouyq/RedPajama-Tiny", split="train")

ConnectionError: Couldn't reach 'ivanzhouyq/RedPajama-Tiny' on the Hub (ConnectionError)

In [38]:
import torch
x = torch.randn(10, 3, 4)  # 例如，形状为[2, 3, 4]
y = torch.randn(3, 4)
# 计算x的整体二范数
((x@y.t())**2).mean()

tensor(8.3071)

In [41]:
((x@y.t())**2).mean()

tensor(0.6681)

In [9]:
(x**2).mean()

tensor(1.3585)