In [7]:
import gradio as gr
import torch
import torchvision
from torchvision import models, transforms
from PIL import Image
import os
import warnings
from datetime import datetime
import architectures.zero_dce_model as zero
import architectures.risk_model as risk

# Configuración
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore")

# Cargar modelos
def load_quality_model(num_classes=1, device="cuda"):
    model = models.resnet18(pretrained=False)
    model.fc = torch.nn.Sequential(torch.nn.Linear(model.fc.in_features, num_classes))
    model = model.to(device)
    model.load_state_dict(torch.load(r'architectures/weights/quality_model.pth', map_location=device))
    model.eval()
    return model

def load_zero_dce(device="cuda"):
    model = zero.zero_dce().to(device)
    model.load_state_dict(torch.load('architectures/weights/zero_dce.pth', map_location=device))
    model.eval()
    return model

def load_risk_model(input_dim=64, device="cuda"):
    model = risk.RiskDetectionModel(input_dim).to(device)
    model.load_state_dict(torch.load('architectures/weights/risk_detection_model.pth', map_location=device))
    model.eval()
    return model

quality_model = load_quality_model(device=device)
DCE_net = load_zero_dce(device=device)
risk_model = load_risk_model(device=device)

# Transformaciones
quality_transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
])

enhancement_transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),  
])

risk_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Función de predicción
def predict_image(image):
    if image is None:
        return "Error: No se ha recibido ninguna imagen.", None, None

    # Predecir calidad
    image_tensor = quality_transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = quality_model(image_tensor)
        output_s = torch.sigmoid(output)
        predicted_class = (output_s > 0.5).float().cpu().numpy()

    quality_category = "bad_quality" if predicted_class[-1][-1] == 1 else "good_quality"

    # Guardar imagen original
    class_folder = f'./output/{quality_category}'
    os.makedirs(class_folder, exist_ok=True)
    image_name = f"image_{datetime.now().strftime('%Y%m%d-%H%M%S')}.jpg"
    image_path = os.path.join(class_folder, image_name)
    image.save(image_path)

    # Mejorar calidad si es de mala calidad
    if quality_category == "bad_quality":
        image_to_enhancement = enhancement_transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            _, enhanced_image, _ = DCE_net(image_to_enhancement)

        enhanced_folder = './output/good_quality'
        os.makedirs(enhanced_folder, exist_ok=True)
        enhanced_image_path = os.path.join(enhanced_folder, image_name)
        torchvision.utils.save_image(enhanced_image, enhanced_image_path)
        image_path = enhanced_image_path
        quality_category = "enhancement"
        #risk_image = risk_transform(Image.open(image_path)).unsqueeze(0).to(device)
    #else:
    risk_image = risk_transform(image).unsqueeze(0).to(device)

    # Predecir riesgo
    with torch.no_grad():
        risk_prediction = risk_model(risk_image)
        risk_class = (risk_prediction > 0.5).float().cpu().numpy()
    risk_class = "risky" if risk_class[-1][-1] == 1 else "no risky"

    return quality_category, risk_class, image_path

# Interfaz de Gradio
iface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),  
    outputs=[
        gr.Text(label="Quality"), 
        gr.Text(label="Risk Category"), 
        gr.Image(label="Final Image")
    ],
    live=True
)

# Ejecutar la interfaz
iface.launch()


* Running on local URL:  http://127.0.0.1:7866

To create a public link, set `share=True` in `launch()`.


