In [None]:
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import requests
from datetime import datetime

from transformers import (
    SegformerImageProcessor,
    AutoModelForSemanticSegmentation,
    AutoImageProcessor,
    AutoModelForImageClassification,
    SiglipForImageClassification
)

# =================== MODELOS =====================
seg_processor = SegformerImageProcessor.from_pretrained(
    "mattmdjaga/segformer_b2_clothes"
)
seg_model = AutoModelForSemanticSegmentation.from_pretrained(
    "mattmdjaga/segformer_b2_clothes"
)

cls_processor = AutoImageProcessor.from_pretrained(
    "wargoninnovation/wargon-clothing-classifier"
)
cls_model = AutoModelForImageClassification.from_pretrained(
    "wargoninnovation/wargon-clothing-classifier"
)

shoes_processor = AutoImageProcessor.from_pretrained(
    "prithivMLmods/shoe-type-detection"
)
shoes_model = SiglipForImageClassification.from_pretrained(
    "prithivMLmods/shoe-type-detection"
)

# =================== LABELS =====================
SEG_LABELS = { 0:"Fondo",1:"Sombrero",3:"Gafas de sol",4:"Parte superior de la ropa",
               5:"Falda",6:"Pantalones",7:"Vestido",8:"Cinturón",9:"Zapato izquierdo",
               10:"Zapato derecho",17:"Bufanda"}

CLOTHING_SEG_IDS = [4, 5, 6, 7, 9, 10]

CLOTHING_LABELS = [
    "Blazer","Blusa","Cárdigan","Vestido","Sudadera con capucha","Chaqueta","Vaqueros",
    "Camisón","Ropa de abrigo","Pijama","Chaqueta de lluvia","Pantalones de lluvia",
    "Bata","Camisa","Pantalones cortos","Falda","Suéter","Camiseta","Camiseta sin mangas",
    "Medias","Top","Top de entrenamiento","Pantalones","Túnica","Chaleco",
    "Chaqueta de invierno","Pantalones de invierno"
]

shoes_labels = {
    0: "Ballet Flat",
    1: "Boat",
    2: "Brogue",
    3: "Clog",
    4: "Sneaker"
}

# oye chapi pon shoes_labels a español  
shoes_labels = {
    0: "Cholas/Bailarinas",
    1: "Botas",
    2: "Mocasines",
    3: "Zuecos",
    4: "Deportivas"
}

COLORS = {
    4:(255,0,0), 5:(0,255,0), 6:(0,0,255),
    7:(255,255,0), 9:(0,255,255), 10:(255,0,255)
}

# =================== UTILS =====================
def overlay_mask(image, seg_map, alpha=0.5):
    img = np.array(image)
    overlay = img.copy()
    for cls, color in COLORS.items():
        mask = seg_map == cls
        overlay[mask] = (
            (1 - alpha) * img[mask] + alpha * np.array(color)
        ).astype(np.uint8)
    return overlay

def get_clothing_text(results):
    ropa = [label for _, label in results]
    return ", ".join(ropa) if ropa else "No se detectó ropa"

# =================== ZAPATOS =====================
def classify_shoes(image: np.ndarray):
    image = Image.fromarray(image).convert("RGB")
    inputs = shoes_processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = shoes_model(**inputs)
        probs = F.softmax(outputs.logits, dim=1).squeeze().tolist()

    return {
        shoes_labels[i]: round(probs[i], 3)
        for i in range(len(probs))
    }

# =================== RECOMENDACIÓN =====================
def get_recommendation(localizacion, ropa, hora, temperatura):
    url = "http://localhost:8000/recommend"
    payload = {
        "localizacion": localizacion,
        "ropa": ropa,
        "hora": hora,
        "temperatura": temperatura
    }
    try:
        res = requests.post(url, json=payload, timeout=10)
        res.raise_for_status()
        return res.json().get("recommendation", "No disponible")
    except Exception as e:
        print(e)
        return "No se pudo obtener recomendación"

# =================== WEATHER =====================
WEATHER_CODES = {
    0:"Despejado",1:"Parcialmente nublado",2:"Nublado",3:"Lluvia ligera",
    61:"Lluvia ligera",63:"Lluvia moderada",65:"Lluvia fuerte",95:"Tormenta"
}

