In [None]:
%matplotlib inline

# 美少女無窮生成 pytorch

### 測試於pytorch 1.0

![md_images](../Images/gan.jpg)

![md_images](../Images/rasgan.png)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import pylab
import PIL
from PIL import Image
import numpy as np
import os
import datetime
import time
import glob
import pylab
import cv2
import math
import string
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

# 是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
img_h = 64
img_w = 64
img_c = 3

#產生高斯分布噪音
def noise_sample(num_samples, g_input_dim=100):
    return np.random.normal(
        size=[num_samples, g_input_dim]
    ).astype(np.float32)

s =  glob.glob('../Data/ex08_train/resized_images/' + '*.jpg')
print('{0}張圖片'.format(len(s)))
random.shuffle(s)
idx = 0

In [None]:
#圖片轉向量
def img2array(img: Image):
    arr = np.array(img).astype(np.float32)
    arr=arr.transpose(2, 0, 1) #轉成CHW
    arr=np.ascontiguousarray(arr)
    return arr[::-1] #顏色排序為BGR

#向量轉圖片
def array2img(arr: np.ndarray):
    arr =arr[::-1]#轉成RGB
    sanitized_img = np.maximum(0, np.minimum(255, np.transpose(arr, (1, 2, 0))))#轉成HWC
    img = Image.fromarray(sanitized_img.astype(np.uint8))
    return img

#隨機加入標準常態分配的噪聲
def add_noise(image):
    noise=np.random.standard_normal(image.shape)*np.random.choice(np.arange(-5,5))
    image=np.clip(image+noise,0,255)
    return image

#調整明暗
def adjust_gamma(image,gamma=1.2):
    image = image.transpose([1, 2, 0])
    invGamma = 1.0 / gamma
    table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8")
    cv2.LUT(image.astype(np.uint8), table)
    image = image.transpose([2, 0, 1])
    return image

#模糊
def adjust_blur(image):
    image = image.transpose([1, 2, 0])
    image=cv2.blur(image, (3, 3))
    image = image.transpose([2, 0, 1])
    return image



def next_minibatch(minibatch_size,is_train=True):
    global s, idx
    features = []
    while len(features) < minibatch_size:
        try:
            im = Image.open(s[idx]).convert('RGB').resize((64,64),Image.ANTIALIAS) 
            im = img2array(im).astype(np.float32)
            #加入數據增強以確保圖片輸入的多元性，避免鑑別模型記憶樣本過擬合
            if is_train:
                im=add_noise(im)
                if random.randint(0,10)%2==0:
                    gamma=np.random.choice(np.arange(0.6, 1.5, 0.05))
                    img=adjust_gamma(im,gamma)
                if random.randint(0,10)%5<=1:
                    im=adjust_blur(im)
                    
            features.append((im-127.5)/127.5)
        except OSError as e:
            print(e)
        idx += 1
        if idx >= len(s):
            random.shuffle(s)
            idx = 0
    return np.asarray(features).astype(np.float32)

![md_images](../Images/self_attention_module.png)

In [None]:
#自注意力機制
class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1) #

    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)
        out = self.gamma*out + x
        return out,attention
    


![md_images](../Images/pixelshuffle.jpg)

In [None]:
def conv5x5(in_planes, out_planes, stride=1,dilation=1,padding=2):
    return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,padding=padding, dilation =dilation, bias=False)

def conv3x3(in_planes, out_planes, stride=1,dilation=1,padding=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=padding,dilation =dilation, bias=False)
def conv1x1(in_planes, out_planes, stride=1,dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,dilation =dilation, bias=False)

class resnet_basic(nn.Module):
    def __init__(self, inplanes, stride=1):
        super(resnet_basic, self).__init__()
        self.conv1 = conv3x3(inplanes, inplanes, stride)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 =conv3x3(inplanes, inplanes, stride)
        self.bn2 = nn.BatchNorm2d(inplanes)
        self.stride = stride
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = residual+0.2*out
        out = self.relu(out)
        return out

def weights_init(m): 
    if isinstance(m, nn.Conv2d): 
        nn.init.kaiming_uniform_(m.weight) 


