In [None]:
#!wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/summer2winter_yosemite.zip
#!unzip summer2winter_yosemite.zip
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms 
import torch.utils.data as Data
import torchvision.transforms as T
from glob import glob
import os
import cv2
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from torchvision.utils import save_image
import time
%matplotlib inline

# from torchsummary import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

to_img= T.Compose([T.ToPILImage()])
to_tensor = T.Compose([T.ToTensor()])
load_norm = T.Compose([T.ToTensor(),T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

class Parser():
    #hyperparameters
    def __init__(self):
        #image setting
        self.n_epoch = 50
        self.batch_size = 1
        self.lr = 0.0002
        self.b1 = 0.9
        self.b2 = 0.999
        self.img_size = 128
        self.lam1 = 10
        self.lam2 = 5
        self.model_save_freq = 1000
        self.img_save_freq = 100
        self.show_freq = 50
        self.model_path = './CycleGAN0704-Normal/Model/'
        self.img_path = './CycleGAN0704-Normal/Image/' 
        self.conv_dim = 64
        self.n_res = 9
        self.D_out_dim = 1
        
args = Parser()

if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)
if not os.path.exists(args.img_path):
    os.makedirs(args.img_path)

class WeatherDataset(Data.Dataset):
    def __init__(self, mode='train', args=None):
        
        self.image_transform = T.Compose([
            #T.RandomResizedCrop(args.img_size, scale=(0.3,1.0)),
            T.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        if mode == 'train':
            self.summer_files = sorted(glob('./data/summer2winter/trainA/*.jpg'))
            self.winter_files = sorted(glob('./data/summer2winter/trainB/*.jpg'))
        else:
            self.summer_files = sorted(glob('./data/summer2winter/testA/*.jpg'))
            self.winter_files = sorted(glob('./data/summer2winter/testB/*.jpg'))
        print('Loaded')

    def __getitem__(self, index):
        
        summer_img = self.image_transform(Image.open(self.summer_files[index % len(self.summer_files)]))
        winter_img = self.image_transform(Image.open(self.winter_files[index % len(self.winter_files)]))
        
        return summer_img, winter_img

    def __len__(self):
        return len(self.summer_files)

class LayerNorm(nn.Module):
    
    def __init__(self, num_features, eps=1e-5, affine=True):
        super(LayerNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps
        
        if self.affine:
            self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) # num_featurs, depth
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        shape = [-1] + [1] * (x.dim() - 1)
        mean = x.view(x.size(0), -1).mean(1).view(shape)
        std = x.view(x.size(0), -1).std(1).view(shape)
        y = (x - mean) / (std + self.eps)
        if self.affine:
            a_shape = [1, -1] + [1] * (x.dim() - 2)
            y = self.gamma.view(a_shape) * y + self.beta.view(a_shape)
        return y
    
class Flatten(nn.Module):
    def __init__(self,):
        super(Flatten, self).__init__()
    def forward(self,x):
        return x.view(x.size(0),-1)
    
class ResBlock(nn.Module):
    def __init__(self, dim):
        super(ResBlock, self).__init__()
        self.model = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim,dim,3,1,0),
            nn.InstanceNorm2d(dim),
            nn.LeakyReLU(0.2,inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim,dim,3,1,0),
            nn.InstanceNorm2d(dim),
        )

    def forward(self,x):
        return nn.LeakyReLU(0.2,inplace=True)(self.model(x) + x)
        
        
        
        
        
        
class UpBlock(nn.Module):
    def __init__(self,in_dim,out_dim):
        super(UpBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_dim,out_dim,3,1,1),
            nn.InstanceNorm2d(out_dim),
            nn.LeakyReLU(0.2,inplace=True)
                                  )
    def forward(self,x):
        return self.model(x)
        
        

    
class Generator(nn.Module):
    def __init__(self, dim = 32, num_res = 6):
        super(Generator, self).__init__()
        
        res_seq = nn.ModuleList()
        res_seq.extend([ResBlock(dim*4) for i in range(num_res)])
        
        self.model = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3,dim,7,1,0),
            nn.InstanceNorm2d(dim),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(dim,dim*2,3,2,1),
            nn.InstanceNorm2d(dim*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(dim*2,dim*4,3,2,1),
            nn.InstanceNorm2d(dim*4),
            nn.LeakyReLU(0.2,inplace=True),
            *res_seq,
            UpBlock(dim*4,dim*2),
            UpBlock(dim*2,dim),
            nn.ReflectionPad2d(3),
            nn.Conv2d(dim,3,7,1,0),
            nn.Tanh()     
        )
        
        
        
    def forward(self,x):
        out = self.model(x)
        return out


