<a href="https://colab.research.google.com/github/RoyMusango/edl-starter/blob/main/Recalage_final_IRPA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# app.py
import os
import io
import base64
from datetime import datetime

from flask import Flask, render_template, request, jsonify
import numpy as np
from PIL import Image

import cv2
from skimage import exposure

app = Flask(__name__, static_folder='static', template_folder='templates')

UPLOAD_DIR = os.path.join(os.path.dirname(__file__), 'uploads')
os.makedirs(UPLOAD_DIR, exist_ok=True)

# ---------- Utilitaires I/O ----------
def pil_to_u8_rgb(pil_img):
    if pil_img.mode not in ('RGB', 'L'):
        pil_img = pil_img.convert('RGB')
    arr = np.array(pil_img)
    if arr.ndim == 2:
        return arr
    if arr.shape[2] == 3:
        return arr
    return cv2.cvtColor(arr, cv2.COLOR_RGBA2RGB)

def imread_file_storage(fs):
    img = Image.open(fs.stream)
    img.load()
    return pil_to_u8_rgb(img)

def to_base64_png(img_u8):
    if img_u8.ndim == 2:
        pil = Image.fromarray(img_u8, mode='L')
    else:
        pil = Image.fromarray(img_u8, mode='RGB')
    buf = io.BytesIO()
    pil.save(buf, format='PNG')
    return base64.b64encode(buf.getvalue()).decode('ascii')

# ---------- Blocs de traitement (reprennent tes fonctions) ----------
def _to_uint8(img):
    from PIL import Image as _PILImage
    if isinstance(img, _PILImage.Image):
        if img.mode in ('RGBA', 'P'):
            img = img.convert('RGBA').convert('RGB')
        elif img.mode not in ('RGB', 'L'):
            img = img.convert('RGB')
        img = np.array(img)
    if isinstance(img, (list, tuple)):
        img = np.array(img)
    if not isinstance(img, np.ndarray):
        raise TypeError("Format d'image non supporté.")
    if img.dtype == np.uint8:
        return img
    if img.dtype.kind == 'f':
        img = img if img.max() > 1.0 else (img * 255.0)
    else:
        img = np.clip(img, 0, 255)
    return img.astype(np.uint8)

def _ensure_gray_u8(img):
    g = _to_uint8(img)
    if g.ndim == 3:
        g = cv2.cvtColor(g, cv2.COLOR_RGB2GRAY)
    return g

def _img_shape_wh(img_u8):
    h, w = img_u8.shape[:2]
    return w, h

def make_orb():
    return cv2.ORB_create(
        nfeatures=5000, scaleFactor=1.2, nlevels=8, edgeThreshold=19,
        firstLevel=0, WTA_K=2, scoreType=cv2.ORB_HARRIS_SCORE,
        patchSize=31, fastThreshold=12
    )

def detect_describe(gray_u8, orb=None):
    if orb is None:
        orb = make_orb()
    kps, des = orb.detectAndCompute(gray_u8, None)
    return kps or [], des

def match_descriptors(des_ref, des_mov, ratio=0.75):
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
    if des_ref is None or des_mov is None:
        return []
    knn_rm = bf.knnMatch(des_ref, des_mov, k=2)
    pre = [(m, n) for m, n in knn_rm if len([m, n]) == 2 and m.distance < ratio * n.distance]
    pre = [m for m, _ in pre]
    knn_mr = bf.knnMatch(des_mov, des_ref, k=2)
    pre_rev = {}
    for m, n in knn_mr:
        if m.distance < ratio * n.distance:
            pre_rev[(m.trainIdx, m.queryIdx)] = m.distance
    sym = [m for m in pre if (m.queryIdx, m.trainIdx) in pre_rev]
    sym.sort(key=lambda x: x.distance)
    return sym

