In [323]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.utils.data as Data
import torchvision.transforms as T
from glob import glob
import os
import cv2
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from torchvision.utils import save_image
import pandas as pd
import time
import pickle
%matplotlib inline

to_img= T.Compose([T.ToPILImage()])
to_tensor = T.Compose([T.ToTensor()])
load_norm = T.Compose([T.ToTensor(),T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Parser():
    #hyperparameters
    def __init__(self):
        #image setting
        self.n_epoch = 50
        self.batch_size = 5
        self.h_dim = 64
        self.n_dim = 128
        self.y_dim = 2
        self.D_lr = 0.01
        self.G_lr = 0.002 # 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.img_size = 64
        self.model_save_freq = 1000
        self.img_save_freq = 100
        self.show_freq = 50
        self.model_path = './BEGAN/Model/'
        self.img_path = './BEGAN/Image/' 
        self.conv_dim = 64
        self.d_dim = 64
        self.D_out_dim = 16
        self.train_img_path = "./data/celeba/"
        self.num_res = 5
        self.D_mode = ''
        self.G_mode = ''
        self.L_mode = ''
        self.k = 0
        self.lam = 0.001
        self.gamma = 0.5
        
args = Parser()  

if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)
if not os.path.exists(args.img_path):
    os.makedirs(args.img_path)

# For creating the data
#######
# df = pd.read_csv('./data/celeba/Anno/list_attr_celeba.txt',delim_whitespace=True)
# path = list(df['Path'])
# is_male = list(df['Male'])
# is_male = [1 if i == 1 else 0 for i in is_male]
# with open("Path_Male.pickle", "wb") as fp:   
#     pickle.dump([path,is_male], fp) 

class CelebADataset(Data.Dataset):
    def __init__(self, mode='train', args=None):
        
        self.image_transform = T.Compose([
            T.Resize((args.img_size,args.img_size)),
            #T.RandomResizedCrop(args.img_size, scale=(1.0,1.0)),
            T.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        
        with open("Path_Male.pickle", 'rb') as f:
            self.path ,self.male= pickle.load(f)
            print('Loaded')
        self.path = self.path[:-2000] if mode == 'train' else self.path[-2000:]
        self.male = self.male[:-2000] if mode == 'train' else self.male[-2000:]
        
    def __getitem__(self, index):
        
        idx = index % len(self.path)
        img = self.image_transform(Image.open(os.path.join(args.train_img_path,
                                                             self.path[idx])))
        
        is_male = self.male[idx]
        return img, is_male

    def __len__(self):
        return len(self.path)


In [324]:
dataset = CelebADataset(mode='train',args= args)
training_loader = DataLoader(dataset,batch_size=args.batch_size,shuffle=True,drop_last=True,pin_memory=True)

Loaded


In [117]:
class To_Image(nn.Module):
    def __init__(self,img_size):
        super(To_Image, self).__init__()
        self.img_size = img_size
    def forward(self,x):
        return x.view(args.batch_size,-1, self.img_size, self.img_size)
    
class Conv3x3Block(nn.Module):
    def __init__(self, dim):
        super(Conv3x3Block, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1),
            nn.BatchNorm2d(dim),
            nn.ELU(inplace = True)
        )
    def forward(self, x):
        return self.model(x)


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(args.h_dim, args.n_dim*8*8),
            To_Image(8),
            Conv3x3Block(args.n_dim),
            Conv3x3Block(args.n_dim),
            nn.Upsample(scale_factor = 2),
            Conv3x3Block(args.n_dim),
            Conv3x3Block(args.n_dim),
            nn.Upsample(scale_factor = 2),
            Conv3x3Block(args.n_dim),
            Conv3x3Block(args.n_dim),
            nn.Upsample(scale_factor = 2),
            Conv3x3Block(args.n_dim),
            Conv3x3Block(args.n_dim),
            nn.Conv2d(args.n_dim, 3, 3, 1, 1),
            nn.Tanh() # Output = 64x64 3Channels Image
        )
        
    def forward(self, x):
        return self.model(x)
    
h = torch.Tensor(args.batch_size, args.h_dim)

class Subsmapling(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Subsmapling, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, 1, 1, 0),
            nn.AvgPool2d(2,2),
        )
    def forward(self, x):
        return self.model(x)
    
class Flatten(nn.Module):
    def __init__(self,):
        super(Flatten, self).__init__()
        
    def forward(self, x):
        return x.view(args.batch_size,-1)
    

class Encoder(nn.Module):
    def __init__(self,):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, args.n_dim, 3, 1, 1),
            Conv3x3Block(args.n_dim),
            Conv3x3Block(args.n_dim),
            Subsmapling(args.n_dim, args.n_dim),
            Conv3x3Block(args.n_dim),
            Conv3x3Block(args.n_dim),
            Subsmapling(args.n_dim, args.n_dim*2),
            Conv3x3Block(args.n_dim*2),
            Conv3x3Block(args.n_dim*2),
            Subsmapling(args.n_dim*2, args.n_dim*3),
            Conv3x3Block(args.n_dim*3),
            Conv3x3Block(args.n_dim*3),
            Flatten(),
            nn.Linear(8*8*3*args.n_dim, 64),
        )
        
    def forward(self, x):
        return self.model(x)
        
class Discriminator(nn.Module): ## Photo to Photo reconstruct
    def __init__(self,):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            Encoder(),
            Decoder(),
        )
    def forward(self, x):
        return self.model(x)

