In [158]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.utils.data as Data
import torchvision.transforms as T
from glob import glob
import os
import cv2
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from torchvision import datasets
from torchvision.utils import save_image
import torch.nn.functional as F
import numpy as np
import pandas as pd
import time
import pickle
%matplotlib inline



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Parser():
    #hyperparameters
    def __init__(self):
        #image setting
        self.n_epoch = 50
        self.batch_size = 5
        self.z_dim = 64
        self.n_dim = 64
        self.y_dim = 10
        self.D_lr = 0.01
        self.G_lr = 0.002 # 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.img_size = 32
        self.model_save_freq = 1000
        self.img_save_freq = 100
        self.show_freq = 50
        self.model_path = './ACGAN/Model/'
        self.img_path = './ACGAN/Image/' 
        self.conv_dim = 64
        self.d_dim = 64
        self.D_out_dim = 16
        self.train_img_path = "./data/celeba/"
        self.num_res = 5
        self.D_mode = ''
        self.G_mode = ''
        self.L_mode = ''
        self.n_ch = 1
        self.k = 0
        self.lam = 0.001
        self.gamma = 0.5
        
args = Parser()


if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)
if not os.path.exists(args.img_path):
    os.makedirs(args.img_path)

to_img= T.Compose([T.ToPILImage()])
to_tensor = T.Compose([T.ToTensor()])
load_norm = T.Compose([T.Resize((args.img_size,args.img_size)),
                       T.ToTensor(),T.Normalize((0.5, 0.5, 0.5),
                                                (0.5, 0.5, 0.5))])



In [165]:
class To_Image(nn.Module):
    def __init__(self,img_size):
        super(To_Image, self).__init__()
        self.img_size = img_size
    def forward(self,x):
        return x.view(args.batch_size,-1, self.img_size, self.img_size)
    
class UpsampleBlock(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(UpsampleBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_dim, out_dim,3,1,1),
        )
    def forward(self,x):
        return self.model(x)
    
