In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

plt.style.use('dark_background')

In [2]:
from gan.waifu_data import Waifu
import torch.utils.data as data

batch_size = 64

WaifuLoader = data.DataLoader(
    dataset=Waifu(),
    batch_size=batch_size,
    num_workers=8,
    shuffle=True
)

In [3]:
from gan.dcgan import dcgan
import torch.optim as optim

z_dim = 128
channels = [512, 256, 128, 64, 3]
kernel_size = [4, 5, 4, 4, 4]
stride = [2, 3, 2, 2, 2]
padding = [0, 1, 1, 1, 1]
lr = 2e-4
betas = (0.5, 0.999)

generator, discriminator = dcgan(kernel_size, z_dim, channels, stride, padding)

loss_func = nn.BCELoss()
def get_noise(size):
    return torch.randn(size=(size, z_dim, 1, 1)).cuda()


In [4]:
g_optimizer = optim.Adam(generator.parameters(), lr, betas)
d_optimizer = optim.Adam(discriminator.parameters(), lr, betas)


In [5]:
from torchvision.utils import make_grid
from matplotlib.image import imsave

def show_example(numbers=20, save_img=True, filename=None):
    noise = get_noise(numbers)
    images = generator(noise).detach().cpu()
    grid = make_grid(images, 4).permute(1, 2, 0).numpy()
    
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(grid)
    
def save_example(numbers=20, file_suffix=None):
    noise = get_noise(numbers)
    images = generator(noise).detach().cpu()
    grid = make_grid(images, 4).permute(1, 2, 0).numpy()
    filename = f'generated/waifu_{int(time.time() * 1e3) % 1000000:6}_{file_suffix}.png'
    imsave(filename, grid)
        

In [6]:
import time
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()
real_labels = torch.ones(size=(batch_size, 1)).cuda()
fake_labels = torch.zeros(size=(batch_size, 1)).cuda()

def train(epochs):
    print('Training start')
    time_start = time.time()
    # plot_sep = 50
    # iteration_cnt = 0
    # g_loss_record = []
    # d_real_loss_record = []
    # d_fake_loss_record = []
    # g_loss_cumsum = 0
    # d_real_loss_cumsum = 0
    # d_fake_loss_cumsum = 0
    for e in range(epochs):
        for i, images in enumerate(WaifuLoader):
            if images.shape[0] != batch_size:
                continue
            
            # train generator
            generated_images = generator(get_noise(batch_size))
            d_fake_discrimination = discriminator(generated_images)
            g_loss = loss_func(d_fake_discrimination, real_labels)
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()
    
            # train discriminator
            d_real_discrimination = discriminator(images.cuda())
            d_fake_discrimination = discriminator(generated_images.detach())
            d_real_loss = loss_func(d_real_discrimination, real_labels)
            d_fake_loss = loss_func(d_fake_discrimination, fake_labels)
            d_loss = d_real_loss + d_fake_loss
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()
            
            time_train = time.time() - time_start
            print(f'\repoch: {e+1}, iterations: {i+1}, time: {time_train:.2f} sec', end='')
            
            writer.add_scalar('Loss/generator', g_loss.item())
            writer.add_scalar('Loss/discriminator', d_loss.item())
            # iteration_cnt += 1
            # g_loss_cumsum += g_loss.item()
            # d_real_loss_cumsum += d_real_loss.item()
            # d_fake_loss_cumsum += d_fake_loss.item()
            # if iteration_cnt == plot_sep:
            #     g_loss_record.append(g_loss_cumsum / plot_sep)
            #     d_real_loss_record.append(d_real_loss_cumsum / plot_sep)
            #     d_fake_loss_record.append(d_fake_loss_cumsum / plot_sep)
            #     g_loss_cumsum, d_real_loss_cumsum, d_fake_loss_cumsum, iteration_cnt = 0, 0, 0, 0
        print('')
        save_example(file_suffix=f'epoch{e + 1:0>3}')
    
    # plt.plot(g_loss_record)
    # plt.plot(d_real_loss_record)
    # plt.plot(d_fake_loss_record)
    # plt.legend(('g loss', 'd real loss', 'd fake loss'))

train(epochs=100)




Training start
epoch: 1, iterations: 800, time: 369.81 sec
epoch: 2, iterations: 800, time: 738.67 sec
epoch: 3, iterations: 800, time: 1106.66 sec
epoch: 4, iterations: 800, time: 1473.52 sec
epoch: 5, iterations: 800, time: 1840.43 sec
epoch: 6, iterations: 800, time: 2207.11 sec
epoch: 7, iterations: 800, time: 2574.65 sec
epoch: 8, iterations: 800, time: 2944.30 sec
epoch: 9, iterations: 800, time: 3312.85 sec
epoch: 10, iterations: 800, time: 3680.98 sec
epoch: 11, iterations: 800, time: 4049.10 sec
epoch: 12, iterations: 800, time: 4415.85 sec
epoch: 13, iterations: 800, time: 4786.10 sec
epoch: 14, iterations: 800, time: 5156.09 sec
epoch: 15, iterations: 800, time: 5526.53 sec
epoch: 16, iterations: 800, time: 5899.98 sec
epoch: 17, iterations: 800, time: 6271.11 sec
epoch: 18, iterations: 800, time: 6641.05 sec
epoch: 19, iterations: 800, time: 7010.43 sec
epoch: 20, iterations: 800, time: 7377.28 sec
epoch: 21, iterations: 800, time: 7745.76 sec
epoch: 22, iterations: 800, ti