In [33]:
import torch
import os
import torch.nn as nn
import torchvision
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [2]:
def D_loss(G, D, X, Y):
    return torch.sum(torch.pow((D(Y) - 1), 2) + torch.pow(D(G(X)), 2)).item()


def G_loss(G, D, X, Y):
    return torch.pow((D(G(X)) - 1), 2).item()


def CC_loss(F, G, X, Y):
    return torch.sum(torch.abs(F(G(X)) - X)).item() + torch.sum(torch.abs(G(F(Y)) - Y)).item()

In [3]:
def get_norm_module(name):
    if name == "batch":
        return nn.BatchNorm2d
    elif name == "instance":
        return nn.InstanceNorm2d
    else:
        return None

In [4]:
class ConvNormRelu(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, padding=(1, "zeros"),
                 stride=1, norm="batch", leaky=True, conv_type="forward"):
        super(ConvNormRelu, self).__init__()
        if padding[1] == "zeros":
            self.pad = None
            if conv_type == "forward":
                self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                      kernel_size=kernel_size, stride=stride, padding=padding[0])
            elif conv_type == "transpose":
                self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                      kernel_size=kernel_size, stride=stride, padding=padding[0], output_padding=padding[0])
        elif padding[1] == "reflection":
            if conv_type == "forward":
                self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                      kernel_size=kernel_size, stride=stride)
                self.pad = nn.ReflectionPad2d(padding[0])
            elif conv_type == "transpose":
                self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                      kernel_size=kernel_size, stride=stride, padding=padding[0], output_padding=padding[0])
                self.pad = None
            
            
        self.leaky = leaky
        if norm:
            self.norm = get_norm_module(norm)(out_channels)
        else:
            self.norm = None
        
    def forward(self, inputs):
        out = inputs
        if self.pad is not None:
            out = self.pad(out)
        print(out.shape)
        out = self.conv(out)
        print(out.shape)
        if self.norm is not None:
            out = self.norm(out)
        if self.leaky:
            return F.leaky_relu(out, negative_slope=0.2)
        else:
            return F.relu(out)

In [5]:
def init_weights(m):
    if type(m) == nn.Conv2d:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif type(m) == nn.InstanceNorm2d:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.normal_(m.bias.data, 0)

In [6]:
#Figure out real PatchGan
class PatchGan(nn.Module):
    
    def __init__(self, input_channels):
        super(PatchGan, self).__init__()
        
        self.layer1 = ConvNormRelu(in_channels=input_channels, out_channels=64, kernel_size=4,
                                        padding=(1, "zeros"), stride=2, norm=None)
        self.layer2 = ConvNormRelu(in_channels=64, out_channels=128, kernel_size=4,
                                        padding=(1, "zeros"), stride=2, norm="instance")
        self.layer3 = ConvNormRelu(in_channels=128, out_channels=256, kernel_size=4,
                                        padding=(1, "zeros"), stride=2, norm="instance")
        #self.layer4 = ConvBatchNormRelu(in_channels=256, out_channels=512, kernel_size=4,
         #                               padding=1, stride=2, batch_norm=True)
        self.layer4 = ConvNormRelu(in_channels=256, out_channels=512, kernel_size=4,
                                        padding=(1, "zeros"), stride=1, norm="instance")
        
        self.conv_fc = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4,
                                 padding=1, stride=1)
    
    def forward(self, inputs):
        out = self.layer1(inputs)
        print(out.shape)
        out = self.layer2(out)
        print(out.shape)
        out = self.layer3(out)
        print(out.shape)
        out = self.layer4(out)
        print(out.shape)
        out = self.conv_fc(out)
        print(out.shape)
        return F.sigmoid(out)

In [7]:
data = torch.rand((1, 64, 70, 70))
#model = PatchGan(64)
layer = ConvNormRelu(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=(1, "reflection"),
                     conv_type="transpose")
layer(data).shape

torch.Size([1, 64, 70, 70])
torch.Size([1, 128, 140, 140])


torch.Size([1, 128, 140, 140])

In [8]:
class ResBlock(nn.Module):
    
    def __init__(self, in_planes, norm="batch"):
        super(ResBlock, self).__init__()
        self.pad1 = nn.ReflectionPad2d(1)
        self.pad2 = nn.ReflectionPad2d(1)
        self.norm1 = get_norm_module(norm)(in_planes)
        self.norm2 = get_norm_module(norm)(in_planes)
        self.conv1 = nn.Conv2d(in_channels=in_planes, out_channels=in_planes, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=in_planes, out_channels=in_planes, kernel_size=3)
        
    def forward(self, inputs):
        out = self.conv1(self.pad1(inputs))
        out = F.relu(self.norm1(out))
        out = self.conv2(self.pad2(out))
        out = self.norm2(out)
        return out + inputs

