# 필요한 모듈 불러오기

In [596]:
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision.datasets as dsets
from torch.autograd import Variable
import torch.nn.functional as F

from tqdm import tqdm_notebook
import random
import numpy as np

# 데이터 불러오기

In [558]:
transform = transforms.Compose([          
        transforms.ToTensor(),                     
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform) 
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)

Files already downloaded


In [586]:
class _netG(nn.Module):
    def __init__(self):
        super(_netG, self).__init__()
        
        self.fc1 = nn.Linear(74, 1024)
        self.fc1_bn = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, 7*7*128)
        self.fc2_bn = nn.BatchNorm1d(7*7*128)
        
        self.conv_tr_1 = nn.ConvTranspose2d(128, 64, kernel_size=4,stride=2,padding=1)
        self.conv_tr_1_bn = nn.BatchNorm2d(64)
        self.conv_tr_2 = nn.ConvTranspose2d(64, 1, kernel_size=4,stride=2,padding=1)


        
    def forward(self, x):
        # (100,74) -> (100,1024)
        x = self.fc1_bn(F.leaky_relu(self.fc1(x),0.1))
        
        # (100,1024) -> (100,6272)
        x = self.fc2_bn(F.leaky_relu(self.fc2(x),0.1))
        
        # (100,6272) -> (100,64,4,4)        
        x = x.view(batch_size,128,7,7)
        x = self.conv_tr_1_bn(F.leaky_relu(self.conv_tr_1(x),0.1))

        # (100,64,4,4) -> (100,1,28,28)
        x = self.conv_tr_2(x)
        
        return x

Hout=(Hin−1)∗stride[0]−2∗padding[0]+kernel_size[0]+output_padding[0]

In [584]:
class _netD(nn.Module):
    def __init__(self):
        super(_netD, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 64, kernel_size=4,stride=2,padding=3)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4,stride=2,padding=1)
        self.conv2_bn = nn.BatchNorm2d(128)
        
        self.fc1 = nn.Linear(8192, 1024)
        self.fc1_bn = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, 1)
        
    def forward(self, x):
        # (100,1,28,28) -> (100,64,16,16)
        x = F.leaky_relu(self.conv1(x),0.1)

        # (100,64,16,16) -> (100,128,8,8)
        x = self.conv2_bn(F.leaky_relu(self.conv2(x),0.1))
        
        # (100,128,8,8) -> (100,8192)
        x = x.view(batch_size,-1)
        
        # (100,8192) -> (100,1024)
        x = self.fc1_bn(F.leaky_relu(self.fc1(x),0.1))
        
        # (100,1024) -> (100,1)
        x = self.fc2(x)
        
        return F.sigmoid(x)

In [589]:
class _netQ(nn.Module):
    def __init__(self):
        super(_netQ, self).__init__()

        self.fc2 = nn.Linear(1024, 128)
        self.fc2_bn = nn.BatchNorm1d(128)
        
        self.fc3 = nn.Linear(128, 10)
        self.fc4 = nn.Linear(128,2)
        
    def setType(c_type):
        self.type = c_type

    def forward(self, x):
        # (100,1,28,28) -> (100,64,16,16)
        x = F.leaky_relu(netD.conv1(x),0.1)
        
        # (100,64,16,16) -> (100,128,8,8)
        x = netD.conv2_bn(F.leaky_relu(netD(x),0.1))
        
        # (100,128,8,8) -> (100,8192)
        x = x.view(batch_size,-1)
        
        # (100,8192) -> (100,1024)
        x = netD.fc1_bn(F.leaky_relu(netD.fc1(x),0.1))
        
        # (100,1024) -> (100,128)
        x = F.leaky_relu(self.fc2_bn(netD.fc2(x)),0.1)
        
        print(x.size())


        
        if self.type == "disc":
            x = self.fc3(x)
            return F.softmax(x)
        else:
            # (100,128) -> (100,2)
            x = self.fc4(x)
            return x.mean(dim = 0),x.std()

# 파라미터 초기값 설정

In [581]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

# 네트워크 초기화

In [587]:
netG = _netG()
netG.apply(weights_init)

print(netG)

_netG (
  (fc1): Linear (74 -> 1024)
  (fc1_bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True)
  (fc2): Linear (1024 -> 6272)
  (fc2_bn): BatchNorm1d(6272, eps=1e-05, momentum=0.1, affine=True)
  (conv_tr_1): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv_tr_1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (conv_tr_2): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)


