In [2]:
_NEG_INF_FP32 = -1e9
_INF_FP32 = 1e9

import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)

def outfit_metric_loss(y_pred, y_true, categories, mask_positions, acc=None, debug=False, categorywise_only=False):
    logger = tf.get_logger()

    feature_dim = y_pred.shape[2]

    # Compute loss only from mask token
    r = tf.range(0, limit=tf.shape(mask_positions)[0])
    r = tf.reshape(r, shape=[tf.shape(r)[0], -1, 1])
    indices = tf.squeeze(tf.concat([r, mask_positions], axis=-1), axis=[1])
    updates = tf.ones(shape=(tf.shape(mask_positions)[0]))
    weights = tf.scatter_nd(indices, updates, tf.shape(categories))
    weights = tf.cast(weights, dtype="float32")
    weights = tf.reshape(weights, [-1])
    print(weights)
    weights_sum = tf.reduce_sum(weights)
    print(weights_sum)

    # Reshape to batch (size * seq length, feature dim)
    pred_batch = tf.reshape(y_pred, [-1, feature_dim])
    true_batch = tf.reshape(y_true, [-1, feature_dim])
    item_count = tf.shape(true_batch)[0]

    # Dot product of every prediction with all labels
    logits = tf.matmul(pred_batch, true_batch, transpose_b=True)
    print(weights_sum)

    # Compute logits only within categories
    if categorywise_only:
        flat_categories = tf.reshape(categories, [-1])
        cat_mask = tf.equal(flat_categories[:, tf.newaxis], flat_categories[tf.newaxis, :])
        cat_mask = tf.logical_not(cat_mask)
        cat_mask = tf.cast(cat_mask, dtype="float32")
        cat_mask = cat_mask * _NEG_INF_FP32  # -inf on cells when categories don't match
        logits = tf.add(logits, cat_mask)
        if debug:
            logger.debug("Category mask")
            logger.debug(cat_mask)

    if debug:
        logger.debug("Loss weights")
        logger.debug(weights)
        logger.debug("Item Count")
        logger.debug(item_count)
        logger.debug("Logits")
        logger.debug(logits)

    # One-hot labels (the indentity matrix)
    labels = tf.eye(item_count, item_count)

    if acc is not None:
        acc(labels, logits, sample_weight=weights)

    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
    if debug:
        logger.debug(cross_entropy)
    cross_entropy = tf.tensordot(cross_entropy, weights, 1)
    if debug:
        logger.debug("Cross Entropy")
        logger.debug(cross_entropy)
    return tf.reduce_sum(cross_entropy) / weights_sum

In [26]:
import logging
import src.models.encoder.utils as utils

# pred = tf.random.uniform((2,10))
# targets = pred  # (pred_count,targets_count,hidden_dim)
# print(pred)
# print(targets)
    


def distances(a, b):
    a = tf.expand_dims(a, 1)
    b = tf.expand_dims(b, 0)

    a = tf.tile(a, [1, b.shape[1], 1])
    b = tf.tile(b, [a.shape[0], 1, 1])

    sub = a - b
    return tf.math.reduce_euclidean_norm(sub, -1)


def get_distances_to_targets(y_pred, y_true, mask_positions, categories, debug=False):
    logger = tf.get_logger()

    r = tf.range(0, limit=tf.shape(mask_positions)[0])
    r = tf.reshape(r, shape=[tf.shape(r)[0], -1, 1])

    # indices of mask tokens (batch_size, 2)
    indices = tf.squeeze(tf.concat([r, mask_positions], axis=-1), axis=[1])

    predictions = tf.gather_nd(y_pred, indices)

    targets = tf.reshape(y_true, [-1, y_true.shape[-1]])
    dist = distances(predictions, targets)
    dist = tf.reshape(dist, (dist.shape[0], y_true.shape[0], -1))  # (batch_size, batch_size, seq_length)

    dist_to_pos = tf.gather_nd(dist, indices, 1)  # (batch_size,)
    
    # Replace distances with true targets with max float, so it don't affect min aggregation
    r = tf.squeeze(r, axis=[1])
    indices = tf.concat([r, indices], axis=-1)  # (batch_size, 3)
    updates = tf.repeat(tf.constant(tf.float32.max, shape=1), indices.shape[0])
    dist_to_neg = tf.tensor_scatter_nd_update(dist, indices, updates)
    dist_to_neg = tf.math.reduce_min(dist_to_neg, axis=[-2, -1])  # (batch_size,)
    # TODO: Mean aggregation?

    if debug:
        logger.debug("Predictions")
        logger.debug(predictions)
        logger.debug("Distances")
        logger.debug(dist)
        logger.debug("Distances to positives")
        logger.debug(dist_to_pos)
        logger.debug("Distances to negatives")
        logger.debug(dist_to_neg)

    return dist_to_pos, dist_to_neg