class Generator(nn.Module):
    def __init__(self,):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(args.z_dim + args.y_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Linear(1024, 2*args.n_dim*(args.img_size//4)*(args.img_size//4)),
            To_Image((args.img_size//4)),
            UpsampleBlock(2*args.n_dim, args.n_dim),
            nn.BatchNorm2d(args.n_dim),
            nn.LeakyReLU(0.2, inplace=True),
            UpsampleBlock(args.n_dim, args.n_ch),
            nn.Tanh() 
        )
    def forward(self,y):
        one_hot_y = torch.zeros(args.batch_size, args.y_dim).to(device)
        one_hot_y.scatter_(1,y.unsqueeze(1),1)
        z = torch.randn(args.batch_size, args.z_dim).to(device)
        return self.model(torch.cat([z,one_hot_y],1))
    
class DownSampleBlock(nn.Module):
    def __init__(self,in_dim, out_dim):
        super(DownSampleBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, 4, 2, 1),
            nn.BatchNorm2d(out_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )
    def forward(self,x):
        return self.model(x)
    
    
class Discriminator(nn.Module):
    def __init__(self,):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            DownSampleBlock(args.n_ch, args.n_dim),
            DownSampleBlock(args.n_dim, args.n_dim*2),
            DownSampleBlock(args.n_dim*2, args.n_dim*4),
            nn.Conv2d(args.n_dim*4,args.n_dim*8,4,1,0)
        )
        
        self.C_ = nn.Linear(512,1)
        self.Q_ = nn.Linear(512,10)
        
    def forward(self,x):
        
        
        shared_tensor = self.model(x).view(args.batch_size,-1)
        C = F.sigmoid(self.C_(shared_tensor))
        Q = self.Q_(shared_tensor)
        
        return C,Q
    
    
    
    
class ACGAN(nn.Module):
    def __init__(self,):
        super(ACGAN, self).__init__()
        
        self.D = Discriminator()
        self.G = Generator()
        
        self.G_optim = optim.Adam(self.G.parameters(), lr = args.G_lr , betas= (args.b1 ,args.b2))
        self.D_optim = optim.Adam(self.D.parameters(), lr = args.D_lr , betas= (args.b1 ,args.b2))
        
        self.BCE = nn.BCELoss()
        self.CE = nn.CrossEntropyLoss()
        
        self.train_hist = {}

        self.train_hist['G_loss'] = []
        self.train_hist['D_loss'] = []
        
        self.apply(self.weight_init)
        self.progress_photo = []
        
        self.real_label = torch.ones(args.batch_size,1).to(device)
        self.fake_label = torch.zeros(args.batch_size,1).to(device)
        
        
    def forward(self,img,y):
        
        img = img.to(device)
        y = y.to(device)
        
        
        ######### Train D #########
        
        self.D_optim.zero_grad()
        
        D_real, C_real = self.D(img)
        D_real_loss = self.BCE(D_real, self.real_label)
        C_real_loss = self.CE(C_real, y)
        
        self.G_img = self.G(y).detach()
        D_fake, C_fake = self.D(self.G_img)
        D_fake_loss = self.BCE(D_fake, self.fake_label)
        C_fake_loss = self.CE(C_fake, y)
        
        self.D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss
        self.train_hist['D_loss'].append(self.D_loss.item())
        self.D_loss.backward()
        self.D_optim.step()
        
        ######### Train G #########
        
        self.G_optim.zero_grad()
        
        self.G_img = self.G(y)
        D_fake, C_fake = self.D(self.G_img)
        
        D_fake_loss = self.BCE(D_fake, self.real_label)
        C_fake_loss = self.CE(C_fake, y)
        
        self.G_loss = D_fake_loss + C_fake_loss
        self.train_hist['G_loss'].append(self.G_loss.item())
        self.G_loss.backward()
        self.G_optim.step()
        
        ######### Adjust lr ######### 
        
        self.G_scheduler.step()
        self.D_scheduler.step()
        
        
        self.progress_photo.append(self.G_img[0].detach())
        self.progress_photo = self.progress_photo[-args.img_save_freq:]
        
        
    def weight_init(self,m):
        if type(m) in [nn.Conv2d, nn.ConvTranspose2d, nn.Linear]:
            #nn.init.xavier_normal_(m.weight,nn.init.calculate_gain('leaky_relu',param=0.02))
            nn.init.kaiming_normal_(m.weight,0.2,nonlinearity='leaky_relu')
            
    def image_save(self, step):
        
        img_save_path = args.img_path + "ACGAN_Step_"+str(step)+".png"
        save_image( torch.stack(self.progress_photo[:args.img_save_freq]), img_save_path , nrow=10, normalize=True, range=(-1,1))
        print('Image saved')  
        
    def model_save(self,step):
        path = args.model_path + 'ACGAN_Step_' + str(step) + '.pth'
        torch.save({'ACGAN':self.state_dict()}, path)
        print('Model saved')
        
    def load_step_dict(self, step):
        
        path = args.model_path + 'ACGAN_Step_' + str(step) + '.pth'
        self.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage)['ACGAN'])
 
    def plot_all_loss(self,step):
        
        fig, ax = plt.subplots(figsize= (20,8))
        for k in self.train_hist.keys():
            plt.plot(self.train_hist[k], label= k)
        plt.ylabel('Loss',fontsize=15)
        plt.xlabel('Number of Steps',fontsize=15)
        plt.title('Loss',fontsize=30,fontweight ="bold")
        plt.legend(loc = 'upper left')
        fig.savefig("ACGAN_Loss_"+str(step)+".png")
        
    def num_all_params(self,):
        return sum([param.nelement() for param in self.parameters()])


In [166]:
training_loader= DataLoader(datasets.MNIST('./data/mnist',
                                        train=True, download=True,
                                        transform=load_norm),
                         batch_size= args.batch_size, shuffle=True)

In [169]:
gan = ACGAN().to(device)
epoch = 0
all_steps = 1

In [None]:
G_scheduler = optim.lr_scheduler.StepLR(gan.G_optim,10000,0.5)
D_scheduler = optim.lr_scheduler.StepLR(gan.D_optim,10000,0.5)

In [None]:
while epoch < args.n_epoch:
    for i, (img, y) in enumerate(training_loader):    
        
        start_t = time.time()
        gan(img,y)
        end_t = time.time()
        
#         G_scheduler.step()
#         D_scheduler.step()
        
        print('| Step [%d] | lr [%.6f] | D Loss: [%.4f] | G Loss: [%.4f] | Time: %.1fs' %\
              ( all_steps, gan.G_optim.param_groups[0]['lr'], gan.D_loss.item(), gan.G_loss.item(),
               end_t - start_t))


        if all_steps % args.show_freq == 0: #args.show_freq
            fig=plt.figure(figsize=(8, 8))
            fig.add_subplot(1,3,1)
            plt.imshow(to_img(gan.G_img[0].cpu()*0.5+0.5))
            plt.show()
            if all_steps % args.img_save_freq ==0: # args.img_save_freq
                gan.image_save(all_steps)
                gan.plot_all_loss('Training')
                if all_steps % args.model_save_freq == 0: #args.model_save_freq
                    gan.model_save(all_steps)
        all_steps += 1
        if all_steps > 5000:
            raise StopIteration
    epoch +=1


| Step [1] | lr [0.002000] | D Loss: [6.9294] | G Loss: [49.2353] | Time: 0.6s
| Step [2] | lr [0.002000] | D Loss: [166.3793] | G Loss: [74.9197] | Time: 0.3s
| Step [3] | lr [0.002000] | D Loss: [91.3975] | G Loss: [78.4208] | Time: 0.4s
| Step [4] | lr [0.002000] | D Loss: [136.8543] | G Loss: [57.6949] | Time: 0.4s
| Step [5] | lr [0.002000] | D Loss: [98.7267] | G Loss: [93.1205] | Time: 0.3s
| Step [6] | lr [0.002000] | D Loss: [116.3149] | G Loss: [42.0056] | Time: 0.4s
| Step [7] | lr [0.002000] | D Loss: [101.4831] | G Loss: [65.8817] | Time: 0.4s
| Step [8] | lr [0.002000] | D Loss: [54.7411] | G Loss: [56.8676] | Time: 0.3s
| Step [9] | lr [0.002000] | D Loss: [195.0187] | G Loss: [42.7245] | Time: 0.3s
| Step [10] | lr [0.002000] | D Loss: [87.7361] | G Loss: [53.9311] | Time: 0.4s
| Step [11] | lr [0.002000] | D Loss: [111.3234] | G Loss: [52.9931] | Time: 0.4s
| Step [12] | lr [0.002000] | D Loss: [101.7921] | G Loss: [106.1222] | Time: 0.4s
| Step [13] | lr [0.002000] | 

In [None]:
gan.G_optim.param_groups[0]['lr'] *= 0.5 
gan.D_optim.param_groups[0]['lr'] *= 0.5