In [1]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
import numpy as np
import matplotlib.pyplot as plt
import cv2
import pandas as pd
from torch_snippets import *
from torchsummary import summary
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# models

In [2]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [157]:
import math
class resDiscriminator(nn.Module):
    def __init__(self,reduction=5,repeat=0,reduction_p2=1,divisor=4,n_map=64,img_size=128,
                 relu=0.2,n_features=32,feature_dim=8,minibatch_input_dim=512,dropout=0.1):
        super().__init__()
        max_reduction = int(math.log2(img_size))
        #print(f"max reduction {max_reduction}")
        self.reduction = min(reduction,max_reduction)
        self.reduction_p2 = max(min(reduction_p2,max_reduction-reduction),0)
        #print(f"reduction 1: {self.reduction}")
        #print(f"reduction 2: {self.reduction_p2}")
        self.final_map_size = img_size // 2**(self.reduction+self.reduction_p2)
        #print(f"final map size: {self.final_map_size}")
        self.final_n_map = n_map * 2**(self.reduction+1)
        #print(f"final n map is: {self.final_n_map}")
        self.minibatch_input_dim = minibatch_input_dim
        self.dropout = nn.Dropout2d(dropout)
        self.n_features = n_features
        self.feature_dim = feature_dim
        self.relu = relu
        self.repeat = repeat
        self.divisor = divisor
        self.activation = nn.LeakyReLU(self.relu)
        self.minibatch_linear = nn.Linear(minibatch_input_dim,n_features*feature_dim)
        self.final_dense = nn.Linear(self.minibatch_input_dim+n_features*2,1)
        self.initial_conv = nn.Sequential(
            nn.Conv2d(3,n_map,7,1,3,bias=False), # =
            nn.BatchNorm2d(n_map),
            nn.LeakyReLU(self.relu),
            nn.Conv2d(n_map,n_map*2,8,2,3,bias=False), # /2
            nn.BatchNorm2d(n_map*2),
            nn.LeakyReLU(self.relu)
        )
        self.initial_block_conv = nn.Sequential(
            nn.Conv2d(n_map*2,n_map*4,1,1,0,bias=False), # =
            nn.BatchNorm2d(n_map*4)
        )
        self.initial_block_bn = nn.Sequential(
            nn.Conv2d(n_map*2,n_map,1,1,0,bias=False), # =
            nn.BatchNorm2d(n_map),
            nn.LeakyReLU(self.relu),
            nn.Conv2d(n_map,n_map,3,1,1), # =
            nn.BatchNorm2d(n_map),
            nn.LeakyReLU(self.relu),
            nn.Conv2d(n_map,n_map*4,1,1,0,bias=False), # =
            nn.BatchNorm2d(n_map*4)
        )
        n = n_map * 4
        self.conv_blocks = nn.ModuleList(
            [self._normal_block(n*2**i,n*2**(i+1)) for i in range(self.reduction-1)]
        )
        self.bn_shrink_blocks = nn.ModuleList(
            [self._bottleneck_block(n*2**i,n*2**(i+1),True) for i in range(self.reduction-1)]
        )
        self.bn_repeat_blocks = nn.ModuleList(
            [self._bottleneck_block(n*2**(i+1),n*2**(i+1),False) for i in range(self.reduction-1) for _ in range(repeat)]
        )
        #print(self.final_n_map)
        self.final_conv = nn.ModuleList(
            [self._normal_block(self.final_n_map,self.final_n_map) for _ in range(self.reduction_p2)] + \
                [nn.Conv2d(self.final_n_map,minibatch_input_dim,self.final_map_size,1,0)]
        )

    
    def _normal_block(self,in_c,out_c):
        return nn.Sequential(nn.Conv2d(in_c,out_c,1,2,0,bias=False),nn.BatchNorm2d(out_c)) # /2
    
    def _bottleneck_block(self,in_c,out_c,shrink=True):
        # if shrink, output = (64 - 2)/2 + 1 = 32
        # if not shrink, output = (64 - 1) / 1 + 1 = 64
        final_config = 2 if shrink else 1
        middle_c = out_c // self.divisor
        block = nn.Sequential(
            nn.Conv2d(in_c,middle_c,1,1,0,bias=False), # =
            nn.BatchNorm2d(middle_c),
            nn.LeakyReLU(self.relu),
            nn.Conv2d(middle_c,middle_c,3,1,1), # = 
            nn.BatchNorm2d(middle_c),
            nn.LeakyReLU(self.relu),
            nn.Conv2d(middle_c,out_c,final_config,final_config,0,bias=False), # /2 or =
            nn.BatchNorm2d(out_c)    
        )
        return block
    
    def _minibatch_discrimination(self,x):
        n_feature,feature_dim = self.n_features, self.feature_dim
        batch_size = x.shape[0]
        x_copy = x.clone()
        x = self.minibatch_linear(x) # e.g., (4, 6)
        #print(f"after linear layer x shape: {x.shape}")
        x = x.view(-1,n_feature,feature_dim) # e.g., (4,3,2)
        #print(f"after reshape x shape: {x.shape}")
        # create mask
        mask = torch.eye(batch_size) # (4,4)
        mask = mask.unsqueeze(1) # (4, 1, 4)
        mask = (1 - mask).to('cuda')
        # calculate diff between features: goal (4, 3, 4)
        m1 = x.unsqueeze(3) # (4,3 2, 1)
        m2 = x.transpose(0,2).transpose(0,1).unsqueeze(0) # (1, 3, 2, 4)
        diff = torch.abs(m1 - m2) # (4, 3, 2, 4)
        diff = torch.sum(diff, dim=2) # (4, 3, 4)
        diff = torch.exp(-diff)
        diff_masked = diff * mask
        #print(f"diff_masked shape {diff_masked.shape}")
        # split sum up the differences goal (4,3*2)
        def half(tensor,second):
            return tensor[:,:,second*batch_size//2:(second+1)*batch_size//2]
        first_half = half(diff_masked, 0) # (4, 3, 2)
        first_half = torch.sum(first_half, dim=2) / torch.sum(first_half) # (4, 3)
        second_half = half(diff_masked, 1) 
        second_half = torch.sum(second_half, dim=2) / torch.sum(second_half)
        features = torch.cat([first_half,second_half], dim=1) # (4, 3*2)
        #print(f"features shape {features.shape}")
        # merge back to the input, goal (4,3*2*2)
        output = torch.cat([x_copy,features], dim=1)
        #print(output.shape)
        return output
    
    def forward(self,x):
        x = self.dropout(self.initial_conv(x))
        x1 = self.initial_block_conv(x)
        x2 = self.initial_block_bn(x)
        x = self.dropout(self.activation(x1 + x2))
        for i in range(self.reduction-1):
            x1 = self.conv_blocks[i](x)
            #print(x1.shape)
            x2 = self.bn_shrink_blocks[i](x)
            #print(f"shrinking block: x1 shape {x1.shape}, x2 shape {x2.shape}")
            for j in range(self.repeat):
                idx = j + self.repeat * i
                x2 = self.bn_repeat_blocks[idx](x2)
                #print(f"repeating block: x2 shape {x2.shape}")
            x = self.dropout(self.activation(x1+x2))
        #print(f"begin {x.shape}")
        for i in range(self.reduction_p2+1):
            if i < self.reduction_p2:
                x = self.dropout(self.activation(self.final_conv[i](x)))
                #print(x.shape)
            else:
                x = self.dropout(self.final_conv[i](x))
                #print(f"last {x.shape}")
        x = nn.Flatten(start_dim=1)(x)
        #print(f"flatten size: {x.shape}")
        x = self._minibatch_discrimination(x)
        #print(f"after minibatch discrimination: {x.shape}")
        x = self.final_dense(x)
        return x
        


In [158]:
discriminator = resDiscriminator(reduction=4,repeat=0,n_map=32,reduction_p2=1)
summary(discriminator,torch.zeros(1,3,128,128))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 64, 64]          --
|    └─Conv2d: 2-1                       [-1, 32, 128, 128]        4,704
|    └─BatchNorm2d: 2-2                  [-1, 32, 128, 128]        64
|    └─LeakyReLU: 2-3                    [-1, 32, 128, 128]        --
|    └─Conv2d: 2-4                       [-1, 64, 64, 64]          131,072
|    └─BatchNorm2d: 2-5                  [-1, 64, 64, 64]          128
|    └─LeakyReLU: 2-6                    [-1, 64, 64, 64]          --
├─Dropout2d: 1-2                         [-1, 64, 64, 64]          --
├─Sequential: 1-3                        [-1, 128, 64, 64]         --
|    └─Conv2d: 2-7                       [-1, 128, 64, 64]         8,192
|    └─BatchNorm2d: 2-8                  [-1, 128, 64, 64]         256
├─Sequential: 1-4                        [-1, 128, 64, 64]         --
|    └─Conv2d: 2-9                       [-1, 32, 64, 64]          2,048

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 64, 64, 64]          --
|    └─Conv2d: 2-1                       [-1, 32, 128, 128]        4,704
|    └─BatchNorm2d: 2-2                  [-1, 32, 128, 128]        64
|    └─LeakyReLU: 2-3                    [-1, 32, 128, 128]        --
|    └─Conv2d: 2-4                       [-1, 64, 64, 64]          131,072
|    └─BatchNorm2d: 2-5                  [-1, 64, 64, 64]          128
|    └─LeakyReLU: 2-6                    [-1, 64, 64, 64]          --
├─Dropout2d: 1-2                         [-1, 64, 64, 64]          --
├─Sequential: 1-3                        [-1, 128, 64, 64]         --
|    └─Conv2d: 2-7                       [-1, 128, 64, 64]         8,192
|    └─BatchNorm2d: 2-8                  [-1, 128, 64, 64]         256
├─Sequential: 1-4                        [-1, 128, 64, 64]         --
|    └─Conv2d: 2-9                       [-1, 32, 64, 64]          2,048

In [None]:
class resGenerator(nn.Module):
    