<a href="https://colab.research.google.com/github/ShawnDong98/GAN/blob/master/VAE-GAN/VAE-GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from __future__ import print_function
import argparse
import h5py
import numpy as np
import os
import time
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image
import torchvision.utils as vutils

In [2]:
torch.cuda.empty_cache()

In [3]:
%matplotlib inline
import matplotlib.pyplot as plt

def show_and_save(file_name,img):
    npimg = np.transpose(img.numpy(),(1,2,0))
    f = "./%s.png" % file_name
    fig = plt.figure(dpi=300)
    fig.suptitle(file_name, fontsize=14, fontweight='bold')
    #plt.imshow(npimg)
    plt.imsave(f,npimg)
    
def save_model(epoch, encoder, decoder, D):
    state = {'encoder': encoder.cpu().state_dict(),
             'decoder': decoder.cpu().state_dict(),
             'D':D.cpu().state_dict(),
             'epoch': epoch,}
    
    torch.save(state, 'VAE_GAN_%d.pth' % epoch)
    
    encoder.cuda()
    decoder.cuda()
    D.cuda()
    
def load_model(G, D, state):
    #  restore models
    G.decoder.load_state_dict(state['decoder'])
    G.decoder.cuda()
    G.encoder.load_state_dict(state['encoder'])
    G.encoder.cuda()
    D.load_state_dict(state['D'])
    D.cuda()

In [4]:
# define constant
input_channels = 3
hidden_size = 64
max_epochs = 250

In [72]:
batch_size = 4
T = transforms.Compose([transforms.Resize(64),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                       ])

dataset = datasets.ImageFolder(
    root = "../data/own_data/", 
    transform=T,
)
    
    


train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size, shuffle=True)

# train_loader = torch.utils.data.DataLoader(
#     datasets.MNIST(
#         "../data",
#         train=True,
#         download=True,
#         transform=transforms.Compose(
#             [transforms.Resize(64), 
#              transforms.ToTensor(),
#              transforms.Normalize([0.5], [0.5])
#             ]
#         ),
#     ),
#     batch_size=batch_size,
#     shuffle=True,
# )

In [6]:
class Encoder(nn.Module):
    def __init__(self, input_channels, output_channels, representation_size = 64):
        super(Encoder, self).__init__()
        # input channels: 3
        # output channels = hidden size : 64
        self.input_channels = input_channels
        self.output_channels = output_channels
        
        self.features = nn.Sequential(
            # nc x 64 x 64
            # 3 -> 64
            nn.Conv2d(self.input_channels, representation_size, 5, stride=2, padding=2),
            nn.BatchNorm2d(representation_size),
            nn.ReLU(),
            # hidden_size x 32 x 32
            # 64 -> 128
            nn.Conv2d(representation_size, representation_size*2, 5, stride=2, padding=2),
            nn.BatchNorm2d(representation_size * 2),
            nn.ReLU(),
            # hidden_size*2 x 16 x 16
            # 128 -> 256
            nn.Conv2d(representation_size*2, representation_size*4, 5, stride=2, padding=2),
            nn.BatchNorm2d(representation_size * 4),
            nn.ReLU())
            # hidden_size*4 x 8 x 8
            
        self.mean = nn.Sequential(
            nn.Linear(representation_size*4*8*8, output_channels)
        )
        
        self.logvar = nn.Sequential(
            nn.Linear(representation_size*4*8*8, output_channels)
        )
        
    def forward(self, x):
        batch_size = x.size()[0]

        hidden_representation = self.features(x)

        mean = self.mean(hidden_representation.view(batch_size, -1))
        logvar = self.logvar(hidden_representation.view(batch_size, -1))

        return mean, logvar
    
    def hidden_layer(self, x):
        batch_size = x.size()[0]
        output = self.features(x)
        return output

