In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
from torch.utils.data import Dataset,DataLoader
import time
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
import glob
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

In [2]:
saved_dir="../saved_data/pix2pix_2/"
sample_dir="../samples/pix2pix_2/"
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
    print("folders created!")
else:
    print("folders not created!")
if not os.path.exists(saved_dir):
    os.makedirs(saved_dir)
    print("folders created!")
else:
    print("folders not created!")

folders created!
folders created!


In [3]:
batch_size=128
n_epochs=200
image_width=80
learning_rate=2e-4

In [4]:
#load imagedataset
class ImageDataset(Dataset):
    
    def __init__(self, root_dir, transform=None):
        self.root_dir = glob.glob(root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.root_dir)

    def __getitem__(self,idx):
        img_name = self.root_dir[idx]
        image = io.imread(img_name)
        sample = image
        if self.transform:
            sample = self.transform(sample)

        return sample

    
#train data load
train_data=ImageDataset(root_dir="D:/pix2pix_dataset/cityscapes/train/*.jpg",
                                     transform=transforms.Compose([
                                         transforms.ToPILImage(),
                                         transforms.Resize((image_width,image_width*2)),
                                         transforms.ToTensor()
                                     ]))
train_loader=torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True)

In [5]:
# image = io.imread("D:/pix2pix_dataset/cityscapes/train/1.jpg")
# width=image.shape[0]
# x_image=image[:,0:width,:] # x_data
# y_image=image[:,width:width*2,:]# y_data
# print("image.shape : ",image.shape,"\nx_image.shape : ",x_image.shape,
#       "\ny_image.shape : ",y_image.shape)
# #reshaped x,y data
# reshaped_x_image=x_image
# reshaped_x_image=np.resize(reshaped_x_image,((image_width-44),(image_width-44),3))
# reshaped_y_image=y_image
# reshaped_y_image=np.resize(reshaped_y_image,((image_width-44),(image_width-44),3))
# print("reshaped_x_image.shape : ",reshaped_x_image.shape,"reshaped_y_image.shape : ",
#      reshaped_y_image.shape)

In [6]:
for i in range(len(train_data)):
    sample=train_data[i]
    print(sample.shape)
    if i==0:
        break

torch.Size([3, 80, 160])


In [7]:
from Generator_model import *

In [8]:
class Unet_model(nn.Module):
    def __init__(self,in_channel=3):
        super(Unet_model,self).__init__()
        self.down1=UNetDown(in_channel,64)
        self.down2=UNetDown(64,128)
        self.down3=UNetDown(128,256)
        
        self.max_pool=nn.MaxPool2d(kernel_size=2,stride=2)
        
        self.up1=UNetUp(256,128)
        self.up1_cnn=UnetUpConv(256,128)
        
        self.up2=UNetUp(128,64)
        self.up2_cnn=UnetUpConv(128,64)
        
        self.up3=UNetUp(64,32)
        self.up3_cnn=UnetUpConv(64,32)
        
        self.final=UnetUpConv(64,3)
        
    def forward(self,x):
        o1=self.down1(x)
        o1_max=self.max_pool(o1)
        o2=self.down2(o1_max)
        o2_max=self.max_pool(o2)
        o3=self.down3(o2_max)
        
        u1=self.up1(o3) # 128 특징 가지고 있다.
        u1_concat=center_crop_concat(o2,u1)
        u1_conv=self.up1_cnn(u1_concat)
        
        u2=self.up2(u1_conv)
        u2_concat=center_crop_concat(o1,u2)
        u2_conv=self.up2_cnn(u2_concat)
        
        output=self.final(u2_conv)
        
#         print("************Unet down model************")
#         print("layer1 : ",o1.shape)
#         print("layer2 : ",o2.shape)
#         print("layer3 : ",o3.shape)   
#         print("************Unet up model************")
#         print("layer2 : ",u1_conv.shape)
#         print("layer1 : ",u2_conv.shape)
#         print("final : ",output.shape)
        return output 
