In [None]:
import numpy as np
import tensorflow as tf
import keras.api as keras
from keras.api.layers import *
import matplotlib.pyplot as plt

In [None]:
(x_train,y_train),(_,_) = keras.datasets.cifar10.load_data()

x_train = x_train[y_train[:,0] == 8]
x_train = (tf.cast(x_train,tf.float32)-127.5) / 127.5

n = np.random.randint(0,x_train.shape[0])

batch_size = 128

data = tf.data.Dataset.from_tensor_slices((x_train)).shuffle(x_train.shape[0]).batch(batch_size,drop_remainder=True)

complete_hist = {
    'loss_dis': [],
    'loss_gen': [],
}

plt.imshow(x_train[n])
plt.axis(False)
plt.show()

In [None]:
class miniBatch(keras.layers.Layer):
    def __init__(self,num_kernels,kernel_dim,batch_size):
        super(miniBatch,self).__init__()
        self.num_kernels = num_kernels
        self.kernel_dim = kernel_dim
        self.batch_size = batch_size
    
    def build(self, input_shape):
        print(input_shape)
        self.T = self.add_weight(
            shape=(input_shape[-1],self.num_kernels*self.kernel_dim), # Teoricamente 128x500
            initializer='random_normal',
            trainable=True,
        )

    def call(self, x):
        M = tf.matmul(x,self.T) # teoricamente 128x128 \times 128x500 = 128x500
        M = tf.reshape(M,(-1,self.num_kernels,self.kernel_dim)) # teoricamente 128x100x5
        M_T = tf.expand_dims(M,1) # teoricamente 128x1x100x5
        M = tf.expand_dims(M,0) # teoricamente 1x128x100x5
        diff = tf.abs(M-M_T)
        exp_diff = tf.exp(-tf.reduce_mean(diff,-1))
        miniBatch_features = tf.reduce_sum(exp_diff,1)
        output = tf.concat([x,miniBatch_features],-1)
        return output
    
    def compute_output_shape(self, input_shape):
        # Define a forma de saída explicitamente
        return (input_shape[0], input_shape[1] + self.num_kernels)

class PixelShuffle(keras.layers.Layer):
    def __init__(self, upscale_factor):
        super(PixelShuffle, self).__init__()
        self.upscale_factor = upscale_factor

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size = input_shape[0]
        h = input_shape[1]
        w = input_shape[2]
        c = input_shape[3]
        out_c = c // (self.upscale_factor ** 2)
        x = tf.reshape(inputs, (batch_size, h, w, self.upscale_factor, self.upscale_factor, out_c))
        x = tf.transpose(x, [0, 1, 2, 4, 3, 5])
        x = tf.reshape(x, (batch_size, h * self.upscale_factor, w * self.upscale_factor, out_c))
        return x
    
    def compute_output_shape(self, input_shape):
        h, w, c = input_shape[1], input_shape[2], input_shape[3]
        out_c = c // (self.upscale_factor ** 2)
        return (input_shape[0], h * self.upscale_factor, w * self.upscale_factor, out_c)
    