In [9]:
class ResnetGenerator(nn.Module):
    
    def __init__(self, in_channels, n_blocks):
        super(ResnetGenerator, self).__init__()
        
        self.conv1 = ConvNormRelu(in_channels=in_channels, out_channels=64, kernel_size=7,
                                       padding=(3, "reflection"), stride=1, norm="instance", leaky=False)
        self.conv2 = ConvNormRelu(in_channels=64, out_channels=128, kernel_size=3,
                                  padding=(1, "reflection"), stride=2, norm="instance", leaky=False)
        self.conv3 = ConvNormRelu(in_channels=128, out_channels=256, kernel_size=3,
                                  padding=(1, "reflection"), stride=2, norm="instance", leaky=False)
        self.blocks = []
        for i in range(n_blocks):
            self.blocks.append(ConvNormRelu(in_channels=256, out_channels=256, kernel_size=3,
                                            padding=(1, "reflection"), stride=1, norm="instance", leaky=False))
        self.conv4 = ConvNormRelu(in_channels=256, out_channels=128, kernel_size=3, 
                                  padding=(1, "reflection"), stride=2, norm="instance", leaky=False, conv_type="transpose")
        self.conv5 = ConvNormRelu(in_channels=128, out_channels=64, kernel_size=3, 
                                  padding=(1, "reflection"), stride=2, norm="instance", leaky=False, conv_type="transpose")
        self.conv6 = ConvNormRelu(in_channels=64, out_channels=3, kernel_size=7, 
                                  padding=(3, "reflection"), stride=1, norm="instance", leaky=False)
        
    def forward(self, inputs):
        out = self.conv1(inputs)
        print(out.shape)
        out = self.conv2(out)
        print(out.shape)
        out = self.conv3(out)
        print(out.shape)
        for block in self.blocks:
            out = block(out)
            print(out.shape)
        out = self.conv4(out)
        print(out.shape)
        out = self.conv5(out)
        print(out.shape)
        out = self.conv6(out)
        print(out.shape)
        return F.tanh(out)

In [10]:
layer = ConvNormRelu(in_channels=48, out_channels=128, kernel_size=3, 
                                  padding=(1, "reflection"), stride=2, norm="instance", leaky=False, conv_type="transpose")

In [11]:
data = torch.rand((1, 48, 256, 256))
print(layer(data).shape)

torch.Size([1, 48, 256, 256])
torch.Size([1, 128, 512, 512])
torch.Size([1, 128, 512, 512])


In [12]:
def generator_optimize_step(G1, G2, D2, inputs):
    optimizer = optim.Adam(G1.parameters(), lr=0.0002, betas=(0.5, 0.999))
    discr_output = D2(G1(inputs))
    adv_loss = F.mse_loss(discr_output, torch.ones(discr_output.shape))
    identity_loss = F.l1_loss(G2(inputs), inputs)
    fwd_cycle_loss = F.l1_loss(G2(G1(inputs)), inputs)
    bwd_cycle_loss = F.l1_loss(G1(G2(inputs)), inputs)
    loss = adv_loss + 5 * identity_loss + 10 * (fwd_cycle_loss + bwd_cycle_loss)
    loss.backward()
    optimizer.step()
    
    optimizer.zero_grad()

In [13]:
def discriminator_optimize_step(G, D, inputs_domain, other_domain):
    optimizer = optim.Adam(D.parameters, lr=0.0002, betas=(0.5, 0.999))
    discr_output = D(inputs_domain)
    d_loss = F.mse_loss(discr_output, torch.ones(discr_output.shape))
    gen_output = G(other_domain)
    g_loss = F.mse_loss(gen_output, torch.zeros(gen_output.shape))
    loss = d_loss + g_loss
    loss.backward()
    optimizer.step()
    
    optimizer.zero_grad()

In [46]:
class Apple2OrangeDataset(Dataset):
    
    def __init__(self, root, folder_names, transform=None):
        super(Apple2OrangeDataset, self).__init__()
        
        self.root = root
        self.transform = transform
        self.folder_names = folder_names
        self.A_size = len(os.listdir(os.path.join(root, folder_names[0])))
        self.B_size = len(os.listdir(os.path.join(root, folder_names[1])))
        self.A_paths = sorted(os.listdir(os.path.join(root, folder_names[0])))
        self.B_paths = sorted(os.listdir(os.path.join(root, folder_names[1])))
        #print(self.A_paths)
        
    def __len__(self):
        return max(self.A_size, self.B_size)
    
    def __getitem__(self, idx):
        idx_A = idx % self.A_size
        idx_B = idx % self.B_size
        image_A = Image.open(os.path.join(self.root, self.folder_names[0], self.A_paths[idx_A]))
        image_B = Image.open(os.path.join(self.root, self.folder_names[1], self.B_paths[idx_B]))
        if self.transform is not None:
            image_A = self.transform(image_A)
            image_B = self.transform(image_B)
            
        return {"A": image_A, "B": image_B}

