In [None]:
import numpy as np
import keras
import keras.layers as lay
import tensorflow as tf
import matplotlib.pyplot as plt

# Preparação do Dataset

In [None]:
(x,y),(_,_) = keras.datasets.cifar10.load_data()
x = (tf.cast(x[y[:,0] == 1],tf.float32)-127.5)/127.5
batch_size = 32
cars = tf.data.Dataset.from_tensor_slices((x)).shuffle(1000).batch(batch_size,drop_remainder=True).prefetch(tf.data.AUTOTUNE)

# Regularização R1

In [None]:
@tf.function
def r1_regularization(discriminator, batch):
    with tf.GradientTape() as tape:
        tape.watch(batch)
        logits = discriminator(batch,training=True)
        logits = tf.reduce_sum(logits)
    grads = tape.gradient(logits,[batch])[0]
    norm = tf.reduce_mean(tf.reduce_sum(tf.square(grads),axis=[1,2,3]))
    del grads
    return norm

# Declaração da Classe AdaIN

In [None]:
@keras.saving.register_keras_serializable()
class AdaIN(lay.Layer):
    def __init__(self, channels, **kargs):
        super().__init__(**kargs)
        self.channels = channels
        self.dense_gamma = lay.Dense(channels,name='Dense_Gamma')
        self.dense_bias = lay.Dense(channels,name='Dense_Bias')

    def call(self, inputs, *args, **kwargs):
        features_map, style_w = inputs
        gamma = self.dense_gamma(style_w)
        bias = self.dense_bias(style_w)
        gamma = tf.reshape(gamma,(-1,1,1,self.channels))
        bias = tf.reshape(bias,(-1,1,1,self.channels))
        mean, variance = tf.nn.moments(features_map,(1,2),keepdims=True)
        normalized = tf.nn.batch_normalization(features_map,mean,variance,bias,gamma,1e-6)
        return normalized

# Declaração da Classe StyleGAN

