-
Notifications
You must be signed in to change notification settings - Fork 2
/
RaSGAN.py
104 lines (85 loc) · 4.31 KB
/
RaSGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import tensorflow as tf
from models.generative.ops import *
from models.generative.utils import *
from models.generative.gans.GAN import GAN
class RaSGAN(GAN):
def __init__(self,
data, # Dataset class, training and test data.
z_dim, # Latent space dimensions.
use_bn, # Batch Normalization flag to control usage in discriminator.
alpha, # Alpha value for LeakyReLU.
beta_1, # Beta 1 value for Adam Optimizer.
learning_rate_g, # Learning rate generator.
learning_rate_d, # Learning rate discriminator.
n_critic=1, # Number of batch gradient iterations in Discriminator per Generator.
loss_type = 'relativistic standard', # Loss function type: Standard, Least Square, Wasserstein, Wasserstein Gradient Penalty.
model_name='RaSGAN' # Name to give to the model.
):
super().__init__(data=data, z_dim=z_dim, use_bn=use_bn, alpha=alpha, beta_1=beta_1, learning_rate_g=learning_rate_g, learning_rate_d=learning_rate_d,
n_critic=n_critic, loss_type=loss_type, model_name=model_name)
def discriminator(self, images, reuse):
with tf.variable_scope('discriminator', reuse=reuse):
# Padding = 'Same' -> H_new = H_old // Stride
# Conv.
net = convolutional(inputs=images, output_channels=64, filter_size=5, stride=2, padding='SAME', conv_type='convolutional', scope=1)
net = leakyReLU(net, self.alpha)
# Shape = (None, 28, 28, 64)
# Conv.
net = convolutional(inputs=net, output_channels=128, filter_size=5, stride=2, padding='SAME', conv_type='convolutional', scope=2)
if self.use_bn: net = tf.layers.batch_normalization(inputs=net, training=True)
net = leakyReLU(net, self.alpha)
# Shape = (None, 14, 14, 128)
# Conv.
net = convolutional(inputs=net, output_channels=256, filter_size=5, stride=2, padding='SAME', conv_type='convolutional', scope=3)
if self.use_bn: net = tf.layers.batch_normalization(inputs=net, training=True)
net = leakyReLU(net, self.alpha)
# Shape = (None, 7, 7, 256)
# Flatten.
net = tf.layers.flatten(inputs=net)
# Shape = (None, 7*7*256)
# Dense.
net = dense(inputs=net, out_dim=1024, scope=1)
if self.use_bn: net = tf.layers.batch_normalization(inputs=net, training=True)
net = leakyReLU(net, self.alpha)
# Shape = (None, 1024)
# Dense
logits = dense(inputs=net, out_dim=1, scope=2)
# Shape = (None, 1)
output = tf.nn.sigmoid(x=logits)
return output, logits
def generator(self, z_input, reuse, is_train):
with tf.variable_scope('generator', reuse=reuse):
# Doesn't work ReLU, tried.
# Input Shape = (None, 100)
# Dense.
net = dense(inputs=z_input, out_dim=1024, scope=1)
net = tf.layers.batch_normalization(inputs=net, training=is_train)
net = leakyReLU(net, self.alpha)
# Shape = (None, 1024)
# Dense.
net = dense(inputs=net, out_dim=256*7*7, scope=2)
net = tf.layers.batch_normalization(inputs=net, training=is_train)
net = leakyReLU(net, self.alpha)
# Shape = (None, 256*7*7)
# Reshape
net = tf.reshape(tensor=net, shape=(-1, 7, 7, 256), name='reshape')
# Shape = (None, 7, 7, 256)
# Conv.
net = convolutional(inputs=net, output_channels=256, filter_size=4, stride=2, padding='SAME', conv_type='transpose', scope=1)
net = tf.layers.batch_normalization(inputs=net, training=is_train)
net = leakyReLU(net, self.alpha)
# Shape = (None, 14, 14, 256)
# Conv.
net = convolutional(inputs=net, output_channels=128, filter_size=5, stride=1, padding='SAME', conv_type='convolutional', scope=2)
net = tf.layers.batch_normalization(inputs=net, training=is_train)
net = leakyReLU(net, self.alpha)
# Shape = (None, 14, 14, 128)
# Conv.
net = convolutional(inputs=net, output_channels=128, filter_size=4, stride=2, padding='SAME', conv_type='transpose', scope=3)
net = tf.layers.batch_normalization(inputs=net, training=is_train)
net = leakyReLU(net, self.alpha)
# Shape = (None, 28, 28, 128)
logits = convolutional(inputs=net, output_channels=self.image_channels, filter_size=4, stride=2, padding='SAME', conv_type='transpose', scope=4)
# Shape = (None, 56, 56, 3)
output = tf.nn.sigmoid(x=logits, name='output')
return output