In [52]:
transform = transforms.Compose([transforms.Resize((128, 128)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

In [53]:
dataset = Apple2OrangeDataset('/home/dpakhom1/Cycle_gan_pytorch/datasets/apple2orange/',
                              ["trainA", "trainB"], transform=transform)

In [55]:
dataloader = DataLoader(dataset, shuffle=True, batch_size=1, num_workers=2)

In [56]:
for idx, data in enumerate(dataloader):
    print(data)
    if idx == 5:
        break

{'A': tensor([[[[ 2.1633,  2.0777,  2.0263,  ..., -2.0323, -2.0665, -2.1179],
          [ 2.1290,  2.0777,  2.0092,  ..., -1.8439, -2.0323, -2.1008],
          [ 2.0948,  1.9407,  1.9235,  ..., -1.6898, -1.9809, -2.0665],
          ...,
          [-2.1179, -2.1179, -2.1179,  ..., -1.6555, -1.7925, -1.8268],
          [-2.1179, -2.1179, -2.1179,  ..., -1.7925, -1.8953, -1.9295],
          [-2.1179, -2.1179, -2.1008,  ..., -1.9124, -1.8268, -1.8268]],

         [[ 2.1485,  1.6408,  1.4482,  ..., -2.0007, -2.0182, -2.0182],
          [ 1.9559,  1.6408,  1.3431,  ..., -2.0007, -2.0182, -2.0182],
          [ 1.7458,  1.0455,  1.0630,  ..., -1.9832, -2.0007, -2.0182],
          ...,
          [-2.0357, -2.0357, -2.0357,  ..., -1.8431, -1.9132, -1.8606],
          [-2.0357, -2.0357, -2.0357,  ..., -1.8782, -1.9132, -1.9132],
          [-2.0357, -2.0357, -2.0357,  ..., -1.9132, -1.8606, -1.8606]],

         [[ 2.5180,  2.3611,  2.2043,  ..., -1.7870, -1.8044, -1.8044],
          [ 2.4657,  2.3

{'A': tensor([[[[ 1.2043,  1.0159,  0.9988,  ...,  1.8037,  1.8379,  1.7865],
          [ 1.1700,  1.3927,  1.3927,  ...,  1.9920,  1.8550,  1.8037],
          [ 1.4440,  1.4098,  1.4269,  ...,  2.1633,  1.8550,  1.6153],
          ...,
          [-1.1075,  0.0227,  1.0331,  ...,  1.9407,  2.1462,  2.1119],
          [-0.6281, -0.7822,  0.5364,  ...,  2.1633,  1.9920,  1.7009],
          [-0.2856, -1.3473, -0.4568,  ...,  2.0434,  1.6667,  1.5468]],

         [[ 0.9405,  0.7654,  0.7304,  ...,  1.7283,  1.6057,  1.6758],
          [ 0.9230,  1.2031,  1.2206,  ...,  1.9909,  1.6758,  1.6758],
          [ 1.1856,  1.1856,  1.2381,  ...,  2.1835,  1.8508,  1.2381],
          ...,
          [-1.4405, -0.3550,  0.3452,  ...,  2.0609,  2.3060,  2.1134],
          [-1.1078, -1.0378,  0.0301,  ...,  2.2360,  2.0259,  1.3957],
          [-0.7927, -1.5805, -0.7227,  ...,  2.1134,  1.2556,  1.1506]],

         [[ 1.0191,  0.8797,  0.8274,  ...,  1.6814,  1.7511,  1.7511],
          [ 0.9842,  1.3

Exception ignored in: <function _DataLoaderIter.__del__ at 0x7f3082b0c268>
Traceback (most recent call last):
  File "/home/dpakhom1/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 399, in __del__
    self._shutdown_workers()
  File "/home/dpakhom1/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 378, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/dpakhom1/anaconda3/lib/python3.7/multiprocessing/queues.py", line 354, in get
    return _ForkingPickler.loads(res)
  File "/home/dpakhom1/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/reductions.py", line 151, in rebuild_storage_fd
    fd = df.detach()
  File "/home/dpakhom1/anaconda3/lib/python3.7/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/dpakhom1/anaconda3/lib/python3.7/multiprocessing/resource_sharer.py", line 87, in get_connection
    c = Client(address, authke