In [588]:
netD = _netD()
netD.apply(weights_init)

print(netD)

_netD (
  (conv1): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(3, 3))
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
  (fc1): Linear (8192 -> 1024)
  (fc1_bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True)
  (fc2): Linear (1024 -> 1)
)


In [590]:
netQ = _netQ()
netQ.apply(weights_init)

print(netQ)

_netQ (
  (fc2): Linear (1024 -> 128)
  (fc2_bn): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True)
  (fc3): Linear (128 -> 10)
  (fc4): Linear (128 -> 2)
)


# 기타 다른 초기값 세팅

In [593]:
batchSize = 100
imageSize = 28

input = torch.FloatTensor(batchSize, 3, imageSize,imageSize)
noise = torch.FloatTensor(batchSize, nz, 1, 1)

label = torch.FloatTensor(batchSize)
real_label = 0.9
fake_label = 0

# Loss 기준 및 Optimizer

In [595]:
criterion = nn.BCELoss()


optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=1e-3, betas=(beta1, 0.999))
optimizerQ = optim.Adam(netQ.parameters(), lr=2e-4, betas=(beta1, 0.999))

# Sampling 함수들

In [567]:
def sample_disc(size):
    return np.random.multinomial(n=1,pvals=[0.1]*10,size=size).astype(float)

In [568]:
def sample_cont(size):
    return np.random.uniform(-1,1,size=size).astype(float)

# loss / score 담을 변수

In [601]:
result_dict = {}
loss_D = []
loss_G = []
score_D = []
score_G = []

In [604]:
## niter = 200
import pickle
for epoch in range(niter):
    for i, (data,_) in enumerate(data_loader):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ############################
        # train with real
        netD.zero_grad()
        label.resize_(batch_size).fill_(real_label)
        
        inputv = Variable(data)
        labelv = Variable(label)
        output = netD(inputv)
        
        errD_real = criterion(output, labelv)
        D_x = output.data.mean()

        # train with fake
        noise.resize_(batch_size, nz).normal_(0, 1)
        
        
        c1 = sample_disc(batch_size).reshape(batch_size,10)
        c1 = torch.FloatTensor(c1)
        c2 = sample_cont(batch_size).reshape(batch_size,1)
        c2 = torch.FloatTensor(c2)
        c3 = sample_cont(batch_size).reshape(batch_size,1)
        c3 = torch.FloatTensor(c3)
        
        c_disc = c1
        c_cont = torch.cat([c2,c3],1)
        c = torch.cat([c_disc,c_cont],1)
        
        noisev = Variable(torch.cat([noise,c],1))
        
        fake = netG(noisev)

        labelv = Variable(label.fill_(fake_label))
        output = netD(fake.detach())
        errD_fake = criterion(output, labelv)

        D_G_z = output.data.mean()
        
        errD = errD_real + errD_fake
        
        errD.backward()
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ############################
        
        netG.zero_grad()
        labelv = Variable(label.fill_(real_label))  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, labelv)
        errG.backward()
        optimizerG.step()
        
        
        ############################
        # (3) Update G network again to ensure that errD doesn't go zero
        ############################
        
        fake = netG(noisev)
        
        netG.zero_grad()
        labelv = Variable(label.fill_(real_label))  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, labelv)
        errG.backward()
        optimizerG.step()

         
        ############################
        # (4) Update Q network       ==> 이 부분과 관련하여 아직 
        ############################
        fake = netG(noisev)
        netQ.setType = "disc"
        Q_c_given_x_cont = netQ(fake)
        
        crossent_loss = torch.mean(-torch.sum(c_cont * torch.log(Q_c_given_x_cont + 1e-8), dim=1))
        ent_loss = torch.mean(-torch.sum(c_cont * torch.log(c_cont + 1e-8), dim=1))
        mi_loss = crossent_loss + ent_loss

        
        

        if i % 100 == 0:
            loss_D.append(errD.data[0])
            loss_G.append(errG.data[0])
            score_D.append(D_x)
            score_G.append(D_G_z1)
            result_dict = {"loss_D":loss_D,"loss_G":loss_G,"score_D":score_D,"score_G":score_G}
            pickle.dump(result_dict,open("result_dict.p","wb"))

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG.pth' % ("results"))
    torch.save(netD.state_dict(), '%s/netD.pth' % ("results"))

RuntimeError: Need input of dimension 4 and input.size[1] == 1 but got input to be of shape: [100 x 64 x 16 x 16] at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/THNN/generic/SpatialConvolutionMM.c:47