In [7]:
class Decoder(nn.Module):
    def __init__(self, input_size, representation_size):
        super(Decoder, self).__init__()
        self.input_size = input_size
        self.representation_size = representation_size
        dim = representation_size[0] * representation_size[1] * representation_size[2]
        
        self.preprocess = nn.Sequential(
            nn.Linear(input_size, dim),
            nn.BatchNorm1d(dim),
            nn.ReLU())
        
            # 256 x 8 x 8
        self.deconv1 = nn.ConvTranspose2d(representation_size[0], 128, 5, stride=2, padding=2)
        self.act1 = nn.Sequential(nn.BatchNorm2d(128),
                                  nn.ReLU())
            # 256 x 16 x 16
        self.deconv2 = nn.ConvTranspose2d(128, 32, 5, stride=2, padding=2)
        self.act2 = nn.Sequential(nn.BatchNorm2d(32),
                                  nn.ReLU())
            # 128 x 32 x 32
        self.deconv3 = nn.ConvTranspose2d(32, input_channels, 5, stride=2, padding=2)
        self.act3 = nn.Sequential(nn.BatchNorm2d(input_channels),
                                  nn.ReLU())
            # 3 x 64 x 64
        self.activation = nn.Tanh()
            
    
    def forward(self, code):
        bs = code.size()[0]
        preprocessed_codes = self.preprocess(code)
        preprocessed_codes = preprocessed_codes.view(-1,
                                                     self.representation_size[0],
                                                     self.representation_size[1],
                                                     self.representation_size[2])
        output = self.deconv1(preprocessed_codes, output_size=(bs, 256, 16, 16))
        output = self.act1(output)
        output = self.deconv2(output, output_size=(bs, 128, 32, 32))
        output = self.act2(output)
        output = self.deconv3(output, output_size=(bs, 32, 64, 64))
        output = self.act3(output)
        output = self.activation(output)
        return output

In [8]:
class VAE_GAN_Generator(nn.Module):
    def __init__(self, input_channels, hidden_size, representation_size=(256, 8, 8)):
        super(VAE_GAN_Generator, self).__init__()
        self.input_channels = input_channels
        self.hidden_size = hidden_size
        self.representation_size = representation_size
        
        self.encoder = Encoder(input_channels, hidden_size)
        self.decoder = Decoder(hidden_size, representation_size)
        
    def forward(self, x):
        batch_size = x.size()[0]
        mean, logvar = self.encoder(x)
        std = logvar.mul(0.5).exp_()
        
        reparametrized_noise = Variable(torch.randn((batch_size, self.hidden_size))).cuda()

        reparametrized_noise = mean + std * reparametrized_noise

        rec_images = self.decoder(reparametrized_noise)
        
        return mean, logvar, rec_images

