In [None]:
import os
import pandas as pd
import torch
from PIL import Image
from diffusers import StableDiffusionXLImg2ImgPipeline
import gc
from tqdm import tqdm

class SDXLImageProcessor:
    def __init__(self, 
                 gpu_id=0,
                 csv_path="datos.csv",
                 input_dir="imagenes/resize768",
                 output_dir="imagenes/output",
                 num_inference_steps=25,
                 strength=0.4,
                 guidance_scale=7.5,
                 start_index=0,
                 end_index=-1,
                 batch_size= 3):

        self.gpu_id = gpu_id
        self.csv_path = csv_path
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.num_inference_steps = num_inference_steps
        self.strength = strength
        self.guidance_scale = guidance_scale
        self.start_index = start_index
        self.end_index = end_index
        self.pipe = None
        self.batch_size = 3
    
    def setup_pipeline(self):
        """Configura el pipeline SDXL optimizado para la GPU especificada"""
        # Seleccionar GPU específica
        torch.cuda.set_device(self.gpu_id)
        
        # Cargar el modelo optimizado
        self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-refiner-1.0",
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True
        )
        
        # Mover el modelo a la GPU seleccionada
        self.pipe = self.pipe.to(f"cuda:{self.gpu_id}")
        
        # Optimizaciones de memoria y rendimiento
        self.pipe.enable_model_cpu_offload()  # Mantiene solo los componentes necesarios en GPU
        self.pipe.enable_vae_slicing()  # Reduce el uso de memoria del VAE
        self.pipe.enable_xformers_memory_efficient_attention()  # Usa xformers para atención eficiente
        
        print(f"Pipeline configurado en GPU {self.gpu_id}")
    
    def process_images_batch(self):
        """Procesa las imágenes en lotes para mayor velocidad"""
        if self.pipe is None:
            self.setup_pipeline()
        
        os.makedirs(self.output_dir, exist_ok=True)
        df = pd.read_csv(self.csv_path)
        
        start_idx = self.start_index
        end_idx = len(df) if self.end_index == -1 else self.end_index
        df_subset = df.iloc[start_idx:end_idx]
        
        # Procesar en lotes
        batch_images = []
        batch_prompts = []
        batch_paths = []
        
        for idx, row in tqdm(df_subset.iterrows(), total=len(df_subset), desc=f"GPU {self.gpu_id}"):
            try:
                file_path = row['file_name']
                prompt = row['processed_prompt']
                image_path = self.input_dir + "/" + file_path
                
                if not os.path.exists(image_path):
                    print(f"Advertencia: Imagen no encontrada - {image_path}")
                    continue
                
                # Agregar al lote actual
                init_image = Image.open(image_path).convert("RGB")
                batch_images.append(init_image)
                batch_prompts.append(prompt)
                batch_paths.append(file_path)
                
                # Procesar cuando el lote esté completo o al final
                if len(batch_images) == self.batch_size or idx == df_subset.index[-1]:
                    # Generar imágenes en lote
                    outputs = self.pipe(
                        prompt=batch_prompts,
                        image=batch_images,
                        num_inference_steps=self.num_inference_steps,
                        strength=self.strength,
                        guidance_scale=self.guidance_scale,
               
                    )
                    
                    # Guardar resultados
                    for i, output_img in enumerate(outputs.images):
                        output_filename = f"{os.path.splitext(os.path.basename(batch_paths[i]))[0]}_generated.png"
                        output_path = os.path.join(self.output_dir, output_filename)
                        output_img.save(output_path)
                    
                    # Limpiar el lote
                    batch_images = []
                    batch_prompts = []
                    batch_paths = []
                    
                    # Limpieza periódica de memoria
                    if idx % 20 == 0:
                        torch.cuda.empty_cache()
                        gc.collect()
            
            except Exception as e:
                print(f"Error procesando {file_path}: {str(e)}")
                
                # Reiniciar el lote en caso de error
                batch_images = []
                batch_prompts = []
                batch_paths = []
        
        self._cleanup()
        print(f"Procesamiento completado en GPU {self.gpu_id}")
    
    def _cleanup(self):
        """Libera recursos y memoria"""
        if self.pipe is not None:
            del self.pipe
            self.pipe = None
        torch.cuda.empty_cache()
        gc.collect()

def run_processor(params):
    processor = SDXLImageProcessor(**params)
    processor.process_images_batch()

# Ejemplo de uso para GPU 0
if __name__ == "__main__":
    processor_gpu0 = SDXLImageProcessor(
        gpu_id=0,
        csv_path="compressPromptLite_resize768.csv",
        input_dir="imagenes/resize768",
        output_dir="imagenes/output_gpu0",
        start_index=0,
        end_index=100  
    )
    
    processor_gpu0.process_images_batch()


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Pipeline configurado en GPU 0


GPU 0:   0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

GPU 0:   4%|▍         | 4/100 [00:11<04:32,  2.84s/it]

  0%|          | 0/10 [00:00<?, ?it/s]

GPU 0:   8%|▊         | 8/100 [00:22<04:13,  2.75s/it]

  0%|          | 0/10 [00:00<?, ?it/s]

