In [1]:
import torch
import os
import torch.nn as nn
import torchvision
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
def get_norm_module(name):
    if name == "batch":
        return nn.BatchNorm2d
    elif name == "instance":
        return nn.InstanceNorm2d
    else:
        return None

In [3]:
class ConvNormRelu(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, padding=(1, "zeros"),
                 stride=1, norm="batch", leaky=True, conv_type="forward"):
        super(ConvNormRelu, self).__init__()
        if padding[1] == "zeros":
            self.pad = None
            if conv_type == "forward":
                self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                      kernel_size=kernel_size, stride=stride, padding=padding[0])
            elif conv_type == "transpose":
                self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                      kernel_size=kernel_size, stride=stride, padding=padding[0], output_padding=padding[0])
        elif padding[1] == "reflection":
            if conv_type == "forward":
                self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                      kernel_size=kernel_size, stride=stride)
                self.pad = nn.ReflectionPad2d(padding[0])
            elif conv_type == "transpose":
                self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                      kernel_size=kernel_size, stride=stride, padding=padding[0], output_padding=padding[0])
                self.pad = None
            
            
        self.leaky = leaky
        if norm:
            self.norm = get_norm_module(norm)(out_channels)
        else:
            self.norm = None
        
    def forward(self, inputs):
        out = inputs
        if self.pad is not None:
            out = self.pad(out)
        out = self.conv(out)
        if self.norm is not None:
            out = self.norm(out)
        if self.leaky:
            return F.leaky_relu(out, negative_slope=0.2)
        else:
            return F.relu(out)

In [4]:
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.normal_(m.bias, 0.0)

In [5]:
#Figure out real PatchGan
class PatchGan(nn.Module):
    
    def __init__(self, input_channels, norm_type):
        super(PatchGan, self).__init__()
        
        self.layer1 = ConvNormRelu(in_channels=input_channels, out_channels=64, kernel_size=4,
                                        padding=(1, "zeros"), stride=2, norm=None)
        self.layer2 = ConvNormRelu(in_channels=64, out_channels=128, kernel_size=4,
                                        padding=(1, "zeros"), stride=2, norm=norm_type)
        self.layer3 = ConvNormRelu(in_channels=128, out_channels=256, kernel_size=4,
                                        padding=(1, "zeros"), stride=2, norm=norm_type)
        #self.layer4 = ConvBatchNormRelu(in_channels=256, out_channels=512, kernel_size=4,
         #                               padding=1, stride=2, batch_norm=True)
        self.layer4 = ConvNormRelu(in_channels=256, out_channels=512, kernel_size=4,
                                        padding=(1, "zeros"), stride=1, norm=norm_type)
        
        self.conv_fc = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4,
                                 padding=1, stride=1)
    
    def forward(self, inputs):
        out = self.layer1(inputs)
        #print(out.shape)
        out = self.layer2(out)
        #print(out.shape)
        out = self.layer3(out)
        #print(out.shape)
        out = self.layer4(out)
        #print(out.shape)
        out = self.conv_fc(out)
        #print(out.shape)
        return F.sigmoid(out)

In [6]:
data = torch.rand((1, 64, 70, 70))
#model = PatchGan(64)
layer = ConvNormRelu(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=(1, "reflection"),
                     conv_type="transpose")
layer(data).shape

torch.Size([1, 128, 140, 140])

In [7]:
class ResBlock(nn.Module):
    
    def __init__(self, in_planes, norm="batch"):
        super(ResBlock, self).__init__()
        self.pad1 = nn.ReflectionPad2d(1)
        self.pad2 = nn.ReflectionPad2d(1)
        self.norm1 = get_norm_module(norm)(in_planes)
        self.norm2 = get_norm_module(norm)(in_planes)
        self.conv1 = nn.Conv2d(in_channels=in_planes, out_channels=in_planes, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=in_planes, out_channels=in_planes, kernel_size=3)
        
    def forward(self, inputs):
        out = self.conv1(self.pad1(inputs))
        out = F.relu(self.norm1(out))
        out = self.conv2(self.pad2(out))
        out = self.norm2(out)
        return out + inputs

