# Local Aggregation Loss

In [14]:
import numpy as np

import torch

from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sklearn.preprocessing import normalize
from scipy.spatial.distance import cosine as cosine_distance

## Teste

In [8]:
nn = NearestNeighbors(n_neighbors=3)
nn.fit([[1, 1, 1], [2, 2, 2], [3, 3, 3]])

NearestNeighbors(n_neighbors=3)

In [13]:
indices_nearest = nn.kneighbors([[1, 1, 2],[3, 3, 2.5]], return_distance=True)
indices_nearest

(array([[1.        , 1.41421356, 3.        ],
        [0.5       , 1.5       , 3.20156212]]), array([[0, 1, 2],
        [2, 1, 0]]))

## Base

### marsaglia

In [None]:
def marsaglia(sphere_dim):
    '''Method to generate a point uniformly distributed on the (N-1) sphere by Marsaglia
    Args:
        sphere_dim (int): dimension of the sphere on which to generate the point
    '''
    norm_vals = np.random.standard_normal(sphere_dim)
    print(norm_vals)
    print(np.linalg.norm(norm_vals))
    print(torch.linalg.norm(torch.from_numpy(norm_vals)))
    #torch.linalg.norm
    return norm_vals / np.linalg.norm(norm_vals)

In [None]:
sphere_dim = 10
t = marsaglia(sphere_dim)

[-0.05454884  1.0800896   0.89597025 -0.02810328  0.39605392 -0.90975487
 -0.63833395  0.2830118   0.54265654 -0.62901078]
2.0335512784319
tensor(2.0336, dtype=torch.float64)


In [None]:
t[0]

-0.026824420777907435

### memory bank

In [None]:
class MemoryBank(object):

    def __init__(self, n_vectors, dim_vector, memory_mixing_rate=None):

        self.dim_vector = dim_vector
        self.vectors = np.array([marsaglia(dim_vector) for _ in range(n_vectors)])
        self.memory_mixing_rate = memory_mixing_rate
        self.mask_init = np.array([False] * n_vectors)

    def update_memory(self, vectors, index):

        if isinstance(index, int):
            self.vectors[index] = self._update_(vectors, self.vectors[index])

        elif isinstance(index, np.ndarray):
            for ind, vector in zip(index, vectors):
                # print(ind)
                # print(vector)
                self.vectors[ind] = self._update_(vector, self.vectors[ind])
                print(self.vectors)

        else:
            raise RuntimeError('Index must be of type integer or NumPy array, not {}'.format(type(index)))

    def mask(self, inds_int):
        ret_mask = []
        
        for row in inds_int:
            row_mask = np.full(self.vectors.shape[0], False)
            print(row_mask)
            row_mask[row.astype(int)] = True
            print(row_mask)
            ret_mask.append(row_mask)
            print(row_mask)

        return np.array(ret_mask)

    def _update_(self, vector_new, vector_recall):
        v_add = vector_new * self.memory_mixing_rate + vector_recall * (1.0 - self.memory_mixing_rate)
        print(v_add / np.linalg.norm(v_add))
        return v_add / np.linalg.norm(v_add)

    def _verify_dim_(self, vector_new):
        if len(vector_new) != self.dim_vector:
            raise VectorUpdateError('Update vector of dimension size {}, '.format(len(vector_new)) + \
                                    'but memory of dimension size {}'.format(self.dim_vector))

In [None]:
mb = MemoryBank(4, 3, 0.1)

[-1.35326955  1.05563127  0.64875714]
1.8348247020678476
tensor(1.8348, dtype=torch.float64)
[ 0.30169195 -0.49381836  0.92747806]
1.0932017909518499
tensor(1.0932, dtype=torch.float64)
[ 0.71598801 -1.13570389 -1.47589053]
1.995172933848223
tensor(1.9952, dtype=torch.float64)
[-1.5349975  -0.02669274 -0.71568599]
1.693852492660835
tensor(1.6939, dtype=torch.float64)


In [None]:
mb.mask(np.array([0, 3]))

[False False False False]
[ True False False False]
[ True False False False]
[False False False False]
[False False False  True]
[False False False  True]


array([[ True, False, False, False],
       [False, False, False,  True]])

In [None]:
mb.vectors

array([[ 0.42802264, -0.64532653, -0.6327324 ],
       [-0.71927906, -0.69011609, -0.07985872],
       [-0.67763074,  0.72475706,  0.12467469],
       [-0.00710815, -0.77451269, -0.63251843]])

In [None]:
mb.mask(np.array([1, 3]))

[False False False False]
[False  True False False]
[False  True False False]
[False False False False]
[False False False  True]
[False False False  True]


array([[False,  True, False, False],
       [False, False, False,  True]])

In [None]:
mb.vectors

