# multiple_DR_GAN

ジュネレータに同じ人の複数の画像（異なるポーズ，シーンetc)を入れ，出力されたそれぞれの特徴量を
重み付けして足し合わせた特徴量を元に画像を生成


In [1]:
import os
import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
import pdb

## Discriminator の定義

single-image_DR_GAN と同じ
> - 論文で用いられている  TensorFlow のConv オプション padding="SAME"と同じ挙動を再現するために padding layer を間に追加
- 入力は バッチ数(B)ｘ96x96x3
- 個人の識別(Nd+1) と　姿勢の推定(Np)を同時に行う

In [None]:
class Discriminator(nn.Module):
    def __init__(self, Nd, Np):
        super(Discriminator, self).__init__()
        convLayers = [
            nn.Conv2d(3, 32, 3, 1, 1, bias=False), # Bx3x96x96 -> Bx32x96x96
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False), # Bx32x96x96 -> Bx64x96x96
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx64x96x96 -> Bx64x97x97
            nn.Conv2d(64, 64, 3, 2, 0, bias=False), # Bx64x97x97 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 128, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx128x48x48
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx128x48x48 -> Bx128x49x49
            nn.Conv2d(128, 128, 3, 2, 0, bias=False), #  Bx128x49x49 -> Bx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 96, 3, 1, 1, bias=False), #  Bx128x24x24 -> Bx96x24x24
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.Conv2d(96, 192, 3, 1, 1, bias=False), #  Bx96x24x24 -> Bx192x24x24
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx192x24x24 -> Bx192x25x25
            nn.Conv2d(192, 192, 3, 2, 0, bias=False), # Bx192x25x25 -> Bx192x12x12
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.Conv2d(192, 128, 3, 1, 1, bias=False), # Bx192x12x12 -> Bx128x12x12
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 256, 3, 1, 1, bias=False), # Bx128x12x12 -> Bx256x12x12
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx256x12x12 -> Bx256x13x13
            nn.Conv2d(256, 256, 3, 2, 0, bias=False),  # Bx256x13x13 -> Bx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.Conv2d(256, 160, 3, 1, 1, bias=False), # Bx256x6x6 -> Bx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            nn.Conv2d(160, 320, 3, 1, 1, bias=False), # Bx160x6x6 -> Bx320x6x6
            nn.BatchNorm2d(320),
            nn.ELU(),
            nn.AvgPool2d(6, stride=1), #  Bx320x6x6 -> Bx320x1x1
        ]
        
        self.convLayers = nn.Sequential(*convLayers)
        self.fc = nn.Linear(320, Nd+1+Np)
        
        # 重みは全て N(0, 0.02) で初期化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.02)
                
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)
        
    def forward(self, input):
        # 畳み込み -> 平均プーリングの結果 B x 320 x 1 x 1の出力を得る
        x = self.convLayers(input)
        
        # バッチ数次元を消さないように１次元の次元を削除　
        x = x.squeeze(2)
        x = x.squeeze(2)
        
        # 全結合 
        x = self.fc(x) # Bx320 -> B x (Nd+1+Np)
        
        return x
    

# Generator の定義

- G_enc は 同一人物に n 枚の画像があるとして nB x 1 x 96 x 96 -> B x n x 321 -> B x 320 と特徴量をencode
- G_dec は single-image DR_GANと同じ
> single-image の時
    - G_enc は Discriminator と最後の全結合層が無い以外同じ構造
    - G_dec のアップサンプリング時は，ダウンサンプリング時に Zeropadding を行なったことの逆で，ConvTranspose2d 後に Crop（negative padding?)

In [None]:
## nn.Module を継承しても， super でコンストラクタを呼び出さないと メンバ変数 self._modues が
## 定義されずに後の重み初期化の際にエラーを出す
## sef._modules はモジュールが格納するモジュール名を格納しておくリスト

class Crop(nn.Module):
    def __init__(self, crop_list):
        super().__init__()
        
        # crop_lsit = [crop_top, crop_bottom, crop_left, crop_right]
        self.crop_list = crop_list
            
    def forward(self, x):
        B,C,H,W = x.size()
        x = x[:,:, self.crop_list[0] : H - self.crop_list[1] , self.crop_list[2] : W - self.crop_list[3]]
        
        return x