In [345]:
class BEGAN(nn.Module):
    def __init__(self,):
        super(BEGAN, self).__init__()
        
        self.D = Discriminator()
        self.G = Decoder()
        
        self.k = 0
        
        self.G_optim = optim.Adam(self.G.parameters(), lr = args.G_lr , betas= (args.b1 ,args.b2))
        self.D_optim = optim.Adam(self.D.parameters(), lr = args.D_lr , betas= (args.b1 ,args.b2))
        
        self.train_hist = {}

        self.train_hist['G_loss'] = []
        self.train_hist['D_loss'] = []
        self.train_hist['M_global'] = []
        
        self.apply(self.weight_init)
        self.progress_photo = []
        
        
    def forward(self,img):
        
        img = img.to(device)
        self.h = torch.Tensor(args.batch_size, args.h_dim).uniform_(-1,1).to(device)
        
        ########## Train D ##########
        
        self.D_optim.zero_grad()
        self.G_img = self.G(self.h).detach()
        D_fake_loss = self.Loss(self.D(self.G_img),self.G_img)
        D_real_loss = self.Loss(self.D(img),img)
        self.D_loss = D_real_loss - self.k * D_fake_loss
        self.train_hist['D_loss'].append(self.D_loss.item())
        self.D_loss.backward()
        self.D_optim.step()
        
        ########## Train G ##########
        
        self.G_optim.zero_grad()
        self.G_img = self.G(self.h)
        self.G_loss = self.Loss(self.D(self.G_img), self.G_img)
        self.train_hist['G_loss'].append(self.G_loss.item())
        self.G_loss.backward()
        self.G_optim.step()
        
        ########## Update K ##########
        
        balance = (args.gamma*D_real_loss - D_fake_loss).item()
        self.k += (args.lam* balance)
        self.k = max(min(self.k,1),0)
        
        ########## Calculate M_global ##########
        
        self.m_global = D_real_loss.item() + (args.gamma*D_real_loss - D_fake_loss)
        self.train_hist['M_global'].append(self.m_global)
        
        self.progress_photo.append(self.G_img[0].detach())
        self.progress_photo = self.progress_photo[-args.img_save_freq:]
        
    def weight_init(self,m):
        if type(m) in [nn.Conv2d, nn.ConvTranspose2d, nn.Linear]:
            #nn.init.xavier_normal_(m.weight,nn.init.calculate_gain('leaky_relu',param=0.02))
            nn.init.kaiming_normal_(m.weight,0.2,nonlinearity='relu')
            
    def Loss(self, x, y):
        return torch.norm(y-x)/ (x.nelement())
    
    def image_save(self, step):
        
        img_save_path = args.img_path + "BEGAN_Step_"+str(step)+".png"
        save_image( torch.stack(self.progress_photo[:args.img_save_freq]), img_save_path , nrow=10, normalize=True, range=(-1,1))
        print('Image saved')  
        
    def model_save(self,step):
        path = args.model_path + 'BEGAN_Step_' + str(step) + '.pth'
        torch.save({'BEGAN':self.state_dict()}, path)
        print('Model saved')
        
    def load_step_dict(self, step):
        
        path = args.model_path + 'BEGAN_Step_' + str(step) + '.pth'
        self.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage)['BEGAN'])
 
    def plot_all_loss(self,step):
        
        fig, ax = plt.subplots(figsize= (20,8))
        for k in self.train_hist.keys():
            plt.plot(self.train_hist[k], label= k)
        plt.ylabel('Loss',fontsize=15)
        plt.xlabel('Number of Steps',fontsize=15)
        plt.title('Loss',fontsize=30,fontweight ="bold")
        plt.legend(loc = 'upper left')
        fig.savefig("BEGAN_Loss_"+str(step)+".png")
        
    def num_all_params(self,):
        return sum([param.nelement() for param in self.parameters()])
    
        

In [346]:
gan = BEGAN().to(device)
epoch = 0
all_steps = 1

In [347]:
while epoch < args.n_epoch:
    for i, (img, y) in enumerate(training_loader):    
        
        start_t = time.time()
        gan(img)
        end_t = time.time()
        
        print('| Step [%d] | lr [%.6f] | D Loss: [%.4f] | G Loss: [%.4f] | M_Global: [%.4f] | Time: %.1fs' %\
              ( all_steps, gan.G_optim.param_groups[0]['lr'], gan.D_loss.item(), gan.G_loss.item(), gan.m_global,
               end_t - start_t))

        if all_steps % args.show_freq == 0: #args.show_freq
            fig=plt.figure(figsize=(8, 8))
            fig.add_subplot(1,3,1)
            plt.imshow(to_img(gan.G_img[0].cpu()*0.5+0.5))
            plt.show()
            if all_steps % args.img_save_freq ==0: # args.img_save_freq
                gan.image_save(all_steps)
                if all_steps % args.model_save_freq == 0: #args.model_save_freq
                    gan.model_save(all_steps)
        all_steps += 1
        if all_steps > 3:
            raise StopIteration
    epoch +=1


| Step [1] | lr [0.002000] | D Loss: [0.0031] | G Loss: [0.0033] | M_Global: [0.0019] | Time: 8.5s
| Step [2] | lr [0.002000] | D Loss: [0.0035] | G Loss: [0.0038] | M_Global: [0.0017] | Time: 8.4s
| Step [3] | lr [0.002000] | D Loss: [0.0038] | G Loss: [0.0039] | M_Global: [0.0018] | Time: 8.4s


StopIteration: 

In [None]:
gan.G_optim.param_groups[0]['lr'] *= 0.5 
gan.D_optim.param_groups[0]['lr'] *= 0.5