In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torchvision.datasets as dsets
import torchvision.transforms as transforms

from visdom import Visdom
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from IPython.display import clear_output

In [2]:
viz = Visdom()
viz.close(env='main')

def loss_tracker(loss_plot, loss_value, num):
    viz.line(X=num, Y=loss_value, win=loss_plot, update='append')

Setting up a new session...


In [3]:
class G(nn.Module):
    def __init__(self, noise=128):
        super(G, self).__init__()

        self.flatten = torch.nn.Flatten()
        self.G = torch.nn.Sequential(
            torch.nn.Linear(noise, 256),
            torch.nn.ReLU(True),
            torch.nn.Linear(256, 28*28),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.G(x)

class D(nn.Module):
    def __init__(self):
        super(D, self).__init__()

        self.flatten = torch.nn.Flatten()
        self.D = torch.nn.Sequential(
            torch.nn.Linear(28*28, 256),
            torch.nn.ReLU(True),
            torch.nn.Linear(256, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.D(x)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device == torch.device('cuda'):
    torch.backends.cudnn.benchmark=True
    
Gen = G().to(device)
Dis = D().to(device)

# from torchsummary import summary
# from torchsummaryX import summary as summaryX
from torchsummaryM import summary
summary(Gen, (16, 128), device=device)
summary(Dis, (16, 28*28), device=device)
# print("="*70)
# summaryX(Gen, torch.zeros((100, 128)))
# summaryX(Dis, torch.zeros((100, 28*28)))
# print("="*70)
# summaryM(Gen, (128, ), batch_size=100, device=device)
# summaryM(Dis, (28*28, ), batch_size=100, device=device)


g_optimizer = torch.optim.Adam(Gen.parameters(), 2e-4)
d_optimizer = torch.optim.Adam(Dis.parameters(), 2e-4)

----------------------------------------------------------------------------
Layer(type)     ||        Kernel Shape         Output Shape         Param #
G Inputs        ||                   -            [16, 128]               -
                ||                                                         
1> G-G-Linear   ||          [128, 256]            [16, 256]          33,024
2> G-G-ReLU     ||                   -            [16, 256]               0
3> G-G-Linear   ||          [256, 784]            [16, 784]         201,488
4> G-G-Sigmoid  ||                   -            [16, 784]               0
Total params: 234,512
Trainable params: 234,512
Non-trainable params: 0
----------------------------------------------------------------------------
Input size (MB): 0.12
Forward/backward pass size (MB): 0.25
Params size (MB): 0.89
Estimated Total Size (MB): 1.27
----------------------------------------------------------------------------

-------------------------------------------------

In [5]:
# ngpu = torch.cuda.device_count()

# if (device == 'cuda') and (ngpu > 1):
#     Dis = nn.DataParallel(Dis, list(range(ngpu)))

# if (device == 'cuda') and (ngpu > 1):
#     Gen = nn.DataParallel(Gen, list(range(ngpu)))

In [6]:
# Train Dataset Prepare
dataset = dsets.MNIST('./MNIST', 
                    train=True, 
                    transform=transforms.ToTensor(),
                    target_transform=None,
                    download=True)
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=512, shuffle=True)

In [7]:
# train_dataset = dsets.ImageFolder(root="/Users/honeyeob/python_workspace/workspace/Pytorch/GAN/img_align_celeba",
#                 transform=transforms.Compose([
#                     transforms.Grayscale(),
#                     transforms.Resize((28, 28)),
#                     transforms.ToTensor()
#                 ])
# )
# data_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)

In [8]:
# Fixed Noise for Testing
test_noise = Variable(torch.randn(25, 128)).to(device)
image_window = viz.images(torch.randn(25, 1, 28, 28), 
                        opts=dict(title = "Generated Imgs",
                        caption = "Generated Image-{}-{}".format(0, 0)))
loss_plt = viz.line(Y=torch.randn(1, 2).zero_(), 
                    opts=dict(title='Tracking Losses',
                    legend=['D_Loss', 'G_aLoss'], 
                    showlegend=True)
)

In [11]:
total_step = 0
total_batch = len(data_loader)

for epoch in range(200):
    for step, data in enumerate(data_loader):
        images = data[0]
        images = images.to(device)

        # Train D
        noise = Variable(torch.randn(images.size(0), 128))
        noise = noise.to(device)
        fake_images = Gen(noise)
        dis_fake_results = Dis(fake_images)
        dis_real_results = Dis(images.reshape(-1, np.prod(images.shape[1:])))

        d_loss = -torch.mean(torch.log(dis_real_results) + torch.log(1-dis_fake_results))
        Dis.zero_grad()
        d_loss.backward()
        d_optimizer.step()


        # Train G
        noise = Variable(torch.randn(images.size(0), 128))
        noise = noise.to(device)
        fake_images = Gen(noise)
        dis_fake_results = Dis(fake_images)
        D_G_z  = dis_fake_results
        g_loss = - torch.mean(torch.log(dis_fake_results) + 1e-6)

        Gen.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        total_step += 1
        
        # Print & Showing via Visdom
        if (step + 1) % 10 == 0:
            clear_output(wait=True)
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f'% (epoch + 1, 200, step + 1, total_batch, d_loss.data, g_loss.data))
            fake_images = Gen(test_noise)
            fake_images = fake_images.reshape(25, 1, 28, 28)
            loss_tracker(loss_plt, np.column_stack((d_loss.detach().cpu().data, g_loss.detach().cpu().data)), 
                        np.column_stack((torch.Tensor([total_step]), 
                                         torch.Tensor([total_step]))
                        )
            )
            
            image_window = viz.images(fake_images.data,
                                    opts=dict(title = "Generated Imgs",
                                    caption = "Generated Image-{}-{}".format(epoch + 1, step + 1)),
                                    win = image_window
            )
        
        # Image Save
        if (epoch + 1) % 10 == 0 and (step+1) == total_batch:
            fake_images = Gen(test_noise)
            fake_images = fake_images.reshape(25, 1, 28, 28).detach()
            save_image(fake_images.data, './IMGS/generatedimage-%d-%d.png' % (epoch + 1, step + 1))

Epoch [104/200], Step[20/118], d_loss: 0.2935, g_loss: 3.1936


KeyboardInterrupt: 

In [None]:
total_step = 0
total_batch = len(data_loader)

for epoch in range(200):
    epoch_average_g_loss = 0
    epoch_average_d_loss = 0
    
    for step, data in enumerate(data_loader):
        images = data[0]
        images = images.to(device)

        # Train D
        noise = Variable(torch.randn(images.size(0), 128))
        noise = noise.to(device)
        fake_images = Gen(noise)
        dis_fake_results = Dis(fake_images)
        dis_real_results = Dis(images.reshape(-1, np.prod(images.shape[1:])))

        d_loss = -torch.mean(torch.log(dis_real_results) + torch.log(1-dis_fake_results))
        Dis.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Train G
        noise = Variable(torch.randn(images.size(0), 128))
        noise = noise.to(device)
        fake_images = Gen(noise)
        dis_fake_results = Dis(fake_images)
        g_loss = - torch.mean(torch.log(dis_fake_results) + 1e-6)

        Gen.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        total_step += 1
        
        epoch_average_d_loss += d_loss.detach().cpu().data / total_batch
        epoch_average_g_loss += g_loss.detach().cpu().data / total_batch
        
        # Print & Showing via Visdom
        if (step + 1) % 10 == 0:
            clear_output(wait=True)
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f'% (epoch + 1, 200, step + 1, total_batch, d_loss.data, g_loss.data))
            fake_images = Gen(test_noise)
            fake_images = fake_images.reshape(10, 1, 28, 28)
            image_window = viz.images(fake_images.data,
                                    opts=dict(title = "Generated Imgs",
                                    caption = "Generated Image-{}-{}".format(epoch + 1, step + 1)),
                                    win = image_window
            )
        
        # Image Save
        if (epoch + 1) % 10 == 0 and (step+1) == total_batch:
            fake_images = Gen(test_noise)
            fake_images = fake_images.reshape(10, 1, 28, 28).detach()
            save_image(fake_images.data, './IMGS/generatedimage-%d-%d.png' % (epoch + 1, step + 1))
            
    loss_tracker(loss_plt, np.column_stack((epoch_average_d_loss, epoch_average_g_loss)), 
                        np.column_stack((torch.Tensor([epoch]), 
                        torch.Tensor([epoch]))))