def get_weather_by_coords(lat, lon):
    url = "https://api.open-meteo.com/v1/forecast"
    params = {
        "latitude": lat,
        "longitude": lon,
        "current_weather": True,
        "timezone": "auto"
    }
    r = requests.get(url, params=params)
    data = r.json().get("current_weather", {})
    temp = data.get("temperature", "?")
    wind = data.get("windspeed", "?")
    code = data.get("weathercode", 0)
    today = datetime.now().strftime("%d/%m/%Y")
    return (
        f" {today}\n"
        f" Temperatura: {temp} °C\n"
        f" Viento: {wind} km/h\n"
        f" Estado: {WEATHER_CODES.get(code, 'Desconocido')}"
    )


def traducir_texto(texto, target="es"):
    try:
        from deep_translator import GoogleTranslator
        return GoogleTranslator(source='auto', target=target).translate(texto)
    except Exception as e:
        print("Error al traducir:", e)
        return texto
    

# =================== LOCALIZACIÓN =====================
def obtener_ubicacion():
    url = "https://ipinfo.io/json"
    try:
        respuesta = requests.get(url)
        datos = respuesta.json()
        ciudad = datos.get("city", "Desconocida")
        region = datos.get("region", "Desconocida")
        loc = datos.get("loc", "")
        if loc:
            latitud, longitud = loc.split(',')
        else:
            latitud = longitud = "Desconocida"
        ciudad_es = traducir_texto(ciudad)
        region_es = traducir_texto(region)
        return ciudad_es, region_es, latitud, longitud
    except Exception as e:
        print(f"Error obteniendo ubicación: {e}")       
        return "Desconocida", "Desconocida", "Desconocida", "Desconocida"

# =================== PIPELINE =====================
def segment_and_classify(img: np.ndarray, lat, lon):
    if img is None:
        return None, [], "Sin datos", "Sin recomendación"

    ciudad, region, lat, lon = obtener_ubicacion()
    weather_info = f"{region}\n{get_weather_by_coords(lat, lon)}"

    image = Image.fromarray(img.astype("uint8"), "RGB")

    # -------- SEGMENTACIÓN --------
    inputs = seg_processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = seg_model(**inputs)

    logits = torch.nn.functional.interpolate(
        outputs.logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False
    )

    seg_map = logits.argmax(dim=1)[0].cpu().numpy()
    overlay = overlay_mask(image, seg_map)

    results = []

    for seg_id in CLOTHING_SEG_IDS:
        mask = seg_map == seg_id
        ys, xs = np.where(mask)
        if len(xs) < 600:
            continue

        x1, x2 = xs.min(), xs.max()
        y1, y2 = ys.min(), ys.max()
        crop = image.crop((x1, y1, x2, y2))
        crop_np = np.array(crop)

        # -------- ZAPATOS --------
        if seg_id in [9, 10]:
            shoe_preds = classify_shoes(crop_np)
            best = max(shoe_preds, key=shoe_preds.get)
            conf = shoe_preds[best] * 100
            results.append((crop_np, f"{best} ({conf:.1f}%)"))

        # -------- ROPA --------
        else:
            cls_inputs = cls_processor(images=crop, return_tensors="pt")
            with torch.no_grad():
                out = cls_model(**cls_inputs)
                probs = torch.softmax(out.logits, dim=-1)

            pred_id = probs.argmax(dim=-1).item()
            label = CLOTHING_LABELS[pred_id]
            conf = probs[0, pred_id].item() * 100
            results.append((crop_np, f"{label} ({conf:.1f}%)"))

    ropa_texto = get_clothing_text(results)
    hora = datetime.now().strftime("%H:%M")
    temp = weather_info.split("Temperatura: ")[1].split("°C")[0] + " °C"
    recomendacion = get_recommendation(region, ropa_texto, hora, temp)

    return overlay, results, weather_info, recomendacion

# =================== GRADIO =====================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Asesor de Outfits según Clima")

    with gr.Row():
        input_img = gr.Image(type="numpy", label="Imagen de entrada")
        output_overlay = gr.Image(label="Segmentación")

    with gr.Row():
        with gr.Column(scale=2):
            output_gallery = gr.Gallery(label="Prendas detectadas", columns=3)
        with gr.Column(scale=1):
            output_weather = gr.Textbox(label="Clima", lines=5)
            output_recommendation = gr.Markdown(label="Recomendación")

    btn = gr.Button("Analizar Outfit", variant="primary")
    input_lat = gr.Textbox(visible=False)
    input_lon = gr.Textbox(visible=False)

    btn.click(
        fn=segment_and_classify,
        inputs=[input_img, input_lat, input_lon],
        outputs=[output_overlay, output_gallery, output_weather, output_recommendation]
    )

demo.launch()
