In [None]:
import numpy as np 
from PIL import Image 
import matplotlib.pyplot as plt 
import os 
import torch as T 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision.transforms as trans
from torch.utils.data import Dataset, DataLoader 
import time 

# Clobal Variable

In [None]:
mode = 'bgan' # 'gan', 'lsgan', 'rsgan', 'bgan'

path_image = "/home/wwang/datasets/celeba/img_align_celeba/"
crop_size = 178
image_size = 64
channel_g = 64
channel_d = 64
row = 2
col = 4
latent_dim = 100
device = T.device("cuda:0")
fixed_noise = T.randn([row*col, latent_dim, 1, 1]).to(device)
batch_size = 128
epochs = 200 #
learning_rate_d = 2e-4
learning_rate_g = 2e-4
step_g = 1
path_work = '/home/wwang/wwgeneration/work/' + mode + '/'
num_workers = 4

# Data

In [None]:
class Dataset_gan(Dataset):
    #
    def __init__(self, path=path_image):
        self.path = path
        self.dirs = os.listdir(path)
        self.transform = trans.Compose([trans.CenterCrop(crop_size),
                                        trans.Resize(image_size, Image.BICUBIC),
                                        trans.RandomHorizontalFlip(),
                                        trans.ToTensor(),
                                        trans.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                        ])
    #
    def __getitem__(self, idx):
        image = Image.open(self.path + self.dirs[idx])
        image = self.transform(image)
        return image
    #
    def __len__(self):
        return len(self.dirs)

In [None]:
dataloader = DataLoader(Dataset_gan(), batch_size=1)
n = 0
for image in dataloader:
    if n == 2:
        break
    image = image.numpy().squeeze().transpose([1,2,0])
    plt.subplot(1,2,1), plt.imshow(image/2+0.5)
    plt.show()
    n += 1

# Network Structure

## Generator

In [None]:
class CNA(nn.Module):
    #
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, conv, norm, act):
        super().__init__()
        if conv == 'conv':
            self.layers = nn.ModuleList([nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)])
        if conv == 'tconv':
            self.layers = nn.ModuleList([nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)])
        if norm == 'bn':
            self.layers.append(nn.BatchNorm2d(out_channels))
        if act == 'relu':
            self.layers.append(nn.ReLU())
        if act == 'lrelu':
            self.layers.append(nn.LeakyReLU(0.2))
    #
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
class Generator(nn.Module):
    #
    def __init__(self, latent_dim=latent_dim, channel=channel_g):
        super().__init__()
        self.layers = nn.Sequential(
            CNA(latent_dim, channel*8, 4, 1, 0, 'tconv', 'bn', 'relu'),
            CNA(channel*8, channel*4, 4, 2, 1, 'tconv', 'bn', 'relu'),
            CNA(channel*4, channel*2, 4, 2, 1, 'tconv', 'bn', 'relu'),
            CNA(channel*2, channel, 4, 2, 1, 'tconv', 'bn', 'relu'),
            CNA(channel, 3, 4, 2, 1, 'tconv', False, False),
            nn.Tanh()
        )
    #
    def forward(self, x):
        return self.layers(x)

## Discriminator

In [None]:
class Discriminator(nn.Module):
    #
    def __init__(self, channel=channel_d):
        super().__init__()
        self.layers = nn.Sequential(
            CNA(3, channel, 4, 2, 1, 'conv', False, 'lrelu'),
            CNA(channel, channel*2, 4, 2, 1, 'conv', 'bn', 'lrelu'),
            CNA(channel*2, channel*4, 4, 2, 1, 'conv', 'bn', 'lrelu'),
            CNA(channel*4, channel*8, 4, 2, 1, 'conv', 'bn', 'lrelu'),
            CNA(channel*8, 1, 4, 1, 0, 'conv', False, False)
        )
    #
    def forward(self, x):
        return self.layers(x).reshape([-1,1])

# Show

In [None]:
@T.no_grad()
def show(net=None, row=2, col=4, fix=True):
    if net is None:
        net = Generator().to(device)
    net.eval()
    if fix:
        images = net(fixed_noise).cpu().numpy().transpose([0,2,3,1])
    else:
        images = net(T.randn([row*col, latent_dim, 1, 1]).to(device)).cpu().numpy().transpose([0,2,3,1])
    print(np.max(images), np.min(images))
    plt.figure(figsize=(16,8))
    for i in range(row*col):
        image = (images[i,:,:,:] + 1) / 2
        plt.subplot(row, col, i+1), plt.imshow(image)
    plt.show()

