In [None]:
from pytorch_metric_learning import losses, miners
from pytorch_metric_learning.distances import CosineSimilarity, DotProductSimilarity
from pytorch_metric_learning.miners import BaseMiner
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
from pytorch_metric_learning.utils import common_functions as c_f
import torch

import numpy as np
from tqdm import tqdm
import time

##########
#This file contains the support methods to retrieve loss functions and miners from a list of notable ones
#Is always possible to use different loss funtions or miners by manually providing them to the net model
##########


#Loss functions
def get_loss(loss_name):
    if loss_name == 'SupConLoss': return losses.SupConLoss(temperature=0.07)
    if loss_name == 'CircleLoss': return losses.CircleLoss(m=0.4, gamma=80) #these are params for image retrieval
    if loss_name == 'MultiSimilarityLoss': return losses.MultiSimilarityLoss(alpha=1.0, beta=50, base=0.0, distance=DotProductSimilarity())
    if loss_name == 'ContrastiveLoss': return losses.ContrastiveLoss(pos_margin=0, neg_margin=1)
    if loss_name == 'Lifted': return losses.GeneralizedLiftedStructureLoss(neg_margin=0, pos_margin=1, distance=DotProductSimilarity())
    if loss_name == 'FastAPLoss': return losses.FastAPLoss(num_bins=10)
    if loss_name == 'NTXentLoss': return losses.NTXentLoss(temperature=0.07) #The MoCo paper uses 0.07, while SimCLR uses 0.5.
    if loss_name == 'TripletMarginLoss': return losses.TripletMarginLoss(margin=0.1, swap=False, smooth_loss=False, triplets_per_anchor='all') #or an int, for example 100
    if loss_name == 'CentroidTripletLoss': return losses.CentroidTripletLoss(margin=0.05,
                                                                            swap=False,
                                                                            smooth_loss=False,
                                                                            triplets_per_anchor="all",)
    raise NotImplementedError(f'Sorry, <{loss_name}> loss function is not implemented!')

#Miners
def get_miner(miner_name, margin=0.1):
    if miner_name == 'TripletMarginMiner' : return miners.TripletMarginMiner(margin=margin, type_of_triplets="semihard") # all, hard, semihard, easy
    if miner_name == 'MultiSimilarityMiner' : return miners.MultiSimilarityMiner(epsilon=margin, distance=CosineSimilarity())
    if miner_name == 'PairMarginMiner' : return miners.PairMarginMiner(pos_margin=0.7, neg_margin=0.3, distance=DotProductSimilarity())

    if miner_name == 'NewTripletMarginMiner': return NewTripletMarginMiner(margin=margin, type_of_triplets="semihard")
    if miner_name == 'NewMultiSimilarityMiner': return NewMultiSimilarityMiner(epsilon=margin, distance=CosineSimilarity())
    if miner_name == 'NewPairMarginMiner' : return NewPairMarginMiner(pos_margin=0.7, neg_margin=0.3, distance=DotProductSimilarity())

    return None

