In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [None]:
import pickle
import time
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
from IPython import display
from sklearn.preprocessing import LabelEncoder

import utils_img_rec as ut
import discriminator as dis
import generator as gen


In [None]:
data_path = "data/mnist_train_1000-pickle.pickle"
out_path = "data/output/mnist/"

IMG_SIZE = 28
channels = 1

BUFFER_SIZE = 10000
BATCH_SIZE = 200

EPOCHS = 100
noise_dim = 100
num_examples_to_generate = 16

In [None]:
import tensorflow as tf
if tf.test.gpu_device_name():
    print('GPU encontrada:', tf.test.gpu_device_name())
else:
    print('Nenhuma GPU encontrada')
tf.__version__

In [None]:
pickle_in = open(data_path,"rb")
data_train = pickle.load(pickle_in)

In [None]:
CATEGORIES = ut.get_classes(data_train)
CATEGORIES.sort()
try:
    CATEGORIES.remove('.ipynb_checkpoints')
except:
    pass

print(CATEGORIES)
num_of_classes = len(CATEGORIES)
print(num_of_classes)

In [None]:
train_X, train_y = ut.prep_data(data=data_train, CATEGORIES=CATEGORIES, IMG_SIZE=IMG_SIZE, num_of_channels=channels)
print('Entradas de treino - {} - ({}x{})'.format( train_X.shape[0], train_X.shape[1], train_X.shape[2] ))
print(train_X.shape)

In [None]:
ut.plot_images(data_train)

In [None]:
train_images = np.array(train_X)
train_images = train_images.reshape(-1, IMG_SIZE, IMG_SIZE, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_y)) \
                               .shuffle(BUFFER_SIZE) \
                               .batch(BATCH_SIZE)

In [None]:
generator = gen.make_generator_model(noise_dim + num_of_classes)

extra_values = tf.constant([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]])  # Tensor de shape (1,2)
noise = tf.concat([tf.random.normal([1, noise_dim]), extra_values], axis=1)  # Agora shape será (1, 102)

generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')

In [None]:
discriminator = dis.make_discriminator_model(num_of_labels=num_of_classes)
decision = discriminator([generated_image,extra_values])
print (decision)

In [None]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

In [None]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
seed, _ = ut.create_seed(num_examples_to_generate=num_examples_to_generate, noise_dim=noise_dim, num_classes=num_of_classes)
seed.shape

In [None]:
@tf.function
def train_step(images, labels):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    labels = tf.cast(labels, tf.float32)  # Converte os labels para float32
    noise = tf.concat([noise, labels], axis=1)

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        # Passa os dois inputs para o discriminador: imagem e label
        real_output = discriminator([images, labels], training=True)
        fake_output = discriminator([generated_images, labels], training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [None]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        for image_batch, label_batch in dataset:  # Agora pegamos rótulos junto com imagens
            train_step(image_batch, label_batch)  # Passamos diretamente para train_step

        # Exibir imagens geradas ao longo do treinamento
        display.clear_output(wait=True)
        ut.generate_and_save_images(generator, epoch + 1, seed, out_path)

        # Salvar modelo a cada 15 épocas
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)

        print(f'Time for epoch {epoch + 1} is {time.time()-start:.2f} sec')

    # Gerar imagens após a última época
    display.clear_output(wait=True)
    ut.generate_and_save_images(generator, epochs, seed, out_path)
    print(f'Total training time {time.time()-start:.2f} sec after {epoch+1} epochs')

In [None]:
train(train_dataset, EPOCHS)

Restore the latest checkpoint.

In [None]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open(out_path+'image_at_epoch_{:04d}.png'.format(epoch_no))

In [None]:
display_image(EPOCHS)

In [None]:
anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob(out_path+'image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

In [None]:
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)