In [8]:
class ResnetGenerator(nn.Module):
    
    def __init__(self, in_channels, n_blocks, norm_type='batch'):
        super(ResnetGenerator, self).__init__()
        
        self.conv1 = ConvNormRelu(in_channels=in_channels, out_channels=64, kernel_size=7,
                                       padding=(3, "reflection"), stride=1, norm=norm_type, leaky=False)
        self.conv2 = ConvNormRelu(in_channels=64, out_channels=128, kernel_size=3,
                                  padding=(1, "reflection"), stride=2, norm=norm_type, leaky=False)
        self.conv3 = ConvNormRelu(in_channels=128, out_channels=256, kernel_size=3,
                                  padding=(1, "reflection"), stride=2, norm=norm_type, leaky=False)
        self.blocks = nn.ModuleList()
        for i in range(n_blocks):
            self.blocks.append(ConvNormRelu(in_channels=256, out_channels=256, kernel_size=3,
                                            padding=(1, "reflection"), stride=1, norm=norm_type, leaky=False))
        self.conv4 = ConvNormRelu(in_channels=256, out_channels=128, kernel_size=3, 
                                  padding=(1, "reflection"), stride=2, norm=norm_type, leaky=False, conv_type="transpose")
        self.conv5 = ConvNormRelu(in_channels=128, out_channels=64, kernel_size=3, 
                                  padding=(1, "reflection"), stride=2, norm=norm_type, leaky=False, conv_type="transpose")
        self.conv6 = ConvNormRelu(in_channels=64, out_channels=3, kernel_size=7, 
                                  padding=(3, "reflection"), stride=1, norm=norm_type, leaky=False)
        
    def forward(self, inputs):
        out = self.conv1(inputs)
        #print(out.shape)
        out = self.conv2(out)
        #print(out.shape)
        out = self.conv3(out)
        #print(out.shape)
        for block in self.blocks:
            out = block(out)
            #print(out.shape)
        out = self.conv4(out)
        #print(out.shape)
        out = self.conv5(out)
        #print(out.shape)
        out = self.conv6(out)
        #print(out.shape)
        return F.tanh(out)

In [9]:
layer = ConvNormRelu(in_channels=48, out_channels=128, kernel_size=3, 
                                  padding=(1, "reflection"), stride=2, norm="instance", leaky=False, conv_type="transpose")

In [10]:
data = torch.rand((1, 48, 256, 256))
print(layer(data).shape)

torch.Size([1, 128, 512, 512])


In [11]:
def calc_mse_loss(inputs, value=0):
    target = torch.Tensor((inputs.shape)).fill_(value).cuda()
    return F.mse_loss(inputs, target)

def calc_Gs_outputs(G1, G2, real_A, real_B):
    fake_B = G1(real_A)
    cycle_BA = G2(fake_B)
    fake_A = G2(real_B)
    cycle_AB = G1(fake_A)
    return fake_B, cycle_BA, fake_A, cycle_AB

def backward_D(real, fake, D):
    real_output = D(real)
    #print(real_output.shape)
    #print(real_output)
    d_real_loss = calc_mse_loss(real_output, 0.9)
    #F.mse_loss(real_output, torch.ones(real_output.shape).cuda())
    
    fake_output = D(fake.detach())
    d_fake_loss = F.mse_loss(fake_output, torch.zeros(fake_output.shape).cuda())
    
    loss = (d_fake_loss + d_real_loss) * 0.5
    print("Discr loss: ", loss)
    loss.backward()
    return loss

    
def backward_Gs(fake_B, cycle_BA, fake_A, cycle_AB, real_A, real_B, G1, G2, D1, D2):
    identity_A = G2(real_A)
    identity_B = G1(real_B)
    
    g1_adv_loss = calc_mse_loss(D2(fake_B), 1.0)
    g2_adv_loss = calc_mse_loss(D1(fake_A), 1.0)
    print("Adv loss: ", g1_adv_loss, g2_adv_loss)
    
    g1_identity_loss = F.l1_loss(identity_B, real_B)
    g2_identity_loss = F.l1_loss(identity_A, real_A)
    print("Identity loss: ", g1_identity_loss, g2_identity_loss)
    
    fwd_cycle_loss = F.l1_loss(cycle_BA, real_A)
    bwd_cycle_loss = F.l1_loss(cycle_AB, real_B)
    print("Cycle losses: ", fwd_cycle_loss, bwd_cycle_loss)
    
    loss = g1_adv_loss + g2_adv_loss + 5 * (g1_identity_loss + g2_identity_loss) + 10 * (fwd_cycle_loss + bwd_cycle_loss)
    print("Gen loss: ", loss)
    loss.backward()
    return loss

In [12]:
class Apple2OrangeDataset(Dataset):
    
    def __init__(self, root, folder_names, transform=None):
        super(Apple2OrangeDataset, self).__init__()
        
        self.root = root
        self.transform = transform
        self.folder_names = folder_names
        self.A_size = len(os.listdir(os.path.join(root, folder_names[0])))
        self.B_size = len(os.listdir(os.path.join(root, folder_names[1])))
        self.A_paths = sorted(os.listdir(os.path.join(root, folder_names[0])))
        self.B_paths = sorted(os.listdir(os.path.join(root, folder_names[1])))
        #print(self.A_paths)
        
    def __len__(self):
        return max(self.A_size, self.B_size)
    
    def __getitem__(self, idx):
        idx_A = idx % self.A_size
        idx_B = idx % self.B_size
        
        image_A = Image.open(os.path.join(self.root, self.folder_names[0], self.A_paths[idx_A])).convert("RGB")
        image_B = Image.open(os.path.join(self.root, self.folder_names[1], self.B_paths[idx_B])).convert("RGB")
        if self.transform is not None:
            image_A = self.transform(image_A)
            image_B = self.transform(image_B)
            
        return {"A": image_A, "B": image_B}