#New implementation of the MultiSimilarity Miner
class NewMultiSimilarityMiner(BaseMiner):
    def __init__(self, epsilon=0.1, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        self.add_to_recordable_attributes(name="epsilon", is_stat=False)

        #Please notice that the following parameter is hardcoded and must be updated every time the
        #miner is used in a new environment
        PATH_TO_DICT = '/content/drive/MyDrive/Colab_Notebooks/Dataframes/positives_classes.npy'
        self.additional_positive_instances = np.load(PATH_TO_DICT, allow_pickle='TRUE').item()

    def mine(self, embeddings, labels, ref_emb, ref_labels):
        mat = self.distance(embeddings, ref_emb)
        a1, p, a2, n = self.get_all_pairs_indices(labels, ref_labels)

        if len(a1) == 0 or len(a2) == 0:
            empty = torch.tensor([], device=labels.device, dtype=torch.long)
            return empty.clone(), empty.clone(), empty.clone(), empty.clone()

        mat_neg_sorting = mat
        mat_pos_sorting = mat.clone()

        dtype = mat.dtype
        pos_ignore = (
            c_f.pos_inf(dtype) if self.distance.is_inverted else c_f.neg_inf(dtype)
        )
        neg_ignore = (
            c_f.neg_inf(dtype) if self.distance.is_inverted else c_f.pos_inf(dtype)
        )

        mat_pos_sorting[a2, n] = pos_ignore
        mat_neg_sorting[a1, p] = neg_ignore
        if embeddings is ref_emb:
            mat_pos_sorting.fill_diagonal_(pos_ignore)
            mat_neg_sorting.fill_diagonal_(neg_ignore)

        pos_sorted, pos_sorted_idx = torch.sort(mat_pos_sorting, dim=1)
        neg_sorted, neg_sorted_idx = torch.sort(mat_neg_sorting, dim=1)

        if self.distance.is_inverted:
            hard_pos_idx = torch.where(
                pos_sorted - self.epsilon < neg_sorted[:, -1].unsqueeze(1)
            )
            hard_neg_idx = torch.where(
                neg_sorted + self.epsilon > pos_sorted[:, 0].unsqueeze(1)
            )
        else:
            hard_pos_idx = torch.where(
                pos_sorted + self.epsilon > neg_sorted[:, 0].unsqueeze(1)
            )
            hard_neg_idx = torch.where(
                neg_sorted - self.epsilon < pos_sorted[:, -1].unsqueeze(1)
            )

        a1 = hard_pos_idx[0]
        p = pos_sorted_idx[a1, hard_pos_idx[1]]
        a2 = hard_neg_idx[0]
        n = neg_sorted_idx[a2, hard_neg_idx[1]]

        return a1, p, a2, n

    def get_default_distance(self):
        return CosineSimilarity()

    def get_all_pairs_indices(self, labels, ref_labels=None):
          """
          Given a tensor of labels, this will return 4 tensors.
          The first 2 tensors are the indices which form all positive pairs
          The second 2 tensors are the indices which form all negative pairs
          """
          matches, diffs = self.get_matches_and_diffs(labels, ref_labels)
          a1_idx, p_idx = torch.where(matches)
          a2_idx, n_idx = torch.where(diffs)
          return a1_idx, p_idx, a2_idx, n_idx

    def output_assertion(self, output):
        """
        Args:
            output: the output of self.mine
        This asserts that the mining function is outputting
        properly formatted indices. The default is to require a tuple representing
        a,p,n indices or a1,p,a2,n indices within a batch of embeddings.
        For example, a tuple of (anchors, positives, negatives) will be
        (torch.tensor, torch.tensor, torch.tensor)
        """
        if len(output) == 3:
            self.num_triplets = len(output[0])
            assert self.num_triplets == len(output[1]) == len(output[2])
        elif len(output) == 4:
            self.num_pos_pairs = len(output[0])
            self.num_neg_pairs = len(output[2])
            assert self.num_pos_pairs == len(output[1])
            assert self.num_neg_pairs == len(output[3])
        else:
            raise TypeError

    def get_matches_and_diffs(self, labels, ref_labels=None):

        if ref_labels is None:
            ref_labels = labels

        labels1 = labels.unsqueeze(1)
        labels2 = ref_labels.unsqueeze(0)
        matches = (labels1 == labels2).byte()
        diffs = matches ^ 1

        labels_list = labels.tolist()
        ref_labels_list = ref_labels.tolist()
        ref_labels_indexes = {}

        for i, element in enumerate(ref_labels_list):
            if element not in ref_labels_indexes:
                ref_labels_indexes[element] = []
            ref_labels_indexes[element].append(i)

        modification_done = 0

        for i, label in enumerate(labels_list):
            if label in self.additional_positive_instances:
                positives = self.additional_positive_instances[label]
                for pos in positives:
                    if pos in ref_labels_indexes:
                        for index in ref_labels_indexes[pos]:
                            modification_done+=1
                            matches[i, index] = 1
                            matches[index, i] = 1

        if ref_labels is labels:
            matches.fill_diagonal_(0)

        return matches, diffs

#New implementation of the TripletMargin Miner
class NewTripletMarginMiner(BaseMiner):
    """
    Returns triplets that violate the margin
    Args:
        margin
        type_of_triplets: options are "all", "hard", or "semihard".
                "all" means all triplets that violate the margin
                "hard" is a subset of "all", but the negative is closer to the anchor than the positive
                "semihard" is a subset of "all", but the negative is further from the anchor than the positive
            "easy" is all triplets that are not in "all"
    """

    def __init__(self, margin=0.2, type_of_triplets="all", **kwargs):
        super().__init__(**kwargs)
        self.margin = margin
        self.type_of_triplets = type_of_triplets
        self.add_to_recordable_attributes(list_of_names=["margin"], is_stat=False)
        self.add_to_recordable_attributes(
            list_of_names=["avg_triplet_margin", "pos_pair_dist", "neg_pair_dist"],
            is_stat=True,
        )

        #Please notice that the following parameter is hardcoded and must be updated every time the
        #miner is used in a new environment
        PATH_TO_DICT = '/content/drive/MyDrive/Colab_Notebooks/Dataframes/positives_classes.npy'
        self.additional_positive_instances = np.load(PATH_TO_DICT, allow_pickle='TRUE').item()

    def mine(self, embeddings, labels, ref_emb, ref_labels):
        anchor_idx, positive_idx, negative_idx = self.get_all_triplets_indices(
            labels, ref_labels
        )
        mat = self.distance(embeddings, ref_emb)
        ap_dist = mat[anchor_idx, positive_idx]
        an_dist = mat[anchor_idx, negative_idx]
        triplet_margin = (
            ap_dist - an_dist if self.distance.is_inverted else an_dist - ap_dist
        )

        self.set_stats(ap_dist, an_dist, triplet_margin)

        if self.type_of_triplets == "easy":
            threshold_condition = triplet_margin > self.margin
        else:
            threshold_condition = triplet_margin <= self.margin
            if self.type_of_triplets == "hard":
                threshold_condition &= triplet_margin <= 0
            elif self.type_of_triplets == "semihard":
                threshold_condition &= triplet_margin > 0

        return (
            anchor_idx[threshold_condition],
            positive_idx[threshold_condition],
            negative_idx[threshold_condition],
        )

    def get_all_triplets_indices(self, labels, ref_labels=None):
          all_matches, all_diffs = self.get_matches_and_diffs(labels, ref_labels)

          if (
              all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1]
              < torch.iinfo(torch.int32).max
          ):
              # torch.nonzero is not supported for tensors with more than INT_MAX elements
              return self.get_all_triplets_indices_vectorized_method(all_matches, all_diffs)

          return self.get_all_triplets_indices_loop_method(labels, all_matches, all_diffs)

    def set_stats(self, ap_dist, an_dist, triplet_margin):
        if self.collect_stats:
            with torch.no_grad():
                self.pos_pair_dist = torch.mean(ap_dist).item()
                self.neg_pair_dist = torch.mean(an_dist).item()
                self.avg_triplet_margin = torch.mean(triplet_margin).item()

    def get_matches_and_diffs(self, labels, ref_labels=None):

        if ref_labels is None:
            ref_labels = labels

        labels1 = labels.unsqueeze(1)
        labels2 = ref_labels.unsqueeze(0)
        matches = (labels1 == labels2).byte()
        diffs = matches ^ 1

        labels_list = labels.tolist()
        ref_labels_list = ref_labels.tolist()
        ref_labels_indexes = {}

        for i, element in enumerate(ref_labels_list):
            if element not in ref_labels_indexes:
                ref_labels_indexes[element] = []
            ref_labels_indexes[element].append(i)

        modification_done = 0

        for i, label in enumerate(labels_list):
            if label in self.additional_positive_instances:
                positives = self.additional_positive_instances[label]
                for pos in positives:
                    if pos in ref_labels_indexes:
                        for index in ref_labels_indexes[pos]:
                            modification_done+=1
                            matches[i, index] = 1
                            matches[index, i] = 1

        if ref_labels is labels:
            matches.fill_diagonal_(0)

        return matches, diffs

    def get_all_triplets_indices_vectorized_method(self, all_matches, all_diffs):
        triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
        return torch.where(triplets)


    def get_all_triplets_indices_loop_method(self, labels, all_matches, all_diffs):
        all_matches, all_diffs = all_matches.bool(), all_diffs.bool()

        # Find anchors with at least a positive and a negative
        indices = torch.arange(0, len(labels), device=labels.device)
        indices = indices[all_matches.any(dim=1) & all_diffs.any(dim=1)]

        # No triplets found
        if len(indices) == 0:
            return (
                torch.tensor([], device=labels.device, dtype=labels.dtype),
                torch.tensor([], device=labels.device, dtype=labels.dtype),
                torch.tensor([], device=labels.device, dtype=labels.dtype),
            )

        # Compute all triplets
        anchors = []
        positives = []
        negatives = []
        for i in indices:
            matches = all_matches[i].nonzero(as_tuple=False).squeeze(1)
            diffs = all_diffs[i].nonzero(as_tuple=False).squeeze(1)
            nd = len(diffs)
            nm = len(matches)
            matches = matches.repeat_interleave(nd)
            diffs = diffs.repeat(nm)
            anchors.append(
                torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device)
            )
            positives.append(matches)
            negatives.append(diffs)
        return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)

    def output_assertion(self, output):
        """
        Args:
            output: the output of self.mine
        This asserts that the mining function is outputting
        properly formatted indices. The default is to require a tuple representing
        a,p,n indices or a1,p,a2,n indices within a batch of embeddings.
        For example, a tuple of (anchors, positives, negatives) will be
        (torch.tensor, torch.tensor, torch.tensor)
        """
        if len(output) == 3:
            self.num_triplets = len(output[0])
            assert self.num_triplets == len(output[1]) == len(output[2])
        elif len(output) == 4:
            self.num_pos_pairs = len(output[0])
            self.num_neg_pairs = len(output[2])
            assert self.num_pos_pairs == len(output[1])
            assert self.num_neg_pairs == len(output[3])
        else:
            raise TypeError

