In [None]:
# inference.py
# -----------------------------------------------------------------------------
# Handler de inferencia para SageMaker (TensorFlow 2.x)
# Entradas:
#   POST /invocations
#   Content-Type: application/json
#   {
#       "frames": ["<b64_jpeg_1>", "<b64_jpeg_2>", ...]   # cualquier nº de frames
#   }
# Salidas:
#   {
#       "prediction": 2,
#       "probs": [0.01, 0.12, 0.80, 0.07]                 # soft-max (orden 0-3)
#   }
# -----------------------------------------------------------------------------
import os, io, json, base64
import numpy as np
import tensorflow as tf
from PIL import Image

# -------------  Hyperparámetros (deben coincidir con entrenamiento) ----------
NUM_FRAMES   = 32
TARGET_SIZE  = (224, 224)
FEAT_DIM     = 1280                     # EfficientNet-B0
# -----------------------------------------------------------------------------

# ----------------------------- Utilidades ------------------------------------
def _sample_indices(n_total, n_target):
    """Devuelve índices equiespaciados para escoger n_target de n_total."""
    if n_total <= n_target:
        return list(range(n_total))
    step = n_total / n_target
    return [int(i * step) for i in range(n_target)]

def _decode_and_resize(b64_str):
    """Base64 -> PIL -> ndarray RGB normalizado a [0,255], tamaño TARGET_SIZE."""
    img_bytes = base64.b64decode(b64_str)
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.resize(TARGET_SIZE, Image.BILINEAR)
    return np.asarray(img, dtype=np.float32)

# ----------------------------- SageMaker API ---------------------------------
def model_fn(model_dir):
    """
    Cargamos ambos modelos:
      - temporal: tu Bi-LSTM/Transformer entrenado (best_model.h5)
      - base: EfficientNet-B0 (ImageNet, sin top)
    """
    temporal = tf.keras.models.load_model(
        os.path.join(model_dir, "best_model.h5"), compile=False
    )

    base = tf.keras.applications.EfficientNetB0(
        include_top=False, weights="imagenet", pooling="avg",
        input_shape=TARGET_SIZE + (3,)
    )
    preprocess = tf.keras.applications.efficientnet.preprocess_input

    # Guardamos referencias en un dict que SageMaker propagará
    return {"temporal": temporal, "base": base, "pre": preprocess}

def input_fn(request_body, content_type):
    if content_type != "application/json":
        raise ValueError(f"Tipo de contenido no soportado: {content_type}")

    body = json.loads(request_body)
    if "frames" not in body or not isinstance(body["frames"], list):
        raise ValueError("JSON debe contener la clave 'frames' con una lista.")

    # Lista de tensores imágenes (sin procesar aún)
    frames = [_decode_and_resize(b64) for b64 in body["frames"]]
    return frames    # lo pasamos tal cual a predict_fn

def predict_fn(frames, models):
    base        = models["base"]
    temporal    = models["temporal"]
    preprocess  = models["pre"]

    # 1) Sub-muestreo / padding para llegar a NUM_FRAMES
    idxs    = _sample_indices(len(frames), NUM_FRAMES)
    frames  = [frames[i] for i in idxs]

    if len(frames) < NUM_FRAMES:                      # padding si hace falta
        pad = [frames[-1]] * (NUM_FRAMES - len(frames))
        frames.extend(pad)

    # 2) Preprocesamiento EfficientNet
    frames_arr = np.stack(frames, axis=0)             # (NUM_FRAMES, H, W, 3)
    frames_arr = preprocess(frames_arr)

    # 3) Extracción de características en batch
    feats = base.predict(frames_arr, verbose=0)       # (NUM_FRAMES, 1280)

    # 4) Temporal model (batch=1)
    feats = feats.astype(np.float32)[None, ...]       # (1, NUM_FRAMES, 1280)
    preds = temporal.predict(feats, verbose=0)[0]     # (4,)

    return preds                                      # vector softmax

def output_fn(prediction, accept):
    if accept not in ("application/json", "application/json; charset=utf-8"):
        raise ValueError(f"Tipo de salida no soportado: {accept}")

    pred_class = int(np.argmax(prediction))
    resp = json.dumps({"prediction": pred_class,
                       "probs": [float(p) for p in prediction]})
    return resp, "application/json"