In [None]:
# 重みの出力の部分にだけシグモイド関数をかけ， その重みを用いて n枚の画像の出力結果を足し合わせる
# 入力： nBx321x1x1　-> 出力: B x 320x1x1

class WSum_feature(nn.Module):
    def __init__(self, n):
        super().__init__()
        
        # 一人にあたり何枚の画像を渡しているのか指示
        self.n = n
    
    def forward(self, x):
        
        # CNN出力結果を特徴量と重みに分けた後に， 各人 n 枚毎のデータへ分割
        features = x[:,:-1].split(self.n,0)
        weights = x[:,-1].unsqueeze(1).sigmoid().split(self.n,0)
        
        features_summed = []
        
        # nBx320x1x1 -> Bx320x1x1
        for (feature_each, weight_each)  in zip(features, weights):
            feature_weighted = feature_each*weight_each
            feature_summed = feature_weighted.sum(0, keepdim=True) / weight_each.sum(0)
            features_summed.append(feature_summed)
        
        features_summed = torch.cat(features_summed)
        
        return features_summed
        

In [110]:
n = 6
B = 50
w = np.arange(300).reshape(300,1)
tmp = np.ones((n*B, 321))
x = Variable(torch.FloatTensor(w*tmp), requires_grad=True)

In [127]:
weight = x[:,-1].unsqueeze(1)#.sigmoid()
features = x*weight
features = features[:,:-1].split(n, 0)
features = torch.cat(features,1)
features = features.sum(0, keepdim=True)
features = features.view(50,-1)


torch.Size([300, 321])
torch.Size([300, 1])


In [125]:
z = features.sum(0, keepdim=True)

In [126]:
z.view(50,-1)

Variable containing:
 5.5000e+01  5.5000e+01  5.5000e+01  ...   5.5000e+01  5.5000e+01  5.5000e+01
 4.5100e+02  4.5100e+02  4.5100e+02  ...   4.5100e+02  4.5100e+02  4.5100e+02
 1.2790e+03  1.2790e+03  1.2790e+03  ...   1.2790e+03  1.2790e+03  1.2790e+03
                ...                   ⋱                   ...                
 4.8566e+05  4.8566e+05  4.8566e+05  ...   4.8566e+05  4.8566e+05  4.8566e+05
 5.0636e+05  5.0636e+05  5.0636e+05  ...   5.0636e+05  5.0636e+05  5.0636e+05
 5.2749e+05  5.2749e+05  5.2749e+05  ...   5.2749e+05  5.2749e+05  5.2749e+05
[torch.FloatTensor of size 50x320]

In [134]:
import pdb
def WSum_feature(x, n):
    """
    重みの出力の部分にだけシグモイド関数をかけ， その重みを用いて n枚の画像の出力結果を足し合わせる
    入力： nBx321x1x1　-> 出力: B x 320x1x1

    n : 一人にあたり何枚の画像をデータとして渡しているのか

    """
    # nBx320x1x1 -> Bx320x1x1
    weight = x[:,-1].unsqueeze(1).sigmoid()
    features = x*weight
    features = features[:,:-1].split(n, 0)
    features = torch.cat(features,1)
    features_summed = features.sum(0, keepdim=True)
    features_summed = features_summed.view(50,-1)
    
    return features_summed

In [133]:
WSum_feature(x,n)

Variable containing:
 5.5000e+01  5.5000e+01  5.5000e+01  ...   5.5000e+01  5.5000e+01  5.5000e+01
 4.5100e+02  4.5100e+02  4.5100e+02  ...   4.5100e+02  4.5100e+02  4.5100e+02
 1.2790e+03  1.2790e+03  1.2790e+03  ...   1.2790e+03  1.2790e+03  1.2790e+03
                ...                   ⋱                   ...                
 4.8566e+05  4.8566e+05  4.8566e+05  ...   4.8566e+05  4.8566e+05  4.8566e+05
 5.0636e+05  5.0636e+05  5.0636e+05  ...   5.0636e+05  5.0636e+05  5.0636e+05
 5.2749e+05  5.2749e+05  5.2749e+05  ...   5.2749e+05  5.2749e+05  5.2749e+05
