In [39]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm
from scipy.ndimage import binary_dilation

# Fonction pour vérifier si un point est dans les limites spécifiées
def is_in_bound(y, x, bounds_list, patch_size):
    delta = patch_size // 2
    x_0, y_0 = bounds_list[0], bounds_list[1]
    x_1, y_1 = bounds_list[2], bounds_list[3]
    return (x_0 - delta <= x <= x_1 + delta) and (y_0 - delta <= y <= y_1 + delta)

# Fonction pour vérifier si un point est à l'intérieur de l'image
def is_in_image(y, x, source_region, patch_size):
    delta = patch_size // 2
    H, W, C = source_region.shape
    return (delta <= y <= H - delta - 1) and (delta <= x <= W - delta - 1)

# Fonction pour calculer les centres de patches valides
def compute_valid_patch_centers(filled_region_mask, patch_size):
    delta = patch_size // 2
    H, W = filled_region_mask.shape
    # Commencer par un masque de positions valides (True)
    valid_centers = np.ones((H, W), dtype=bool)
    # Exclure les bords de l'image où les patches dépasseraient
    valid_centers[:delta, :] = False
    valid_centers[-delta:, :] = False
    valid_centers[:, :delta] = False
    valid_centers[:, -delta:] = False
    # Dilater le masque de la région cible pour exclure les patches chevauchants
    dilated_mask = binary_dilation(filled_region_mask == 0, iterations=delta)
    valid_centers &= ~dilated_mask  # Exclure les positions invalides
    return valid_centers

# Fonction pour initialiser le NNField
def initialize_NNField(source_region, target_region, bounds_list, patch_size, valid_centers):
    delta = patch_size // 2
    h, w, c = target_region.shape
    H, W, C = source_region.shape
    NNField = np.zeros((h, w, 3), dtype="int32")
    # Trouver les indices des centres valides
    valid_indices = np.column_stack(np.where(valid_centers))
    for i in range(delta, h - delta):
        for j in range(delta, w - delta):
            # Sélectionner un centre valide aléatoire
            random_idx = np.random.randint(len(valid_indices))
            y_source, x_source = valid_indices[random_idx]
            if not valid_centers[y_source, x_source]:
                print("Patch dans la région cible")
                continue
            NNField[i, j, :2] = [y_source, x_source]
            # Calculer le SSD initial
            target_patch = target_region[i - delta:i + delta + 1, j - delta:j + delta + 1, :]
            source_patch = source_region[y_source - delta:y_source + delta + 1,
                                         x_source - delta:x_source + delta + 1, :]
            diff_all = (source_patch - target_patch) ** 2
            NNField[i, j, 2] = np.sum(diff_all)
    return NNField

# Fonction pour mettre à jour la propagation
def update_propagation(source_region, target_region, i, j, NNField, patch_size, dx, dy, valid_centers):
    delta = patch_size // 2
    h, w, c = target_region.shape
    H, W, C = source_region.shape
    if not (0 <= i + dy < h and 0 <= j + dx < w):
        return
    offset_neighbor = NNField[i + dy, j + dx, :2]
    y_source_neighbor, x_source_neighbor = offset_neighbor + [dy, dx]
    # Vérifier si le nouveau centre est valide
    if not (delta <= y_source_neighbor < H - delta and delta <= x_source_neighbor < W - delta):
        return
    if not valid_centers[y_source_neighbor, x_source_neighbor]:
        return
    target_patch = target_region[i - delta:i + delta + 1, j - delta:j + delta + 1, :]
    source_patch_neighbor = source_region[y_source_neighbor - delta:y_source_neighbor + delta + 1,
                                          x_source_neighbor - delta:x_source_neighbor + delta + 1, :]
    diff_neighbor = (source_patch_neighbor - target_patch) ** 2
    ssd_neighbor = np.sum(diff_neighbor)
    if ssd_neighbor < NNField[i, j, 2]:
        NNField[i, j, :2] = [y_source_neighbor, x_source_neighbor]
        NNField[i, j, 2] = ssd_neighbor

