In [None]:
num_features = X.shape[1]
kernel_dims = 1

features = tf.reshape(X, (-1, num_features, kernel_dims)) # NxBxC

We add a dimension so that when we subtract this
with the rest of the batch we can take advantage
of broadcasting to calculate all substractions
at once in a big tensor. So it'll go from 1xNxBxC
to a NxNxBxC which is basically saying that we want
our NxBxC input tensor copied N times. So we'll
end up with N identical NxBxC tensors. So for example
let's assume that
```
[
  [ [a], [b], [c] ],
  [ [d], [e], [f] ]
]
```
is our input (NxBxC=2x3x1) once we do Mi - Mj, then
Mi will be broadcasted to be a (NxNxBxC=2x2x3x1)
```
[
  [
    [ [a], [b], [c] ],
    [ [d], [e], [f] ]
  ],
  [
    [ [a], [b], [c] ],
    [ [d], [e], [f] ]
  ]
]
```
which is the same tensor Mi copied N times.

In [None]:
Mi = tf.expand_dims(features, axis=0) # 1xNxBxC
# Use this to visualize what it will look like after broadcasting
# Mi = tf.broadcast_to(Mi, (X.shape[1], num_samples, num_features, kernel_dims)) # NxNxBxC

This is the same idea as above but this time we
want the final tensor to be arranged slightly differently.
By adding a dimension in axis 1 we are basically saying
that we want N tensors but that in each tensor we want
one matrix that corresponds to 1 sample in the batch.
So if you had for example the same as above:
```
[
  [ [a], [b], [c] ],
  [ [d], [e], [f] ]
]
```
after this exansion you would now have
```
[
  [
    [ [a], [b], [c] ]
  ],
  [
    [ [d], [e], [f] ]
  ]
] # (Nx1xBxC=2x1x3x1)
```
you can think of it as having separated each input in
the batch in its own tensor. So once we do Mi - Mj
it will be broadcasted to a (NxNxBxC=2x2x3x1) giving you
```
[
  [
    [ [a], [b], [c] ],
    [ [a], [b], [c] ]
  ],
  [
    [ [d], [e], [f] ],
    [ [d], [e], [f] ]
  ]
]
```
which is basically copying each input N times for all
the samples in the batch

In [None]:
Mj = tf.expand_dims(features, axis=1) # Nx1xBxC
# Use this to visualize what it will look like after broadcasting
# Mj = tf.broadcast_to(Mj, (num_samples, Mj.shape[0], num_features, kernel_dims)) # NxNxBxC

In [None]:
abs_diff = tnp.abs(Mi - Mj) # NxNxBxC

# Sum each tube
norm = tnp.sum(abs_diff, axis=3) # NxNxB

Each tube contains one image that has been subtracted with
another image in the batch, so all the subtractions of one
image with the others in the batch are stored in the 1st
dimension (`shape[0]`) i.e All image_1 calcs are stored in
`[0,:,:]`, image_2 in `[1,:,:]` etc So we sum accross the 1st
dimensation to get the sum of all image subtraction for all
image_i

In [None]:
outputs = tnp.sum(tnp.exp(-norm), axis=0) # NxB

Stack those differences with your input. So for example if
your input is
```
[
  [ [z], [y] ],
  [ [x], [w] ]
] # (NxB=2x2)
```
you would now have something like this
```
[
  [ [z], [y], [a], [b], [c] ],
  [ [x], [w], [d], [e], [f] ]
] # (NxB1+B2=2x5)
```

In [None]:
inputs_to_next_layer = Concatenate(axis=1)((X, outputs)) # Nx(B+X.shape[-1])

### Imports

In [2]:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
from tensorflow.keras.layers import Concatenate



### With comments (Lambda)

