In [1]:
def preprocessing(imagename,listpaths,resultpath):

    '''----------------Imports----------------'''

    import cv2
    import numpy as np
    from scipy.ndimage import median_filter
    import matplotlib.pyplot as plt
    from astropy.io import fits
    from skimage.measure import EllipseModel
    from astropy.io import fits
    from sklearn.cluster import DBSCAN
    from matplotlib.patches import Ellipse
    import os
    from astropy.io import fits




    
    '''----------------End Of Imports----------------'''



    '''----------------Definition of Functions----------------'''
    
    def apply_percentile_threshold(image_data, percentile=99.7):
        threshold_value = np.percentile(image_data, percentile)
        
        return np.clip(image_data, None, threshold_value)
    
    def apply_adaptive_filtering(image_data, size=3):
        return median_filter(image_data, size=size)
    
    def preprocess_image(raw_image, percentile=99.7, size=3):
        preprocessed_image = apply_percentile_threshold(raw_image, percentile)
        preprocessed_image = apply_adaptive_filtering(preprocessed_image, size)
        return preprocessed_image
    
    def normalize_image(image):
        """Normalise l'image pour faciliter le traitement."""
        return cv2.normalize(image.astype('float'), None, 0.0, 1.0, cv2.NORM_MINMAX)


    def detect_edges_in_profile(profile, threshold):
        """Détecte les bords gauche et droit d'un profil horizontal ou vertical."""
        left_edge = np.argmax(profile > threshold)
        right_edge = len(profile) - np.argmax(profile[::-1] > threshold) - 1
        return left_edge, right_edge


    def find_contour_points(image, center_y, center_x, threshold):
        """Trouve tous les points de contour dans les directions Y et X."""
        contour_points = []
        
        # Trouver les points de contour dans la direction Y (haut et bas)
        for y in range(center_y, image.shape[0]):
            profile = image[y, :]
            left_edge, right_edge = detect_edges_in_profile(profile, threshold)
            contour_points.append((y, left_edge))
            contour_points.append((y, right_edge))
        
        for y in range(center_y, -1, -1):
            profile = image[y, :]
            left_edge, right_edge = detect_edges_in_profile(profile, threshold)
            contour_points.append((y, left_edge))
            contour_points.append((y, right_edge))
        
        # Trouver les points de contour dans la direction X (gauche et droite)
        for x in range(center_x, image.shape[1]):
            profile = image[:, x]
            top_edge, bottom_edge = detect_edges_in_profile(profile, threshold)
            contour_points.append((top_edge, x))
            contour_points.append((bottom_edge, x))
        
        for x in range(center_x, -1, -1):
            profile = image[:, x]
            top_edge, bottom_edge = detect_edges_in_profile(profile, threshold)
            contour_points.append((top_edge, x))
            contour_points.append((bottom_edge, x))
        
        return np.array(contour_points)


    def group_contour_points(contour_points, max_distance=10):
        """Regroupe les points de contour en fonction de leur proximité."""
        groups = []
        visited = np.zeros(len(contour_points), dtype=bool)
        
        def euclidean_distance(p1, p2):
            """Calcule la distance euclidienne entre deux points."""
            return np.sqrt(np.sum((np.array(p1) - np.array(p2))**2))

        for i, point in enumerate(contour_points):
            if visited[i]:
                continue
            # Créer un nouveau groupe
            group = [point]
            visited[i] = True
            
            # Chercher les voisins de ce point pour les ajouter au groupe
            stack = [point]
            while stack:
                p = stack.pop()
                for j, q in enumerate(contour_points):
                    if not visited[j] and euclidean_distance(p, q) < max_distance:
                        group.append(q)
                        visited[j] = True
                        stack.append(q)
            
            # Ajouter le groupe trouvé
            groups.append(np.array(group))
        
        return groups


    def select_group_near_center(groups, center_y, center_x):
        """Sélectionne le groupe le plus proche du centre de l'image."""
        min_distance = float('inf')
        selected_group = None
        
        for group in groups:
            distances = np.linalg.norm(group - np.array([center_y, center_x]), axis=1)
            avg_distance = np.mean(distances)
            
            if avg_distance < min_distance:
                min_distance = avg_distance
                selected_group = group
        
        return selected_group


    def fit_ellipse_to_contour(contour_points):
        """Ajuste une ellipse aux points de contour."""
        if len(contour_points) >= 5:
            # Convertir les points en un format approprié pour EllipseModel
            group_y, group_x = zip(*contour_points)
            
            # Ajuster l'ellipse aux points du groupe
            ellipse = EllipseModel()
            ellipse.estimate(np.column_stack((group_x, group_y)))

            # Obtenir les paramètres de l'ellipse ajustée
            yc, xc, a, b, theta = ellipse.params

            # Retourner les paramètres sous la forme demandée
            return (yc, xc), (a, b), theta
        else:
            print("Pas assez de points pour ajuster une ellipse.")
            return None


    def plot_contour_and_ellipse(image, contour_points, ellipse):
        """Affiche l'image avec le contour détecté et l'ellipse ajustée."""
        plt.figure(figsize=(8, 8))
        plt.imshow(image, cmap='gray')
        plt.plot(contour_points[:, 1], contour_points[:, 0], 'r.', markersize=1)
        
        # Extraire les paramètres de l'ellipse ajustée
        ellipse_center = ellipse[0]
        ellipse_axes = ellipse[1]
        ellipse_angle = ellipse[2]
        
        # Générer les points de l'ellipse ajustée
        t = np.linspace(0, 2 * np.pi, 100)
        xc, yc = ellipse_center
        a, b = ellipse_axes
        theta = ellipse_angle
        
        ellipse_x = xc + a * np.cos(t) * np.cos(theta) - b * np.sin(t) * np.sin(theta)
        ellipse_y = yc + a * np.cos(t) * np.sin(theta) + b * np.sin(t) * np.cos(theta)
        

        # Afficher l'ellipse ajustée
        plt.plot(ellipse_x, ellipse_y, color='blue', linewidth=2, label='Ellipse ajustée')
        
        plt.title("Contour extérieur et ellipse ajustée de la galaxie")
        plt.legend()
        plt.show()



    def detect_contour_with_ellipse(image, max_distance=10, threshold=0.1):
        """Fonction principale pour détecter les contours et ajuster une ellipse."""
        norm_image = normalize_image(image)
        
        # Calculer les coordonnées centrales
        center_y, center_x = np.array(norm_image.shape) // 2
        
        # Trouver les points de contour dans les directions Y et X
        contour_points = find_contour_points(norm_image, center_y, center_x, threshold)
        # Sauvegarder les points de contour dans un fichier texte
        np.savetxt('contour_points.txt', contour_points, fmt='%d', delimiter=',')
        
        # Regrouper les points voisins
        groups = group_contour_points(contour_points, max_distance)
        
        # Sélectionner le groupe proche du centre
        selected_group = select_group_near_center(groups, center_y, center_x)
        
        if selected_group is not None:
            # Ajuster l'ellipse
            ellipse = fit_ellipse_to_contour(selected_group)
            if ellipse:
                plot_contour_and_ellipse(norm_image, selected_group, ellipse)
        else:
            print("Aucun groupe trouvé près du centre.")


    def detectHighIntesitypointsLabels(enhanced_image):
        def select_central_labels(image, coords, labels):
            center_y, center_x = np.array(image.shape) // 2
            central_label = None
            min_distance = float('inf')
            
            for point, label in zip(coords, labels):
                if label != -1:  # Ignore noise points
                    distance_to_center = np.sqrt((point[0] - center_x)**2 + (point[1] - center_y)**2)
                    if distance_to_center < min_distance:
                        min_distance = distance_to_center
                        central_label = label
                        
            return central_label



        # Find the coordinates of pixels with intensity above the threshold (75% of the max intensity)
        y_coords, x_coords = np.where(enhanced_image > np.max(enhanced_image) * 0.75)

        # Combine x and y coordinates into a single array
        coords = np.column_stack((x_coords, y_coords))

        # Use DBSCAN to group points based on their closeness
        db = DBSCAN(eps=3, min_samples=5).fit(coords)
        labels = db.labels_


        central_label = select_central_labels(enhanced_image, coords, labels)

       
        return labels, central_label,coords
    

    def plotEllipses(enhanced_image,coords, labels, central_label):
        '''-----definition of function-----'''
        def fit_gaussian_ellipse(points):
            if len(points) < 5:
                return None

            # Calculate the mean and covariance matrix of the points
            mean = np.mean(points, axis=0)
            cov = np.cov(points, rowvar=False)

            # Eigenvalues and eigenvectors of the covariance matrix
            eigenvalues, eigenvectors = np.linalg.eigh(cov)

            # Sort eigenvalues and eigenvectors
            order = eigenvalues.argsort()[::-1]
            eigenvalues = eigenvalues[order]
            eigenvectors = eigenvectors[:, order]

            # Calculate the angle of the ellipse
            angle = np.degrees(np.arctan2(*eigenvectors[:, 0][::-1]))

            # Calculate the width and height of the ellipse and increase by 10%
            width, height = 2 * np.sqrt(eigenvalues) * 3

            return mean, width, height, angle
        '''-----End of function-----'''

        '''-----Main code-----'''
        # Plot Gaussian ellipses for each cluster other than the central one
        ellipses=[]
        for k in set(labels):
            if k == -1 or k == central_label:
                continue
            class_member_mask = (labels == k)
            cluster_points = coords[class_member_mask]

            ellipse_params = fit_gaussian_ellipse(cluster_points)
            if ellipse_params:
                mean, width, height, angle = ellipse_params
                ellipse = Ellipse(xy=mean, width=width, height=height, angle=angle, edgecolor='red', fc='None', lw=2)
                ellipses.append(ellipse)

        return ellipses

        '''-----End of Main code-----'''


    def fade_ellipses(enhanced_image, ellipses):
        # Create a mask for the points inside the ellipses
        mask = np.zeros_like(enhanced_image, dtype=bool)

        for ellipse in ellipses:
            center = ellipse.center
            width = ellipse.width
            height = ellipse.height
            angle = np.deg2rad(ellipse.angle)
            
            y, x = np.ogrid[:enhanced_image.shape[0], :enhanced_image.shape[1]]
            x_centered = x - center[0]
            y_centered = y - center[1]
            
            cos_angle = np.cos(angle)
            sin_angle = np.sin(angle)
            
            ellipse_mask = ((x_centered * cos_angle + y_centered * sin_angle) ** 2 / (width / 2) ** 2 +
                            (x_centered * sin_angle - y_centered * cos_angle) ** 2 / (height / 2) ** 2 <= 1)
            
            mask |= ellipse_mask

        # Apply Gaussian filter to the masked region
        # Create a copy of the enhanced image to apply inpainting
        dimmed_image = enhanced_image.copy()

        # Apply inpainting to the masked region
        inpainted_image = cv2.inpaint(dimmed_image, mask.astype(np.uint8), inpaintRadius=3, flags=cv2.INPAINT_TELEA)

        # Update the dimmed image with the inpainted result
        dimmed_image[mask] = inpainted_image[mask]

        return dimmed_image

    '''----------------End Of Definition of Functions----------------'''


    '''----------------Main Code----------------'''
    #Load Images
    images = []
    superposed_images=[]
    for path in listpaths:
        hdu_list = fits.open(path)
        raw_image= hdu_list[0].data
        raw_image=preprocess_image(raw_image)
        if len(superposed_images) == 0:
            superposed_images = raw_image
        else:
            superposed_images += raw_image
    
    superposed_images=superposed_images/len(listpaths)

    # Detect high intensity points and labels
    labels, central_label,coords = detectHighIntesitypointsLabels(superposed_images)

    # create ellipses around concerned labels using guassian fit
    ellipsis=plotEllipses(superposed_images,coords,labels,central_label)

    dimmed_image = fade_ellipses(superposed_images, ellipsis)

    # Save the dimmed image
    finalName=imagename+'_fix.fits'
    finalPath=os.path.join(resultpath,finalName)
    fits.writeto(finalPath, dimmed_image, overwrite=True)  # overwrite=True will overwrite the file if it already exists


In [2]:
listimg=['../test/img-g\\PGC0000282_g.fits', '../test/img-i\\PGC0000282_i.fits', '../test/img-r\\PGC0000282_r.fits']
resultpath='../test/'
test_image=preprocessing('PGC0000282',listimg,resultpath)