tensor([[ 0.3871, -0.7581, -0.5249],
        [ 0.3998, -0.3902,  0.8294]])

In [None]:
t = np.array([[1, 2, 3], [4, 5, 6]])
i = np.array([0, 1])

In [None]:
mb.update_memory(t, i)

[ 0.38707319 -0.75805997 -0.52489945]
[[ 0.38707319 -0.75805997 -0.52489945]
 [-0.01517683 -0.97448762  0.22392754]]
[ 0.39978466 -0.39015898  0.82942643]
[[ 0.38707319 -0.75805997 -0.52489945]
 [ 0.39978466 -0.39015898  0.82942643]]


## LLA

In [None]:
class LocalAggregationLoss(nn.Module):

    def __init__(self, temperature,
                 k_nearest_neighbours, clustering_repeats, number_of_centroids,
                 memory_bank,
                 kmeans_n_init=1, nn_metric=cosine_distance, nn_metric_params={},
                 include_self_index=True, force_stacking=False):
        super(LocalAggregationLoss, self).__init__()

        self.temperature = temperature
        self.memory_bank = memory_bank
        self.include_self_index = include_self_index
        self.force_stacking = force_stacking

        self.background_neighbours = None
        self.close_neighbours = None

        self.neighbour_finder = NearestNeighbors(n_neighbors=k_nearest_neighbours + 1,
                                                 algorithm='ball_tree',
                                                 metric=nn_metric, metric_params=nn_metric_params)
        self.clusterer = []
        for k_clusterer in range(clustering_repeats):

            self.clusterer.append(KMeans(n_clusters=number_of_centroids, init='random', n_init=kmeans_n_init))

    def _nearest_neighbours(self, codes_data, indices):

        self.neighbour_finder.fit(self.memory_bank.vectors)
        indices_nearest = self.neighbour_finder.kneighbors(codes_data, return_distance=False)

        if not self.include_self_index:
            self_neighbour_masks = [np.where(indices_nearest[k] == indices[k]) for k in range(indices_nearest.shape[0])]
            if any([len(x) != 1 for x in self_neighbour_masks]):
                raise RuntimeError('Self neighbours not correctly shaped')
            indices_nearest = np.delete(indices_nearest, self_neighbour_masks, axis=1)

        return self.memory_bank.mask(indices_nearest)

    def _close_grouper(self, indices):
      
        memberships = [[]] * len(indices)
        for clusterer in self.clusterer:
            clusterer.fit(self.memory_bank.vectors)
            
            for k_index, cluster_index in enumerate(clusterer.labels_[indices]):
                other_members = np.where(clusterer.labels_ == cluster_index)[0]
                other_members_union = np.union1d(memberships[k_index], other_members)
                memberships[k_index] = other_members_union.astype(int)

        return self.memory_bank.mask(np.array(memberships, dtype=object))

    def _intersecter(self, n1, n2):
      
        ret = [[v1 and v2 for v1, v2 in zip(n1_x, n2_x)] for n1_x, n2_x in zip(n1, n2)]
        return np.array(ret)

    def _prob_density(self, codes, indices, force_stack=False):
      
        ragged = len(set([np.count_nonzero(ind) for ind in indices])) != 1

        if not ragged and not force_stack:
            vals = torch.tensor([np.compress(ind, self.memory_bank.vectors, axis=0) for ind in indices],
                                requires_grad=False)
            v_dots = torch.matmul(vals, codes.unsqueeze(-1))
            exp_values = torch.exp(torch.div(v_dots, self.temperature))
            xx = torch.sum(exp_values, dim=1).squeeze(-1)

        else:
            xx_container = []
            for k_item in range(codes.size(0)):
                vals = torch.tensor(np.compress(indices[k_item], self.memory_bank.vectors, axis=0),
                                    requires_grad=False)
                v_dots_prime = torch.mv(vals, codes[k_item])
                exp_values_prime = torch.exp(torch.div(v_dots_prime, self.temperature))
                xx_prime = torch.sum(exp_values_prime, dim=0)
                xx_container.append(xx_prime)
            xx = torch.stack(xx_container, dim=0)

        return xx

    def forward(self, codes, indices):

        assert codes.shape[0] == len(indices)

        # print(indices)

        codes = codes.type(torch.DoubleTensor)
        code_data = normalize(codes.detach().numpy(), axis=1)

        self.memory_bank.update_memory(code_data, indices)
        self.background_neighbours = self._nearest_neighbours(code_data, indices)
        self.close_neighbours = self._close_grouper(indices)
        self.neighbour_intersect = self._intersecter(self.background_neighbours, self.close_neighbours)

        v = F.normalize(codes, p=2, dim=1)
        d1 = self._prob_density(v, self.background_neighbours, self.force_stacking)
        d2 = self._prob_density(v, self.neighbour_intersect, self.force_stacking)
        loss_cluster = torch.sum(torch.log(d1) - torch.log(d2)) / codes.shape[0]

        return loss_cluster