In [13]:
transform = transforms.Compose([transforms.Resize((256, 256)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

In [14]:
dataset = Apple2OrangeDataset('/home/dpakhom1/Cycle_gan_pytorch/datasets/horse2zebra/',
                              ["trainA", "trainB"], transform=transform)

In [15]:
dataloader = DataLoader(dataset, shuffle=True, batch_size=1, num_workers=2)

In [16]:
%matplotlib notebook
f, (real_pic, gen_pic, loss_axis) = plt.subplots(3, 1)
real_pic.set_title("Real Domain A picture")
gen_pic.set_title("Generated Domain B picture")
loss_axis.set_title("Loss G and D")
real_pic.plot()
gen_pic.plot()
loss_d = []
loss_g = []
loss_axis.plot(loss_d, list(range(len(loss_d))), 'b',
               loss_g, list(range(len(loss_g))), 'r')
plt.tight_layout()

<IPython.core.display.Javascript object>

In [17]:
def tensor_to_image(tensor):
    std = torch.Tensor([0.229, 0.224, 0.225])
    mean = torch.Tensor([0.485, 0.456, 0.406])
    transf = transforms.Compose([
                                 transforms.Normalize(mean=(-mean/std).tolist(), std=(1.0/std).tolist()),
                                 transforms.ToPILImage()])
    return transf(tensor)

In [18]:
def train_loop(num_epochs, dataloader, G1, G2, D1, D2):
    optimizer_G = optim.Adam(list(G1.parameters()) + list(G2.parameters()), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(list(D1.parameters()) + list(D2.parameters()), lr=0.0002, betas=(0.5, 0.999))
    for epoch in range(num_epochs):
        for idx, data in enumerate(dataloader):
            domain_A, domain_B = data["A"].cuda(), data["B"].cuda()
            
            fake_B, cycle_BA, fake_A, cycle_AB = calc_Gs_outputs(G1, G2, domain_A, domain_B)
            
            optimizer_G.zero_grad()
            loss_G = backward_Gs(fake_B, cycle_BA, fake_A, cycle_AB, domain_A, domain_B, G1, G2, D1, D2)
            loss_g.append(loss_G.item())
            
            optimizer_G.step()
            optimizer_D.zero_grad()
            
            loss_D1 = backward_D(domain_A, fake_A, D1)
            loss_D2 = backward_D(domain_B, fake_B, D2)
            loss_d.append((loss_D1.item() + loss_D2.item())/2)
            optimizer_D.step()
            
            if idx % 10 == 9:
                with torch.no_grad():
                    real_pic.imshow(tensor_to_image(torch.squeeze(data["A"])))
                    #real_pic.imshow(np.transpose(torch.squeeze(domain_A).cpu().detach().numpy(), (1, 2, 0)))
                    real_pic.relim()
                    real_pic.autoscale_view()
                    real_pic.figure.canvas.draw()
                    G1.eval()
                    output = G1(domain_A)
                    output = transforms.ToPILImage()(torch.squeeze(output.cpu().detach())).convert("RGB")
                    gen_pic.imshow(output)
                    gen_pic.relim()
                    gen_pic.autoscale_view()
                    gen_pic.figure.canvas.draw()
                    G1.train()
                    
                    loss_axis.lines[0].set_xdata(list(range(len(loss_d))))
                    loss_axis.lines[0].set_ydata(loss_d)
                    loss_axis.lines[1].set_xdata(list(range(len(loss_g))))
                    loss_axis.lines[1].set_ydata(loss_g)
                    loss_axis.relim()
                    loss_axis.autoscale_view()
                    loss_axis.figure.canvas.draw()

In [19]:
G1 = ResnetGenerator(3, 9, norm_type='instance')
G2 = ResnetGenerator(3, 9, norm_type='instance')
D1 = PatchGan(3, norm_type='instance')
D2 = PatchGan(3, norm_type='instance')
G1.train()
G2.train()
D1.train()
D2.train()
G1 = G1.cuda()
G2 = G2.cuda()
D1 = D1.cuda()
D2 = D2.cuda()

In [20]:
G1.apply(init_weights)
G2.apply(init_weights)
D1.apply(init_weights)
D2.apply(init_weights)

PatchGan(
  (layer1): ConvNormRelu(
    (conv): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (layer2): ConvNormRelu(
    (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
  (layer3): ConvNormRelu(
    (conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
  (layer4): ConvNormRelu(
    (conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
  (conv_fc): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
)

In [None]:
train_loop(10, dataloader, G1, G2, D1, D2)