In [1]:
import torch
from torch import nn

In [18]:
class ConvLReLU(nn.Module):
    def __init__(self,
                 in_channels, out_channels, kernel_size=3,
                 stride=2, padding=1, padding_mode='reflect'):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                          stride=stride, padding=padding, padding_mode=padding_mode)
        self.lrelu = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.lrelu(x)
        return x

In [2]:
class UpConvInReLU(nn.Module):
    def __init__(self, in_channels, out_channels,
                 kernel_size=3, stride=2,
                 padding=1, output_padding=1):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels,
                                   kernel_size=kernel_size, stride=stride,
                                   padding=padding, output_padding=output_padding)
        self.insnorm = nn.InstanceNorm2d(out_channels)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.upconv(x)
        x = self.insnorm(x)
        x = self.relu(x)
        return x

In [3]:
class ConvTanh(nn.Module):
    def __init__(self, in_channels, out_channels,
                       kernel_size=7, stride=1,
                       padding=3, padding_mode='reflect'):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                             stride=stride, padding=padding, padding_mode=padding_mode)
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.tanh(x)
        return x

In [4]:
class ConvIn(nn.Module):
    def __init__(self,
                 in_channels, out_channels, kernel_size=3,
                 stride=2, padding=1, padding_mode='reflect'):
        super().__init__()
        self.conv=nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                          stride=stride, padding=padding, padding_mode=padding_mode)
        self.insnorm=nn.InstanceNorm2d(out_channels)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.insnorm(x)
        return x

In [5]:
class ConvInReLU(nn.Module):
    def __init__(self,
                 in_channels, out_channels, kernel_size=3,
                 stride=2, padding=1, padding_mode='reflect'):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                          stride=stride, padding=padding, padding_mode=padding_mode)
        self.insnorm = nn.InstanceNorm2d(out_channels, affine=False)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.insnorm(x)
        x = self.relu(x)
        return x