## Otimizado

In [None]:
def marsaglia(sphere_dim):

    normal_values = torch.randn(sphere_dim)
    norma         = torch.linalg.norm(normal_values)

    return normal_values/norma

In [None]:
class MemoryBank(object):

    def __init__(self, n_vectors, dimension_vector, memory_mixing_rate=None):

        self.dimension_vector   = dimension_vector
        self.memory_mixing_rate = memory_mixing_rate

        self.vectors            = torch.stack([marsaglia(dimension_vector) for _ in range(n_vectors)], dim=0)
        self.mask_init          = torch.zeros(n_vectors, dtype=torch.bool)

    def update_memory(self, vectors, index):

        if isinstance(index, int):
            self.vectors[index] = self._update_(vectors, self.vectors[index])

        elif isinstance(index, torch.Tensor):
            for ind, vector in zip(index, vectors):                    
                self.vectors[ind] = self._update_(vector, self.vectors[ind])

        else:
            raise RuntimeError('Index must be of type integer or torch.Tensor, not {}'.format(type(index)))  

    def mask(self, inds_int):

        ret_mask = torch.zeros((len(inds_int), self.vectors.shape[0]), dtype=torch.bool)
        
        for idx, row in enumerate(inds_int):
            ret_mask[idx, row] = True

        return ret_mask

    def _update_(self, vector_new, vector_recall):
        v_add = vector_new * self.memory_mixing_rate + vector_recall * (1.0 - self.memory_mixing_rate)
        return v_add / torch.linalg.norm(v_add)

    def _verify_dim_(self, vector_new):
        if len(vector_new) != self.dim_vector:
            raise VectorUpdateError('Update vector of dimension size {}, '.format(len(vector_new)) + \
                                    'but memory of dimension size {}'.format(self.dim_vector))

In [None]:
class LALoss(nn.Module):

    def __init__(self, temperature,
                 k_nearest_neighbours, clustering_repeats, number_of_centroids,
                 memory_bank,
                 kmeans_n_init=1, nn_metric=cosine_distance, nn_metric_params={},
                 include_self_index=True, force_stacking=False):
      
        super(LocalAggregationLoss, self).__init__()

        self.temperature        = temperature
        self.memory_bank        = memory_bank
        self.include_self_index = include_self_index
        self.force_stacking     = force_stacking

        self.background_neighbours = None
        self.close_neighbours      = None

        self.neighbour_finder      = NearestNeighbors(n_neighbors=k_nearest_neighbours + 1,
                                                      algorithm='ball_tree',
                                                      metric=nn_metric, metric_params=nn_metric_params)
        self.clusterer = []

        for k_clusterer in range(clustering_repeats):
            self.clusterer.append(KMeans(n_clusters=number_of_centroids, init='random', n_init=kmeans_n_init))

    # Nearest Neighbours =======================================================================================================

    def _nearest_neighbours(self, codes_data, indices):     # Neighbour Finder

        self.neighbour_finder.fit(self.memory_bank.vectors)
        indices_nearest = self.neighbour_finder.kneighbors(codes_data, return_distance=False)

        if not self.include_self_index:
            self_neighbour_masks = [np.where(indices_nearest[k] == indices[k]) for k in range(indices_nearest.shape[0])]
            if any([len(x) != 1 for x in self_neighbour_masks]):
                raise RuntimeError('Self neighbours not correctly shaped')
            indices_nearest = np.delete(indices_nearest, self_neighbour_masks, axis=1)

        return self.memory_bank.mask(indices_nearest)

    # Close Grouper ============================================================================================================

    def _close_grouper(self, indices):  # self.clusterer
      
        memberships = [[]] * len(indices)
        
        for clusterer in self.clusterer:
            clusterer.fit(self.memory_bank.vectors)
            
            for k_index, cluster_index in enumerate(clusterer.labels_[indices]):
                other_members        = np.where(clusterer.labels_ == cluster_index)[0]
                other_members_union  = np.union1d(memberships[k_index], other_members)
                memberships[k_index] = other_members_union.astype(int)

        return self.memory_bank.mask(np.array(memberships, dtype=object))

    # Intersection =============================================================================================================

    def _intersecter(self, n1, n2):
      
        ret = [[v1 and v2 for v1, v2 in zip(n1_x, n2_x)] for n1_x, n2_x in zip(n1, n2)]
        return np.array(ret)

    # Probability Density ======================================================================================================

    def _prob_density(self, codes, indices, force_stack=False):
      
        ragged = len(set([np.count_nonzero(ind) for ind in indices])) != 1

        if not ragged and not force_stack:
            vals = torch.tensor([np.compress(ind, self.memory_bank.vectors, axis=0) for ind in indices],
                                requires_grad=False)
            v_dots = torch.matmul(vals, codes.unsqueeze(-1))
            exp_values = torch.exp(torch.div(v_dots, self.temperature))
            xx = torch.sum(exp_values, dim=1).squeeze(-1)

        else:
            xx_container = []
            for k_item in range(codes.size(0)):
                vals = torch.tensor(np.compress(indices[k_item], self.memory_bank.vectors, axis=0),
                                    requires_grad=False)
                v_dots_prime = torch.mv(vals, codes[k_item])
                exp_values_prime = torch.exp(torch.div(v_dots_prime, self.temperature))
                xx_prime = torch.sum(exp_values_prime, dim=0)
                xx_container.append(xx_prime)
            xx = torch.stack(xx_container, dim=0)

        return xx

    # Forward ==================================================================================================================

    def forward(self, codes, indices):

        assert codes.shape[0] == len(indices)

        codes     = codes.type(torch.DoubleTensor)
        code_data = normalize(codes.detach().numpy(), axis=1)

        self.memory_bank.update_memory(code_data, indices)

        self.background_neighbours = self._nearest_neighbours(code_data, indices)
        self.close_neighbours      = self._close_grouper(indices)
        self.neighbour_intersect   = self._intersecter(self.background_neighbours, self.close_neighbours)

        v            = F.normalize(codes, p=2, dim=1)
        d1           = self._prob_density(v, self.background_neighbours, self.force_stacking)
        d2           = self._prob_density(v, self.neighbour_intersect, self.force_stacking)
        loss_cluster = torch.sum(torch.log(d1) - torch.log(d2)) / codes.shape[0]

        return loss_cluster

