In [2]:
!pip install tensorflow==2.0.0



In [0]:
import numpy as np
import tensorflow as tf

batch = 64
emb_dim = 1024

np.random.seed(1234)
emb1 = np.random.rand(batch,emb_dim).astype(np.float32)
np.random.seed(2345)
emb2 = np.random.rand(batch,emb_dim).astype(np.float32)
margin = 0.3
labels = np.concatenate((np.arange(batch/2),np.arange(batch/2)), axis=0)
# labels = np.expand_dims(np.arange(batch), axis=1)

In [19]:
labels.shape

(64,)

In [20]:
tf.__version__

'2.0.0'

In [0]:
def _distance_metric(x, y):
    """
    Args:
        x: tensor, with shape [m, d], (batch_size, d)
        y: tensor, with shape [n, d], (batch_size, d)
    Returns:
        dist: tensor, with shape [m, n], (batch_size, batch_size)
    """
    # |x-y|^2 = x^2 - 2xy + y^2
    # xy
    xy = tf.matmul(x, tf.transpose(y))
    # x^2
    xx = tf.matmul(x, tf.transpose(x))
    xx = tf.linalg.diag_part(xx)
    # y^2
    yy = tf.matmul(y, tf.transpose(y))
    yy = tf.linalg.diag_part(yy)
    '''
    (batch_size,1)-(batch_size,batch_size):
        Equivalent to each column operation
    (batch_size,batch_size)+(1,batch_size):
        Equivalent to each row operation
    '''
    distances = tf.expand_dims(xx, 1) - 2.0*xy + tf.expand_dims(yy, 0)
    return distances

In [0]:
def _label_mask(labels):
    '''
    if label is same, label_mask will return True
    ------------------------------------
    Args:
        labels:     Label Data, shape = (batch_size,1)
    Returns:
        label_mask: tensor, with shape [m, n], (batch_size, batch_size)
        ex.
            labels = [1,0,1]
            label_mask = [[1, 0, 1],
                          [0, 1, 0],
                          [1, 0, 1]]
    '''
    label_mask = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
    return label_mask

In [0]:
def batch_all(labels, emb1, emb2, margin):
    '''
    batch all triplet loss of a batch
    ------------------------------------
    Args:
        labels:     Label Data, shape = (batch_size,1)
        emb1, emb2: Embedding Feature, shape = (batch_size, vector_size)
        margin:     margin, scalar
    Returns:
        triplet_loss: scalar, for one batch
    '''
    dist_mat = _distance_metric(emb1, emb2)
    # an and ap mask
    ap_mask = _label_mask(labels)
    an_mask = tf.dtypes.cast(tf.math.logical_not(ap_mask), dtype=tf.float32)
    ap_mask = tf.dtypes.cast(ap_mask, dtype=tf.float32)
    # distance between anchor and positive
    dist_ap = tf.reduce_sum(dist_mat*ap_mask, axis=1)/tf.reduce_sum(ap_mask, axis=1)
    # ap - dist_mat + margin
    mat = tf.expand_dims(dist_ap, 1) - dist_mat + margin
    # only need ap-an
    mat = mat*an_mask
    # caluculate the number of valid triplet loss
    mask = tf.dtypes.cast(tf.math.greater(mat, 0.0), dtype=tf.float32)
    num_valid_triplets = tf.reduce_sum(mask)
    triplet_loss = mat*mask
    # <1 : 1
    num_valid_triplets = tf.maximum(num_valid_triplets, 1.0)
    # divided triplet_loss by num_valid_triplets
    triplet_loss = tf.reduce_sum(triplet_loss)/(num_valid_triplets + 1e-16)
    return triplet_loss, num_valid_triplets

In [0]:
triplet_loss, num_valid_triplets = batch_all(labels, emb1, emb2, margin)

In [28]:
triplet_loss

<tf.Tensor: id=237, shape=(), dtype=float32, numpy=5.733356>

In [29]:
num_valid_triplets

<tf.Tensor: id=232, shape=(), dtype=float32, numpy=1930.0>

In [0]:
def batch_hard(labels, emb1, emb2, margin):
    '''
    batch hard triplet loss of a batch
    ------------------------------------
    Args:
        labels:     Label Data, shape = (batch_size,1)
        emb1, emb2: Embedding Feature, shape = (batch_size, vector_size)
        margin:     margin, scalar
    Returns:
        triplet_loss: scalar, for one batch
    '''
    dist_mat = _distance_metric(emb1, emb2)
    # an and ap mask
    ap_mask = _label_mask(labels)
    an_mask = tf.dtypes.cast(tf.math.logical_not(ap_mask), dtype=tf.float32)
    ap_mask = tf.dtypes.cast(ap_mask, dtype=tf.float32)
    # distance between anchor and positive
    dist_ap = tf.reduce_sum(dist_mat*ap_mask, axis=1)/tf.reduce_sum(ap_mask, axis=1)
    # ap - dist_mat + margin
    mat = tf.expand_dims(dist_ap, 1) - dist_mat + margin
    # only need ap-an
    mat = mat*an_mask
    # caluculate the number of valid triplet loss
    mask = tf.dtypes.cast(tf.math.greater(mat, 0.0), dtype=tf.float32)
    num_valid_triplets = tf.reduce_sum(mask)

    # the max of distance between anchor and positive
    hardest_dist_ap = tf.reduce_max(dist_mat*ap_mask, axis=1)
    # the max of distance of all dist_mat
    max_num = tf.reduce_max(dist_mat, axis=1)
    # the max of distance between anchor and negative
    hardest_dist_an = dist_mat*an_mask + ap_mask*max_num
    hardest_dist_an = tf.reduce_min(hardest_dist_an, axis=1)
    # ap - dist_mat + margin
    mat = hardest_dist_ap - hardest_dist_an + margin
    # caluculate the number of valid triplet loss
    # mask = tf.dtypes.cast(tf.math.greater(mat, 0.0), dtype=tf.float32)
    # triplet_loss = tf.reduce_mean(mat*mask)
    # triplet_loss = tf.maximum(triplet_loss, 0.0)
    triplet_loss = tf.maximum(mat, 0.0)
    triplet_loss = tf.reduce_mean(triplet_loss)
    return triplet_loss, num_valid_triplets

In [0]:
triplet_loss, num_valid_triplets = batch_hard(labels, emb1, emb2, margin)

In [32]:
triplet_loss

<tf.Tensor: id=306, shape=(), dtype=float32, numpy=16.31944>

In [33]:
num_valid_triplets

<tf.Tensor: id=289, shape=(), dtype=float32, numpy=1930.0>