In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT))

import tensorflow as tf
import matplotlib.pyplot as plt
from architectures import SRGAN
from image_processing import downscale, load_image_paths, load_and_preprocess

# -----------------------------
# Configuración
# -----------------------------
UP_RATIO = 4
BATCH_SIZE = 8
EPOCHS = 30          # puedes subirlo cuando tengas GPU
HR_SIZE = (256, 256)  # tamaño de imagen HR para entrenamiento
DATA_FOLDER = "DIV2K_train_HR"  # ruta a tu dataset de imágenes HR

In [None]:
print("GPUs visibles:", tf.config.list_physical_devices("GPU"))

In [None]:
# -----------------------------
# Generador de datos
# -----------------------------
def build_dataset(image_paths, hr_size, up_ratio, batch_size, training=True):
    ds = tf.data.Dataset.from_tensor_slices(image_paths)

    if training:
        ds = ds.shuffle(buffer_size=len(image_paths))

    ds = ds.map(
        lambda p: load_and_preprocess(p, hr_size, up_ratio),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.prefetch(tf.data.AUTOTUNE)

    return ds

# -----------------------------
# Mostrar imagenes predecidas por el modelo
# -----------------------------

def show_example(model, dataset, image_path=None, hr_size=(256, 256), up_ratio=4):
    """
    Muestra un ejemplo de la imagen desde la ruta de archivo(si se proporciona, sino se escoge de maenra aleatoria): LR → SR → HR
    """
    # Si `image_path` se proporciona, cargar esa imagen
    if image_path is not None:
        lr, hr = load_and_preprocess(image_path, hr_size, up_ratio)
        # Añadir dimensión de batch para que sea (1, h, w, 3)
        lr = tf.expand_dims(lr, axis=0)
        
    else:
        # Si no se proporciona, tomar una imagen aleatoria del dataset
        for lr_batch, hr_batch in dataset.take(1):
            lr = lr_batch[0:1]   # Seleccionar la primera imagen del lote
            hr = hr_batch[0]     # Seleccionar la primera imagen del lote
            

    # Predicción
    sr = model(lr, training=False)[0]

    # Convertir a numpy para visualización
    lr_np = lr[0].numpy()
    sr_np = sr.numpy()
    hr_np = hr.numpy()

    # Mostrar imágenes
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.title("Low Res (Input)")
    plt.imshow(lr_np)
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.title("Super Res SRGAN (Prediction)")
    plt.imshow(sr_np)
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.title("High Res (Target)")
    plt.imshow(hr_np)
    plt.axis("off")

    plt.show()

In [None]:
# -----------------------------
# Crear y compilar modelo
# -----------------------------
model = SRGAN(up_ratio=UP_RATIO, num_blocks=8, filters=64, lambda_adv=0.05)
gen_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9)

model.compile(gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer, jit_compile=False)
model.generator = tf.keras.models.load_model("generator.keras", compile=False) # Cargar generador pre-entrenado

model.summary()

# -----------------------------
# Entrenamiento
# -----------------------------
image_paths = load_image_paths(DATA_FOLDER)

train_ds = build_dataset(image_paths, HR_SIZE, UP_RATIO, BATCH_SIZE, training=True)

steps_per_epoch = len(image_paths) // BATCH_SIZE

history = model.fit(train_ds, epochs=EPOCHS, steps_per_epoch=steps_per_epoch)

In [None]:
# Mostrar tantos ejemplos como quieras
for _ in range(5):
    show_example(model, train_ds, up_ratio=UP_RATIO)