# Fonction de propagation globale
def propagation(source_region, target_region, bounds_list, NNField, patch_size, valid_centers):
    delta = patch_size // 2
    h, w, c = target_region.shape
    H, W, C = source_region.shape
    # Première étape: de haut en bas et de gauche à droite
    for j in range(delta + 1, w - delta):
        for i in range(delta + 1, h - delta):
            # De haut en bas
            update_propagation(source_region, target_region, i, j, NNField, patch_size, 0, -1, valid_centers)
            # De gauche à droite
            update_propagation(source_region, target_region, i, j, NNField, patch_size, -1, 0, valid_centers)
    # Deuxième étape: de bas en haut et de droite à gauche
    for j in range(w - delta - 2, delta - 1, -1):
        for i in range(h - delta - 2, delta - 1, -1):
            # De bas en haut
            update_propagation(source_region, target_region, i, j, NNField, patch_size, 0, 1, valid_centers)
            # De droite à gauche
            update_propagation(source_region, target_region, i, j, NNField, patch_size, 1, 0, valid_centers)

# Fonction de recherche aléatoire
def random_search(source_region, target_region, NNField, patch_size, valid_centers, n_iter=5):
    delta = patch_size // 2
    h, w, c = target_region.shape
    H, W, _ = source_region.shape

    for i in range(delta, h - delta):
        for j in range(delta, w - delta):
            y_source, x_source = NNField[i, j, :2]
            best_ssd = NNField[i, j, 2]

            for t in range(n_iter):
                search_radius = int(max(H, W) * (0.5 ** t))

                if search_radius < 1:
                    break  # Sortir si le rayon de recherche est trop petit

                max_attempts = 10  # Limiter le nombre de tentatives pour trouver une position valide
                attempts = 0

                while attempts < max_attempts:
                    y_random = y_source + np.random.randint(-search_radius, search_radius + 1)
                    x_random = x_source + np.random.randint(-search_radius, search_radius + 1)
                    if (delta <= y_random < H - delta) and (delta <= x_random < W - delta):
                        if valid_centers[y_random, x_random]:
                            break  # Position valide trouvée
                    attempts += 1
                else:
                    continue  # Passer à la prochaine itération si aucune position valide n'est trouvée

                source_patch_random = source_region[
                    y_random - delta:y_random + delta + 1,
                    x_random - delta:x_random + delta + 1,
                    :
                ]
                target_patch = target_region[
                    i - delta:i + delta + 1,
                    j - delta:j + delta + 1,
                    :
                ]
                ssd = np.sum((source_patch_random - target_patch) ** 2)

                if ssd < best_ssd:
                    NNField[i, j, :2] = [y_random, x_random]
                    NNField[i, j, 2] = ssd
                    best_ssd = ssd


# # Fonction pour reconstruire la région cible
# def construct_target_region(source_region, target_region, bounds_list, NNField, patch_size, valid_centers):
#     h, w, c = target_region.shape
#     delta = patch_size // 2
#     result = np.copy(source_region)

#     # Remplacer la région cible par des pixels issus d'un seul patch
#     for i in range(delta, h - delta):
#         for j in range(delta, w - delta):
#             source_y, source_x = NNField[i, j, :2]

#             # Vérifier que le patch source est valide
#             if not valid_centers[source_y, source_x]:
#                 #print("Erreur : un patch invalide est utilisé")
#                 continue

#             # Copier le patch correspondant dans la région cible
#             result[
#                 bounds_list[1] + i - delta:bounds_list[1] + i + delta + 1,
#                 bounds_list[0] + j - delta:bounds_list[0] + j + delta + 1,
#                 :
#             ] = source_region[
#                 source_y - delta:source_y + delta + 1,
#                 source_x - delta:source_x + delta + 1,
#                 :
#             ]

#     return result

def construct_target_region(source_region, target_region, bounds_list, NNField, patch_size, valid_centers):
    h, w, c = target_region.shape
    delta = patch_size // 2
    result = np.copy(source_region)

    # Remplacer la région cible par des pixels issus d'un seul patch sans chevauchement
    for i in range(delta, h - delta, delta * 2):
        for j in range(delta, w - delta, delta * 2):
            source_y, source_x = NNField[i, j, :2]

            # Vérifier que le patch source est valide
            if not valid_centers[source_y, source_x]:
                continue

            # Calculer les indices globaux dans l'image résultante
            global_i = bounds_list[1] + i
            global_j = bounds_list[0] + j

            # Vérifier que les indices restent dans les limites de l'image
            if (global_i - delta >= 0 and global_i + delta + 1 <= result.shape[0] and
                global_j - delta >= 0 and global_j + delta + 1 <= result.shape[1] and
                source_y - delta >= 0 and source_y + delta + 1 <= source_region.shape[0] and
                source_x - delta >= 0 and source_x + delta + 1 <= source_region.shape[1]):

                # Copier le patch correspondant dans la région cible
                result[
                    global_i - delta : global_i + delta + 1,
                    global_j - delta : global_j + delta + 1,
                    :
                ] = source_region[
                    source_y - delta : source_y + delta + 1,
                    source_x - delta : source_x + delta + 1,
                    :
                ]

    return result


