In [3]:
_NEG_INF_FP32 = -1e9
_INF_FP32 = 1e9

import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)


import logging
import src.models.encoder.utils as utils

# pred = tf.random.uniform((2,10))
# targets = pred  # (pred_count,targets_count,hidden_dim)
# print(pred)
# print(targets)
    


def distances(a, b):
    logger = tf.get_logger()
    a = tf.expand_dims(a, 1)
    b = tf.expand_dims(b, 0)

    a = tf.tile(a, [1, b.shape[1], 1])
    b = tf.tile(b, [a.shape[0], 1, 1])
    
    logger.debug("A")
    logger.debug(a)
    logger.debug("B")
    logger.debug(b)
    
    sub = a - b
    logger.debug(sub)
    return tf.math.reduce_euclidean_norm(sub, -1)


def get_distances_to_targets(y_pred, y_true, pred_positions, target_positions, categories, debug=False):
    logger = tf.get_logger()

    r = tf.range(0, limit=tf.shape(pred_positions)[0])
    r = tf.reshape(r, shape=[tf.shape(r)[0], -1, 1])
    # indices of predictions [batch_size, 2]
    pred_indices = tf.squeeze(tf.concat([r, pred_positions], axis=-1), axis=[1])
    predictions = tf.gather_nd(y_pred, pred_indices)

    targets = tf.reshape(y_true, [-1, y_true.shape[-1]])
    dist = distances(predictions, targets)
    dist = tf.reshape(dist, (dist.shape[0], y_true.shape[0], -1))  # [batch_size, batch_size, seq_length]

    # indices of correct targets [batch_size, 2]
    target_indices = tf.squeeze(tf.concat([r, target_positions], axis=-1), axis=[1])
    dist_to_pos = tf.gather_nd(dist, target_indices, 1)  # [batch_size,]
    
    # Replace distances with true targets with max float, so they don't affect min aggregation
    r = tf.squeeze(r, axis=[1])
    target_indices = tf.concat([r, target_indices], axis=-1)  # [batch_size, 3]
    updates = tf.repeat(tf.constant(tf.float32.max, shape=1), target_indices.shape[0])
    dist_to_neg = tf.tensor_scatter_nd_update(dist, target_indices, updates)
    dist_to_neg = tf.math.reduce_min(dist_to_neg, axis=[-2, -1])  # [batch_size,]

    if debug:
        logger.debug("Predictions")
        logger.debug(predictions)
        logger.debug("Targets")
        logger.debug(targets)
        logger.debug("Distances")
        logger.debug(dist)
        logger.debug("Distances to positives")
        logger.debug(dist_to_pos)
        logger.debug("Distances to negatives")
        logger.debug(dist_to_neg)

    return dist_to_pos, dist_to_neg


def outfit_distance_loss(y_pred, y_true, categories, mask_positions, margin, acc: tf.metrics.Accuracy = None,
                         debug=False):
    logger = tf.get_logger()
    
    # Replace padding to max vectors
    padding_indices = tf.where(tf.math.equal(categories, tf.zeros_like(categories)))
    max_tensor = tf.repeat(tf.constant(_INF_FP32), [y_true.shape[-1]])
    max_tensor = tf.expand_dims(max_tensor, 0)
    max_tensor = tf.tile(max_tensor, [padding_indices.shape[0],1])
    y_true = tf.tensor_scatter_nd_update(y_true, padding_indices, max_tensor)
    
    
    dist_to_pos, dist_to_neg = get_distances_to_targets(y_pred, y_true, mask_positions, mask_positions, categories, debug)

    if acc is not None:
        # The predictions are correct, when minimal distance is to positives
        dist_delta = dist_to_pos - dist_to_neg
        correct = tf.math.less_equal(dist_delta, tf.zeros_like(dist_delta))
        correct = tf.cast(correct, dtype=tf.int32)
        correct = tf.expand_dims(correct, axis=-1)
        if debug:
            logger.debug("Distance delta")
            logger.debug(dist_delta)
            logger.debug("Correct predictions")
            logger.debug(correct)

        acc.update_state(tf.ones_like(correct), correct)

    margin = tf.constant(margin)
    margin = tf.repeat(margin, dist_to_pos.shape[0])

    loss = dist_to_pos - dist_to_neg + margin
    loss = tf.clip_by_value(loss, 0, tf.float32.max)

    if debug:
        logger.debug("Distance losses")
        logger.debug(loss)

    return tf.reduce_mean(loss)


def outfit_distance_fitb(y_pred, y_true, pred_positions, target_position, categories, acc: tf.metrics.Accuracy,
                         debug=False):
    dist_to_pos, dist_to_neg = get_distances_to_targets(y_pred, y_true, pred_positions, categories, debug)

    if acc is not None:
        # The predictions are correct, when minimal distance is to positives
        dist_delta = dist_to_pos - dist_to_neg
        correct = tf.math.less_equal(dist_delta, tf.zeros_like(dist_delta))
        correct = tf.cast(correct, dtype=tf.int32)
        correct = tf.expand_dims(correct, axis=-1)
        acc.update_state(tf.ones_like(correct), correct)




logger = tf.get_logger()
logger.setLevel(logging.DEBUG)
tf.config.experimental_run_functions_eagerly(True)
x = tf.random.uniform((5,3,10))
y = tf.random.uniform((5,3,10))
categories = tf.random.uniform((5,3), dtype=tf.int32, maxval=10)
mask_positions = tf.constant([0,2,2,1,0], dtype=tf.int32, shape=(5,1,1))
acc = tf.metrics.Accuracy()
print(categories)
print(x)
print(mask_positions)

loss = outfit_distance_loss(x, y, categories, mask_positions, 0.3, acc, True)

print(loss)
print(acc.result())
# outfit_metric_loss(x,y, categories, mask)


tf.Tensor(
[[9 7 9]
 [5 9 3]
 [6 6 2]
 [8 4 6]
 [8 9 8]], shape=(5, 3), dtype=int32)
tf.Tensor(
[[[0.6332818  0.37437057 0.97834027 0.29615462 0.81051874 0.00326824
   0.30646002 0.30802965 0.81789315 0.9001994 ]
  [0.22230077 0.12858367 0.3751992  0.15147996 0.6637571  0.77481043
   0.0857228  0.10861349 0.4825722  0.41864192]
  [0.9120662  0.04456544 0.19175863 0.0079267  0.7025062  0.5789548
   0.6471     0.36209428 0.308851   0.17166579]]

 [[0.23634124 0.3345909  0.5464374  0.94532716 0.39460695 0.93103766
   0.7921376  0.9655868  0.9963026  0.21122253]
  [0.33498442 0.1518004  0.5120069  0.10390615 0.9349996  0.7113334
   0.8012409  0.2497276  0.3313235  0.35388887]
  [0.37096    0.47547102 0.05785775 0.01490092 0.01865232 0.71636987
   0.67896664 0.5716256  0.29744112 0.61859787]]

 [[0.690277   0.5703813  0.45322597 0.93969595 0.2257464  0.61083555
   0.4119166  0.65223885 0.6922053  0.19512045]
  [0.35838473 0.39219773 0.05804503 0.18354249 0.03886712 0.9149746
   0.48377693 0