In [1]:
import tensorflow as tf

In [2]:
te1 = tf.random.normal((16,28,28,10))
te1 = tf.nn.softmax( te1 )

te2 = tf.random.normal((16,28,28,10))
te2 = tf.nn.softmax( te2 )

print(te1.shape, te2.shape)

(16, 28, 28, 10) (16, 28, 28, 10)


In [3]:
# 1 - directly
outer_product = \
    tf.einsum('bxyi, bxyj -> bxyij', te1, te2) # BATCH * WIDTH * HEIGHT * CLASSES * CLASSES

joint_p = tf.math.reduce_mean(outer_product, axis=0)            # осталось WIDTH * HEIGHT * CLASSES * CLASSES

# transpose last two axes and sum to get symmetric matrix
joint_p_T = tf.transpose(joint_p, perm=[0, 1, 3, 2])
#joint_p = (joint_p + joint_p_T) / 2

P_i = tf.math.reduce_sum(joint_p, axis=-1, keepdims=True) # WIDTH * HEIGHT * CLASSES * 1
P_j = tf.math.reduce_sum(joint_p, axis=-2, keepdims=True) # WIDTH * HEIGHT * 1 * CLASSES

eps = tf.keras.backend.epsilon()

per_pixel_mutual_inf = tf.math.reduce_sum(
    joint_p * tf.math.log((P_i * P_j + eps) / (joint_p + eps)), # +eps, чтобы избежать деления на 0
    axis=(-1, -2)
)

print(per_pixel_mutual_inf.shape)

mean_mi = tf.math.reduce_mean(per_pixel_mutual_inf)

(28, 28)


In [4]:
# 2 - via con
te1T = tf.transpose( te1, (3,1,2,0) )
te2T = tf.transpose( te2, (1,2,0,3) )
te1T = tf.pad( te1T, [[0,0],[1,1],[1,1],[0,0]] )
print(te1T.shape,te2T.shape)

(10, 30, 30, 16) (28, 28, 16, 10)


In [5]:
re_conv = tf.nn.conv2d( te1T, te2T, strides=1, padding='VALID'  )
print(re_conv.shape)

(10, 3, 3, 10)


In [9]:
red_P_dir = tf.math.reduce_sum(joint_p, axis=(0,1) )
red_P_conv = tf.squeeze(re_conv[:,1,1,:])/16
print( red_P_dir )
print( red_P_conv )
print( red_P_conv/red_P_dir )

