In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms 
import torch.utils.data as Data
import torchvision.transforms as T
from glob import glob
import os
import cv2
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from torchvision.utils import save_image
import pandas as pd
import time
import pickle
%matplotlib inline

to_img= T.Compose([T.ToPILImage()])
to_tensor = T.Compose([T.ToTensor()])
load_norm = T.Compose([T.ToTensor(),T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Parser():
    #hyperparameters
    def __init__(self):
        #image setting
        self.n_epoch = 50
        self.batch_size = 64
        self.z_dim = 100
        self.y_dim = 2
        self.D_lr = 0.01
        self.G_lr = 0.002 # 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.img_size = 64
        self.lam1 = 10
        self.lam2 = 5
        self.model_save_freq = 1000
        self.img_save_freq = 100
        self.show_freq = 50
        self.model_path = './cLSDCGAN5/Model/'
        self.img_path = './cLSDCGAN5/Image/' 
        self.conv_dim = 64
        self.d_dim = 64
        self.D_out_dim = 16
        self.train_img_path = "./data/celeba/"
        self.num_res = 5
        self.D_mode = ''
        self.G_mode = ''
        self.L_mode = ''
        
args = Parser()  

if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)
if not os.path.exists(args.img_path):
    os.makedirs(args.img_path)

# For creating the data
#######
# df = pd.read_csv('./data/celeba/Anno/list_attr_celeba.txt',delim_whitespace=True)
# path = list(df['Path'])
# is_male = list(df['Male'])
# is_male = [1 if i == 1 else 0 for i in is_male]
# with open("Path_Male.pickle", "wb") as fp:   
#     pickle.dump([path,is_male], fp) 

class CelebADataset(Data.Dataset):
    def __init__(self, mode='train', args=None):
        
        self.image_transform = T.Compose([
            T.Resize((args.img_size,args.img_size)),
            #T.RandomResizedCrop(args.img_size, scale=(1.0,1.0)),
            T.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        with open("Path_Male.pickle", 'rb') as f:
            self.path ,self.male= pickle.load(f)
            print('Loaded')
        self.path = self.path[:-2000] if mode == 'train' else self.path[-2000:]
        self.male = self.male[:-2000] if mode == 'train' else self.male[-2000:]
        
    def __getitem__(self, index):
        
        idx = index % len(self.path)
        img = self.image_transform(Image.open(os.path.join(args.train_img_path,
                                                             self.path[idx])))
        
        is_male = self.male[idx]
        return img, is_male

    def __len__(self):
        return len(self.path)


# dataset = CelebADataset(mode='train',args= args)
# dataloader = DataLoader(dataset,batch_size=args.batch_size,shuffle=True,drop_last=True,pin_memory=True)
# data_t = iter(dataloader).next()
# plt.imshow(to_img(data_t[0][0]*0.5+0.5))
# print('Label: ',data_t[1][0].item())


class ResBlock(nn.Module):
    def __init__(self, dim):
        super(ResBlock, self).__init__()
        self.model = nn.Sequential(
            #nn.ReflectionPad2d(1),
            nn.Conv2d(dim,dim,3,1,1),
            nn.BatchNorm2d(dim),
            nn.LeakyReLU(0.2,inplace=True),
            #nn.ReflectionPad2d(1),
            nn.Conv2d(dim,dim,3,1,1),
            nn.BatchNorm2d(dim),
        )

    def forward(self,x):
        #nn.LeakyReLU(0.2,inplace=True)
        return nn.LeakyReLU(0.2,inplace=True)(self.model(x) + x)

class ConcatBlock_G(nn.Module):
    def __init__(self,dim = 64):
        super(ConcatBlock_G, self).__init__()

        if args.G_mode == "Upsample":
            self.model_z = nn.Sequential(
                nn.Upsample(scale_factor=2),
                #nn.ReflectionPad2d(1),
                nn.Conv2d(args.z_dim,dim*2,3,1,1),
                nn.BatchNorm2d(dim*2),
                nn.LeakyReLU(0.2,inplace=True),
                nn.Upsample(scale_factor=2),
                #nn.ReflectionPad2d(1),
                nn.Conv2d(dim*2,dim*4,3,1,1),
                nn.BatchNorm2d(dim*4),
                nn.LeakyReLU(0.2,inplace=True),
            )

            self.model_y = nn.Sequential(
                nn.Upsample(scale_factor=2),
                #nn.ReflectionPad2d(1),
                nn.Conv2d(args.y_dim,dim*2,3,1,1),
                nn.BatchNorm2d(dim*2),
                nn.LeakyReLU(0.2,inplace=True),
                nn.Upsample(scale_factor=2),
                #nn.ReflectionPad2d(1),
                nn.Conv2d(dim*2,dim*4,3,1,1),
                nn.BatchNorm2d(dim*4),
                nn.LeakyReLU(0.2,inplace=True),
            )
        elif args.G_mode == "Deconv":
            self.model_z = nn.Sequential(
                nn.ConvTranspose2d(args.z_dim,dim*4,4,1,0),
                nn.BatchNorm2d(dim*4),
                nn.LeakyReLU(0.2,inplace=True),
            )
            self.model_y = nn.Sequential(
                nn.ConvTranspose2d(args.y_dim,dim*4,4,1,0),
                nn.BatchNorm2d(dim*4),
                nn.LeakyReLU(0.2,inplace=True),
            )
        else:
            raise NotImplementedError
    def forward(self, data):
        return torch.cat([self.model_z(data[0]), self.model_y(data[1])],1)
    

class Generator(nn.Module):    
    '''
    Only need y, 
    '''
    def __init__(self, dim = 64):
        super(Generator, self).__init__()
        
        
        if args.G_mode == "Upsample":
            print("Use Upsample in Generator")
            res_seq = nn.ModuleList()
            res_seq.extend([ResBlock(dim*4) for i in range(args.num_res)])
            self.model = nn.Sequential(
                ConcatBlock_G(dim),
                nn.Upsample(scale_factor=2),
                #nn.ReflectionPad2d(1),
                nn.Conv2d(dim*8,dim*4,3,1,1),
                nn.BatchNorm2d(dim*4),
                nn.LeakyReLU(0.2,inplace=True),
                nn.Upsample(scale_factor=2),
                ResBlock(dim*4),
                #nn.ReflectionPad2d(1),
#                 nn.Conv2d(dim*4,dim*4,3,1,1),
#                 nn.BatchNorm2d(dim*4),
#                 nn.LeakyReLU(0.2,inplace=True),
                nn.Upsample(scale_factor=2),
                #nn.ReflectionPad2d(1),
                nn.Conv2d(dim*4,dim*2,3,1,1),
                nn.BatchNorm2d(dim*2),
                nn.LeakyReLU(0.2,inplace=True),
                nn.Upsample(scale_factor=2),
                #nn.ReflectionPad2d(1), # Can be deleted? 
                nn.Conv2d(dim*2,3,3,1,1),
                nn.Tanh()
            )
        elif args.G_mode:
            print("Use Deconv in Generator")
            self.model = nn.Sequential(
                ConcatBlock_G(dim),
                nn.ConvTranspose2d(dim*8,dim*4,4,2,1),
                nn.BatchNorm2d(dim*4),
                nn.LeakyReLU(0.2,inplace=True),
                nn.ConvTranspose2d(dim*4,dim*4,4,2,1),
                nn.BatchNorm2d(dim*4),
                nn.LeakyReLU(0.2,inplace=True),
                nn.ConvTranspose2d(dim*4,dim*2,4,2,1),
                nn.BatchNorm2d(dim*2),
                nn.LeakyReLU(0.2,inplace=True),
                nn.ConvTranspose2d(dim*2,3,4,2,1),
                nn.Tanh()
            )
        else:
            raise NotImplementedError
        
    def forward(self, y):  # Expect y as one hot code 
        self.z = torch.randn(args.batch_size, args.z_dim,1 ,1 ).to(device)
        return self.model([self.z,y])
    
    
    
    
class ConcatBlock_D(nn.Module):
    
    def __init__(self,dim = 64):
        super(ConcatBlock_D, self).__init__()
        self.model_img = nn.Sequential(
            nn.Conv2d(3,dim,4,2,1),
            #nn.BatchNorm2d(dim),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.model_y = nn.Sequential(
            nn.Conv2d(args.y_dim,dim,4,2,1),
            #nn.BatchNorm2d(dim),
            nn.LeakyReLU(0.2, inplace=True),
        ) 
        
    def forward(self, data): # Expect repeatitive y 64x64
        return torch.cat([self.model_img(data[0]),self.model_y(data[1])],1)
        
        

class Discrimator(nn.Module):
    def __init__(self, dim = 64):
        super(Discrimator, self).__init__()
        
        #TODO : PatchGAN, Remove sigmoid function
        if args.D_mode == "Patch":
            print("Use PatchGAN Architecture in Discriminator")
            self.model = nn.Sequential(
                ConcatBlock_D(dim),
                nn.Conv2d(dim*2,dim*4,4,2,1),
                nn.BatchNorm2d(dim*4),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(dim*4,dim*8,4,2,1),
                nn.BatchNorm2d(dim*8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(dim*8,dim*8,4,2,1),
                nn.BatchNorm2d(dim*8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(dim*8,1,1,1,0),
            )
            
        elif args.D_mode == "Normal":    
            print("Use Normal Architecture in Discriminator")
            self.model = nn.Sequential(
                ConcatBlock_D(dim),
                nn.Conv2d(dim*2,dim*4,4,2,1),
                nn.BatchNorm2d(dim*4),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(dim*4,dim*8,4,2,1),
                nn.BatchNorm2d(dim*8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(dim*8,dim*8,4,2,1),
                nn.BatchNorm2d(dim*8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(dim*8,1,4,1,0),
            )
        else:
            raise NotImplementedError
            
        if args.L_mode != "LS":
            self.model.add_module('Sigmoid',nn.Sigmoid())
        
    def forward(self, img, y): # Expect y as one hot code 
        y = y.repeat([1,1,args.img_size,args.img_size])
        return self.model([img,y]) #Squeeze
        


class cLSDCGAN(nn.Module):
    def __init__(self,):
        super(cLSDCGAN, self).__init__()
        
        self.G = Generator(args.d_dim).to(device)
        self.D = Discrimator(args.d_dim).to(device)
        
        if args.L_mode == 'LS':
            print("Use LS Loss")
            self.Loss = nn.MSELoss().to(device)
        elif args.L_mode == 'BCE':
            print("Use BCE Loss")
            self.Loss = nn.BCELoss().to(device)
        else:
            raise NotImplementedError
        
        self.G_loss_hist = []
        self.D_loss_hist = []
        
        self.G_optim = optim.Adam(self.G.parameters(), lr = args.G_lr, betas=(args.b1, args.b2))
        self.D_optim = optim.Adam(self.D.parameters(), lr = args.D_lr, betas=(args.b1, args.b2))
        
        if args.D_mode== 'Patch':
            self.real_label = torch.ones(args.batch_size,1,4,4).to(device)
            self.fake_label = torch.zeros(args.batch_size,1,4,4).to(device)
        else:
            self.real_label = torch.ones(args.batch_size,1,1,1).to(device)
            self.fake_label = torch.zeros(args.batch_size,1,1,1).to(device)
        
        
        self.apply(self.weight_init)
        
        self.progress_photo = []
    
    def forward(self, img, y):
        
        img = img.to(device)
        y = y.to(device)
        
        
        one_hot_y = torch.zeros(args.batch_size,args.y_dim).to(device)
        one_hot_y.scatter_(1,y.unsqueeze(1),1)
        one_hot_y = one_hot_y.view(-1,args.y_dim,1,1)
        
        ############ Train D ############
        self.D_optim.zero_grad()
        self.G_img  = self.G(one_hot_y).detach()
        D_fake_loss = self.Loss(self.D(self.G_img, one_hot_y),self.fake_label)
        D_real_loss = self.Loss(self.D(img, one_hot_y),self.real_label)
        self.D_loss = D_fake_loss + D_real_loss
        self.D_loss_hist.append(self.D_loss.item())
        self.D_loss.backward()
        self.D_optim.step()
        
        ############ Train G ############
        self.G_optim.zero_grad()
        self.G_img = self.G(one_hot_y)
        self.G_loss = self.Loss(self.D(self.G_img, one_hot_y),self.real_label)
        self.G_loss_hist.append(self.G_loss.item())
        self.G_loss.backward()
        self.G_optim.step()
        
        self.progress_photo.append(self.G_img[0].detach())
        self.progress_photo = self.progress_photo[-args.img_save_freq:]
        
    def image_save(self, step):
        
        img_save_path = args.img_path + "cLSDCAN_Step_"+str(step)+".png"
        save_image( self.progress_photo[:args.img_save_freq], img_save_path , nrow=10, normalize=True, range=(-1,1))
        print('Image saved')  
        
    def model_save(self,step):
        path = args.model_path + 'cLSDCAN_Step_' + str(step) + '.pth'
        torch.save({'cLSDCAN':self.state_dict()}, path)
        print('Model saved')
        
    def load_step_dict(self, step):
        
        path = args.model_path + 'cLSDCAN_Step_' + str(step) + '.pth'
        self.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage)['cLSDCAN'])
 
    def plot_all_loss(self,step):
        
        fig, ax = plt.subplots(figsize= (20,8))
        plt.plot(self.G_loss_hist,label='G_loss')
        plt.plot(self.D_loss_hist,label='D_loss')
        plt.ylabel('Loss',fontsize=15)
        plt.xlabel('Number of Steps',fontsize=15)
        plt.title('Loss',fontsize=30,fontweight ="bold")
        plt.legend(loc = 'upper left')
        fig.savefig("cLSDCAN_Loss_"+str(step)+".png")
        
    def num_all_params(self,):
        return sum([param.nelement() for param in self.parameters()])
    
    def weight_init(self,m):
        if type(m) in [nn.Conv2d, nn.ConvTranspose2d]:
            #nn.init.xavier_normal_(m.weight,nn.init.calculate_gain('leaky_relu',param=0.02))
            nn.init.kaiming_normal_(m.weight,0.2,nonlinearity='leaky_relu')
            
        
                
dataset = CelebADataset(mode='train',args= args)
training_loader = DataLoader(dataset,batch_size=args.batch_size,shuffle=True,drop_last=True,pin_memory=True)


In [None]:
'''
Choice of model:
L_mode: LS, BCE
G_mode: Upsample, Deconv
D_mode: Normal, Patch (The Patch method need LS for L_mode)

'''
args.G_mode = 'Upsample'
args.D_mode = 'Normal'
args.L_mode = 'LS'
gan = cLSDCGAN().to(device)
epoch = 0
all_steps = 1

In [None]:
while epoch < args.n_epoch:
    for i, (img, y) in enumerate(training_loader):    
        
        start_t = time.time()
        gan(img, y)
        end_t = time.time()
        
        print('| Epoch [%d] | Step [%d] | D Loss: [%.4f] | G Loss: [%.4f] | Time: %.1fs' %\
              (epoch, all_steps, gan.D_loss.item(), gan.G_loss.item(),
               end_t - start_t))

        if all_steps % args.show_freq == 0: #args.show_freq
            fig=plt.figure(figsize=(8, 8))
            fig.add_subplot(1,3,1)
            plt.imshow(to_img(gan.G_img[0].cpu()*0.5+0.5))
            plt.show()
            if all_steps % args.img_save_freq ==0: # args.img_save_freq
                gan.image_save(all_steps)
                if all_steps % args.model_save_freq == 0: #args.model_save_freq
                    gan.model_save(all_steps)
        all_steps += 1
    epoch +=1