#New implementation of the PairMargin Miner
class NewPairMarginMiner(BaseMiner):
    """
    Returns positive pairs that have distance greater than a margin and negative
    pairs that have distance less than a margin
    """

    def __init__(self, pos_margin=0.2, neg_margin=0.8, **kwargs):
        super().__init__(**kwargs)
        self.pos_margin = pos_margin
        self.neg_margin = neg_margin
        self.add_to_recordable_attributes(
            list_of_names=["pos_margin", "neg_margin"], is_stat=False
        )
        self.add_to_recordable_attributes(
            list_of_names=["pos_pair_dist", "neg_pair_dist"], is_stat=True
        )

        #Please notice that the following parameter is hardcoded and must be updated every time the
        #miner is used in a new environment
        PATH_TO_DICT = '/content/drive/MyDrive/Colab_Notebooks/Dataframes/positives_classes.npy'
        self.additional_positive_instances = np.load(PATH_TO_DICT, allow_pickle='TRUE').item()

    def mine(self, embeddings, labels, ref_emb, ref_labels):
        mat = self.distance(embeddings, ref_emb)
        a1, p, a2, n = self.get_all_pairs_indices(labels, ref_labels)
        pos_pair = mat[a1, p]
        neg_pair = mat[a2, n]
        self.set_stats(pos_pair, neg_pair)
        pos_mask = (
            pos_pair < self.pos_margin
            if self.distance.is_inverted
            else pos_pair > self.pos_margin
        )
        neg_mask = (
            neg_pair > self.neg_margin
            if self.distance.is_inverted
            else neg_pair < self.neg_margin
        )
        return a1[pos_mask], p[pos_mask], a2[neg_mask], n[neg_mask]

    def set_stats(self, pos_pair, neg_pair):
        if self.collect_stats:
            with torch.no_grad():
                self.pos_pair_dist = (
                    torch.mean(pos_pair).item() if len(pos_pair) > 0 else 0
                )
                self.neg_pair_dist = (
                    torch.mean(neg_pair).item() if len(neg_pair) > 0 else 0
                )

    def get_all_pairs_indices(self, labels, ref_labels=None):
          """
          Given a tensor of labels, this will return 4 tensors.
          The first 2 tensors are the indices which form all positive pairs
          The second 2 tensors are the indices which form all negative pairs
          """
          matches, diffs = self.get_matches_and_diffs(labels, ref_labels)
          a1_idx, p_idx = torch.where(matches)
          a2_idx, n_idx = torch.where(diffs)
          return a1_idx, p_idx, a2_idx, n_idx

    def get_matches_and_diffs(self, labels, ref_labels=None):

        if ref_labels is None:
            ref_labels = labels

        labels1 = labels.unsqueeze(1)
        labels2 = ref_labels.unsqueeze(0)
        matches = (labels1 == labels2).byte()
        diffs = matches ^ 1

        labels_list = labels.tolist()
        ref_labels_list = ref_labels.tolist()
        ref_labels_indexes = {}

        for i, element in enumerate(ref_labels_list):
            if element not in ref_labels_indexes:
                ref_labels_indexes[element] = []
            ref_labels_indexes[element].append(i)

        modification_done = 0

        for i, label in enumerate(labels_list):
            if label in self.additional_positive_instances:
                positives = self.additional_positive_instances[label]
                for pos in positives:
                    if pos in ref_labels_indexes:
                        for index in ref_labels_indexes[pos]:
                            modification_done+=1
                            matches[i, index] = 1
                            matches[index, i] = 1

        if ref_labels is labels:
            matches.fill_diagonal_(0)

        return matches, diffs

    def output_assertion(self, output):
        """
        Args:
            output: the output of self.mine
        This asserts that the mining function is outputting
        properly formatted indices. The default is to require a tuple representing
        a,p,n indices or a1,p,a2,n indices within a batch of embeddings.
        For example, a tuple of (anchors, positives, negatives) will be
        (torch.tensor, torch.tensor, torch.tensor)
        """
        if len(output) == 3:
            self.num_triplets = len(output[0])
            assert self.num_triplets == len(output[1]) == len(output[2])
        elif len(output) == 4:
            self.num_pos_pairs = len(output[0])
            self.num_neg_pairs = len(output[2])
            assert self.num_pos_pairs == len(output[1])
            assert self.num_neg_pairs == len(output[3])
        else:
            raise TypeError