In [1]:
import torch
import os
import torch.nn as nn
import torchvision
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [2]:
def D_loss(G, D, X, Y):
    return torch.sum(torch.pow((D(Y) - 1), 2) + torch.pow(D(G(X)), 2)).item()


def G_loss(G, D, X, Y):
    return torch.pow((D(G(X)) - 1), 2).item()


def CC_loss(F, G, X, Y):
    return torch.sum(torch.abs(F(G(X)) - X)).item() + torch.sum(torch.abs(G(F(Y)) - Y)).item()

In [3]:
def get_norm_module(name):
    if name == "batch":
        return nn.BatchNorm2d
    elif name == "instance":
        return nn.InstanceNorm2d
    else:
        return None

In [4]:
class ConvNormRelu(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, padding, stride, norm="batch"):
        super(ConvNormRelu, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                              kernel_size=kernel_size, stride=stride, padding=padding)
        if norm:
            self.norm = get_norm_module(norm)(out_channels)
        else:
            self.norm = None
        
    def forward(self, inputs):
        out = self.conv(inputs)
        if self.norm is not None:
            out = self.norm(out)
        return F.leaky_relu(out, negative_slope=0.2)

In [5]:
class PatchGan(nn.Module):
    
    def __init__(self, input_channels):
        super(PatchGan, self).__init__()
        
        self.layer1 = ConvNormRelu(in_channels=input_channels, out_channels=64, kernel_size=4,
                                        padding=1, stride=2, norm=None)
        self.layer2 = ConvNormRelu(in_channels=64, out_channels=128, kernel_size=4,
                                        padding=1, stride=2, norm="instance")
        self.layer3 = ConvNormRelu(in_channels=128, out_channels=256, kernel_size=4,
                                        padding=1, stride=2, norm="instance")
        #self.layer4 = ConvBatchNormRelu(in_channels=256, out_channels=512, kernel_size=4,
         #                               padding=1, stride=2, batch_norm=True)
        self.layer4 = ConvNormRelu(in_channels=256, out_channels=512, kernel_size=4,
                                        padding=1, stride=1, norm="instance")
        
        self.conv_fc = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4,
                                 padding=1, stride=1)
    
    def forward(self, inputs):
        out = self.layer1(inputs)
        print(out.shape)
        out = self.layer2(out)
        print(out.shape)
        out = self.layer3(out)
        print(out.shape)
        out = self.layer4(out)
        print(out.shape)
        out = self.conv_fc(out)
        print(out.shape)
        return F.sigmoid(out)

In [6]:
data = torch.rand((1, 64, 70, 70))
model = PatchGan(64)
model(data)

torch.Size([1, 64, 35, 35])
torch.Size([1, 128, 17, 17])
torch.Size([1, 256, 8, 8])
torch.Size([1, 512, 7, 7])
torch.Size([1, 1, 6, 6])




tensor([[[[0.5182, 0.4908, 0.4183, 0.5620, 0.4957, 0.4948],
          [0.4155, 0.5736, 0.5308, 0.4960, 0.3603, 0.4201],
          [0.5741, 0.3917, 0.6676, 0.4266, 0.4624, 0.4744],
          [0.6440, 0.5472, 0.4961, 0.4115, 0.4901, 0.5817],
          [0.6026, 0.5980, 0.3749, 0.3671, 0.4852, 0.3585],
          [0.4968, 0.4225, 0.4962, 0.4581, 0.4968, 0.5066]]]],
       grad_fn=<SigmoidBackward>)

In [10]:
class ResBlock(nn.Module):
    
    def __init__(self, in_planes, norm="batch"):
        super(ResBlock, self).__init__()
        self.pad1 = nn.ReflectionPad2d(1)
        self.pad2 = nn.ReflectionPad2d(1)
        self.norm1 = get_norm_module(norm)(in_planes)
        self.norm2 = get_norm_module(norm)(in_planes)
        self.conv1 = nn.Conv2d(in_channels=in_planes, out_channels=in_planes, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=in_planes, out_channels=in_planes, kernel_size=3)
        
    def forward(self, inputs):
        out = self.conv1(self.pad1(inputs))
        out = F.relu(self.norm1(out))
        out = self.conv2(self.pad2(out))
        out = self.norm2(out)
        return out + inputs

In [11]:
model = ResBlock(10)
data = torch.rand(1, 10, 5, 5)
print(model(data).shape)

torch.Size([1, 10, 5, 5])