In [9]:
class Discriminator(nn.Module):
    def __init__(self, input_channels, representation_size=(256, 8, 8)):  
        super(Discriminator, self).__init__()
        self.representation_size = representation_size
        dim = representation_size[0] * representation_size[1] * representation_size[2]
        
        self.main = nn.Sequential(
            # 3x64x64 -> 32x32x32
            nn.Conv2d(input_channels, 32, 5, stride=2, padding=2),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            # 32x32x32 -> 128x16x16
            nn.Conv2d(32, 128, 5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # 128x16x16 -> 256x8x8
            nn.Conv2d(128, 256, 5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
        )
        
        self.lth_features = nn.Sequential(
            nn.Linear(dim, 2048),
            nn.LeakyReLU(0.2))
        
        self.sigmoid_output = nn.Sequential(
            nn.Linear(2048, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        batch_size = x.size()[0]
        features = self.main(x)
        lth_rep = self.lth_features(features.view(batch_size, -1))
        output = self.sigmoid_output(lth_rep)
        return output
    
    def similarity(self, x):
        batch_size = x.size()[0]
        features = self.main(x)
        lth_rep = self.lth_features(features.view(batch_size, -1))
        return lth_rep

In [10]:
lr = 3e-4

beta = 5
alpha = 0.1
gamma = 15

In [11]:
G = VAE_GAN_Generator(input_channels, hidden_size).cuda()
D = Discriminator(input_channels).cuda()

criterion = nn.BCELoss()
criterion.cuda()

opt_enc = optim.RMSprop(G.encoder.parameters(), lr=lr)
opt_dec = optim.RMSprop(G.decoder.parameters(), lr=lr)
opt_dis = optim.RMSprop(D.parameters(), lr=lr * alpha)

# opt_enc =  optim.Adam(G.encoder.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# opt_dec =  optim.Adam(G.decoder.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# opt_dis =  optim.Adam(D.parameters(), lr=lr * alpha, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

In [12]:
fixed_noise = Variable(torch.randn(batch_size, hidden_size)).cuda()
data, _ = next(iter(train_loader))
fixed_batch = Variable(data).cuda()
fixed_batch = fixed_batch.resize_(fixed_batch.shape[0], input_channels, 64,  64)

In [60]:
import torch
from PIL import Image
import matplotlib.pyplot as plt
import cv2

# loader使用torchvision中自带的transforms函数
loader = transforms.Compose([
    transforms.ToTensor()])  

unloader = transforms.ToPILImage()

def imshow(tensor, title=None):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    print(image.size())
    image = image.squeeze(0)  # remove the fake batch dimension
    print(image.size())
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated
    
def toTensor(img):
    assert type(img) == np.ndarray,'the img type is {}, but ndarry expected'.format(type(img))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = torch.from_numpy(img.transpose((2, 0, 1)))
    return img.float().div(255).unsqueeze(0) 
    
def tensor_to_np(tensor):
    img = tensor.mul(255).byte()
    img = img.cpu().squeeze(0).numpy().transpose((1, 2, 0))
    return img
    
def show_from_tensor(tensor, title=None):
    img = tensor.clone()
    img = tensor_to_np(img)
    plt.figure()
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

In [74]:
state = {'epoch': 0}
try:
    print("loading model...")
    state = torch.load('VAE_GAN_%d.pth' % 50)
    load_model(G, D, state)
except:
    print("No checkpoint...")
for epoch in range(state['epoch']+1, max_epochs):
    D_real_list, D_rec_enc_list, D_rec_noise_list, D_list = [], [], [], []
    g_loss_list, rec_loss_list, prior_loss_list = [], [], []
    for data, _ in train_loader:
        batch_size = data.size()[0]
        ones_label =  torch.full((batch_size, 1), 1, device=torch.device("cuda:0"))
        zeros_label = torch.full((batch_size, 1), 0, device=torch.device("cuda:0"))
        

        datav = Variable(data).cuda()
        
        mean, logvar, rec_enc = G(datav)
        #print ("The size of rec_enc:", rec_enc.size())
        
        noisev = Variable(torch.randn(batch_size, hidden_size)).cuda()
        rec_noise = G.decoder(noisev)
        
        # train discriminator
        # L_GAN = log Dis(x) + log Dis(Dec(z)) + log D(Dec(Enc(x)))
        output = D(datav)
        # log(Dis(x))
        errD_real =  criterion(output, ones_label)
        D_real_list.append(output.data.mean().cpu())
        output = D(rec_enc)
        # log (Dis(Dec(Enc(x)))
        # print(output)
        errD_rec_enc =  criterion(output, zeros_label)
        D_rec_enc_list.append(output.data.mean().cpu())
        output = D(rec_noise)
        # log (Dis(Dec(z)))
        errD_rec_noise =  criterion(output, zeros_label)
        D_rec_noise_list.append(output.data.mean().cpu())
        
        # L_GAN = log Dis(x) + log Dis(Dec(z)) + log D(Dec(Enc(x)))
        dis_img_loss = errD_real + errD_rec_enc + errD_rec_noise
        D_list.append(dis_img_loss.data.mean().cpu())
        opt_dis.zero_grad()
        dis_img_loss.backward(retain_graph=True)
        opt_dis.step()
        
        # train decoder
        # gamma * L_llike^Disl - L_GAN
        # L_llike^Disl = \sum q(z|x) log p(Dis(x)|z)
        # 也就是x送入Discriminator和recon_x送入Discriminator的feature的MSE
        output = D(datav)
        # log Dis(x)
        errD_real =  criterion(output, ones_label)
        output = D(rec_enc)
        # log Dis(Dec(Enc(x)))
        errD_rec_enc =  criterion(output, zeros_label)
        output = D(rec_noise)
        # log Dis(Dec(z))
        errD_rec_noise =  criterion(output, zeros_label)
                              
        # recon_x
        similarity_rec_enc = D.similarity(rec_enc)
        # x
        similarity_data = D.similarity(datav)
        
        # L_GAN
        dis_img_loss = errD_real + errD_rec_enc + errD_rec_noise
        # -L_GAN
        gen_img_loss = - dis_img_loss
        
        g_loss_list.append(gen_img_loss.data.mean().cpu())
        # L_llike^Disl: MSE loss
        rec_loss = ((similarity_rec_enc - similarity_data) ** 2).mean()
        rec_loss_list.append(rec_loss.data.mean().cpu())
        # L_Dec = gamma * L_llike^Disl - L_GAN
        err_dec = gamma * rec_loss + gen_img_loss
        
        opt_dec.zero_grad()
        err_dec.backward(retain_graph=True)
        opt_dec.step()
        
        # train encoder
        prior_loss = 1 + logvar - mean.pow(2) - logvar.exp()
        prior_loss = (-0.5 * torch.sum(prior_loss))/torch.numel(mean.data)
        #print (prior_loss, mean, std)
        prior_loss_list.append(prior_loss.data.mean().cpu())
        err_enc = prior_loss + beta * rec_loss
        
        opt_enc.zero_grad()
        err_enc.backward()
        opt_enc.step()
        
#         print("running...")
        
    if epoch % 10 == 0:
        save_model(epoch, G.encoder, G.decoder, D)
    _, _, rec_imgs = G(fixed_batch)
    show_and_save('image/rec_epoch_%d.png' % epoch ,make_grid((rec_imgs.data*0.5+0.5).cpu(),8))
    '''
    vutils.save_image(rec_imgs.data,
            'rec_epoch_%d.png' % epoch,
            normalize=True)
    '''
    samples = G.decoder(fixed_noise)
    vutils.save_image(samples.data,
            'image/sample_epoch_%d.png' % epoch,
            normalize=True)
    localtime = time.asctime( time.localtime(time.time()))
    print (localtime)
    print ('[%d/%d]: D_real:%.4f, D_enc:%.4f, D_noise:%.4f, Loss_D:%.4f, Loss_G:%.4f, rec_loss:%.4f, prior_loss:%.4f' 
           % (epoch, 
              max_epochs, 
              np.mean(D_real_list), 
              np.mean(D_rec_enc_list), 
              np.mean(D_rec_noise_list), 
              np.mean(D_list), 
              np.mean(g_loss_list),
              np.mean(rec_loss_list),
              np.mean(prior_loss_list)))

loading model...
No checkpoint...
Thu Mar 19 00:47:51 2020
[1/250]: D_real:0.9716, D_enc:0.0242, D_noise:0.0069, Loss_D:0.0887, Loss_G:-0.0210, rec_loss:0.7357, prior_loss:16.9756
Thu Mar 19 00:49:09 2020
[2/250]: D_real:0.9773, D_enc:0.0248, D_noise:0.0013, Loss_D:0.1022, Loss_G:-0.0413, rec_loss:0.9387, prior_loss:612.4183
Thu Mar 19 00:50:27 2020
[3/250]: D_real:0.9904, D_enc:0.0117, D_noise:0.0000, Loss_D:0.0370, Loss_G:-0.0059, rec_loss:1.3270, prior_loss:2.5347
Thu Mar 19 00:51:44 2020
[4/250]: D_real:0.9895, D_enc:0.0116, D_noise:0.0000, Loss_D:0.0676, Loss_G:-0.0321, rec_loss:1.5391, prior_loss:3.3592
Thu Mar 19 00:53:01 2020
[5/250]: D_real:0.9961, D_enc:0.0057, D_noise:0.0000, Loss_D:0.0166, Loss_G:-0.0033, rec_loss:1.8123, prior_loss:10.0930
Thu Mar 19 00:54:17 2020
[6/250]: D_real:0.9981, D_enc:0.0019, D_noise:0.0000, Loss_D:0.0040, Loss_G:-0.0011, rec_loss:1.9981, prior_loss:37.7629
Thu Mar 19 00:55:33 2020
[7/250]: D_real:0.9928, D_enc:0.0051, D_noise:0.0000, Loss_D:0.036

RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 2.00 GiB total capacity; 913.75 MiB already allocated; 12.44 MiB free; 128.25 MiB cached)

<Figure size 1800x1200 with 0 Axes>

<Figure size 1800x1200 with 0 Axes>

<Figure size 1800x1200 with 0 Axes>

<Figure size 1800x1200 with 0 Axes>

<Figure size 1800x1200 with 0 Axes>

<Figure size 1800x1200 with 0 Axes>

<Figure size 1800x1200 with 0 Axes>

<Figure size 1800x1200 with 0 Axes>

<Figure size 1800x1200 with 0 Axes>

<Figure size 1800x1200 with 0 Axes>