# #Fonction de diffusion pour l'inpainting initial
# def diffusion_inpainting(source_region, filled_region_mask):
#     H, W, C = source_region.shape
#     filled_region_mask_copy = np.copy(filled_region_mask)
#     source_region_copy = np.copy(source_region)
#     while np.sum(filled_region_mask_copy == 0) > 0:
#         unfilled_points = np.column_stack(np.where(filled_region_mask_copy == 0))
#         for y, x in unfilled_points:
#             sum_around = []
#             if x + 1 < W and filled_region_mask_copy[y, x + 1] == 1:
#                 sum_around.append(source_region_copy[y, x + 1, :])
#             if x - 1 >= 0 and filled_region_mask_copy[y, x - 1] == 1:
#                 sum_around.append(source_region_copy[y, x - 1, :])
#             if y + 1 < H and filled_region_mask_copy[y + 1, x] == 1:
#                 sum_around.append(source_region_copy[y + 1, x, :])
#             if y - 1 >= 0 and filled_region_mask_copy[y - 1, x] == 1:
#                 sum_around.append(source_region_copy[y - 1, x, :])
#             if len(sum_around) > 0:
#                 filled_region_mask_copy[y, x] = 1
#                 source_region_copy[y, x, :] = np.mean(sum_around, axis=0)
#     return source_region_copy

import numpy as np
from scipy.ndimage import convolve

def diffusion_inpainting(source_region, filled_region_mask, max_iterations=10):
    source_region_copy = np.copy(source_region)
    mask = filled_region_mask.astype(bool)
    
    kernel = np.array([[0, 1, 0],
                       [1, 0, 1],
                       [0, 1, 0]], dtype=np.float32)
    
    for _ in range(max_iterations):
        # Trouver les pixels à remplir
        unfilled = ~mask
        
        # Vérifier s'il reste des pixels à remplir
        if not np.any(unfilled):
            break
        
        # Calculer la somme des pixels voisins
        convolved = np.zeros_like(source_region_copy)
        count = np.zeros_like(mask, dtype=np.float32)
        for c in range(source_region_copy.shape[2]):
            convolved[:, :, c] = convolve(source_region_copy[:, :, c], kernel, mode='constant', cval=0.0)
        count = convolve(mask.astype(float), kernel, mode='constant', cval=0.0)
        # Éviter la division par zéro
        valid = (count > 0) & unfilled
        # Mettre à jour les pixels non remplis
        for c in range(source_region_copy.shape[2]):
            source_region_copy[:, :, c][valid] = convolved[:, :, c][valid] / count[valid]
        # Mettre à jour le masque
        mask[valid] = True
    
    return source_region_copy


# Fonction pour construire la pyramide d'images
def build_image_pyramid(image, mask, num_levels):
    image_pyramid = [image]
    mask_pyramid = [mask]
    for _ in range(1, num_levels):
        # Réduire l'image et le masque de moitié
        image = cv2.pyrDown(image)
        mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
        image_pyramid.append(image)
        mask_pyramid.append(mask)
    return image_pyramid[::-1], mask_pyramid[::-1]  # Inverser pour commencer par la plus petite échelle