[torch.FloatTensor of size 50x320]

In [None]:
class Generator(nn.Module):
    def __init__(self, Np, Nz, n):
        super(Generator, self).__init__()
        G_enc_convLayers = [
            nn.Conv2d(3, 32, 3, 1, 1, bias=False), # nBx3x96x96 -> nBx32x96x96
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False), # nBx32x96x96 -> nBx64x96x96
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # nBx64x96x96 -> nBx64x97x97
            nn.Conv2d(64, 64, 3, 2, 0, bias=False), # nBx64x97x97 -> nBx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), # nBx64x48x48 -> nBx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 128, 3, 1, 1, bias=False), # nBx64x48x48 -> nBx128x48x48
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # nBx128x48x48 -> nBx128x49x49
            nn.Conv2d(128, 128, 3, 2, 0, bias=False), #  nBx128x49x49 -> nBx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 96, 3, 1, 1, bias=False), #  nBx128x24x24 -> nBx96x24x24
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.Conv2d(96, 192, 3, 1, 1, bias=False), #  nBx96x24x24 -> nBx192x24x24
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # nBx192x24x24 -> nBx192x25x25
            nn.Conv2d(192, 192, 3, 2, 0, bias=False), # nBx192x25x25 -> nBx192x12x12
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.Conv2d(192, 128, 3, 1, 1, bias=False), # nBx192x12x12 -> nBx128x12x12
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 256, 3, 1, 1, bias=False), # nBx128x12x12 -> nBx256x12x12
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # nBx256x12x12 -> nBx256x13x13
            nn.Conv2d(256, 256, 3, 2, 0, bias=False),  # nBx256x13x13 -> nBx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.Conv2d(256, 160, 3, 1, 1, bias=False), # nBx256x6x6 -> nBx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            
            # 同一人物の画像の特徴量を足し合わせる際の重みを示す値 w を１次元分チャネルに追加
            nn.Conv2d(160, 321, 3, 1, 1, bias=False), # nBx160x6x6 -> nBx321x6x6
            nn.BatchNorm2d(321),
            nn.ELU(),
            nn.AvgPool2d(6, stride=1), #  nBx321x6x6 -> nBx321x1x1
            
            # 同一人物の画像の特徴量を重みを用いて足し合わせる
            WSum_feature(n), # nBx321x1x1 -> Bx320x1x1
        ]
        self.G_enc_convLayers = nn.Sequential(*G_enc_convLayers)
        
        G_dec_convLayers = [
            nn.ConvTranspose2d(320,160, 3,1,1, bias=False), # Bx320x6x6 -> Bx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            nn.ConvTranspose2d(160, 256, 3,1,1, bias=False), # Bx160x6x6 -> Bx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ConvTranspose2d(256, 256, 3,2,0, bias=False), # Bx256x6x6 -> Bx256x13x13
            nn.BatchNorm2d(256),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(256, 128, 3,1,1, bias=False), # Bx256x12x12 -> Bx128x12x12  
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ConvTranspose2d(128, 192,  3,1,1, bias=False), # Bx128x12x12 -> Bx192x12x12            
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ConvTranspose2d(192, 192,  3,2,0, bias=False), # Bx128x12x12 -> Bx192x25x25            
            nn.BatchNorm2d(192),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(192, 96,  3,1,1, bias=False), # Bx192x24x24 -> Bx96x24x24 
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.ConvTranspose2d(96, 128,  3,1,1, bias=False), # Bx96x24x24 -> Bx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ConvTranspose2d(128, 128,  3,2,0, bias=False), # Bx128x24x24 -> Bx128x49x49      
            nn.BatchNorm2d(128),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(128, 64,  3,1,1, bias=False), # Bx128x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ConvTranspose2d(64, 64,  3,1,1, bias=False), # Bx64x48x48 -> Bx64x48x48  
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ConvTranspose2d(64, 64,  3,2,0, bias=False), # Bx64x48x48 -> Bx64x97x97  
            nn.BatchNorm2d(64),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(64, 32,  3,1,1, bias=False), # Bx64x96x96 -> Bx32x96x96 
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.ConvTranspose2d(32, 3,  3,1,1, bias=False), # Bx32x96x96 -> Bx3x96x96 
            nn.ELU(),
        ]
        
        self.G_dec_convLayers = nn.Sequential(*G_dec_convLayers)
        
        self.G_dec_fc = nn.Linear(320+Np+Nz, 320*6*6)
        
        # 重みは全て N(0, 0.02) で初期化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.02)
                
            elif isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0, 0.02)
                
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)
        
        
        
    def forward(self, input, pose, noise):
        
        x = self.G_enc_convLayers(input) # nBx1x96x96 -> Bx320x1x1
        
        x = x.squeeze(2)
        x = x.squeeze(2)
        
        x = torch.cat([x, pose, noise], 1)  # nBx320 -> nB x (320+Np+Nz)
        
        x = self.G_dec_fc(x) # B x (320+Np+Nz) -> B x (320x6x6)
    
        x = x.view(-1, 320, 6, 6) # B x (320x6x6) -> B x 320 x 6 x 6
        
        x = self.G_dec_convLayers(x) #  B x 320 x 6 x 6 -> Bx1x96x96
        
        return x
    