GPU 0:  12%|█▏        | 12/100 [00:33<04:00,  2.74s/it]

  0%|          | 0/10 [00:00<?, ?it/s]

GPU 0:  16%|█▌        | 16/100 [00:43<03:48,  2.72s/it]

  0%|          | 0/10 [00:00<?, ?it/s]

GPU 0:  20%|██        | 20/100 [00:54<03:37,  2.71s/it]

  0%|          | 0/10 [00:00<?, ?it/s]

GPU 0:  23%|██▎       | 23/100 [01:01<03:26,  2.68s/it]


KeyboardInterrupt: 

In [None]:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
# import pandas as pd
# import torch
# from diffusers import StableDiffusionXLImg2ImgPipeline
# from PIL import Image
# import matplotlib.pyplot as plt
# import os
# from IPython.display import display

# # Cargar el dataset
# df = pd.read_csv('too.csv')  

# example_row = df.iloc[2]
# file_path = "imagenes/resizeSD/" + example_row['file_name']
# prompt = example_row['generated_prompt']

# print(f"Procesando: {file_path}")
# print(f"Prompt: {prompt}")

# # Cargar el modelo
# pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
#     "stabilityai/stable-diffusion-xl-refiner-1.0",
#     torch_dtype=torch.float16,
#     variant="fp16",
#     use_safetensors=True
# )

# # Mover a GPU si está disponible
# device = "cuda" if torch.cuda.is_available() else "cpu"
# pipe = pipe.to(device)

# # Función para procesar una imagen
# def process_image(image_path, prompt, strength=0.75, guidance_scale=7.5, num_inference_steps=50):
#     """
#     Procesa una imagen con SDXL usando image-to-image.
    
#     Parámetros:
#     - image_path: Ruta a la imagen de entrada
#     - prompt: Texto que guía la generación
#     - strength: Qué tanto se modificará la imagen (0-1)
#     - guidance_scale: Cuánto seguir el prompt (valores típicos: 7-9)
#     - num_inference_steps: Número de pasos de inferencia
    
#     Devuelve la imagen generada
#     """
#     try:
#         # Cargar y preparar la imagen
#         init_image = Image.open(image_path).convert("RGB")
        
#         # Mostrar imagen original
#         plt.figure(figsize=(10, 10))
#         plt.imshow(init_image)
#         plt.title("Imagen Original")
#         plt.axis('off')
#         plt.show()
        
#         # Generar la nueva imagen
#         image = pipe(
#             prompt=prompt,
#             image=init_image,
#             strength=strength,
#             guidance_scale=guidance_scale,
#             num_inference_steps=num_inference_steps,
#         ).images[0]
        
#         # Mostrar la imagen generada
#         plt.figure(figsize=(10, 10))
#         plt.imshow(image)
#         plt.title(f"Imagen Generada\nStrength: {strength}, Guidance: {guidance_scale}, Steps: {num_inference_steps}")
#         plt.axis('off')
#         plt.show()
        
#         return image
    
#     except Exception as e:
#         print(f"Error al procesar la imagen: {e}")
#         return None

# # Probar con diferentes parámetros para ver el efecto
# strengths = [0.4, 0.5]  # Cuánto modificar la imagen original
# guidance_scales = [7.5, 9.0]  # Qué tanto seguir el prompt

# for strength in strengths:
#     for guidance in guidance_scales:
#         print(f"\nPrueba con strength={strength}, guidance_scale={guidance}")
#         result = process_image(
#             file_path, 
#             prompt, 
#             strength=strength,
#             guidance_scale=guidance,
#             num_inference_steps=50
#         )

# # Función para procesar varias imágenes del dataset
# def process_dataset_sample(df, num_samples=3, strength=0.75, guidance_scale=8.5, steps=50):
#     """
#     Procesa varias imágenes del dataset
#     """
#     # Tomar algunas muestras aleatorias
#     sample_df = df.sample(num_samples) if len(df) > num_samples else df
    
#     for idx, row in sample_df.iterrows():
#         print(f"\nProcesando imagen {idx+1}/{len(sample_df)}")
#         file_path = row['file_name']
#         prompt = row['generated_prompt']
        
#         print(f"Archivo: {file_path}")
#         print(f"Prompt: {prompt}")
        
#         result = process_image(
#             file_path,
#             prompt,
#             strength=strength,
#             guidance_scale=guidance_scale,
#             num_inference_steps=steps
#         )

# recommended_params = {
#     'strength': 0.85,         # Alto para permitir cambios significativos
#     'guidance_scale': 9.0,    # Fuerte adherencia al prompt
#     'num_steps': 60           # Más pasos para mejor calidad
# }

# print("\n--- Prueba con parámetros recomendados ---")
# print(f"Strength: {recommended_params['strength']}")
# print(f"Guidance Scale: {recommended_params['guidance_scale']}")
# print(f"Steps: {recommended_params['num_steps']}")

# result = process_image(
#     file_path,
#     prompt,
#     strength=recommended_params['strength'],
#     guidance_scale=recommended_params['guidance_scale'],
#     num_inference_steps=recommended_params['num_steps']
# )