tf.Tensor(
[[1.         0.99999994 1.         1.         1.         1.
  1.         1.         1.         1.         1.         1.
  1.         1.         1.         1.         1.         1.
  1.         1.         0.99999994 1.         1.         1.
  1.         1.         1.         0.99999994]
 [1.         1.         1.         1.         1.         1.
  0.99999994 1.         1.         1.         1.         1.
  1.         1.         1.         0.99999994 1.         1.
  1.         1.         0.99999994 1.         1.         1.
  1.         1.         1.         1.        ]
 [1.         1.         0.9999999  1.         1.         0.99999994
  1.         1.0000001  1.         1.         1.         1.
  1.         1.         1.         1.         1.         0.99999994
  1.         1.         1.         1.         1.         1.
  1.         1.         0.99999994 1.        ]
 [1.         1.         1.         0.99999994 1.         1.
  1.         1.         1.         1.         1.    

InvalidArgumentError: {{function_node __wrapped__RealDiv_device_/job:localhost/replica:0/task:0/device:GPU:0}} required broadcastable shapes [Op:RealDiv]

In [7]:
mean_mi

<tf.Tensor: shape=(), dtype=float32, numpy=-0.02175135>

In [8]:
P_i

<tf.Tensor: shape=(28, 28, 10, 1), dtype=float32, numpy=
array([[[[0.09889098],
         [0.11546433],
         [0.10938349],
         ...,
         [0.09735816],
         [0.08905205],
         [0.08005883]],

        [[0.10646892],
         [0.1207222 ],
         [0.10706271],
         ...,
         [0.05770566],
         [0.13026455],
         [0.06445151]],

        [[0.12049067],
         [0.07907608],
         [0.10022384],
         ...,
         [0.07530724],
         [0.13147008],
         [0.09464107]],

        ...,

        [[0.10759808],
         [0.08001567],
         [0.11190833],
         ...,
         [0.14674664],
         [0.11685097],
         [0.07652178]],

        [[0.12089074],
         [0.11033066],
         [0.09316619],
         ...,
         [0.08520522],
         [0.09340656],
         [0.0935605 ]],

        [[0.11825372],
         [0.11446823],
         [0.08219303],
         ...,
         [0.09562606],
         [0.08193766],
         [0.10925566]]],


   

In [25]:
def conv_loss(te1: tf.Tensor, te2: tf.Tensor):
    #print(te1, te2)
    te1T = tf.transpose( te1, (3,1,2,0) )
    te2T = tf.transpose( te2, (1,2,0,3) )
    te1T = tf.pad( te1T, [[0,0],[1,1],[1,1],[0,0]] )
    re_conv = tf.nn.conv2d( te1T, te2T, strides=1, padding='VALID'  )
    joint_p = tf.squeeze(re_conv[:,1,1,:])/16/28/28
    
    # transpose and sum to get symmetric matrix
    joint_p_T = tf.transpose(joint_p, perm=[1, 0])
    joint_p = (joint_p + joint_p_T) / 2

    P_i = tf.math.reduce_sum(joint_p, axis=-1, keepdims=True) # CLASSES * 1
    P_j = tf.math.reduce_sum(joint_p, axis=-2, keepdims=True) # 1 * CLASSES

    eps = tf.keras.backend.epsilon()

    per_pixel_mutual_inf = tf.math.reduce_sum(
        joint_p * tf.math.log((P_i * P_j + eps) / (joint_p + eps)), # +eps, чтобы избежать деления на 0
        axis=(-1, -2)
    )
    return tf.math.reduce_mean(per_pixel_mutual_inf)

In [26]:
conv_loss(te1, te2)

tf.Tensor(
[[0.00998408 0.00993702 0.01016819 0.00995807 0.01014788 0.00999915
  0.01007969 0.01009627 0.01012776 0.00996185]
 [0.00975654 0.00980369 0.00984628 0.00996711 0.00991098 0.00975422
  0.00987861 0.01019149 0.0098987  0.01002124]
 [0.01002306 0.01001921 0.00973457 0.00990755 0.01026458 0.00989603
  0.01011785 0.01016392 0.01004293 0.00997903]
 [0.00983996 0.00994958 0.00973935 0.00983507 0.01000422 0.00995556
  0.00977747 0.00996091 0.00972078 0.00994237]
 [0.00999482 0.00995146 0.00993355 0.00999395 0.01007077 0.01001008
  0.01012508 0.01013936 0.00988353 0.01015522]
 [0.00990851 0.00998177 0.00994581 0.01016514 0.01015332 0.00995566
  0.01002532 0.01022292 0.01022278 0.01004158]
 [0.01014345 0.01017935 0.00970935 0.01012404 0.01024052 0.00987536
  0.01012599 0.01007097 0.01019522 0.01014043]
 [0.0098561  0.00998967 0.00997244 0.01008971 0.01018165 0.00987839
  0.01005443 0.00999466 0.01014842 0.00997118]
 [0.00989098 0.01014195 0.00999058 0.00997368 0.01002982 0.01009271
 

<tf.Tensor: shape=(), dtype=float32, numpy=-2.496715e-05>

In [23]:
def negative_mutual_inf_without_shifts(outp: tf.Tensor, inv_transformed_outp: tf.Tensor):
    """Принимает на вход два тензора с размерностью BATCH * X * Y * CLASSES:
    - shift(pred(img)), где shift - произвольный сдвиг
    - shift_inv(T_inv(pred(T(img)))), где T - произвольная трансформация

    Собирая статистику по батчу и по пикселям, мы вычисляем отрицательную взаимную информацию.
    """

    joint_p = \
        tf.einsum('bxyi, bxyj -> ij', outp, inv_transformed_outp) / 16 / 28/ 28
        # CLASSES * CLASSES
    print(joint_p)

    # transpose and sum to get symmetric matrix
    joint_p_T = tf.transpose(joint_p, perm=[1, 0])
    joint_p = (joint_p + joint_p_T) / 2

    P_i = tf.math.reduce_sum(joint_p, axis=-1, keepdims=True) # CLASSES * 1
    P_j = tf.math.reduce_sum(joint_p, axis=-2, keepdims=True) # 1 * CLASSES

    eps = tf.keras.backend.epsilon()

    per_pixel_mutual_inf = tf.math.reduce_sum(
        joint_p * tf.math.log((P_i * P_j + eps) / (joint_p + eps)), # +eps, чтобы избежать деления на 0
        axis=(-1, -2)
    )
    return tf.math.reduce_mean(per_pixel_mutual_inf)

In [24]:
negative_mutual_inf_without_shifts(te1, te2)

tf.Tensor(
[[0.00998408 0.00993702 0.0101682  0.00995807 0.01014788 0.00999915
  0.0100797  0.01009627 0.01012776 0.00996186]
 [0.00975654 0.00980368 0.00984628 0.00996711 0.00991098 0.00975422
  0.00987861 0.01019149 0.0098987  0.01002125]
 [0.01002306 0.0100192  0.00973457 0.00990755 0.01026458 0.00989603
  0.01011785 0.01016392 0.01004293 0.00997903]
 [0.00983997 0.00994958 0.00973935 0.00983507 0.01000422 0.00995556
  0.00977748 0.00996091 0.00972078 0.00994237]
 [0.00999482 0.00995146 0.00993355 0.00999395 0.01007077 0.01001008
  0.01012508 0.01013936 0.00988353 0.01015522]
 [0.00990851 0.00998177 0.00994581 0.01016514 0.01015332 0.00995566
  0.01002532 0.01022292 0.01022278 0.01004159]
 [0.01014345 0.01017936 0.00970935 0.01012404 0.01024052 0.00987536
  0.01012599 0.01007098 0.01019522 0.01014043]
 [0.0098561  0.00998967 0.00997244 0.01008971 0.01018165 0.00987839
  0.01005443 0.00999466 0.01014842 0.00997118]
 [0.00989098 0.01014195 0.00999058 0.00997367 0.01002982 0.0100927
  

<tf.Tensor: shape=(), dtype=float32, numpy=-2.4985522e-05>