class SelfAttention(Layer):
    def __init__(self, filters):
        super(SelfAttention, self).__init__()
        self.filters = filters
        self.query_conv = Conv2D(filters // 8, kernel_size=1)
        self.key_conv = Conv2D(filters // 8, kernel_size=1)
        self.value_conv = Conv2D(filters, kernel_size=1)
        self.softmax = Softmax(axis=-1)
    
    def call(self, x):
        batch, height, width, channels = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]

        Q = tf.reshape(self.query_conv(x), (batch, height * width, -1))  # [B, HW, C/8]
        K = tf.reshape(self.key_conv(x), (batch, -1, height * width))    # [B, C/8, HW]
        V = tf.reshape(self.value_conv(x), (batch, height * width, -1))  # [B, HW, C]

        attention_map = self.softmax(tf.matmul(Q, K))  # [B, HW, HW]

        attention_output = tf.matmul(attention_map, V)  # [B, HW, C]
        attention_output = tf.reshape(attention_output, (batch, height, width, channels))

        return attention_output + x



    
def residual_Gblock(x):

    skip = x
    x = Conv2D(x.shape[-1],3,1,'same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(.2)(x)
    x = Conv2D(x.shape[-1],3,1,'same')(x)
    x = Add()([x,skip])

    return x

def residual_Dblock(x):

    skip = x
    x = Conv2D(x.shape[-1],3,1,'same')(x)
    x = LeakyReLU(.2)(x)
    x = Conv2D(x.shape[-1],3,1,'same')(x)
    x = Add()([x,skip])
    x = LeakyReLU(.2)(x)

    return x


def create_discriminator():
    
    input = Input((32,32,3))
    x = input
    k = 1
    for _ in range(4):
        x = Conv2D(512//k,3,2,'same')(x)
        x = LeakyReLU(.2)(x)
        k = k*2

    x = Flatten()(x)
    x = miniBatch(100,5,128)(x)
    x = Dropout(.3)(x)
    x = Dense(1)(x)
    x = Activation('sigmoid')(x)
    

    return keras.Model(input,x,name='Discriminator')

def create_generator():

    input = Input((128,))

    x = Dense(8*8*2048)(input)
    x = LeakyReLU(.2)(x)
    x = BatchNormalization()(x)
    x = Reshape((8,8,2048))(x)
    x = SelfAttention(x.shape[-1])(x)

    for _ in range(2):
        x = UpSampling2D()(x)
        x = Conv2DTranspose(256,3,1,'same')(x)
        x = LeakyReLU(.2)(x)
        x = BatchNormalization()(x)

    x = Conv2DTranspose(3,3,1,'same')(x)
    x = Activation('tanh')(x)

    return keras.Model(input,x)

gen = create_generator()
dis = create_discriminator()
dis.summary()
gen.summary()

In [None]:
bce = keras.losses.BinaryCrossentropy()
MSE = keras.losses.MeanSquaredError()

LR = 1e-4

gen_opt = keras.optimizers.Adam(LR,.5)
dis_opt = keras.optimizers.Adam(LR/2,.5)
n = 5
noise_out = tf.random.normal((n**2,128))

In [None]:
for i,layer in enumerate(dis.layers):
    print(f'L = {i}, layer = {layer}')

#n = np.random.uniform(0,1,(1,32,32,3))

maps = keras.Model(dis.input, dis.layers[7].output)
# map = maps(n)

In [None]:
def clip_weights(model,clip_value=0.1):
    
    for layer in model.trainable_variables:
        layer.assign(tf.clip_by_value(layer, -clip_value, clip_value))

# def gradient_penalty(discriminator, real_samples, fake_samples):
#     alpha = tf.random.uniform((real_samples.shape[0],), 0, 1)
#     interpolated_samples = alpha[0] * real_samples + (1 - alpha[0]) * fake_samples
#     with tf.GradientTape() as tape:
#         tape.watch(interpolated_samples)
#         predictions = discriminator(interpolated_samples)
#     gradients = tape.gradient(predictions, interpolated_samples)
#     gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1))
#     penalty = tf.reduce_mean((gradients_norm - 1) ** 2)
#     return penalty

# def wasser_dist(y_true,y_pred):
#     return tf.reduce_mean(y_true * y_pred)




@tf.function
def train_step():
    gen_loss,dis_loss = 0.,0.
    gen_loss_iter,dis_loss_iter = 0.,0.
    for batch in data:
        
        noise = tf.random.normal((batch_size,128))

        with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
            fake_imgs = gen(noise,training=True)
            true_labels = dis(batch,training=True)
            fake_labels = dis(fake_imgs,training=True)
            #true_map = maps(batch,training=True)
            #fake_map = maps(fake_imgs, training=True)

            
            gen_loss_iter = bce(tf.ones_like(fake_labels),fake_labels) #- (1/2)*tf.reduce_mean(tf.math.reduce_std(fake_imgs,axis=0))
            dis_loss_iter = bce(tf.ones_like(true_labels),true_labels) + bce(tf.zeros_like(fake_labels),fake_labels)
            # dis_loss_iter = tf.reduce_mean(fake_labels) - tf.reduce_mean(true_labels)
            # gen_loss_iter = -tf.reduce_mean(fake_labels) + MSE(true_map,fake_map)
        
        gen_gras = gen_tape.gradient(gen_loss_iter,gen.trainable_variables)
        gen_opt.apply_gradients(zip(gen_gras,gen.trainable_variables))

        dis_grads = dis_tape.gradient(dis_loss_iter,dis.trainable_variables)
        dis_opt.apply_gradients(zip(dis_grads,dis.trainable_variables))

        gen_loss += gen_loss_iter/batch_size
        dis_loss += dis_loss_iter/batch_size
        gen_loss_iter,dis_loss_iter = 0.,0.

    return gen_loss/tf.cast(len(data),tf.float32),dis_loss/tf.cast(len(data),tf.float32)

In [None]:
EPOCHS = 5000
EPOCH_SAMPLE = 10

loss_dis, loss_gen = 0.,0.

for i in range(EPOCHS):

    # Histórico de Loss
    
    loss_gen, loss_dis = train_step()
    #loss_gen, loss_dis = loss_gen.numpy(), loss_dis.numpy()
    clip_weights(dis,0.1)
    complete_hist['loss_gen'].append(loss_gen)
    complete_hist['loss_dis'].append(loss_dis)
    
    # Iteração das épocas
    if i % EPOCH_SAMPLE == 0:
        # Print Loss
        print(f'Ep = {i} | Loss_gen = {loss_gen:.4f}; Loss_dis = {loss_dis:.4f}')
        # Salvar uma amostra das imagens
        img_fake = gen(noise_out)
        fig, ax = plt.subplots(n,n,figsize=(1,1))
        ax = ax.ravel()
        for ii in range(n**2):
            ax[ii].imshow(np.uint8(img_fake[ii]*127.5+127.5))
            ax[ii].set_axis_off()
        plt.savefig(f'imgs/fig{i}.png',dpi=1000)
        plt.close()

    plt.semilogy(np.array(complete_hist['loss_gen']),label=f'GEN = {loss_gen:.4f}')
    plt.semilogy(np.array(complete_hist['loss_dis']),label=f'DIS = {loss_dis:.4f}')
    plt.legend()
    plt.grid(True,'minor')
    plt.savefig('loss.png')
    plt.close()


print('==================== COMPLETE ====================')

# plt.Figure()
# plt.plot(loss_list[:,0],label='Discriminator')
# plt.plot(loss_list[:,1],label='GAN')
# plt.legend()
# plt.show()

In [None]:
gen.save('generator_cars.keras')
dis.save('discriminator_cars.keras')

In [None]:
n = 5

noise_out = tf.random.normal((1000,1024))

img_fake = gen(noise_out)
out_true = dis(x_train[0:1000])
out_fake = dis(gen(tf.random.uniform((1000,1024))))

print(f'Média True = {tf.reduce_mean(out_true):.4f} \nMédia False = {tf.reduce_mean(out_fake):.4f}')

plt.Figure()
plt.semilogy(np.array(complete_hist['loss_gen']),label='GEN')
plt.semilogy(np.array(complete_hist['loss_dis']),label='DIS')
plt.legend()
plt.grid(True,'minor')
plt.show()


fig, ax = plt.subplots(n,n,figsize=(10,10))
ax = ax.ravel()

for i in range(n**2):
    ax[i].imshow(img_fake[i])
    ax[i].set_axis_off()
plt.show()