class generator(nn.Module):
    def __init__(self, input_dim=100,input_size=64):
        super(generator, self).__init__()
        self.input_dim=input_dim
        self.input_size=input_size
        self.att= Self_Attn(64, 'relu')
        self.fc = nn.Linear(self.input_dim, 256 * (self.input_size//16) * (self.input_size//16))
        self.ps1=nn.PixelShuffle(2)
        self.tu1 = nn.Sequential(
            conv5x5(256//4, 256, stride=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.ps2=nn.PixelShuffle(2)
        self.tu2 = nn.Sequential(
            conv5x5(256//4, 128, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            resnet_basic(128, 1),
            conv3x3(128, 128, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.ps3=nn.PixelShuffle(2)
        self.tu3 = nn.Sequential(
            conv3x3(128//4, 128, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            resnet_basic(128,1),
            conv3x3(128, 64, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU()

        )
        self.ps4=nn.PixelShuffle(2)
        self.tu4 = nn.Sequential(
            conv3x3(64//4, 64, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            conv3x3(64, 64, stride=1,dilation=2,padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            conv3x3(64, 64, stride=1,dilation=4,padding=4),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            conv1x1(64, 3, stride=1),
            nn.Tanh()
        )
        


    def forward(self, input):
        x = self.fc(input)
        x = x.view(x.size(0),256 ,(self.input_size//16) ,(self.input_size//16))
        x = self.ps1(x)
        x = self.tu1(x)
        x = self.ps2(x)
        x = self.tu2(x)
        x = self.ps3(x)
        x = self.tu3(x)
        x,attention=self.att(x)
        x = self.ps4(x)
        x =self.tu4(x)
        return x

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.td1 = nn.Sequential(
            conv5x5(3, 32, stride=2),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            conv3x3(32, 32, stride=1,dilation=2,padding=2),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            conv5x5(32, 64, stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2)
 
        )
        self.td2 = nn.Sequential(
            conv3x3(64, 64, stride=1,dilation=4,padding=4),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            conv3x3(64, 128, stride=1,dilation=8,padding=8),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            resnet_basic(128,1),
            conv3x3(128, 128, stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            resnet_basic(128,1),
            conv3x3(128, 128, stride=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            conv1x1(128, 64, stride=1),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Linear(64, 1)
        self.att= Self_Attn(64, 'relu')

    def forward(self, input):
        x = self.td1(input)
        x,attention=self.att(x)
        x = self.td2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
def tile_rgb_images(x, row=2, col=2):
    fig = pylab.gcf()
    fig.set_size_inches(col * 2, row * 2)
    pylab.clf()
    pylab.ioff()
    for m in range(row * col):
        pylab.subplot(row, col, m + 1)
        img = array2img(x[m]*127.5+127.5)
        pylab.imshow(img, interpolation="nearest", animated=True)
        pylab.axis("off")
    filename='Pytorch_Results/RaSGAN_{}.png'.format(
    str(datetime.datetime.fromtimestamp(time.time())).replace(' ', '').replace(':', '').replace('-', '').replace(
        '.', ''))
    pylab.savefig(filename, bbox_inches='tight')
    plt.axis('off')
    plt.imshow(img)

In [None]:
G = generator(100).to(device)
D = discriminator().to(device)

G.apply(weights_init)
D.apply(weights_init)

if os.path.exists('Models/RaSGAN_X_fake_pytorch.cnn'):
    G=torch.load('Models/RaSGAN_X_fake_pytorch.cnn')
    print('G recovered!!')
    
if os.path.exists('Models/RaSGAN_D_real_pytorch.cnn'):
    model=torch.load('Models/RaSGAN_D_real_pytorch.cnn')
    print('D recovered!!')
    

G_optimizer = optim.Adam(G.parameters(), lr=2e-4,betas=(0.0, 0.999), weight_decay=5e-5)
D_optimizer = optim.Adam(D.parameters(), lr=2e-4,betas=(0.9, 0.999), weight_decay=5e-5)


In [None]:
minibatch_size=32
num_epochs=300
print('epoch start')
D.train()
for epoch in range(num_epochs):
    mbs=0
    G.train()
    while mbs <1000:
        Z_data = noise_sample(minibatch_size)
        X_data = next_minibatch(minibatch_size)
        
        
        Z_data, X_data = torch.from_numpy(Z_data), torch.from_numpy(X_data)
     
        Z_data, X_data = Variable(Z_data).to(device), Variable(X_data).to(device)
        
        D_optimizer.zero_grad()

        D_real = D(X_data)

        G_ = G(Z_data)
        D_fake = D(G_)
        
        epsilon=1e-10
        D_r_tilde = torch.sigmoid(D_real -D_fake.mean())
        D_f_tilde = torch.sigmoid(D_fake - D_real.mean())
        D_loss = - ((D_r_tilde + epsilon).log()).mean() - ((1 - D_f_tilde + epsilon).log()).mean()
        
        #在倒傳導階段是不允許兩個計算圖同時更新梯度，所以要設定retain_graph=True
        D_loss.backward(retain_graph=True)
        D_optimizer.step()

        G_optimizer.zero_grad()

        G_ = G(Z_data)
        D_fake = D(G_)
         
        G_loss =- ((D_f_tilde + epsilon).log()).mean() - ((1 - D_r_tilde + epsilon).log()).mean()
      
        G_loss.backward(retain_graph=True)
        G_optimizer.step()
        
        

        if (mbs+1)%50==0:
            print("Epoch: {}/{} ".format(epoch+1, num_epochs),
                                      "Step: {} ".format(mbs+1),
                                      "D Loss: {:.4f} (D_real:{:.3%})  ".format(D_loss.data.item(),D_real.cpu().detach().numpy().mean()),
                                      "G Loss: {:.4f} (D_fake:{:.3%})  ".format(G_loss.data.item(),D_fake.cpu().detach().numpy().mean()))
            tile_rgb_images(G_.cpu().detach().numpy(), 4, 4)
            torch.save(G, 'Models/RaSGAN_X_fake_pytorch.cnn')
            torch.save(D, 'Models/RaSGAN_D_real_pytorch.cnn')
        mbs+=1