def outfit_distance_loss(y_pred, y_true, categories, mask_positions, margin, acc: tf.metrics.Accuracy = None,
                         debug=False):
    logger = tf.get_logger()
    
    # Replace padding to max vectors
    padding_indices = tf.where(tf.math.equal(categories, tf.zeros_like(categories)))
    max_tensor = tf.repeat(tf.constant(_INF_FP32), [y_true.shape[-1]])
    max_tensor = tf.expand_dims(max_tensor, 0)
    max_tensor = tf.tile(max_tensor, [padding_indices.shape[0],1])
    y_true = tf.tensor_scatter_nd_update(y_true, padding_indices, max_tensor)
    
    
    dist_to_pos, dist_to_neg = get_distances_to_targets(y_pred, y_true, mask_positions, categories, debug)

    if acc is not None:
        # The predictions are correct, when minimal distance is to positives
        dist_delta = dist_to_pos - dist_to_neg
        correct = tf.math.less_equal(dist_delta, tf.zeros_like(dist_delta))
        correct = tf.cast(correct, dtype=tf.int32)
        correct = tf.expand_dims(correct, axis=-1)
        if debug:
            logger.debug("Distance delta")
            logger.debug(dist_delta)
            logger.debug("Correct predictions")
            logger.debug(correct)

        acc.update_state(tf.ones_like(correct), correct)

    margin = tf.constant(margin)
    margin = tf.repeat(margin, dist_to_pos.shape[0])

    loss = dist_to_pos - dist_to_neg + margin
    loss = tf.clip_by_value(loss, 0, tf.float32.max)

    if debug:
        logger.debug("Distance losses")
        logger.debug(loss)

    return tf.reduce_mean(loss)


def outfit_distance_fitb(y_pred, y_true, pred_positions, target_position, categories, acc: tf.metrics.Accuracy,
                         debug=False):
    dist_to_pos, dist_to_neg = get_distances_to_targets(y_pred, y_true, pred_positions, categories, debug)

    if acc is not None:
        # The predictions are correct, when minimal distance is to positives
        dist_delta = dist_to_pos - dist_to_neg
        correct = tf.math.less_equal(dist_delta, tf.zeros_like(dist_delta))
        correct = tf.cast(correct, dtype=tf.int32)
        correct = tf.expand_dims(correct, axis=-1)
        acc.update_state(tf.ones_like(correct), correct)




logger = tf.get_logger()
logger.setLevel(logging.DEBUG)
tf.config.experimental_run_functions_eagerly(True)
x = tf.random.uniform((5,3,10))
y = tf.random.uniform((5,3,10))
categories = tf.random.uniform((5,3), dtype=tf.int32, maxval=10)
mask_positions = tf.constant([0,2,2,1,0], dtype=tf.int32, shape=(5,1,1))
acc = tf.metrics.Accuracy()
print(categories)
print(x)
print(mask_positions)

loss = outfit_distance_loss(x, y, categories, mask_positions, 0.3, acc, True)

print(loss)
print(acc.result())
# outfit_metric_loss(x,y, categories, mask)


tf.Tensor(
[[0 7 8]
 [8 8 6]
 [3 6 4]
 [5 8 8]
 [1 4 0]], shape=(5, 3), dtype=int32)
tf.Tensor(
[[[0.5967723  0.5554106  0.9064776  0.9156077  0.27187347 0.4802996
   0.99415004 0.25797164 0.1838733  0.46029675]
  [0.8987533  0.77158034 0.85079825 0.22337222 0.24118328 0.61456895
   0.49452412 0.11292434 0.6120714  0.06572187]
  [0.6880684  0.50898004 0.21556902 0.622622   0.59759843 0.72719824
   0.4446032  0.64208543 0.5643567  0.9708276 ]]

 [[0.22983992 0.49004972 0.899488   0.3994311  0.9147587  0.35180414
   0.96081185 0.49318624 0.01536858 0.4317062 ]
  [0.45768368 0.49648523 0.63508797 0.25599325 0.8320546  0.16188025
   0.97645664 0.58394694 0.18825042 0.21743977]
  [0.19358087 0.21551955 0.9478748  0.3367859  0.06275141 0.00188279
   0.26903653 0.5136992  0.801062   0.68568754]]

 [[0.39514256 0.6507938  0.4341234  0.61445653 0.86800027 0.31206632
   0.9667369  0.7841456  0.2793628  0.93566287]
  [0.616403   0.745919   0.7116568  0.00232315 0.01109266 0.27215922
   0.74591446