class Discriminator(nn.Module):
    def __init__(self,in_size=6,out_size=64):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0, bias=False),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(out_size),
                                    nn.Conv2d(out_size, out_size*2,
                                              3, 1, 0, bias=False),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(out_size*2),
                                    
                                    nn.Conv2d(out_size*2, out_size*4,
                                              3, 1, 0, bias=False),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(out_size*4),
                                    nn.Conv2d(out_size*4, out_size*2,
                                              3, 1, 0, bias=False),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(out_size*2),
                                    
                                    nn.MaxPool2d(2, 2),
                                    nn.Conv2d(out_size*2, 1,
                                              3, 1, 0, bias=False)
                                    )

    def forward(self, x1,x2):
#         print(x1.shape,x2.shape)
        x_cat=torch.cat((x1,x2),1)
#         print("x1.shape : ",x1.shape,"x2.shape : ",x2.shape,"cat.shape : ",x_cat.shape)
        out = self.layer1(x_cat)
        return out

In [9]:

generator_model=Unet_model().cuda()
discriminator_model=Discriminator().cuda()

optimizer_G=torch.optim.Adam(generator_model.parameters(),lr=learning_rate)
optimizer_D=torch.optim.Adam(discriminator_model.parameters(),lr=learning_rate)
gen_loss=nn.MSELoss()
l1_loss=nn.L1Loss()

In [10]:
print(torch.typename(train_loader))

torch.utils.data.dataloader.DataLoader


In [11]:
start=time.time()
total_step=len(train_loader)
for epochs in range(n_epochs):
    for i ,image in enumerate(train_loader):
        image_num=int(len(image))
#         print("image_num : ",image_num)
        #image 분리
        y_image=image[:,:,:,0:image_width] # x_data
        x_image=image[:,:,:,image_width:image_width*2]# y_data

        #reshaped x,y data
        reshaped_x_image=torch.FloatTensor(image_num,3,1,1)
    #     reshaped_x_image=Variable(reshaped_x_image)
        reshaped_x_image.data.resize_(x_image.size()).copy_(x_image)
        reshaped_x_image.data.resize_((image_num,3,(image_width-44),(image_width-44)))
#         print("reshaped x shape : ",reshaped_x_image.shape)
        reshaped_y_image=torch.FloatTensor(image_num,3,1,1)
    #     reshaped_y_image=Variable(reshaped_y_image)
        reshaped_y_image.data.resize_(y_image.size()).copy_(y_image)
        reshaped_y_image.data.resize_((image_num,3,(image_width-44),(image_width-44)))
#         print("reshaped_y shape : ",reshaped_y_image.shape)

        true_label=torch.ones(image_num,1,12,12).cuda() 
        false_label=torch.zeros(image_num,1,12,12).cuda()


        x_image=Variable(x_image).cuda()
        y_image=Variable(y_image).cuda()
        reshaped_x_image=Variable(reshaped_x_image).cuda()
        reshaped_y_image=Variable(reshaped_y_image).cuda()
        #------------------
        #Generator training
        #------------------
        optimizer_G.zero_grad()
        gen_out=generator_model.forward(x_image)
#         print("gen_out.shape : ",gen_out.shape,"reshaped_x_image.shape:",reshaped_x_image.shape)
        gen_dis_out=discriminator_model(gen_out,reshaped_x_image)

