In [1]:
from __future__ import print_function
import itertools
import math
import time
from MiddleBlock.DiscriminatorMiddleBlock import DiscriminatorMiddleBlock
from MiddleBlock.GeneratorMiddleBlock import GeneratorMiddleBlock
import numpy as np
import torch
from torch import optim
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.utils as vutils
import torchvision.transforms as transforms
from IPython import display
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
import torchvision.models as models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
nz = 100 # 노이즈 벡터의 크기
nc = 1 # 채널의 수
ngf = 64 # generator 필터 조정
ndf = 64 # discriminator 필터 조정
niter = 200 # 에폭 수
lr = 0.0001
beta1 = 0.9

imageSize = 64 # 만들어지는 이미지의 크기
batchSize = 64 # 미니배치의 크기
outf = "result"

In [3]:
transform = transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor()                    
])

dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform)
train_set, val_set = torch.utils.data.random_split(dataset, [50000, 10000])
train_loader = torch.utils.data.DataLoader(train_set, batch_size= batchSize, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size= batchSize, shuffle=True)

In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:         # Conv weight init
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:  # BatchNorm weight init
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [8]:
class _netG(nn.Module):
    def __init__(self):
        super(_netG, self).__init__()
        self.firstConv = nn.Conv2d(ndf*4, 10+1, 4,1)
        self.randomVec = nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0)
        self.catConv = nn.ConvTranspose2d(ngf*4+ngf*4, ngf*4, 4, 2, 1)
        
        self.conv2 = nn.Conv2d(ngf, nc, 3, padding = 1)
        
        self.main = nn.Sequential(
            GeneratorMiddleBlock((ndf*4, ndf*4, 4,2,1),(ndf*2, ndf*4, 4,2,1)),
            GeneratorMiddleBlock((ndf*2, ndf*4, 4,2,1),(ndf, ndf*2, 4,2,1), batchNorm = True),
            GeneratorMiddleBlock((ndf, ndf*2, 4,2,1),(nc, ndf, 4,2,1), batchNorm = True),
            GeneratorMiddleBlock((nc, ndf, 4,2,1),last = True, img_channel = 1),
        )
        
        
    def forward(self, input, label):
        first = self.firstConv.weight[label]
        
        output = torch.cat([F.relu(self.randomVec(input)), first], dim=1)
        output = self.catConv(output)
        output = self.main(output)
        output = torch.tanh(output)
        return output
        

In [11]:
class _netD(nn.Module):
    def __init__(self):
        super(_netD, self).__init__()
        self.labelClassify = nn.Conv2d(ndf*4, 11, 4, 1, 0)
        self.main = nn.Sequential(
            # (nc) x 64 x 64)
            DiscriminatorMiddleBlock(nc, ndf, 4,2,1, dropout_ratio=0.5,batchNorm = True),
            # (ndf) x 32 x 32
            DiscriminatorMiddleBlock(ndf, ndf*2, 4,2,1, dropout_ratio=0.5, batchNorm = True),
            #(ndf*2) x 16 x 16
            DiscriminatorMiddleBlock(ndf*2, ndf*4, 4,2,1, dropout_ratio=0.5),
            #(ndf*4) x 8 x 8
            DiscriminatorMiddleBlock(ndf*4, ndf*4, 4,2,1) #55
        )
    def forward(self, input):
        output = self.main(input)
        return self.labelClassify(output).view(output.shape[0],-1)

In [27]:
def copyWeight(netD, netG, tau):
    netG.firstConv.load_state_dict(netD.labelClassify.state_dict())
    for i, layer in enumerate(netG.main):
        layer.conv.load_state_dict(netD.main[len(netD.main) - i - 1].conv.state_dict())
        
netG = _netG().to(device)
netG.apply(weights_init)
print(netG)

netD = _netD().to(device)
netD.apply(weights_init)
print(netD)
criterion = nn.CrossEntropyLoss()
copyWeight(netD, netG, 0.05)
fixed_noise = torch.randn(batchSize, nz, 1, 1, device=device)
fixed_noise_label = torch.randint(10, size = (batchSize,), device=device) 

_netG(
  (firstConv): Conv2d(256, 11, kernel_size=(4, 4), stride=(1, 1))
  (randomVec): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1))
  (catConv): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (main): Sequential(
    (0): GeneratorMiddleBlock(
      (branch): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (firstTransposed): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (secondTransposed): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (1): GeneratorMiddleBlock(
      (branch): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv): Conv2d(128, 256, kernel_size=(4, 

In [28]:
optimizerD = optim.RMSprop(netD.parameters(), lr=lr)
optimizerG = optim.RMSprop(netG.parameters(), lr=lr)
writer = SummaryWriter()
total_step = 0

In [29]:
for epoch in range(5):
    data = None
    #netG.train()
    for i, (data,label) in enumerate(train_loader):
        # train with real
        netD.zero_grad()
        batch_size = data.shape[0]
        label = label.to(device)
        output= netD(data.to(device))
        errD_real = criterion(output, label)
        errD_real.backward()
        
        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        
        output = netD(netG(noise, label).detach())
        fake_labels = torch.ones_like(label)*10
        errD_fake = criterion(output, fake_labels)
        errD_fake.backward()
        errD = errD_real + errD_fake
        writer.add_scalar('Discriminator total loss',
                                      errD, total_step)
        optimizerD.step()
        copyWeight(netD, netG,0.05)
        
        netG.zero_grad()
        
        fake = netG(noise, label)
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        writer.add_scalar('Generator total loss',
                                      errG, total_step)

        
        sampleNoise = torch.randn(batch_size, nz, 1, 1, device=device)
        sampleLoss = F.smooth_l1_loss(netG(sampleNoise,label).detach(),netG(noise, label))
        (-torch.log(sampleLoss)).backward()
        optimizerG.step()
        total_step += 1
        if ((i+1) % 100 == 0):
            print(i, "step")
            #print(sampleLoss)
            #netG.eval()
            fake = netG(fixed_noise, fixed_noise_label)
            #netG.train()
            vutils.save_image(fake.data,
                '%s/fake_samples_epoch_%s.png' % (outf, str(epoch)+" "+str(i+1)),
                normalize=True)
    vutils.save_image(data,
            '%s/real_samples.png' % outf,
            normalize=True)
    fake = netG(fixed_noise,fixed_noise_label) 
    vutils.save_image(fake.data,
            '%s/fake_samples_epoch_%s.png' % (outf, epoch),
            normalize=True)

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG.pth' % (outf))
    torch.save(netD.state_dict(), '%s/netD.pth' % (outf))

99 step
199 step
299 step
399 step
499 step
599 step
699 step
99 step
199 step
299 step
399 step
499 step
599 step
699 step
99 step
199 step
299 step
399 step
499 step
599 step
699 step
99 step
199 step
299 step
399 step
499 step
599 step
699 step
99 step
199 step
299 step
399 step
499 step
599 step
699 step


In [31]:
netG.eval()

randomLabel = torch.randint(10, size = (batchSize,), device=device) 
randomLabel.fill_(np.random.randint(10))
fake = netG(fixed_noise, randomLabel)
netG.train()
vutils.save_image(fake.data,
                '%s/test_%s.png' % (outf, str(epoch)+" "+str(i+1)),
                normalize=True)