In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv2d_bn_ReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(Conv2d_bn_ReLU, self).__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        self.layers = nn.Sequential(*layers)
    def forward(self, x):
        return self.layers(x)

class GenderDiscriminator(nn.Module):
    def __init__(self):
        super(GenderDiscriminator, self).__init__()
        layers = [
            nn.Conv2d(3, 32, kernel_size=3, stride=2),
            Conv2d_bn_ReLU(32, 64, 3, 1),
            Conv2d_bn_ReLU(64, 128, 3, 2),
            # Conv2d_bn_ReLU(128, 256, 3, 1),
        ]
        self.layers = nn.Sequential(*layers)
        self.classifier = nn.Sequential(
              nn.Linear(128, 2),
              nn.Softmax(),
        )
        for m in self.modules():
        	if isinstance(m, nn.Conv2d):
        		nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        		if m.bias is not None:
        			nn.init.constant_(m.bias, 0)
        		elif isinstance(m, nn.BatchNorm2d):
        			nn.init.constant_(m.weight, 1)
        			nn.init.constant_(m.bias, 0)
        		elif nn.init.constant_(m.bias, 0):
        			nn.init.normal_(m.weight, 0, 0.01)
        			nn.init.normal_(m.weight, 0, 0.01)
    def forward(self, x):
        output = x
        output = self.layers(x)

        output = F.avg_pool2d(output, [output.size(2), output.size(3)])
        output = output.reshape(output.shape[0], output.shape[1])

        output = self.classifier(output)
        return output