In [80]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os

In [81]:
# Configura la ruta a tus datos
data_dir = "./simpsons_dataset/"

# Asegúrate de que las imágenes sean del mismo tamaño
img_height = 200
img_width = 200
batch_size = 64

In [82]:
# Crea un generador de datos
datagen = keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)


In [83]:
# Carga las imágenes desde el directorio
train_generator = datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='input',  # Utilizaremos las imágenes originales como objetivo
    subset='training'
)

Found 7902 images belonging to 1 classes.


In [84]:
validation_generator = datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='input',
    subset='validation'
)


Found 1975 images belonging to 1 classes.


# Construir el generador y el discriminador

In [85]:
# Modifica el discriminador para imágenes de 200x200 píxeles
def build_discriminator():
    model = keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(200, 200, 3)))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    # Agrega más capas convolucionales si es necesario

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

# Modifica el generador para imágenes de 200x200 píxeles
def build_generator():
    model = keras.Sequential()
    model.add(layers.Dense(25 * 25 * 256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((25, 25, 256)))
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    # Agrega más capas convolucionales transpuestas si es necesario

    model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

    return model

generator = build_generator()
discriminator = build_discriminator()


## Funciones de perdidas

In [86]:
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = keras.optimizers.Adam(1e-4)
discriminator_optimizer = keras.optimizers.Adam(1e-4)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)



## Entrenamiento

In [113]:
@tf.function
def train_discriminator(images):
    noise = tf.random.normal([batch_size, 100])

    with tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

@tf.function
def train_generator():
    noise = tf.random.normal([batch_size, 100])

    with tf.GradientTape() as gen_tape:
        generated_images = generator(noise, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))


In [114]:
steps_per_epoch = 100

num_epochs = 100
for epoch in range(num_epochs):
    for _ in range(steps_per_epoch):
        image_batch, _ = next(train_generator)  # Opcionalmente puedes desempaquetar las etiquetas si es necesario
        train_discriminator(image_batch)
        train_generator()


TypeError: 'Function' object is not an iterator