In [125]:
import torch
import time as time

In [123]:
class Chunk:
    def __init__(self, X,Y=None):
        self.X = X
        self.Y = Y

        if self.Y == None:
            self.Y = self.X

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def sim_matrix(self, a, b, eps=1e-8):
        """
        Compute the cosine similarity between two matrices of vectors
        :param a: matrix of vectors (n x d)
        :param b: matrix of vectors (m x d)
        :param eps: added eps for numerical stability
        :return: scalar product between each vector of a and each vector of b (n x m)
        """
        a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
        a_norm = a / torch.clamp(a_n, min=eps)
        b_norm = b / torch.clamp(b_n, min=eps)
        sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
        return sim_mt
    
    def compute_sim_matrix(self, keep_n=10, chunk_size=100, verbose=True):
        """
        Compute the similarity matrix between X and Y and return the indices of the top-n elements as well as the distances
        Args:
        keep_n: number of elements to keep
        chunk_size: size of the chunks to split the data. This is useful to avoid memory issues
        """
        assert keep_n <= chunk_size, "keep_n should be less than or equal to chunk_size"
        assert keep_n <= self.Y.shape[0], "keep_n should be less or equal to the number of elements in Y"
        if self.device == "cuda":
            torch.cuda.empty_cache()
        indices = torch.zeros(self.X.shape[0], keep_n)
        distances = torch.zeros(self.X.shape[0], keep_n)

        splits_X = self.X.split(chunk_size,dim=0)
        split_lenght_X = [i.shape[0] for i in splits_X]

        splits_Y = self.Y.split(chunk_size,dim=0)
        split_lenghts_Y = [i.shape[0] for i in splits_Y]

        print(f"Number of chunks for X: {len(splits_X)}")
        print(f"Number of chunks for Y: {len(splits_Y)}")

        start = time.time()

        for k,i in enumerate(splits_X):
            top_n_all_fused = []
            top_n_all_fused_values = []
            y_dim = i.shape[0]
            for l,j in enumerate(splits_Y):
                if self.device == "cuda":
                    with torch.cuda.amp.autocast():
                        inter = self.sim_matrix(i.to(self.device),j.to(self.device))
                        inter, top_n = torch.topk(inter, k=keep_n, dim=1)
                        top_n_all_fused_values.append(inter)
                        top_n_all_fused.append(top_n+sum(split_lenghts_Y[:l]))
                else:
                    inter = self.sim_matrix(i,j)
                    inter, top_n = torch.topk(inter, k=keep_n, dim=1)
                    top_n_all_fused_values.append(inter)
                    top_n_all_fused.append(top_n+sum(split_lenghts_Y[:l]))

                if verbose == True:
                    print(f"Processing of chunk {k+1}/{len(splits_X)} with chunk {l+1}/{len(splits_Y)} done in {time.time()-start:2.3f}s")

            top_n_all_fused = torch.cat(top_n_all_fused,dim=1)
            top_n_all_fused_values = torch.cat(top_n_all_fused_values,dim=1)

            if self.device == "cuda":    
                with torch.cuda.amp.autocast():
                    val, top_n_all_fused_values = torch.topk(top_n_all_fused_values,k=keep_n,dim=1)
            else:
                val, top_n_all_fused_values = torch.topk(top_n_all_fused_values,k=keep_n,dim=1)

            comb = torch.cat([a[i].reshape(1,-1) for a,i in zip(top_n_all_fused,top_n_all_fused_values)],dim=0)

            indices[sum(split_lenght_X[:k]):sum(split_lenght_X[:k])+y_dim] = comb.cpu()
            distances[sum(split_lenght_X[:k]):sum(split_lenght_X[:k])+y_dim] = val.cpu()

            del val, comb, top_n_all_fused, top_n_all_fused_values, inter, top_n
            if self.device == "cuda":
                torch.cuda.empty_cache()
        return indices, distances
    
    def get_chunk_size(self):
        pass

    def verbose(self, *args):
        pass
        
        
    


In [124]:
a = torch.randn(5000,32)
b = torch.randn(10000,32)
c = Chunk(a,b)
k = c.compute_sim_matrix(keep_n=30, chunk_size=1000, verbose=True)


