In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, TensorDataset

import gzip
import pickle
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from pathlib import Path

from CGAN_MNIST import NetG, NetD
import config

In [2]:
device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
DATA_PATH = '../data/MNIST/mnist.pkl.gz'
BATCH_SIZE = config.BATCH_SIZE
MAX_EPOCH = config.MAX_EPOCH
latent_dim = config.latent_dim

In [3]:
def get_sample_image(G, n_noise=latent_dim):
    """
        save sample 100 images
    """
    img = np.zeros([280, 280])
    for j in range(10):
        c = torch.zeros([10, 10]).to(device)
        c[:, j] = 1
        z = torch.randn(10, n_noise).to(device)
        
        y_hat = G(torch.cat((z,c),dim=1)).view(10, 28, 28)
        result = y_hat.cpu().data.numpy()
        img[j*28:(j+1)*28] = np.concatenate([x for x in result], axis=-1)
    return img

In [4]:
with gzip.open(DATA_PATH, 'rb') as mnist:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(mnist, encoding='latin-1')

In [5]:
x_train, y_train, x_valid, y_valid = map(torch.detach,
                                                        map(torch.Tensor,
                                                           (x_train, y_train, x_valid, y_valid)))

In [6]:
x_train = (x_train-0.5)/0.5
x_valid = (x_valid-0.5)/0.5
ds_train = TensorDataset(x_train, y_train)
ds_valid = TensorDataset(x_valid, y_valid)

In [7]:
dl_train = DataLoader(
    dataset = ds_train,              
    batch_size = BATCH_SIZE,
    num_workers=8,
    shuffle = True,       
    drop_last = True
)

In [8]:
dl_valid = DataLoader(
    dataset = ds_valid,              
    batch_size = BATCH_SIZE,         
    shuffle = True,       
    drop_last = False
)

In [9]:
def model_init():
    netG, netD = NetG().to(device), NetD().to(device)
    # Adam 修改一下Momentum的参数，0.5等于focus on最近的2次迭代，默认的beta1 = 0.9，focus on最近10次有点太长了。
    optG = torch.optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optD = torch.optim.Adam(netD.parameters(), lr=0.0001, betas=(0.5, 0.999))
    return netG, optG, netD, optD

In [10]:
def fit(G, D, optG, optD, criterion, start_epoch = 0, max_epoch = MAX_EPOCH):
    for epoch in range(start_epoch, max_epoch):
        for i, (x_gt, y_gt) in enumerate(dl_train):
            # 将x,y转好格式
            x_gt = x_gt.to(device) # (1024,784)float32
            y_gt = F.one_hot(y_gt.to(torch.long)).to(torch.float).to(device) # (1024,10) float32,因为输入要cat，与z保持同类型
            
            # 固定G，训练D
#             G.eval()
#             D.train()              
            # 生成Gaussian噪声
            z = torch.randn((len(y_gt), latent_dim)).to(device) # (1024,100) float32
   
            out_x = D(torch.cat((x_gt, y_gt), dim=1))
            loss_x = criterion(out_x, real_labels)  
        
            x_z = G(torch.cat((z, y_gt), dim=1)) # (1024,784) float32
            out_z = D(torch.cat((x_z.detach(), y_gt), dim=1)) # (1024,1) float32 detach防止误差传播到G
            loss_z = criterion(out_z, fake_labels)
            loss_D = loss_z + loss_x
            
            optD.zero_grad()
            loss_D.backward()
            optD.step()
            
            # 固定G，训练D
            G.train()
            D.eval()
            # 生成Gaussian噪声
            z = torch.randn((len(y_gt), latent_dim)).to(device) # (1024,100) float32
            x_z = G(torch.cat((z, y_gt), dim=1)) # (1024,784) float32
            out_z = D(torch.cat((x_z, y_gt), dim=1)) # (1024,1) float32
            loss_G = criterion(out_z, real_labels)
            
            optG.zero_grad()
            loss_G.backward()
            optG.step()
        
        
        if (epoch+1) % 10 == 0:
            G.eval()
            print('Epoch:[%d/%d], Loss_G = %f, Loss_D = %f.\n' %(epoch+1, max_epoch, loss_G.item(), loss_D.item()))
            img = get_sample_image(G)
#             imsave('samples/{}_epoch{}.jpg'.format(MODEL_NAME, str(epoch+1).zfill(3)), img, cmap='gray')
            cv2.imwrite('samples/%d.jpg' % (epoch+1), img*255)

In [11]:
criterion = F.binary_cross_entropy #predict使用onehot过的标签作为输入，但是计算损失还是只考虑是否为数字
G, optG, D, optD = model_init()
real_labels = torch.ones(BATCH_SIZE, 1).to(device)
fake_labels = torch.zeros(BATCH_SIZE, 1).to(device)

Linear(in_features=110, out_features=128, bias=True)
Linear(in_features=128, out_features=256, bias=True)
Linear(in_features=256, out_features=512, bias=True)
Linear(in_features=512, out_features=1024, bias=True)
Linear(in_features=1024, out_features=784, bias=True)
Linear(in_features=794, out_features=512, bias=True)
Linear(in_features=512, out_features=512, bias=True)
Linear(in_features=512, out_features=512, bias=True)
Linear(in_features=512, out_features=1024, bias=True)
Linear(in_features=1024, out_features=1, bias=True)


In [12]:
fit(G, D, optG, optD, criterion)

Epoch:[10/500], Loss_G = 1.337309, Loss_D = 0.723467.

Epoch:[20/500], Loss_G = 2.706451, Loss_D = 1.151028.

Epoch:[30/500], Loss_G = 2.192599, Loss_D = 0.692334.

Epoch:[40/500], Loss_G = 2.291567, Loss_D = 0.420434.

Epoch:[50/500], Loss_G = 2.378611, Loss_D = 0.413174.

Epoch:[60/500], Loss_G = 4.038810, Loss_D = 0.592034.

Epoch:[70/500], Loss_G = 2.713987, Loss_D = 0.563897.

Epoch:[80/500], Loss_G = 2.529819, Loss_D = 0.603097.

Epoch:[90/500], Loss_G = 1.953907, Loss_D = 0.608995.

Epoch:[100/500], Loss_G = 2.584807, Loss_D = 0.695384.

Epoch:[110/500], Loss_G = 2.466191, Loss_D = 0.716644.

Epoch:[120/500], Loss_G = 3.073211, Loss_D = 1.051528.

Epoch:[130/500], Loss_G = 1.256303, Loss_D = 0.854432.

Epoch:[140/500], Loss_G = 1.942178, Loss_D = 0.756161.

Epoch:[150/500], Loss_G = 2.227684, Loss_D = 0.854831.

Epoch:[160/500], Loss_G = 1.244714, Loss_D = 0.815726.

Epoch:[170/500], Loss_G = 1.105452, Loss_D = 0.899203.

Epoch:[180/500], Loss_G = 1.027015, Loss_D = 0.960187.

E

In [13]:
img = get_sample_image(G)

In [14]:
with gzip.open(DATA_PATH, 'rb') as mnist:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(mnist, encoding='latin-1')
x_train, y_train, x_valid, y_valid = map(torch.detach,
                                                        map(torch.Tensor,
                                                           (x_train, y_train, x_valid, y_valid)))

In [15]:
x_train[1]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 