# 画像の取得

# ランダム入力データの生成

In [None]:
data_size = 80
imnum_each_person = 5
Nd = 200
Np = 9

images = np.random.randn(data_size, 3, 96,96)
# id は0~199 と仮定
id_labels = np.random.randint(Nd, size=data_size)
pose_labels = np.random.randint(Np, size=data_size)

# 訓練の実行

In [None]:
# Discriminator の正解率を算出し， 指定した以上の正解率であれば， 十分強いとみなす

def Is_D_strong(real_output, syn_output, id_label_tensor, pose_label_tensor, syn_id_label_tensor, Nd, batch_size, thresh=0.9):
    # Discriminator の正解率を算出
    _, id_real_ans = torch.max(real_output[:, :Nd+1], 1)
    _, pose_real_ans = torch.max(real_output[:, Nd+1:], 1)
    _, id_syn_ans = torch.max(syn_output[:, :Nd+1], 1)

    id_real_precision = (id_real_ans==id_label_tensor).type(torch.FloatTensor).sum() / batch_size
    pose_real_precision = (pose_real_ans==pose_label_tensor).type(torch.FloatTensor).sum() / batch_size
    id_syn_precision = (id_syn_ans==syn_id_label_tensor).type(torch.FloatTensor).sum() / batch_size
    
    total_precision = (id_real_precision+pose_real_precision+id_syn_precision)/3
    
    # Variable(FloatTensor) -> Float へと変換
    total_precision = total_precision.data[0]
    
    if total_precision>=thresh:
        flag_D_strong = True
    else:
        flag_D_strong = False
    
    return flag_D_strong

In [None]:
batch_size = 10 
imnum_each_person = 5 # 一人あたり何枚の画像がデータに含まれているか
epoch = 10000
image_size = images.shape[0]
epoch_time = np.ceil(image_size / batch_size).astype(int)

Nd = 200 # number of ID (person)
Np = 9 # number of discrite poses
Nz = 50 # number of noise dimension

lr_Adam = 0.0002 
m_Adam = 0.5
flag_D_strong  = False

D = Discriminator(Nd, Np)
G = Generator(Np, Nz, imnum_each_person)
optimizer_D = optim.Adam(D.parameters())
optimizer_G = optim.Adam(G.parameters())
loss_criterion = nn.CrossEntropyLoss()