#         print("gen_dis_out.shape : ",gen_dis_out.shape)
        g_loss=gen_loss(gen_dis_out,true_label)
        g_l1=l1_loss(gen_out,reshaped_y_image)

        loss_G=g_loss+g_l1*100

        loss_G.backward(retain_graph=True)
        optimizer_G.step()


        #----------------------
        #Discriminator training
        #----------------------

        optimizer_D.zero_grad()
        #d fake loss
        dis_fake_out=discriminator_model.forward(gen_out,reshaped_y_image)
        d_fake_loss=gen_loss(dis_fake_out,false_label)

        #d real loss
        dis_real_out=discriminator_model.forward(reshaped_x_image,reshaped_y_image)
        d_real_loss=gen_loss(dis_real_out,true_label)

        loss_D=(d_fake_loss+d_real_loss)*0.5
        loss_D.backward()
        optimizer_D.step()

    save_image(gen_out,os.path.join(
        sample_dir,'fake_images-{}.png'.format(epochs+1)))
    print ("[Epoch %d/%d] [D loss: %f] [G loss: %f]" % (epochs, n_epochs,
                                                        loss_D.item(), loss_G.item()))
    if(epochs%50==0):
        # Save the model checkpoints
        torch.save(optimizer_G.state_dict(), saved_dir +
                   '/G_pix2pix_{}.ckpt'.format(epochs+1))
        torch.save(optimizer_D.state_dict(), saved_dir +
                   '/D_pix2pix_{}.ckpt'.format(epochs+1))
finished = time.time()
hours = finished-start
print("training finished! %d minutes" % hours)       

[Epoch 0/200] [D loss: 0.048125] [G loss: 26.854290]
[Epoch 1/200] [D loss: 0.092475] [G loss: 16.437357]
[Epoch 2/200] [D loss: 0.024986] [G loss: 14.506286]
[Epoch 3/200] [D loss: 0.015373] [G loss: 14.663190]
[Epoch 4/200] [D loss: 0.015067] [G loss: 13.945982]
[Epoch 5/200] [D loss: 0.013373] [G loss: 13.362953]
[Epoch 6/200] [D loss: 0.008679] [G loss: 15.653759]
[Epoch 7/200] [D loss: 0.009990] [G loss: 14.416258]
[Epoch 8/200] [D loss: 0.007615] [G loss: 15.764406]
[Epoch 9/200] [D loss: 0.007918] [G loss: 15.834070]
[Epoch 10/200] [D loss: 0.008239] [G loss: 13.852053]
[Epoch 11/200] [D loss: 0.007020] [G loss: 15.781377]
[Epoch 12/200] [D loss: 0.006337] [G loss: 14.067452]
[Epoch 13/200] [D loss: 0.005608] [G loss: 13.370281]
[Epoch 14/200] [D loss: 0.006897] [G loss: 13.532378]
[Epoch 15/200] [D loss: 0.004236] [G loss: 14.590778]
[Epoch 16/200] [D loss: 0.008143] [G loss: 14.652566]
[Epoch 17/200] [D loss: 0.007013] [G loss: 12.903384]
[Epoch 18/200] [D loss: 0.010851] [G l

[Epoch 151/200] [D loss: 0.000375] [G loss: 17.084087]
[Epoch 152/200] [D loss: 0.000473] [G loss: 15.293009]
[Epoch 153/200] [D loss: 0.000368] [G loss: 15.931844]
[Epoch 154/200] [D loss: 0.000374] [G loss: 14.440784]
[Epoch 155/200] [D loss: 0.000382] [G loss: 13.773494]
[Epoch 156/200] [D loss: 0.000444] [G loss: 13.729418]
[Epoch 157/200] [D loss: 0.000383] [G loss: 15.169604]
[Epoch 158/200] [D loss: 0.000601] [G loss: 12.548455]
[Epoch 159/200] [D loss: 0.000310] [G loss: 14.998318]
[Epoch 160/200] [D loss: 0.000473] [G loss: 14.356856]
[Epoch 161/200] [D loss: 0.000368] [G loss: 13.958835]
[Epoch 162/200] [D loss: 0.000407] [G loss: 15.414515]
[Epoch 163/200] [D loss: 0.000416] [G loss: 15.048975]
[Epoch 164/200] [D loss: 0.000426] [G loss: 15.135411]
[Epoch 165/200] [D loss: 0.000490] [G loss: 16.413946]
[Epoch 166/200] [D loss: 0.000309] [G loss: 14.780694]
[Epoch 167/200] [D loss: 0.000531] [G loss: 14.324646]
[Epoch 168/200] [D loss: 0.000440] [G loss: 13.728860]
[Epoch 169