In [None]:
def minibatch_discrimination(X):
    num_features = X.shape[1]
    kernel_dims = 1

    features = tf.reshape(X, (-1, num_features, kernel_dims)) # NxBxC
    
    # We add a dimension so that when we subtract this
    # with the rest of the batch we can take advantage
    # of broadcasting to calculate all substractions
    # at once in a big tensor. So it'll go from 1xNxBxC
    # to a NxNxBxC which is basically saying that we want
    # our NxBxC input tensor copied N times. So we'll
    # end up with N identical NxBxC tensors. So for example
    # let's assume that
    # [
    #   [ [a], [b], [c] ],
    #   [ [d], [e], [f] ]
    # ]
    # is our input (NxBxC=2x3x1) once we do Mi - Mj, then
    # Mi will be broadcasted to be a (NxNxBxC=2x2x3x1)
    # [
    #  [
    #   [ [a], [b], [c] ],
    #   [ [d], [e], [f] ]
    #  ],
    #  [
    #   [ [a], [b], [c] ],
    #   [ [d], [e], [f] ]
    #  ]
    # ]
    # which is the same tensor Mi copied N times.
    Mi = tf.expand_dims(features, axis=0) # 1xNxBxC
    # Use this to visualize what it will look like after broadcasting
    # Mi = tf.broadcast_to(Mi, (X.shape[1], num_samples, num_features, kernel_dims)) # NxNxBxC

    # This is the same idea as above but this time we
    # want the final tensor to be arranged slightly differently.
    # By adding a dimension in axis 1 we are basically saying
    # that we want N tensors but that in each tensor we want
    # one matrix that corresponds to 1 sample in the batch.
    # So if you had for example the same as above:
    # [
    #   [ [a], [b], [c] ],
    #   [ [d], [e], [f] ]
    # ] after this exansion you would now have
    # [
    #  [
    #   [ [a], [b], [c] ]
    #  ],
    #  [
    #   [ [d], [e], [f] ]
    #  ]
    # ] (Nx1xBxC=2x1x3x1)
    # you can think of it as having separated each input in
    # the batch in its own tensor. So once we do Mi - Mj
    # it will be broadcasted to a (NxNxBxC=2x2x3x1) giving you
    # [
    #  [
    #   [ [a], [b], [c] ],
    #   [ [a], [b], [c] ]
    #  ],
    #  [
    #   [ [d], [e], [f] ],
    #   [ [d], [e], [f] ]
    #  ]
    # ] which is basically copying each input N times for all
    # the samples in the batch
    Mj = tf.expand_dims(features, axis=1) # Nx1xBxC
    # Use this to visualize what it will look like after broadcasting
    # Mj = tf.broadcast_to(Mj, (num_samples, Mj.shape[0], num_features, kernel_dims)) # NxNxBxC

    abs_diff = tnp.abs(Mi - Mj) # NxNxBxC

    # Sum each tube
    norm = tnp.sum(abs_diff, axis=3) # NxNxB

    # Each tube contains one image that has been subtracted with
    # another image in the batch, so all the subtractions of one
    # image with the others in the batch are stored in the 1st
    # dimension (shape[0]) i.e All image_1 calcs are stored in
    # [0,:,:], image_2 in [1,:,:] etc So we sum accross the 1st
    # dimensation to get the sum of all image subtraction for all
    # image_i
    outputs = tnp.sum(tnp.exp(-norm), axis=0) # NxB

    # Stack those differences with your input. So for example if
    # your input is
    # [
    #   [ [z], [y] ],
    #   [ [x], [w] ]
    # ] (NxB=2x2)
    # you would now have something like this
    # [
    #   [ [z], [y], [a], [b], [c] ],
    #   [ [x], [w], [d], [e], [f] ]
    # ] (NxB1+B2=2x5)
    return Concatenate(axis=1)((X, outputs)) # Nx(B+X.shape[-1])

### Without comments (Lambda)

In [None]:
def minibatch_discrimination(X):
    num_features = X.shape[1]
    kernel_dims = 1

    features = tf.reshape(X, (-1, num_features, kernel_dims)) # NxBxC

    Mi = tf.expand_dims(features, axis=0) # 1xNxBxC
  
    Mj = tf.expand_dims(features, axis=1) # Nx1xBxC

    abs_diff = tnp.abs(Mi - Mj) # NxNxBxC

    norm = tnp.sum(abs_diff, axis=3) # NxNxB

    outputs = tnp.sum(tnp.exp(-norm), axis=0) # NxB

    return Concatenate(axis=1)((X, outputs)) # Nx(B+X.shape[-1])

### Implemented as Layer

In [46]:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from tensorflow.keras.layers import Layer

class MinibatchDiscrimination(Layer):
    def __init__(self, kernel_dims, **kwargs):
        super(MinibatchDiscrimination, self).__init__(**kwargs)
        self.kernel_dims = kernel_dims

    def build(self, input_shape):
        self.in_features = input_shape[-1]
    
    def call(self, X):
        features = tf.reshape(X, (-1, self.in_features, self.kernel_dims)) # NxBxC

        Mi = tf.expand_dims(features, axis=0) # 1xNxBxC

        Mj = tf.expand_dims(features, axis=1) # Nx1xBxC

        abs_diff = tnp.abs(Mi - Mj) # NxNxBxC
        print('diff', abs_diff)

        norm = tnp.sum(abs_diff, axis=3) # NxNxB
        print('norm', norm)
    
        print(tnp.exp(-norm))
        outputs = tnp.sum(tnp.exp(-norm), axis=0) # NxB

        return tnp.concatenate((X, outputs), axis=1) # Nx(B+X.shape[-1])
    
    def get_config(self):
        config = super(MinibatchDiscrimination, self).get_config()
        config.update({ 'kernel_dims': self.kernel_dims })
        return config
    
layer = MinibatchDiscrimination(1)

outputs = layer(tf.constant([[1],[2]]))

print(outputs)

diff ndarray<tf.Tensor(
[[[[0]]

  [[1]]]


 [[[1]]

  [[0]]]], shape=(2, 2, 1, 1), dtype=int32)>
norm ndarray<tf.Tensor(
[[[0]
  [1]]

 [[1]
  [0]]], shape=(2, 2, 1), dtype=int64)>
ndarray<tf.Tensor(
[[[1.        ]
  [0.36787944]]

 [[0.36787944]
  [1.        ]]], shape=(2, 2, 1), dtype=float64)>
ndarray<tf.Tensor(
[[1.         0.36787944]
 [2.         0.36787944]], shape=(2, 2), dtype=float64)>