In [None]:
show()

# Train

In [None]:
def trainer_d(net_d, optimizer_d, net_g, image_real, mode):
    #
    src_real = net_d(image_real)
    image_fake = net_g(T.randn([image_real.shape[0], latent_dim, 1, 1]).to(device))
    src_fake = net_d(image_fake.detach())
    #
    if mode == 'gan' or mode == 'bgan':
        loss = F.binary_cross_entropy_with_logits(src_real, T.ones_like(src_real)) + F.binary_cross_entropy_with_logits(src_fake, T.zeros_like(src_fake))
    if mode == 'lsgan':
        loss = T.mean((1-src_real)**2) + T.mean(src_fake**2)
    if mode == 'rsgan':
        loss = F.binary_cross_entropy_with_logits(src_real-src_fake, T.ones_like(src_real)) + F.binary_cross_entropy_with_logits(src_fake-src_real, T.zeros_like(src_fake))
    #
    optimizer_d.zero_grad()
    loss.backward()
    optimizer_d.step()

In [None]:
def trainer_g(net_g, optimizer_g, net_d, image_real, mode):
    #
    image_fake = net_g(T.randn([image_real.shape[0], latent_dim, 1, 1]).to(device))
    src_fake = net_d(image_fake)
    #
    if mode == 'gan':
        loss = F.binary_cross_entropy_with_logits(src_fake, T.ones_like(src_fake))
    if mode == 'lsgan':
        loss = T.mean((1-src_fake)**2)
    if mode == 'rsgan':
        src_real = net_d(image_real)
        loss = F.binary_cross_entropy_with_logits(src_fake-src_real, T.ones_like(src_real)) + F.binary_cross_entropy_with_logits(src_real-src_fake, T.zeros_like(src_fake))
    if mode == 'bgan':
        loss = T.mean((T.log(T.sigmoid(src_fake)+1e-8) - T.log(1-T.sigmoid(src_fake)+1e-8)) ** 2) / 2
    #
    optimizer_g.zero_grad()
    loss.backward()
    optimizer_g.step()

In [None]:
def train(epochs, load_model):
    #
    epoch_last = np.load(path_work + 'epoch.npy', allow_pickle=True).item() if load_model else -1
    net_d = Discriminator().to(device)
    net_d.train()
    net_g = Generator().to(device)
    net_g.train()
    optimizer_d = T.optim.Adam(net_d.parameters(), learning_rate_d, betas=(0.5,0.999))
    optimizer_g = T.optim.Adam(net_g.parameters(), learning_rate_g, betas=(0.5,0.999))
    if load_model:
        net_d.load_state_dict(T.load(path_work + 'net_d.pt'))
        net_g.load_state_dict(T.load(path_work + 'net_g.pt'))
    #
    iteration = 0
    time_start = time.time()
    for epoch in range(epoch_last+1, epochs):
        dataloader = DataLoader(Dataset_gan(), batch_size=batch_size, shuffle=True, num_workers=num_workers)
        for data in dataloader:
            image_real = data.to(device)
            trainer_d(net_d, optimizer_d, net_g, image_real, mode)
            iteration += 1
            if iteration % step_g == 0:
                trainer_g(net_g, optimizer_g, net_d, image_real, mode)
    #
            if iteration % 1000 == 0:
                print('Epoch: ', epoch, ', Iteration: ', iteration)
                show(net_g)
                net_g.train()
        T.save(net_d.state_dict(), path_work + 'net_d.pt')
        T.save(net_g.state_dict(), path_work + 'net_g.pt')
        T.save(net_d.state_dict(), path_work + 'net_d_backup.pt')
        T.save(net_g.state_dict(), path_work + 'net_g_backup.pt')
        np.save(path_work + 'epoch.npy', epoch) 
        print('This epoch costs {} seconds.'.format(time.time() - time_start))
        time_start = time.time()

In [None]:
# train(epochs=epochs, load_model=False)

In [None]:
train(epochs=epochs, load_model=True)

# Test

In [None]:
net_g = Generator().to(device)
net_g.load_state_dict(T.load(path_work + 'net_g.pt'))
show(net_g, fix=False)