In [None]:
mb = MemoryBank(4, 3, 0.1)

In [None]:
mb.mask([0, 3])

tensor([ True, False, False,  True])

In [None]:
mb.vectors = torch.Tensor([[ 0.14335052, -0.7205703 , -0.67840185], [-0.01517683, -0.97448762,  0.22392754]])

In [None]:
t = torch.Tensor([[1, 2, 3], [4, 5, 6]])
i = torch.tensor([0, 1])

In [None]:
mb.update_memory(t, i)

0 tensor(0) tensor([1., 2., 3.])
1 tensor(1) tensor([4., 5., 6.])


In [None]:
mb.vectors

tensor([[ 0.3871, -0.7581, -0.5249],
        [ 0.3998, -0.3902,  0.8294]])

In [None]:
row_mask = torch.zeros(3, dtype=torch.bool)
print(row_mask)
# row_mask[inds_int] = True

tensor([False, False, False])


In [None]:
row_mask[[1, 2]] = True

In [None]:
row_mask

tensor([False,  True,  True])

In [None]:
print('pão')

pão


In [22]:
def mask(inds_int):
        ret_mask = []
        
        for row in inds_int:
            row_mask = np.full(5, False)
            row_mask[row] = True
            ret_mask.append(row_mask)

        return np.array(ret_mask)

In [22]:
m = mask([[1, 2], [0,1]])
m

[False False False False False]
[False  True  True False False]
[False  True  True False False]
[False False False False False]
[ True  True False False False]
[ True  True False False False]


array([[False,  True,  True, False, False],
       [ True,  True, False, False, False]])

In [24]:
def mask(inds_int):

        ret_mask = torch.zeros((len(inds_int), 5), dtype=torch.bool)
        
        for idx, row in enumerate(inds_int):
            ret_mask[idx, row] = True

        return ret_mask

In [18]:
def _nearest_neighbours():     # Neighbour Finder

    nn = NearestNeighbors(n_neighbors=3)
    nn.fit([[1, 1, 1], [2, 2, 2], [3, 3, 3]])

    indices_nearest = nn.kneighbors([[1, 1, 2],[3, 3, 2.5]], return_distance=False)

    print(indices_nearest)

    return mask(indices_nearest)

In [25]:
m = _nearest_neighbours()
m

[[0 1 2]
 [2 1 0]]


tensor([[ True,  True,  True, False, False],
        [ True,  True,  True, False, False]])

In [9]:
m = mask([[1, 2], [0,1]])
m

tensor([[False,  True,  True, False, False],
        [ True,  True, False, False, False]])

In [6]:
z = torch.zeros((3, 5), dtype=torch.bool)
z

tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])

In [7]:
z[0, [1, 2]] = 2
z

tensor([[False,  True,  True, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])

In [2]:
import torch

In [None]:
nn = NearestNeighbors(n_neighbors=3)
nn.fit([[1, 1, 1], [2, 2, 2], [3, 3, 3]])

indices_nearest = nn.kneighbors([[1, 1, 2],[3, 3, 2.5]], return_distance=True)
indices_nearest