In [42]:
import torch
import torch.nn as nn

In [6]:
from keras.layers import LeakyReLU

In [64]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, discriminator_conv_channels, discriminator_conv_kernel_size, 
                 discriminator_conv_strides, discriminator_activation, discriminator_dropout_rate, 
                 discriminator_batch_norm_momentum=0.1, discriminator_batch_norm_use=True):
        
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.discriminator_conv_channels=discriminator_conv_channels
        self.discriminator_conv_kernel_size = discriminator_conv_kernel_size
        self.discriminator_conv_strides = discriminator_conv_strides
        self.discriminator_batch_norm_momentum = discriminator_batch_norm_momentum
        self.discriminator_batch_norm_use=discriminator_batch_norm_use
        self.discriminator_activation = discriminator_activation
        self.discriminator_dropout_rate = discriminator_dropout_rate
        self.n_layers_discriminator = len(discriminator_conv_channels)
        self.discriminator_conv_layers=nn.Sequential()


        # discriminator_input = Input(shape=self.input_dim, name='discriminator_input')

        zero_input=torch.zeros(size=self.input_dim)
        for i in range(self.n_layers_discriminator):
            if i==0:
                self.discriminator_conv_layers.add_module("Conv "+str(i), nn.Conv2d(1, self.discriminator_conv_channels[i], 
                                                                     self.discriminator_conv_kernel_size[i],
                                                                    stride=self.discriminator_conv_strides[i],
                                                                    padding=2))
            else:
                self.discriminator_conv_layers.add_module("Conv "+str(i), nn.Conv2d(self.discriminator_conv_channels[i-1],
                                                                     self.discriminator_conv_channels[i], 
                                                                     self.discriminator_conv_kernel_size[i],
                                                                    stride=self.discriminator_conv_strides[i],
                                                                    padding=2))
            if self.discriminator_batch_norm_use and i > 0:
                self.discriminator_conv_layers.add_module("Batchnorm "+str(i), nn.BatchNorm2d(self.discriminator_conv_channels[i]))                

            self.discriminator_conv_layers.add_module("Activation "+str(i), self.get_activation())
            if self.discriminator_dropout_rate:
                self.discriminator_conv_layers.add_module("Dropout "+str(i), nn.Dropout(p=self.discriminator_dropout_rate))

        self.discriminator_conv_layers.add_module("Flatten", nn.Flatten())
        zero_output=self.discriminator_conv_layers(zero_input)
        output_size=zero_output.size()[1]
        

        self.discriminator_conv_layers.add_module("Fully Connected Layer", nn.Linear(output_size, 1))
    
    
            
    def forward(self, x):
        out=self.discriminator_conv_layers(x)
        return out
             
                
    def get_activation(self, activation='relu'):
        if activation == 'leaky_relu':
            layer = nn.LeakyReLU(negative_slope=0.2)
        else:
            layer = nn.ReLU()
        return layer

In [65]:
test_dcm=discriminator(input_dim=(60,1,28,28), discriminator_conv_channels = [64,64,128,128]
, discriminator_conv_kernel_size = [5,5,5,5]
, discriminator_conv_strides = [2,2,2,1]
, discriminator_batch_norm_momentum = None
, discriminator_activation = 'relu'
, discriminator_dropout_rate = 0.4, discriminator_learning_rate=0.0001, )

In [66]:
print(test_dcm)

discriminator(
  (discriminator_conv_layers): Sequential(
    (Conv 0): Conv2d(1, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (Activation 0): ReLU()
    (Dropout 0): Dropout(p=0.4, inplace=False)
    (Conv 1): Conv2d(64, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (Batchnorm 1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (Activation 1): ReLU()
    (Dropout 1): Dropout(p=0.4, inplace=False)
    (Conv 2): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (Batchnorm 2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (Activation 2): ReLU()
    (Dropout 2): Dropout(p=0.4, inplace=False)
    (Conv 3): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (Batchnorm 3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (Activation 3): ReLU()
    (Dropout 3): Dropout(p=0.4, inplace=False)
    (Flatten): Flatten()


In [None]:
def _build_discriminator(self):

    ### THE discriminator
    discriminator_input = Input(shape=self.input_dim, name='discriminator_input')

    x = discriminator_input

    for i in range(self.n_layers_discriminator):

        x = Conv2D(
            filters = self.discriminator_conv_filters[i]
            , kernel_size = self.discriminator_conv_kernel_size[i]
            , strides = self.discriminator_conv_strides[i]
            , padding = 'same'
            , name = 'discriminator_conv_' + str(i)
            , kernel_initializer = self.weight_init
            )(x)

        if self.discriminator_batch_norm_momentum and i > 0:
            x = BatchNormalization(momentum = self.discriminator_batch_norm_momentum)(x)

        x = self.get_activation(self.discriminator_activation)(x)

        if self.discriminator_dropout_rate:
            x = Dropout(rate = self.discriminator_dropout_rate)(x)

    x = Flatten()(x)

    discriminator_output = Dense(1, activation='sigmoid', kernel_initializer = self.weight_init)(x)

    self.discriminator = Model(discriminator_input, discriminator_output)
        
    def get_activation(self, activation):
        if activation == 'leaky_relu':
            layer = LeakyReLU(alpha = 0.2)
        else:
            layer = Activation(activation)
        return layer