In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv_in_relu = ConvInReLU(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.conv_in = ConvIn(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        fx = self.conv_in_relu(x)
        fx = self.conv_in(fx)
        return fx + x

In [7]:
class GlobalGenerator(nn.Module):
    def __init__(self, in_channels=3, hid_channels=64, out_channels=3):
        super().__init__()
        
        self.Input= ConvInReLU(in_channels, hid_channels, kernel_size=7,
                               stride=1, padding=3, padding_mode='reflect')
        
        self.DownStage=nn.Sequential()
        self.DownStage.add_module('block0', ConvInReLU(1*hid_channels, 2*hid_channels))
        self.DownStage.add_module('block1', ConvInReLU(2*hid_channels, 4*hid_channels))
        self.DownStage.add_module('block2', ConvInReLU(4*hid_channels, 8*hid_channels))
        
        self.ResStage=nn.Sequential()
        self.ResStage.add_module('block0', ResidualBlock(8*hid_channels))
        self.ResStage.add_module('block1', ResidualBlock(8*hid_channels))
        self.ResStage.add_module('block2', ResidualBlock(8*hid_channels))
        self.ResStage.add_module('block3', ResidualBlock(8*hid_channels))
        self.ResStage.add_module('block4', ResidualBlock(8*hid_channels))
        self.ResStage.add_module('block5', ResidualBlock(8*hid_channels))
        self.ResStage.add_module('block6', ResidualBlock(8*hid_channels))
        self.ResStage.add_module('block7', ResidualBlock(8*hid_channels))
        self.ResStage.add_module('block8', ResidualBlock(8*hid_channels))
        
        self.UpStage = nn.Sequential()
        self.UpStage.add_module('block0', UpConvInReLU(8*hid_channels, 4*hid_channels))
        self.UpStage.add_module('block1', UpConvInReLU(4*hid_channels, 2*hid_channels))
        self.UpStage.add_module('block2', UpConvInReLU(2*hid_channels, 1*hid_channels))
        
        self.Output=ConvTanh(hid_channels, out_channels)
        
    def forward(self, x):
        x = self.Input(x)
        x = self.DownStage(x)
        x = self.ResStage(x)
        x = self.UpStage(x)
        x = self.Output(x)
        return x

In [13]:
gen = GlobalGenerator()
list(gen.children())[:-1]

[ConvInReLU(
   (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), padding_mode=reflect)
   (insnorm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
   (relu): ReLU(inplace=True)
 ),
 Sequential(
   (block0): ConvInReLU(
     (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
     (insnorm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
     (relu): ReLU(inplace=True)
   )
   (block1): ConvInReLU(
     (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
     (insnorm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
     (relu): ReLU(inplace=True)
   )
   (block2): ConvInReLU(
     (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
     (insnorm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_r

In [14]:
class LocalEnhancer(nn.Module):
    def __init__(self, in_channels=3, hid_channels=32, out_channels=3):
        super().__init__()
        
        self.Input=ConvInReLU(in_channels, hid_channels, kernel_size=7,
                              stride=1, padding=3, padding_mode='reflect')
        
        self.DownStage=nn.Sequential()
        self.DownStage.add_module('block0', ConvInReLU(1*hid_channels, 2*hid_channels))
        
        self.AvgPool=nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
        generator=GlobalGenerator(in_channels, 2*hid_channels, out_channels)
        self.InsideGenerator=nn.Sequential(*list(generator.children())[:-1])
        
        self.ResStage=nn.Sequential()
        self.ResStage.add_module('block0', ResidualBlock(2*hid_channels))
        self.ResStage.add_module('block1', ResidualBlock(2*hid_channels))
        self.ResStage.add_module('block2', ResidualBlock(2*hid_channels))
        
        self.UpStage = nn.Sequential()
        self.UpStage.add_module('block0', UpConvInReLU(2*hid_channels, 1*hid_channels))
        
        self.Output=ConvTanh(hid_channels, out_channels)
        
    def forward(self, x):
        y1 = self.Input(x)
        y1 = self.DownStage(y1)
        
        y0 = self.AvgPool(x)
        y0 = self.InsideGenerator(y0)
        
        y = self.ResStage(y0+y1)
        y = self.UpStage(y)
        y = self.Output(y)
        return y

In [15]:
model = LocalEnhancer()
p = model(torch.rand(1, 3, 1024, 2048))

In [51]:
p.shape

torch.Size([1, 3, 1024, 2048])

In [36]:
class Discriminator(nn.Module):

    def __init__(self, in_channels=3, hid_channels=64, out_channels=1):
        super().__init__()
        self.Input = ConvLReLU(in_channels, hid_channels,
                              kernel_size=4, stride=2, padding=2)
        
        self.Base = nn.Sequential()
        self.Base.add_module('block0', ConvInReLU(hid_channels, 2*hid_channels,
                                                 kernel_size=4, stride=2, padding=2))
        self.Base.add_module('block1', ConvInReLU(2*hid_channels, 4*hid_channels,
                                                 kernel_size=4, stride=2, padding=2))
        self.Base.add_module('block2', ConvInReLU(4*hid_channels, 8*hid_channels,
                                                 kernel_size=4, stride=1, padding=2))
        
        self.Output = nn.Conv2d(8*hid_channels, out_channels,
                                kernel_size=4, stride=1, padding=2)
        
    def forward(self, x):
        x0 = self.Input(x)
        x1 = self.Base.block0(x0)
        x2 = self.Base.block1(x1)
        x_ = self.Base.block2(x2)
        x3 = self.Output(x_)
        return [x0, x1, x2, x3]

In [37]:
model = Discriminator()

In [38]:
p = model(torch.rand(1, 3, 256, 256))

In [39]:
len(p)

4

In [40]:
for pi in p:
    print(pi.shape)

torch.Size([1, 64, 129, 129])
torch.Size([1, 128, 65, 65])
torch.Size([1, 256, 33, 33])
torch.Size([1, 1, 35, 35])


In [43]:
class MultiScaleDiscriminator(nn.Module):

    def __init__(self, in_channels=3, hid_channels=64, out_channels=1):
        super().__init__()
        
        self.discriminator0=Discriminator(in_channels, hid_channels,
                                          out_channels)
        self.discriminator1=Discriminator(in_channels, hid_channels,
                                          out_channels)
        self.discriminator2=Discriminator(in_channels, hid_channels,
                                          out_channels)
        
        self.AvgPool=nn.AvgPool2d(3, stride=2, padding=1,
                                  count_include_pad=False)
        
    def forward(self, x):
        x0 = self.discriminator0(x)
        x1_ = self.AvgPool(x)
        x1 = self.discriminator1(x1_)
        x2_ = self.AvgPool(x1_)
        x2 = self.discriminator2(x2_)
        return [x0, x1, x2]

In [44]:
d = MultiScaleDiscriminator()

In [45]:
p = d(torch.rand(1, 3, 256,256))

In [46]:
len(p)

3

In [49]:
for pi in p:
    for pj in pi:
        
        print(pj.shape)
    print()

torch.Size([1, 64, 129, 129])
torch.Size([1, 128, 65, 65])
torch.Size([1, 256, 33, 33])
torch.Size([1, 1, 35, 35])

torch.Size([1, 64, 65, 65])
torch.Size([1, 128, 33, 33])
torch.Size([1, 256, 17, 17])
torch.Size([1, 1, 19, 19])

torch.Size([1, 64, 33, 33])
torch.Size([1, 128, 17, 17])
torch.Size([1, 256, 9, 9])
torch.Size([1, 1, 11, 11])



In [51]:
class InstanceWiseAvgPool(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.n_channels=n_channels
    
    def forward(self, x, instance):
        x_mean = torch.zeros_like(x)
        classes = torch.unique(instance, return_inverse=False,
                               return_counts=False)
        nB = x.size(0)
        for c in classes:
            for b in range(nB):
                idx = torch.nonzero(instance[b:b+1]==c, as_tuple=False)
                for k in range(self.n_channels):
                    x_ins = x[idx[:, 0] + b, idx[:, 1] + k, idx[:, 2], idx[:, 3]]
                    mean_feat = torch.mean(x_ins).expand_as(x_ins)
                    x_mean[idx[:, 0] + b, idx[:, 1] + k, idx[:, 2], idx[:, 3]]=mean_feat
        return x_mean

In [65]:
class Encoder(nn.Module):

    def __init__(self, in_channels=3, hid_channels=16, out_channels=1):
        super().__init__()
        self.Input = ConvInReLU(in_channels, hid_channels,
                               kernel_size=7, stride=1, padding=3)
        
        self.DownStage = nn.Sequential()
        self.DownStage.add_module('block0', ConvInReLU(1*hid_channels, 2*hid_channels,
                                               kernel_size=3, stride=2, padding=1))
        self.DownStage.add_module('block1', ConvInReLU(2*hid_channels, 4*hid_channels,
                                               kernel_size=3, stride=2, padding=1))
        self.DownStage.add_module('block2', ConvInReLU(4*hid_channels, 8*hid_channels,
                                               kernel_size=3, stride=2, padding=1))
        self.DownStage.add_module('block3', ConvInReLU(8*hid_channels, 16*hid_channels,
                                               kernel_size=3, stride=2, padding=1))
        
        self.UpStage = nn.Sequential()
        self.UpStage.add_module('block0', UpConvInReLU(16*hid_channels, 8*hid_channels,
                                                      kernel_size=3, stride=2, padding=1))
        self.UpStage.add_module('block1', UpConvInReLU(8*hid_channels, 4*hid_channels,
                                                      kernel_size=3, stride=2, padding=1))
        self.UpStage.add_module('block2', UpConvInReLU(4*hid_channels, 2*hid_channels,
                                                      kernel_size=3, stride=2, padding=1))
        self.UpStage.add_module('block3', UpConvInReLU(2*hid_channels, 1*hid_channels,
                                                      kernel_size=3, stride=2, padding=1))
        
        self.Output=ConvTanh(hid_channels, out_channels,
                            kernel_size=7, stride=1, padding=3)
        
        #self.InstAvgPool=InstanceWiseAvgPool(out_channels)
    def instancewise_average_pooling(self, x, inst):
        '''
        Applies instance-wise average pooling.

        Given a feature map of size (b, c, h, w), the mean is computed for each b, c
        across all h, w of the same instance
        '''
        x_mean = torch.zeros_like(x)
        classes = torch.unique(inst, return_inverse=False, return_counts=False) # gather all unique classes present

        for i in classes:
            for b in range(x.size(0)):
                indices = torch.nonzero(inst[b:b+1] == i, as_tuple=False) # get indices of all positions equal to class i
                for j in range(1):
                    x_ins = x[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]]
                    mean_feat = torch.mean(x_ins).expand_as(x_ins)
                    x_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]] = mean_feat

        return x_mean
        
    def forward(self, x, instance):
        x0 = self.Input(x)
        x1 = self.DownStage(x0)
        x2 = self.UpStage(x1)
        x3 = self.Output(x2)
        #y = self.InstAvgPool(x3, instance)
        y = self.instancewise_average_pooling(x3, instance)
        return y

In [66]:
m = Encoder()

In [67]:
p = m(torch.rand(10, 3, 256, 256), torch.rand(10, 3, 256, 256))

IndexError: index 2 is out of bounds for dimension 1 with size 1