In [16]:
_NEG_INF_FP32 = -1e9
_INF_FP32 = 1e9

import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)

def outfit_metric_loss(y_pred, y_true, categories, mask_positions, acc=None, debug=False, categorywise_only=False):
    logger = tf.get_logger()

    feature_dim = y_pred.shape[2]

    # Compute loss only from mask token
    r = tf.range(0, limit=tf.shape(mask_positions)[0])
    r = tf.reshape(r, shape=[tf.shape(r)[0], -1, 1])
    indices = tf.squeeze(tf.concat([r, mask_positions], axis=-1), axis=[1])
    updates = tf.ones(shape=(tf.shape(mask_positions)[0]))
    weights = tf.scatter_nd(indices, updates, tf.shape(categories))
    weights = tf.cast(weights, dtype="float32")
    weights = tf.reshape(weights, [-1])
    print(weights)
    weights_sum = tf.reduce_sum(weights)
    print(weights_sum)

    # Reshape to batch (size * seq length, feature dim)
    pred_batch = tf.reshape(y_pred, [-1, feature_dim])
    true_batch = tf.reshape(y_true, [-1, feature_dim])
    item_count = tf.shape(true_batch)[0]

    # Dot product of every prediction with all labels
    logits = tf.matmul(pred_batch, true_batch, transpose_b=True)
    print(weights_sum)

    # Compute logits only within categories
    if categorywise_only:
        flat_categories = tf.reshape(categories, [-1])
        cat_mask = tf.equal(flat_categories[:, tf.newaxis], flat_categories[tf.newaxis, :])
        cat_mask = tf.logical_not(cat_mask)
        cat_mask = tf.cast(cat_mask, dtype="float32")
        cat_mask = cat_mask * _NEG_INF_FP32  # -inf on cells when categories don't match
        logits = tf.add(logits, cat_mask)
        if debug:
            logger.debug("Category mask")
            logger.debug(cat_mask)

    if debug:
        logger.debug("Loss weights")
        logger.debug(weights)
        logger.debug("Item Count")
        logger.debug(item_count)
        logger.debug("Logits")
        logger.debug(logits)

    # One-hot labels (the indentity matrix)
    labels = tf.eye(item_count, item_count)

    if acc is not None:
        acc(labels, logits, sample_weight=weights)

    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
    if debug:
        logger.debug(cross_entropy)
    cross_entropy = tf.tensordot(cross_entropy, weights, 1)
    if debug:
        logger.debug("Cross Entropy")
        logger.debug(cross_entropy)
    return tf.reduce_sum(cross_entropy) / weights_sum

In [10]:
# pred = tf.random.uniform((2,10))
# targets = pred  # (pred_count,targets_count,hidden_dim)
# print(pred)
# print(targets)

def distances(a, b):
    a = tf.expand_dims(a, 1)
    b = tf.expand_dims(b, 0)
    
    a = tf.tile(a, [1, b.shape[1], 1])
    b = tf.tile(b, [pred.shape[0], 1, 1])
    
    sub = a - b
    return tf.math.reduce_euclidean_norm(sub, -1)

In [21]:
tf.config.experimental_run_functions_eagerly(True)
x = tf.random.uniform((5,3,10))
y = tf.random.uniform((5,3,10))
categories = tf.random.uniform((5,3), dtype=tf.int64, maxval=10)
mask_positions = tf.constant([0,2,2,1,0], dtype=tf.int32, shape=(5,1,1))
print(x)
print(mask_positions)

# Compute loss only from mask token
r = tf.range(0, limit=tf.shape(mask_positions)[0])
r = tf.reshape(r, shape=[tf.shape(r)[0], -1, 1])
indices = tf.squeeze(tf.concat([r, mask_positions], axis=-1), axis=[1])
updates = tf.ones(shape=(tf.shape(mask_positions)[0]))
weights = tf.scatter_nd(indices, updates, tf.shape(categories))
weights = tf.cast(weights, dtype="float32")
weights = tf.reshape(weights, [-1])
weights_sum = tf.reduce_sum(weights)

pred = tf.gather_nd(x, indices)
print(pred)
targets = tf.reshape(y, [-1, y.shape[-1]])
dist = distances(pred, targets)
print(dist)

dist = tf.reshape(dist, (dist.shape[0], y.shape[0], -1))  # (batch_size, batch_size, seq_length)
print(dist)

true_dist = tf.gather_nd(dist, indices, 1)
print(true_dist)

r = tf.squeeze(r, axis=[1])
indices = tf.concat([r, indices], axis=-1)
print(indices)
updates = tf.repeat(tf.constant(tf.float32.max, shape=1), indices.shape[0])
dist = tf.tensor_scatter_nd_update(dist, indices, updates)
print(dist)
dist = tf.math.reduce_min(dist,axis=[-2,-1])
print(dist)

margin = tf.constant(0.3)
margin = tf.repeat(margin, dist.shape[0])

loss = true_dist - dist + margin
loss = tf.clip_by_value(loss, 0, tf.float32.max)
loss = tf.reduce_mean(loss)
print(loss)
# outfit_metric_loss(x,y, categories, mask)

tf.Tensor(
[[[0.9993987  0.46449792 0.92946947 0.6358464  0.40524757 0.38769674
   0.74511075 0.3057344  0.90844727 0.45114195]
  [0.01258767 0.67991424 0.7020533  0.760991   0.38382316 0.7346399
   0.05640066 0.55158293 0.02948117 0.52735734]
  [0.21707976 0.2439655  0.5829599  0.41761196 0.9259398  0.61132014
   0.7485541  0.6093919  0.34434843 0.14940023]]

 [[0.73144996 0.39506054 0.16059327 0.7643317  0.7535453  0.23402584
   0.5283179  0.87390506 0.4676646  0.4282725 ]
  [0.81178796 0.87135506 0.1088239  0.21052313 0.8355006  0.17974329
   0.6357639  0.29977334 0.88418937 0.835163  ]
  [0.8404137  0.67142844 0.2886945  0.31312466 0.55817723 0.98303986
   0.6188991  0.21973157 0.205114   0.5254713 ]]

 [[0.22903907 0.05853724 0.552331   0.8723451  0.31157577 0.619758
   0.28508592 0.14667821 0.8354666  0.82019794]
  [0.07968247 0.46111286 0.0777427  0.5082909  0.44312954 0.13500571
   0.94962823 0.34331    0.07990849 0.4066987 ]
  [0.39795744 0.6146102  0.28037405 0.43636453 0.544