class Discirminator(nn.Module):
    def __init__(self,dim = 64):
        super(Discirminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(3,dim,4,2,1),
            nn.LeakyReLU(dim),
            nn.Conv2d(dim,dim*2,4,2,1),
            nn.InstanceNorm2d(dim*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(dim*2,dim*4,4,2,1),
            nn.InstanceNorm2d(dim*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(dim*4,dim*8,4,2,1),
            nn.InstanceNorm2d(dim*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(dim*8,dim*4,4,2,1),
            nn.InstanceNorm2d(dim*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(dim*4,dim*2,4,2,1),
            nn.InstanceNorm2d(dim*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(dim*2,1,4,2,0),
            nn.Sigmoid()
        )
    def forward(self,x):
        return self.model(x)

    
class CycleGAN(nn.Module):
    
    def __init__(self,):
        super(CycleGAN, self).__init__()
        
        self.G_SW = Generator(num_res=args.n_res).to(device)
        self.G_WS = Generator(num_res=args.n_res).to(device)
        self.D_S = Discirminator(args.conv_dim).to(device)
        self.D_W = Discirminator(args.conv_dim).to(device)
        
        self.real_labels = torch.ones(args.batch_size,1,args.D_out_dim, args.D_out_dim ).to(device)
        self.fake_labels = torch.zeros(args.batch_size,1,args.D_out_dim ,args.D_out_dim).to(device)
        g_params = list(self.G_SW.parameters()) + list(self.G_WS.parameters())
        d_params = list(self.D_S.parameters()) + list(self.D_W.parameters())
        
        self.G_optim = optim.Adam(g_params, lr = 0.0002, betas=(args.b1,args.b2))
        self.D_optim = optim.Adam(d_params, lr = 0.0002, betas=(args.b1,args.b2))
        
      #  self.Loss = nn.MSELoss().to(device)
        self.Loss = nn.BCELoss().to(device)
        
        self.D_loss_hist = []
        self.G_loss_hist = []
        
        self.W_img_pool = []
        self.S_img_pool = []
        
        self.L1 = nn.L1Loss()
        
        self.apply(self.weight_init)
       
    def forward(self, s_img, w_img):
        
        self.s_img = s_img.to(device)
        self.w_img = w_img.to(device)  
        
        ############### Train D ###############
        self.D_optim.zero_grad()
        
        self.g_w_img = self.G_SW(self.s_img).detach()   ### Generated Winter image
        self.g_s_img = self.G_WS(self.w_img).detach()  ### Generated Summer image
        
        self.W_img_pool.append(self.g_w_img)
        self.S_img_pool.append(self.g_s_img)
        
        self.W_img_pool = self.W_img_pool[-50:]
        self.S_img_pool = self.S_img_pool[-50:]
        
        r_idx = torch.randint(0,len(self.W_img_pool),(2,1)).type(torch.LongTensor)
        
        test_w = self.W_img_pool[r_idx[0]]
        test_s = self.S_img_pool[r_idx[1]]
        
        D_W_loss = self.Loss(self.D_W(test_w),self.fake_labels)
        D_S_loss = self.Loss(self.D_S(test_s),self.fake_labels)
        
#         D_W_loss = self.MSE(self.D_W(self.g_w_img),self.fake_labels)
#         D_S_loss = self.MSE(self.D_S(self.g_s_img),self.fake_labels)
        
        D_W_r_loss = self.Loss(self.D_W(self.w_img), self.real_labels)
#         D_W_f_loss = self.BCE(self.D_W(self.s_img), self.fake_labels)
#         D_S_f_loss = self.BCE(self.D_S(self.w_img), self.fake_labels)
        D_S_r_loss = self.Loss(self.D_S(self.s_img), self.real_labels)
        
        
        self.D_loss = D_W_r_loss  + D_S_r_loss + D_W_loss + D_S_loss
        # + D_W_f_loss + D_S_f_loss
        
        self.D_loss_hist.append(self.D_loss.item())
        self.D_loss.backward()
        self.D_optim.step()

        ############### Train G ###############
        self.G_optim.zero_grad()
        
        # A way to skip twice forward in G ?
        self.g_w_img = self.G_SW(self.s_img)   ### Generated Winter image
        self.g_s_img = self.G_WS(self.w_img)  ### Generated Summer image
        
        G_SW_loss = self.Loss(self.D_W(self.g_w_img),self.real_labels)
        G_WS_loss = self.Loss(self.D_S(self.g_s_img),self.real_labels)
        
        # Idendity Loss
        
        self.g_s_s_img = self.G_WS(self.s_img)
        self.g_w_w_img = self.G_SW(self.w_img)
        
        identity_loss = self.L1(self.g_s_s_img, self.s_img) + self.L1(self.g_w_w_img, self.w_img)
        
        
        # Reconstruct Loss
        
        self.g_re_s_img = self.G_WS(self.g_w_img)
        self.g_re_w_img = self.G_SW(self.g_s_img)
        
        recon_loss = self.L1(self.g_re_s_img, self.s_img) + self.L1(self.g_re_w_img, self.w_img)
        
        self.G_loss = G_SW_loss + G_WS_loss + (recon_loss * args.lam1) + (identity_loss * args.lam2)
        self.G_loss_hist.append(self.G_loss.item())
        self.G_loss.backward()
        self.G_optim.step()
        
        torch.cuda.empty_cache()
        
        
    def image_save(self, step):
        
        training_img_path = args.img_path + "CycleGAN_Step_"+str(step)+".png"
        save_image(torch.cat([self.s_img,self.g_w_img,self.g_re_s_img,self.w_img,self.g_s_img,self.g_re_w_img],0), training_img_path , nrow=3, normalize=True, range=(-1,1))
        print('Image saved')
        
        
    def model_save(self,step):
        path = args.model_path + 'CycleGAN_Step_' + str(step) + '.pth'
        torch.save({'CycleGAN':self.state_dict()}, path)
        print('Model saved')
        
    def load_step_dict(self, step):
        
        path = args.model_path + 'CycleGAN_Step_' + str(step) + '.pth'
        self.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage)['CycleGAN'])
 
    def plot_all_loss(self,step):
        
        fig, ax = plt.subplots(figsize= (20,8))
        plt.plot(self.G_loss_hist,label='G_loss')
        plt.plot(self.D_loss_hist,label='D_loss')
        plt.ylabel('Loss',fontsize=15)
        plt.xlabel('Number of Steps',fontsize=15)
        plt.title('Loss',fontsize=30,fontweight ="bold")
        plt.legend(loc = 'upper left')
        fig.savefig("CycleGAN_Loss_"+str(step)+".png")
        
    def num_all_params(self,):
        return sum([param.nelement() for param in self.parameters()])
    
    def weight_init(self,m):
        if type(m) in [nn.Conv2d, nn.ConvTranspose2d]:
            #nn.init.xavier_normal_(m.weight,nn.init.calculate_gain('leaky_relu',param=0.02))
            nn.init.kaiming_normal_(m.weight,0.2,nonlinearity='leaky_relu')
            
cycleGAN = CycleGAN().to(device)      

training_set = WeatherDataset('train', args = args)
training_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle = True,pin_memory=True)
       





In [None]:
epoch = 0
all_steps = 1
while epoch < args.n_epoch:
    for i, (s_img ,w_img) in enumerate(training_loader):
        
        
        start_t = time.time()
        cycleGAN(s_img,w_img)
        end_t = time.time()
        
        print('| Epoch [%d] | Step [%d] | D Loss: [%.4f] | G Loss: [%.4f] | Time: %.1fs' %\
              (epoch, all_steps, cycleGAN.D_loss.item(), cycleGAN.G_loss.item(),
               end_t - start_t))
        
        
        if all_steps % args.show_freq == 0: 
            fig=plt.figure(figsize=(8, 8))
            fig.add_subplot(1,3,1)
            plt.imshow(to_img(cycleGAN.s_img[0].cpu()*0.5+0.5))
            fig.add_subplot(1,3,2)
            plt.imshow(to_img(cycleGAN.g_w_img[0].cpu()*0.5+0.5))
            fig.add_subplot(1,3,3)
            plt.imshow(to_img(cycleGAN.g_re_s_img[0].cpu()*0.5+0.5))
            plt.show()
            if all_steps % args.img_save_freq ==0:
                cycleGAN.image_save(all_steps)
                if all_steps % args.model_save_freq == 0:
                    cycleGAN.model_save(all_steps)
        all_steps += 1
    epoch +=1