for epoch in range(epoch):
    for i in range(epoch_time):
        D.zero_grad()
        G.zero_grad()
        start = i*batch_size
        end = start + batch_size
        batch_image = images[start:end]
        batch_id_label = id_labels[start:end]
        batch_pose_label = pose_labels[start:end]
        minibatch_size = len(batch_image)
        
        
        # 学習の中で使われるVariable変数の定義
        # ラベルの定義(CrossEntropy 誤差で用いる際は FloatTensorでは☓)
        img_tensor = Variable(torch.FloatTensor(batch_image))
        id_label_tensor = Variable(torch.LongTensor(batch_id_label))
        id_label_unique_tensor = id_label_tensor[::imnum_each_person]
        pose_label_tensor = Variable(torch.LongTensor(batch_pose_label))
        
        syn_id_labels = Nd*np.ones(minibatch_size//imnum_each_person).astype(int)
        syn_id_label_tensor = Variable(torch.LongTensor(syn_id_labels))
        
        # ノイズと姿勢コードを生成
        
        # 実際に入力するデータは minibatch_size = n x 人数 で CNNから出力される特徴は   condition_batchsize = 人数
        condition_batchsize = int(minibatch_size/imnum_each_person)
        
        fixed_noise_tensor = Variable(torch.FloatTensor(np.random.uniform(-1,1, (condition_batchsize, Nz))))
        pose_code = np.zeros((condition_batchsize, Np))
        tmp  = np.random.randint(Np, size=condition_batchsize)
        pose_code[:, tmp] = 1
        pose_code_label_tensor = Variable(torch.LongTensor(tmp)) # CrossEntropy 誤差に使用
        pose_code_tensor = Variable(torch.FloatTensor(pose_code)) # Condition 付に使用
        
        # Generatorでイメージ生成
        generated = G(img_tensor, pose_code_tensor, fixed_noise_tensor)
        
        # バッチ毎に交互に D と G の学習，　Dが90%以上の精度の場合は 1:4の比率で学習
        print(i)
        if flag_D_strong:
            if i%5 == 0:
                # Discriminator の学習
                real_output = D(img_tensor)
                syn_output = D(generated.detach()) # .detach() をすることでGeneratorのパラメータを更新しない

                # id についての出力とラベル, pose についての出力とラベル それぞれの交差エントロピー誤差を計算
                
                d_loss = loss_criterion(real_output[:, :Nd+1], id_label_tensor) +\
                                        loss_criterion(real_output[:, Nd+1:], pose_label_tensor) +\
                                        loss_criterion(syn_output[:, :Nd+1], syn_id_label_tensor)
                
                d_loss.backward()
                optimizer_D.step()
                print("EPOCH : {0}, D : {1}".format(epoch, d_loss.data[0]))
                
                # Discriminator の強さを判別
                flag_D_strong = Is_D_strong(real_output, syn_output, id_label_tensor, pose_label_tensor, syn_id_label_tensor, Nd, condition_batchsize)
                
            else:
                # Generatorの学習
                syn_output=D(generated)

                # id についての出力と元画像のラベル, poseについての出力と生成時に与えたposeコード それぞれの交差エントロピー誤差を計算
                g_loss = loss_criterion(syn_output[:, :Nd+1], id_label_unique_tensor) +\
                    loss_criterion(syn_output[:, Nd+1:], pose_code_label_tensor)

                optimizer_G.step()
                print("EPOCH : {0}, G : {1}".format(epoch, g_loss.data[0]))
        
        else:

            if i%2==0:
                # Discriminator の学習
                real_output = D(img_tensor)
                syn_output = D(generated.detach()) # .detach() をすることでGeneratorのパラメータを更新しない

                # id についての出力とラベル, pose についての出力とラベル それぞれの交差エントロピー誤差を計算

                d_loss = loss_criterion(real_output[:, :Nd+1], id_label_tensor) +\
                                        loss_criterion(real_output[:, Nd+1:], pose_label_tensor) +\
                                        loss_criterion(syn_output[:, :Nd+1], syn_id_label_tensor)

                d_loss.backward()
                optimizer_D.step()
                print("EPOCH : {0}, D : {1}".format(epoch, d_loss.data[0]))
                
                # Discriminator の強さを判別
                flag_D_strong = Is_D_strong(real_output, syn_output, id_label_tensor, pose_label_tensor, syn_id_label_tensor, Nd, minibatch_size)
                
            else:
                # Generatorの学習
                syn_output=D(generated)

                # id についての出力と元画像のラベル, poseについての出力と生成時に与えたposeコード それぞれの交差エントロピー誤差を計算
                g_loss = loss_criterion(syn_output[:, :Nd+1], id_label_unique_tensor) +\
                    loss_criterion(syn_output[:, Nd+1:], pose_code_label_tensor)

                optimizer_G.step()
                print("EPOCH : {0}, G : {1}".format(epoch, g_loss.data[0]))
    
    
    # 各エポックで学習したモデルを保存，
    torch.save(D, "D.model")
    torch.save(G, "G.model")
    