In [1]:
import numpy

from model.data_loader import *

In [4]:
train_pth = 'data/facades_processed/train'
input_sz = (256, 256, 3)
batch_size = 3

In [5]:
train_generator = ImageDataGenerator(
    rescale=1./255,
    zoom_range=[0.895, 1.0],
    horizontal_flip=True,
    fill_mode='constant',
    data_format='channels_last',
    validation_split=0.0
)
train_loader = dataLoader(train_pth, train_generator, 
                          batch_sz=batch_size, img_sz=input_sz[:2])

In [9]:
inputs, outputs = next(train_loader)

In [10]:
inputs.shape

(3, 256, 256, 3)

In [12]:
# We first compute the standard deviation for each feature in each spatial location over the minibatch.
std = np.std(inputs, axis=0)
std.shape

(256, 256, 3)

In [14]:
# We then average these estimates over all features and spatial locations to arrive at a single value
mean = np.mean(std)
mean

0.41226003

In [24]:
# We replicate the value and concatenate it to all spatial locations and over the minibatch, 
# yielding one additional (constant) feature map.

op_sz = (4, 4)
op = np.ones((batch_size, op_sz[0], op_sz[1], 1)) * mean
op.shape

(3, 4, 4, 1)

In [25]:
def minibatch_std(minibatch_inps, op_sz):
    """
    A simplified alternative to minibatch discrimination
    as described in section3 of:
    https://arxiv.org/abs/1710.10196
    
    Generates a new feature which represents the mean of
    standard deviation across samples in a minibatch. This
    is then fed to latter layer of discriminator to increase
    variation of GAN outputs and mitigate mode collapse.
    """
    # We first compute the standard deviation for each feature 
    # in each spatial location over the minibatch.
    std = np.std(minibatch_inps, axis=0)
    # We then average these estimates over all features and 
    # spatial locations to arrive at a single value
    mean = np.mean(std)
    # We replicate the value and concatenate it to all spatial 
    #locations and over the minibatch, yielding one additional 
    # (constant) feature map.  
    op = np.ones((batch_size, op_sz[0], op_sz[1], 1)) * mean
    return op

In [29]:
minibatch_std(inputs, op_sz)
print(f'shape:{op.shape}')
print(op)

shape:(3, 4, 4, 1)
[[[[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]

  [[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]

  [[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]

  [[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]]


 [[[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]

  [[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]

  [[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]

  [[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]]


 [[[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]

  [[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]

  [[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]

  [[0.41226003]
   [0.41226003]
   [0.41226003]
   [0.41226003]]]]