def bucket_matches(kp_ref, kp_mov, matches, img_size, grid=(8, 8), max_per_cell=80):
    if len(matches) == 0:
        return []
    w, h = img_size
    gx, gy = grid
    cell_w, cell_h = w / gx, h / gy
    buckets = {(ix, iy): [] for ix in range(gx) for iy in range(gy)}
    for m in matches:
        x, y = kp_ref[m.queryIdx].pt
        ix = min(gx - 1, max(0, int(x // cell_w)))
        iy = min(gy - 1, max(0, int(y // cell_h)))
        buckets[(ix, iy)].append(m)
    kept = []
    for key in buckets:
        cell = sorted(buckets[key], key=lambda mm: mm.distance)[:max_per_cell]
        kept.extend(cell)
    kept = sorted(kept, key=lambda m: m.distance)
    return kept

def _pts_from_matches(kp_ref, kp_mov, matches):
    pts_ref = np.float32([kp_ref[m.queryIdx].pt for m in matches])
    pts_mov = np.float32([kp_mov[m.trainIdx].pt for m in matches])
    return pts_ref, pts_mov

def _err_reproj_mean(pts_src, pts_dst, H, is_homography):
    if is_homography:
        pts_src_h = cv2.convertPointsToHomogeneous(pts_src)[:, 0, :]
        proj_h = (H @ pts_src_h.T).T
        proj = proj_h[:, :2] / proj_h[:, 2:3]
    else:
        proj = cv2.transform(pts_src.reshape(1, -1, 2), H).reshape(-1, 2)
    return np.linalg.norm(proj - pts_dst, axis=1).mean()

def _estimate_models_and_score(ref_gray, mov_gray, kp_ref, kp_mov, matches, img_size, rth=3.0):
    results = []
    if len(matches) < 8:
        return results
    pts_ref, pts_mov = _pts_from_matches(kp_ref, kp_mov, matches)

    A_sim, inl_sim = cv2.estimateAffinePartial2D(pts_mov, pts_ref, method=cv2.RANSAC,
                                                 ransacReprojThreshold=rth, confidence=0.999, maxIters=5000)
    if A_sim is not None and inl_sim is not None:
        inliers = int(inl_sim.ravel().sum()); ratio = inliers / max(1, len(matches))
        err = _err_reproj_mean(pts_mov[inl_sim.ravel() == 1], pts_ref[inl_sim.ravel() == 1], A_sim, is_homography=False) if inliers > 0 else 1e9
        warped = cv2.warpAffine(mov_gray, A_sim, _img_shape_wh(ref_gray), flags=cv2.INTER_LINEAR)
        results.append(dict(model_name='similarity', ok=inliers >= 6, inlier_ratio=ratio, err_px=err,
                            warped=warped, H=A_sim, is_h=False))

    A_aff, inl_aff = cv2.estimateAffine2D(pts_mov, pts_ref, method=cv2.RANSAC,
                                          ransacReprojThreshold=rth, confidence=0.999, maxIters=5000)
    if A_aff is not None and inl_aff is not None:
        inliers = int(inl_aff.ravel().sum()); ratio = inliers / max(1, len(matches))
        err = _err_reproj_mean(pts_mov[inl_aff.ravel() == 1], pts_ref[inl_aff.ravel() == 1], A_aff, is_homography=False) if inliers > 0 else 1e9
        warped = cv2.warpAffine(mov_gray, A_aff, _img_shape_wh(ref_gray), flags=cv2.INTER_LINEAR)
        results.append(dict(model_name='affine', ok=inliers >= 6, inlier_ratio=ratio, err_px=err,
                            warped=warped, H=A_aff, is_h=False))

    H, mask = cv2.findHomography(pts_mov, pts_ref, method=cv2.RANSAC, ransacReprojThreshold=max(rth, 4.0), maxIters=5000, confidence=0.999)
    if H is not None and mask is not None:
        inliers = int(mask.ravel().sum()); ratio = inliers / max(1, len(matches))
        err = _err_reproj_mean(pts_mov[mask.ravel() == 1], pts_ref[mask.ravel() == 1], H, is_homography=True) if inliers > 0 else 1e9
        warped = cv2.warpPerspective(mov_gray, H, _img_shape_wh(ref_gray), flags=cv2.INTER_LINEAR)
        results.append(dict(model_name='homography', ok=inliers >= 10, inlier_ratio=ratio, err_px=err,
                            warped=warped, H=H, is_h=True))
    return results

def votre_pipeline_orb_ransac(img_ref_gray_u8, img_mov_gray_u8, verbose=False):
    orb = make_orb()
    kp_ref, des_ref = detect_describe(img_ref_gray_u8, orb)
    kp_mov, des_mov = detect_describe(img_mov_gray_u8, orb)
    if des_ref is None or des_mov is None or len(kp_ref) < 8 or len(kp_mov) < 8:
        return dict(ok=False, reason='peu_de_points', matches=[], kp_ref=kp_ref, kp_mov=kp_mov)
    matches = match_descriptors(des_ref, des_mov, ratio=0.75)
    matches = bucket_matches(kp_ref, kp_mov, matches, _img_shape_wh(img_ref_gray_u8), grid=(8, 8), max_per_cell=80)
    matches = matches[:2000]
    if len(matches) < 12:
        return dict(ok=False, reason='trop_peu_de_matches', matches=matches, kp_ref=kp_ref, kp_mov=kp_mov)
    model_results = _estimate_models_and_score(img_ref_gray_u8, img_mov_gray_u8, kp_ref, kp_mov, matches, _img_shape_wh(img_ref_gray_u8))
    if not model_results:
        return dict(ok=False, reason='estimation_impossible', matches=matches, kp_ref=kp_ref, kp_mov=kp_mov)
    model_results.sort(key=lambda r: (-r['inlier_ratio'], r['err_px'], -len(matches)))
    best = model_results[0]
    return dict(ok=best['ok'], model_name=best['model_name'], inlier_ratio=float(best['inlier_ratio']),
                err_px=float(best['err_px']), warped=best['warped'], H=best['H'], is_h=best['is_h'],
                matches=matches, matches_count=len(matches), kp_ref=kp_ref, kp_mov=kp_mov)

def approche_robuste_recalage(img_NB, img_color):
    ref_rgb_u8 = _to_uint8(img_color)
    ref_gray_u8 = cv2.cvtColor(ref_rgb_u8, cv2.COLOR_RGB2GRAY) if ref_rgb_u8.ndim == 3 else ref_rgb_u8
    mov_gray_u8 = _ensure_gray_u8(img_NB)
    # multi-variantes (CLAHE / blur)
    def run(gray_ref, gray_mov):
        return votre_pipeline_orb_ransac(gray_ref, gray_mov, verbose=False)
    A = run(ref_gray_u8, mov_gray_u8)
    clahe = cv2.createCLAHE(clipLimit=1.6, tileGridSize=(8, 8))
    B = run(clahe.apply(ref_gray_u8), clahe.apply(mov_gray_u8))
    C = run(cv2.GaussianBlur(ref_gray_u8, (3, 3), 0.6), cv2.GaussianBlur(mov_gray_u8, (3, 3), 0.6))
    candidates = {'A': A, 'B': B, 'C': C}
    def score(x): return (-x.get('inlier_ratio', 0.0), x.get('err_px', 1e9), -x.get('matches_count', 0))
    best_key = sorted(candidates.keys(), key=lambda k: score(candidates[k]))[0]
    return candidates[best_key]

def adapter_histogramme(cible, reference):
    return exposure.match_histograms(cible, reference)

def normaliser_histogrammes(img_color, img_NB):
    if img_color.ndim == 3:
        img_color_gray = cv2.cvtColor(img_color, cv2.COLOR_RGB2GRAY)
    else:
        img_color_gray = img_color
    img_NB_eq = exposure.equalize_hist(img_NB)
    img_color_eq = exposure.equalize_hist(img_color_gray)
    return img_NB_eq, img_color_eq


def _norm01(x):
    x = x.astype(np.float32)
    mn, mx = x.min(), x.max()
    if mx <= mn:
        return np.zeros_like(x, dtype=np.float32)
    return (x - mn) / (mx - mn)

def _gauss(img, sigma):
    k = max(3, int(6*sigma+1)//2*2+1)
    return cv2.GaussianBlur(img, (k, k), sigma)

def _grad_energy(gray_u8):
    # Scharr = gradients stables
    gx = cv2.Scharr(gray_u8, cv2.CV_32F, 1, 0)
    gy = cv2.Scharr(gray_u8, cv2.CV_32F, 0, 1)
    mag = np.sqrt(gx*gx + gy*gy)
    return _norm01(mag)

def _local_contrast(gray_u8, sigma=2.0):
    f = gray_u8.astype(np.float32)/255.0
    mu = _gauss(f, sigma)
    mu2 = _gauss(f*f, sigma)
    var = np.clip(mu2 - mu*mu, 0, 1)
    std = np.sqrt(var)
    return _norm01(std)

def _build_pyramids(img, levels):
    G = [img.astype(np.float32)]
    for _ in range(levels-1):
        G.append(cv2.pyrDown(G[-1]))
    L = [G[-1]]
    for i in range(levels-1, 0, -1):
        up = cv2.pyrUp(G[i], dstsize=(G[i-1].shape[1], G[i-1].shape[0]))
        L.append(G[i-1] - up)
    L.reverse()
    return G, L  # Gaussian, Laplacian

def _blend_pyramids(LA, LB, GW):
    # LA/LB: Laplaciennes, GW: gaussiennes de poids B (même nb de niveaux)
    out = []
    for lA, lB, gW in zip(LA, LB, GW):
        W = gW.astype(np.float32)
        out.append((1.0 - W) * lA + W * lB)
    # reconstruction
    res = out[-1]
    for i in range(len(out)-2, -1, -1):
        res = cv2.pyrUp(res, dstsize=(out[i].shape[1], out[i].shape[0])) + out[i]
    return res

def construire_superposition_robuste(img_color_matched, result_align, img_NB_norm_uint8,
                                     prefer_nb_on_whites=True,
                                     levels=4,
                                     softness=1.6,
                                     w_white=0.55, w_lowgrad=0.30, w_lowcont=0.15):
    """
    Renvoie une image finale sans "trous blancs" en privilégiant NB alignée
    là où la couleur normalisée est peu informative. Mélange multi‑résolution.
    - prefer_nb_on_whites: si True, booste le poids NB sur zones très claires.
    - levels: niveaux de pyramide (3-5 conseillé).
    - softness: contrôle la pente (sigmoïde) de la carte poids.
    - w_*: pondérations (doivent sommer ~1).
    """
    base = _to_uint8(img_color_matched)
    base_gray = cv2.cvtColor(base, cv2.COLOR_RGB2GRAY) if base.ndim == 3 else base

    if result_align.get('ok') and result_align.get('warped') is not None:
        nb_aligned = _to_uint8(result_align['warped'])
    else:
        nb_aligned = _to_uint8(img_NB_norm_uint8)

    if nb_aligned.shape[:2] != base_gray.shape[:2]:
        nb_aligned = cv2.resize(nb_aligned, (base_gray.shape[1], base_gray.shape[0]), interpolation=cv2.INTER_LINEAR)

    # 1) Cartes mesures
    white = base_gray.astype(np.float32)/255.0                          # blancheur
    grad  = _grad_energy(base_gray)                                     # énergie gradient
    cont  = _local_contrast(base_gray, sigma=2.0)                       # contraste local

    low_grad = 1.0 - grad                                              # faible structure
    low_cont = 1.0 - cont

    # 2) Carte de poids NB (où on veut remplacer par NB)
    w_nb = w_white*white + w_lowgrad*low_grad + w_lowcont*low_cont
    if not prefer_nb_on_whites:
        w_nb = w_lowgrad*low_grad + w_lowcont*low_cont
    w_nb = _norm01(w_nb)

    # Douceur (sigmoïde) pour éviter des bords durs
    w_nb = 1.0 / (1.0 + np.exp(-softness*(w_nb - 0.5)))

    # 3) Lissage supplémentaire
    w_nb = _gauss(w_nb, 1.2)
    w_nb = np.clip(w_nb, 0.0, 1.0)

    # 4) Mélange pyramidal (évite halos/joints)
    A = base_gray.astype(np.float32)/255.0
    B = nb_aligned.astype(np.float32)/255.0

    # Pyramides
    GA, LA = _build_pyramids(A, levels)
    GB, LB = _build_pyramids(B, levels)
    GW = []
    g = w_nb.astype(np.float32)
    for _ in range(levels):
        GW.append(g)
        g = cv2.pyrDown(g)

    blended = _blend_pyramids(LA, LB, GW)
    out = np.clip(blended*255.0, 0, 255).astype(np.uint8)
    return out

# Version simple, très rapide (fallback)
def construire_superposition_adaptative(img_color_matched, result_align, img_NB_norm_uint8,
                                        t_white=None, blur=5, alpha_min=0.1, alpha_max=0.9):
    base = _to_uint8(img_color_matched)
    base_gray = cv2.cvtColor(base, cv2.COLOR_RGB2GRAY) if base.ndim == 3 else base
    nb_aligned = _to_uint8(result_align['warped']) if result_align.get('ok') and result_align.get('warped') is not None else _to_uint8(img_NB_norm_uint8)
    if nb_aligned.shape[:2] != base_gray.shape[:2]:
        nb_aligned = cv2.resize(nb_aligned, (base_gray.shape[1], base_gray.shape[0]), interpolation=cv2.INTER_LINEAR)

    # seuil automatique si non fourni (percentile 90)
    if t_white is None:
        t_white = int(np.percentile(base_gray, 90))

    mask_white = (base_gray >= t_white).astype(np.float32)
    mask_white = cv2.GaussianBlur(mask_white, (blur|1, blur|1), 0)
    # poids NB fort sur blancs; sinon mix doux
    alpha_map = alpha_min + (alpha_max - alpha_min) * mask_white
    out = (1.0 - alpha_map) * base_gray.astype(np.float32) + alpha_map * nb_aligned.astype(np.float32)
    return np.clip(out, 0, 255).astype(np.uint8)


# FIN DE LA NOUVELLE FONCTION ------------------******************----------------------------***************************-----------------------------***************************----------------------------

def resize_to(img, ref_shape_hw):
    rh, rw = ref_shape_hw
    if img.ndim == 2:
        return cv2.resize(img, (rw, rh), interpolation=cv2.INTER_LINEAR)
    else:
        return cv2.resize(img, (rw, rh), interpolation=cv2.INTER_LINEAR)

# ---------- Routes ----------
@app.route('/')
def home():
    return render_template('index.html')

@app.post('/upload')
def upload_compat():
    file = request.files.get('image')
    if not file:
        return "Aucune image reçue", 400
    file.save(os.path.join(UPLOAD_DIR, file.filename))
    return f"Image {file.filename} uploadée avec succès", 200

@app.post('/uploads')
def uploads_triplet():
    try:
        rgb = request.files.get('imageRGB')
        nb = request.files.get('imageNB')
        mask = request.files.get('Masque')  # optionnel

        # Masque optionnel, on exige seulement RGB et NB
        missing = [k for k, v in (('imageRGB', rgb), ('imageNB', nb)) if v is None]
        if missing:
            return jsonify(ok=False, error=f"Champs manquants: {', '.join(missing)}"), 400

        # Sauvegarde brute
        ts = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
        for prefix, f in (('rgb', rgb), ('nb', nb), ('mask', mask)):
            if f:
                f.save(os.path.join(UPLOAD_DIR, f"{prefix}_{ts}_{f.filename}"))

        # Lecture
        img_color = imread_file_storage(rgb)  # RGB uint8
        img_nb_any = imread_file_storage(nb)
        img_nb = img_nb_any if img_nb_any.ndim == 2 else cv2.cvtColor(img_nb_any, cv2.COLOR_RGB2GRAY)

        # Normalisations
        img_NB_norm_f, img_color_norm_gray_f = normaliser_histogrammes(img_color, img_nb)
        nb_eq_u8 = (np.clip(img_NB_norm_f, 0, 1) * 255.0).astype(np.uint8)
        color_gray_eq_u8 = (np.clip(img_color_norm_gray_f, 0, 1) * 255.0).astype(np.uint8)

        # Adapter l’histo de la couleur (gris) vers la NB
        img_color_matched_f = adapter_histogramme(color_gray_eq_u8.astype(np.float32) / 255.0,
                                                  nb_eq_u8.astype(np.float32) / 255.0)
        img_color_matched_u8 = (np.clip(img_color_matched_f, 0, 1) * 255.0).astype(np.uint8)

        # Recalage NB -> couleur (référence = couleur originale)
        result_align = approche_robuste_recalage(nb_eq_u8, img_color)
        nb_aligned_u8 = _to_uint8(result_align['warped']) if result_align.get('ok') and result_align.get('warped') is not None else nb_eq_u8

        # Superposition finale
        # La ligne suivante est l'appel de l'overlay de superpositio avec l'ancienne fonction
        """
        overlay_final_u8 = construire_superposition(img_color_matched_u8, result_align, nb_eq_u8, alpha=0.5)
        """

        # UTILISATION DE LA NOUVELLE FONCTION --------------*********************-------------------***********
        overlay_final_u8 = construire_superposition_robuste(
        img_color_matched_u8, result_align, nb_eq_u8,
        prefer_nb_on_whites=True, levels=4, softness=1.6,
        w_white=0.55, w_lowgrad=0.30, w_lowcont=0.15)
        # FIN D'UTILISATION DE LA NOUVELLE FONCTION --------------*********************-------------------***********

        # Assurer TOUTES les images à la même taille (référence = img_color_matched_u8)
        ref_h, ref_w = img_color_matched_u8.shape[:2]
        ref_shape = (ref_h, ref_w)

        original_color_rgb_u8 = resize_to(img_color, ref_shape)
        original_nb_u8        = resize_to(img_nb, ref_shape)
        color_norm_u8         = resize_to(img_color_matched_u8, ref_shape)        # gris (matched)
        nb_norm_u8            = resize_to(nb_eq_u8, ref_shape)
        nb_norm_aligned_u8    = resize_to(nb_aligned_u8, ref_shape)
        overlay_final_u8      = resize_to(overlay_final_u8, ref_shape)

        out = {
            "ok": True,
            "model": result_align.get('model_name', None),
            "inlier_ratio": float(result_align.get('inlier_ratio', 0.0)),
            "err_px": float(result_align.get('err_px', 0.0)),

            # 5 images à taille uniforme
            "original_color_rgb_b64": to_base64_png(original_color_rgb_u8),
            "original_nb_b64":        to_base64_png(original_nb_u8),
            "color_normalized_b64":   to_base64_png(color_norm_u8),
            "nb_norm_b64":            to_base64_png(nb_norm_u8),
            "nb_norm_aligned_b64":    to_base64_png(nb_norm_aligned_u8),
            "overlay_final_b64":      to_base64_png(overlay_final_u8)
        }
        return jsonify(out), 200

    except Exception as e:
        import traceback; traceback.print_exc()
        return jsonify(ok=False, error=str(e)), 500

if __name__ == '__main__':
    app.run(host="127.0.0.1", port=8000, debug=True)