Number of chunks for X: 5
Number of chunks for Y: 10
Processing of chunk 1/5 with chunk 1/10 done in 0.006s
Processing of chunk 1/5 with chunk 2/10 done in 0.011s
Processing of chunk 1/5 with chunk 3/10 done in 0.016s
Processing of chunk 1/5 with chunk 4/10 done in 0.020s
Processing of chunk 1/5 with chunk 5/10 done in 0.025s
Processing of chunk 1/5 with chunk 6/10 done in 0.028s
Processing of chunk 1/5 with chunk 7/10 done in 0.032s
Processing of chunk 1/5 with chunk 8/10 done in 0.036s
Processing of chunk 1/5 with chunk 9/10 done in 0.040s
Processing of chunk 1/5 with chunk 10/10 done in 0.043s
Processing of chunk 2/5 with chunk 1/10 done in 0.058s
Processing of chunk 2/5 with chunk 2/10 done in 0.062s
Processing of chunk 2/5 with chunk 3/10 done in 0.065s
Processing of chunk 2/5 with chunk 4/10 done in 0.068s
Processing of chunk 2/5 with chunk 5/10 done in 0.071s
Processing of chunk 2/5 with chunk 6/10 done in 0.074s
Processing of chunk 2/5 with chunk 7/10 done in 0.078s
Processing 

In [102]:
k


(tensor([[0.0000e+00, 2.4620e+03, 3.4170e+03,  ..., 2.9430e+03, 2.3330e+03,
          4.9170e+03],
         [1.0000e+00, 2.5680e+03, 2.3110e+03,  ..., 3.7400e+03, 3.6840e+03,
          3.1560e+03],
         [2.0000e+00, 2.7620e+03, 3.3780e+03,  ..., 1.7100e+02, 4.9770e+03,
          3.7700e+02],
         ...,
         [4.9970e+03, 1.7640e+03, 5.3100e+02,  ..., 1.8880e+03, 1.5390e+03,
          8.3300e+02],
         [4.9980e+03, 3.3560e+03, 2.7890e+03,  ..., 4.4920e+03, 3.2230e+03,
          3.3140e+03],
         [4.9990e+03, 6.3400e+02, 3.4480e+03,  ..., 2.2960e+03, 2.9650e+03,
          1.9400e+02]]),
 tensor([[1.0000, 0.5896, 0.5683,  ..., 0.4407, 0.4396, 0.4383],
         [1.0000, 0.7273, 0.5525,  ..., 0.4221, 0.4193, 0.4191],
         [1.0000, 0.5572, 0.5523,  ..., 0.4383, 0.4376, 0.4369],
         ...,
         [1.0000, 0.5746, 0.5151,  ..., 0.4330, 0.4310, 0.4292],
         [1.0000, 0.6351, 0.6035,  ..., 0.4323, 0.4293, 0.4280],
         [1.0000, 0.5834, 0.5682,  ..., 0.4276, 0.4

In [60]:
b

tensor([[ 0.4237,  0.4108,  0.3189,  0.1641,  0.0533, -0.0518, -0.1069, -0.1566,
         -0.1658, -0.1773],
        [ 0.2841,  0.1914, -0.0286, -0.0307, -0.0742, -0.0869, -0.0958, -0.2334,
         -0.3007, -0.4513],
        [ 0.2100,  0.1110,  0.0564,  0.0510, -0.0655, -0.1925, -0.2202, -0.2214,
         -0.2465, -0.2558],
        [ 0.2392,  0.1191,  0.0660, -0.0197, -0.0996, -0.1177, -0.1702, -0.1759,
         -0.3280, -0.3513],
        [ 0.1260,  0.0768,  0.0401, -0.0016, -0.0432, -0.0542, -0.1386, -0.1932,
         -0.2641, -0.4437],
        [ 0.2901,  0.1198, -0.0116, -0.0966, -0.1017, -0.2160, -0.2720, -0.2991,
         -0.3387, -0.3425],
        [ 0.2860,  0.2406,  0.1618,  0.0637,  0.0618, -0.0146, -0.0677, -0.1033,
         -0.1410, -0.1941],
        [ 0.3961,  0.2570,  0.1627,  0.0603, -0.0013, -0.0378, -0.0475, -0.0796,
         -0.1045, -0.1971],
        [ 0.4101,  0.0323,  0.0321,  0.0282,  0.0243, -0.1340, -0.1390, -0.1404,
         -0.1476, -0.1952],
        [ 0.0783,  