In [None]:
@keras.saving.register_keras_serializable()
class StyleGAN(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
        # Neurônios de conversão de z para w
        self.mapping_network = keras.Sequential([
            lay.InputLayer(shape=(128,)),
            *[lay.Dense(128,activation='leaky_relu') for i in range(7)]
            ],name='mapping_network')

        # Mapa inicial
        self.latent_map = self.add_weight(
            (1,4,4,512),
            trainable=True
        )

        # Convoluções para super-resolução
        self.conv2d = [
            lay.Conv2D(512,3,1,'same',activation='leaky_relu'),
            lay.Conv2D(256,3,1,'same',activation='leaky_relu'),
            lay.Conv2D(256,3,1,'same',activation='leaky_relu'),
            lay.Conv2D(128,3,1,'same',activation='leaky_relu'),
            lay.Conv2D(128,3,1,'same',activation='leaky_relu'),
            lay.Conv2D(64,3,1,'same',activation='leaky_relu'),
            lay.Conv2D(64,3,1,'same',activation='leaky_relu'),
            lay.Conv2D(3,1,1,'same',activation='tanh',name='toRGB')
        ]
        
        # Escaladores de ruído
        self.noises = [
            self.add_weight(shape=(1,1,1,512),trainable=True),
            self.add_weight(shape=(1,1,1,512),trainable=True),
            self.add_weight(shape=(1,1,1,256),trainable=True),
            self.add_weight(shape=(1,1,1,256),trainable=True),
            self.add_weight(shape=(1,1,1,128),trainable=True),
            self.add_weight(shape=(1,1,1,128),trainable=True),
            self.add_weight(shape=(1,1,1,64),trainable=True),
            self.add_weight(shape=(1,1,1,64),trainable=True),
        ]

        # Camadas de Normalização de instância Adaptativa
        self.AdaIN = [
            AdaIN(512),
            AdaIN(512),
            AdaIN(256),
            AdaIN(256),
            AdaIN(128),
            AdaIN(128),
            AdaIN(64),
            AdaIN(64),
        ]

        # Camadas de upsampling
        self.upsampling2d = [lay.UpSampling2D() for i in range(3)]

    # Quando o modelo for chamado para inferência ou treinamento, faz:
    def call(self, z, *args, **kwargs):

        batch_size = z.shape[0]
        w = self.mapping_network(z)
        initial_map = tf.tile(self.latent_map,(batch_size,1,1,1))

        # Bloco do 4x4
        x = initial_map + self.noises[0]*tf.random.normal((batch_size,4,4,1))
        x = self.AdaIN[0]([x,w])
        x = self.conv2d[0](x) + self.noises[1]*tf.random.normal((batch_size,4,4,1))
        x = self.AdaIN[1]([x,w])

        # Bloco do 8x8
        x = self.upsampling2d[0](x)
        x = self.conv2d[1](x) + self.noises[2]*tf.random.normal((batch_size,8,8,1))
        x = self.AdaIN[2]([x,w])
        x = self.conv2d[2](x) + self.noises[3]*tf.random.normal((batch_size,8,8,1))
        x = self.AdaIN[3]([x,w])

        # Bloco do 16x16
        x = self.upsampling2d[1](x)
        x = self.conv2d[3](x) + self.noises[4]*tf.random.normal((batch_size,16,16,1))
        x = self.AdaIN[4]([x,w])
        x = self.conv2d[4](x) + self.noises[5]*tf.random.normal((batch_size,16,16,1))
        x = self.AdaIN[5]([x,w])

        # Bloco do 32x32
        x = self.upsampling2d[2](x)
        x = self.conv2d[5](x) + self.noises[6]*tf.random.normal((batch_size,32,32,1))
        x = self.AdaIN[6]([x,w])
        x = self.conv2d[6](x) + self.noises[7]*tf.random.normal((batch_size,32,32,1))
        x = self.AdaIN[7]([x,w])

        x = self.conv2d[-1](x)

        return x

In [None]:
generator = StyleGAN()

# vgg16 = keras.applications.vgg16.VGG16(include_top=False,input_shape=(32,32,3))
# discriminator = []
# for layer in vgg16.layers:
#     # if isinstance(layer,lay.Conv2D):
#     #     discriminator.append(lay.SpectralNormalization(layer))
#     # else:
#     discriminator.append(layer)

# discriminator = keras.Sequential([*discriminator,lay.Flatten(),lay.Dense(256,activation='leaky_relu'),lay.Dense(1,activation='sigmoid')])
# discriminator.summary()

resnet = keras.applications.resnet.ResNet50(include_top=False,input_shape=(32,32,3))
x = lay.Flatten()(resnet.output)
x = lay.Dense(1,activation='sigmoid')(x)
discriminator = keras.Model(resnet.input,x)

gan = [generator,discriminator]

opt = [keras.optimizers.RMSprop(1e-4), keras.optimizers.RMSprop(1e-4)]

In [None]:
@tf.function
def train_step(gan,data,opt,batch_size,epoch):
    g_loss = 0.
    d_loss = 0.
    for batch in data:
        latent_z = tf.random.normal((batch_size,128))
        with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
            fake_imgs = gan[0](latent_z, trainable=True)
            true_logis = gan[1](batch,trainable=True)
            fake_logits = gan[1](fake_imgs, trainable=True)
            g_loss = keras.losses.binary_crossentropy(tf.ones_like(fake_logits), fake_logits)# - 1e-4*tf.reduce_sum(tf.math.reduce_std(fake_imgs,axis=0))
            d_loss = keras.losses.binary_crossentropy(tf.ones_like(true_logis),true_logis)+keras.losses.binary_crossentropy(tf.zeros_like(fake_logits),fake_logits)
            if epoch%8 == 0:
                d_loss += 5*r1_regularization(gan[1],batch)


        g_grads = g_tape.gradient(g_loss,gan[0].trainable_variables)
        opt[0].apply_gradients(zip(g_grads,gan[0].trainable_variables))

        d_grads = d_tape.gradient(d_loss,gan[1].trainable_variables)
        opt[1].apply_gradients(zip(d_grads,gan[1].trainable_variables))
    return tf.reduce_mean(g_loss), tf.reduce_mean(d_loss)


In [None]:
for i in range(10001):
    g_loss, d_loss = train_step(gan,cars,opt,batch_size,i)
    if i%10 == 0:
        print(f'{i} - G = {g_loss:.4f}; D = {d_loss:.4f}')

    if i % 50 == 0:
        n = 10
        img = gan[0](tf.random.normal((n**2,128)),training=False)

        fig, ax = plt.subplots(n,n,figsize=(7,7))
        ax = ax.ravel()
        for ii in range(n**2):
            ax[ii].matshow(np.uint8(img[ii]*127.5+127.5))
            ax[ii].set_axis_off()
        plt.tight_layout(pad=0)
        plt.savefig(f'fig_{i}.png')
        plt.close()
        generator.save('style.keras')

In [None]:
n = 15
img = gan[0](tf.random.normal((n**2,128)),training=False)

fig, ax = plt.subplots(n,n,figsize=(6,6))
ax = ax.ravel()
for i in range(n**2):
    ax[i].matshow(np.uint8(img[i]*127.5+127.5))
    ax[i].set_axis_off()
plt.tight_layout(pad=0)
plt.show()