# Contrastive Loss 

Contrastive loss encourages similar images to have similar representations and dissimilar images to have dissimilar representations.

## 1. Pairwise Contrastive Loss:

Pairwise contrastive loss is a simple form of contrastive loss that encourages similar pairs to have small distances and dissimilar pairs to have large distances. It is often used in Siamese networks.

$$ L_{\text{pairwise}}(y, y') = \frac{1}{N} \sum_{i} \left( y \cdot (y')^2 + (1 - y) \cdot \max(0, (m - y')^2) \right) $$


In [1]:
# Define margin for the contrastive loss
margin = 1.0

# Define the pairwise contrastive loss function
def pairwise_contrastive_loss(y_true, y_pred):
    square_pred = K.square(y_pred)
    margin_square = K.square(K.maximum(margin - y_pred, 0))
    return K.mean(y_true * square_pred + (1 - y_true) * margin_square)


## 2. Triplet Loss:

Triplet loss extends pairwise contrastive loss by considering triplets of samples: an anchor, a positive sample (similar to the anchor), and a negative sample (dissimilar to the anchor). It encourages the positive sample to be closer to the anchor than the negative sample by a margin.

$$L_{\text{triplet}}(a,p,n)=\frac{1}{N} \sum_{i} \max⁡(d(a,p)−d(a,n)+α,0)$$

- $a$: anchor
- $p$: positive
- $n$: negative
- $α$: margin

In [1]:
from tensorflow.keras.losses import Loss

class TripletLoss(Loss):
    def __init__(self, margin=0.2):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def call(self, y_true, y_pred):
        anchor, positive, negative = y_pred[:, 0], y_pred[:, 1], y_pred[:, 2]
        positive_distance = K.sum(K.square(anchor - positive), axis=-1)
        negative_distance = K.sum(K.square(anchor - negative), axis=-1)
        loss = K.maximum(positive_distance - negative_distance + self.margin, 0.0)
        return K.mean(loss)


## 3. Online Contrastive Loss:

Online contrastive loss employs hard negative mining to select the most challenging negative samples within a mini-batch. This speeds up training by focusing on the most informative samples.

$$ L_{\text{online}}(y, y') = \frac{1}{N} \sum_{i} \max(0, d_{\text{pos}} - d_{\text{neg}}) $$


In [2]:
# Define online contrastive loss function with hard negative mining
def online_contrastive_loss(y_true, y_pred):
    margin = 1
    square_pred = K.square(y_pred)
    neg_pred = K.min(square_pred, axis=1)
    pos_pred = K.max(square_pred, axis=1)
    loss = K.maximum(0.0, margin - pos_pred + neg_pred)
    return K.mean(loss)


## 4. Margin-based Contrastive Loss:

Margin-based contrastive loss introduces a margin hyperparameter that controls the minimum distance between the anchor and the negative sample compared to the anchor and the positive sample.

$$ L_{\text{margin-based}}(y, y') = \frac{1}{N} \sum_{i} \max(0, d_{\text{pos}} - d_{\text{neg}} + \alpha) $$

In [3]:
# Define margin-based contrastive loss function
def margin_based_contrastive_loss(y_true, y_pred):
    margin = 0.2
    square_pred = K.square(y_pred)
    positive_distance = K.sum(square_pred[:, :1], axis=-1)
    negative_distance = K.sum(square_pred[:, 1:], axis=-1)
    loss = K.maximum(0.0, positive_distance - negative_distance + margin)
    return K.mean(loss)

## 5. Angular Contrastive Loss:

Angular contrastive loss operates in the embedding space by encouraging similar samples to have small angular distances (e.g., cosine similarity) and dissimilar samples to have large angular distances.

$$ L_\text{{angular}​}(y,y′)=1−\text{cosine similarity}(a,p) $$

In [4]:
from tensorflow.keras.losses import CosineSimilarity

# Define angular contrastive loss function
def angular_contrastive_loss(y_true, y_pred):
    cosine_sim = CosineSimilarity()
    similarity = cosine_sim(y_pred[:, :1], y_pred[:, 1:])
    return 1 - similarity

## 6. Multi-modal Contrastive Loss:

Multi-modal contrastive loss is used in multi-modal settings to learn a joint embedding space where similar items across modalities are closer together and dissimilar items are farther apart. The multi-modal contrastive loss function is similar to the margin-based contrastive loss function.

In [5]:
# Define multi-modal contrastive loss function
def multi_modal_contrastive_loss(y_true, y_pred):
    margin = 0.2
    square_pred = K.square(y_pred)
    positive_distance = K.sum(square_pred[:, :1], axis=-1)
    negative_distance = K.sum(square_pred[:, 1:], axis=-1)
    loss = K.maximum(0.0, positive_distance - negative_distance + margin)
    return K.mean(loss)