# Implémentation finale de PatchMatch
def patchmatch_inpainting(image_path, mask_path, num_levels=3, patch_size=9, n_iter=5):
    # Chargement des images
    image = Image.open(image_path)
    filled_region_mask = Image.open(mask_path)
    # Convertir les images en tableaux NumPy
    filled_region_mask = np.array(filled_region_mask) / 255
    image = np.array(image) / 255
    # Séparer le canal alpha si présent
    if image.shape[2] == 4:
        alpha_channel = image[:, :, 3]
        image = image[:, :, :3]
    else:
        alpha_channel = None
    # Copie de l'image source
    source_region = np.copy(image)
    source_region[filled_region_mask == 0] = 0
    source_region = diffusion_inpainting(source_region, current_mask)
    # Construire les pyramides d'images et de masques
    image_pyramid, mask_pyramid = build_image_pyramid(source_region, filled_region_mask, num_levels)
    prev_NNField = None
    for level in range(num_levels):
        print(f"Traitement du niveau {level + 1}/{num_levels}")
        # Obtenir l'image et le masque à l'échelle actuelle
        current_image = image_pyramid[level]
        current_mask = mask_pyramid[level]
        # Appliquer la diffusion à la plus petite échelle
        if level == 0:
            #current_image = diffusion_inpainting(current_image, current_mask)
            image_pyramid[level] = current_image
        else:
            # Utiliser l'image de l'échelle inférieure pour initialiser le trou
            prev_image = cv2.pyrUp(image_pyramid[level - 1])
            if prev_image.shape != current_image.shape:
                prev_image = cv2.resize(prev_image, (current_image.shape[1], current_image.shape[0]))
            current_image[current_mask == 0] = prev_image[current_mask == 0]
        # Définir les dimensions
        H, W, C = current_image.shape
        # Identifier la région cible
        target_region_indices = np.where(current_mask == 0)
        if target_region_indices[0].size == 0 or target_region_indices[1].size == 0:
            continue  # Pas de trou à cette échelle
        y_min, y_max = np.min(target_region_indices[0]), np.max(target_region_indices[0])
        x_min, x_max = np.min(target_region_indices[1]), np.max(target_region_indices[1])
        bounds_list = [x_min, y_min, x_max, y_max]
        # Extraire la région cible
        target_region = current_image[y_min:y_max + 1, x_min:x_max + 1, :]
        # Calculer le masque des centres de patches valides
        valid_centers = compute_valid_patch_centers(current_mask, patch_size)
        # Initialiser le NNField
        if prev_NNField is None:
            NNField = initialize_NNField(current_image, target_region, bounds_list, patch_size, valid_centers)
        else:
            # Agrandir le NNField précédent
            scale_factor_y = current_image.shape[0] / image_pyramid[level - 1].shape[0]
            scale_factor_x = current_image.shape[1] / image_pyramid[level - 1].shape[1]
            NNField = cv2.resize(prev_NNField, (target_region.shape[1], target_region.shape[0]), interpolation=cv2.INTER_NEAREST)
            NNField[:, :, 0] = (NNField[:, :, 0] * scale_factor_y).astype(int)
            NNField[:, :, 1] = (NNField[:, :, 1] * scale_factor_x).astype(int)
            NNField[:, :, 2] = NNField[:, :, 2] * (scale_factor_y * scale_factor_x)
        # Exécuter PatchMatch
        for i in range(n_iter):
            propagation(current_image, target_region, bounds_list, NNField, patch_size, valid_centers)
            print("propag")
            random_search(current_image, target_region, NNField, patch_size, valid_centers)
            print("random")
        # Reconstruire la région cible
        current_image = construct_target_region(current_image, target_region, bounds_list, NNField, patch_size, valid_centers)
        # Mettre à jour l'image dans la pyramide
        image_pyramid[level] = current_image
        # Conserver le NNField pour l'échelle supérieure
        prev_NNField = NNField
    # Image finale à la résolution originale
    final_image = image_pyramid[-1]
    print("fini")
    # Ajouter le canal alpha si nécessaire
    if alpha_channel is not None:
        final_image = np.dstack([final_image, alpha_channel])
    # Sauvegarder l'image finale
    final_image_uint8 = (final_image * 255).astype(np.uint8)
    Image.fromarray(final_image_uint8).save("image_patchmatch_multires.png")
    print("Inpainting terminé. Image sauvegardée sous 'image_patchmatch_multires.png'.")

# Exécution de la fonction principale
patchmatch_inpainting("TSP.png", "filled_region_mask.png", num_levels=5, patch_size=5, n_iter=5)

Traitement du niveau 1/5
propag
random
propag
random
propag
random
propag
random
propag
random
Traitement du niveau 2/5
propag
random
propag
random
propag
random
propag
random
propag
random
Traitement du niveau 3/5
propag
random
propag
random
propag
random
propag
random
propag
random
Traitement du niveau 4/5
propag
random
propag
random
propag
random
propag
random
propag
random
Traitement du niveau 5/5
propag
random
propag
random
propag
random
propag
random
propag
random
fini
Inpainting terminé. Image sauvegardée sous 'image_patchmatch_multires.png'.
