# Création de tuiles .npy à partir de fichiers Ortho .jp2 et BDTopo .shp

Suivez ces étapes pour traiter correctement vos fichiers :

> * Choisissez les fichiers .jp2 que vous souhaitez traiter
>   - Commencez par sélectionner les fichiers .jp2 avec lesquels vous souhaitez travailler.
>   - Une fois sélectionnés, stockez ces fichiers dans le répertoire approprié, qui est 'DATA/Selected/jp2'.

> * Renommez les fichiers .shp et .shx liés à la région de l'image sélectionnée
>   - Les fichiers .shp et .shx correspondant à la région couverte par votre image sélectionnée doivent être renommés.
>   - Ajoutez le numéro du département au début de ces noms de fichier, en suivant ce format : "{numéro_département}_BATIMENT.shp".
>   - Après avoir renommé les fichiers, placez-les dans le dans le répertoire approprié, qui est 'DATA/Selected/shp'

> * Répétez le processus pour chaque image et région souhaitées
>   - Vous pouvez répéter le même processus pour autant d'images et de régions que vous le souhaitez.
>   - Si vous suivez correctement les étapes, la structure de votre répertoire devrait ressembler à ceci :


<img src='Screenshot 2023-05-22 at 09.10.58.png'>

In [1]:
import os
import glob
import rasterio
import numpy as np
import geopandas as gpd
import pandas as pd
from rasterio.mask import mask
from PIL import Image, ImageDraw
from skimage.draw import polygon
from pathlib import Path
import copy

class RGBImage:
    """
    Class that represents an RGB Image file.

    Attributes:
        file_path (str): Path to the image file.
        department_number (str): Department number.
        bounds (list): Bounding coordinates [left, bottom, right, top].
        rgb (np.array): Numpy array representing the RGB image.

    """
    def __init__(self, path):
        """
        Initialize an RGBImage instance.

        Args:
            file_path (str): Path to the image file.

        """
        self.path = path
        self.rgb = None
        self.xrange = None
        self.yrange = None
        self.resolution = None
        self.bounds = None
        self.crs = None
        self.transform = None
        self.department_number = os.path.basename(path).split("-")[0]
        self.rgb_id = os.path.basename(path).split(".")[0]
        

    def load(self):
        """
        Load the image file and stores the data in the rgb attribute.
        
        This function reads a rasterio object from the file at self.path, 
        extracts the RGB channels, the resolution, bounds, CRS, transform, 
        width, and height, and stores them as attributes of the object.
        """
        with rasterio.open(self.path) as src:
            self.rgb = src.read([1,2,3]).astype(np.float32)
            self.resolution = src.res
            self.bounds = src.bounds
            self.crs = src.crs
            self.transform = src.transform
            self.xrange = [self.bounds.left, self.bounds.right]
            self.yrange = [self.bounds.bottom, self.bounds.top]
            self.img_size = (src.height, src.width)

    def save(self, destination_path):
        """
        Save the RGB image to the specified directory.

        This function writes the RGB channels of the image stored in self.rgb 
        into a new TIFF file at destination_path, using the metadata stored in 
        the object's attributes.

        Args:
            destination_path (str): Path to the directory where the image should be saved.
        """
        # Define metadata of the output file
        profile = {
            'driver': 'GTiff',
            'height': self.rgb.shape[1],
            'width': self.rgb.shape[2],
            'count': 3,
            'dtype': self.rgb.dtype,
            'crs': self.crs,
            'transform': self.transform,
        }
        rgb_filename = f"rgb_{self.rgb_id}.tif"
        with rasterio.open(os.path.join(destination_path, rgb_filename), 'w', **profile) as dst:
            dst.write(self.rgb)


class Building():
    """
    Load all shapefiles corresponding to the given department number.
        
    This function reads all shapefiles found in the paths in self.paths_shp, 
    adds a 'Category' and 'Type' column to each GeoDataFrame, and concatenates 
    all GeoDataFrames into a single one.
        
    Returns:
        GeoDataFrame: The concatenated GeoDataFrame containing all shapefiles data.
    """    
    def __init__(self, department_number, shp_root_dir, xrange, yrange, img_size, centroid_loaded_path, image):
        self.paths_shp = glob.glob(os.path.join(shp_root_dir, f"{department_number}-*.shp"))
        self.xrange = xrange
        self.yrange = yrange
        self.map_size = img_size
        self.bdtopo = self.load_all_shapefiles()
        self.bdtopo_batiment = self.bdtopo[(self.bdtopo['Category'] == 'BATI') & (self.bdtopo['Type'] == 'BATIMENT')].copy()
        self.centroids = pd.DataFrame()
        
        self.preprocess_data(centroid_loaded_path, image)
        
        

    def load_all_shapefiles(self):
        bdtopo = []
        for path in self.paths_shp:
            data = gpd.read_file(path)
            category = path.split('/')[-2] 
            file_info = path.split('/')[-1][:-4]
            data.insert(0, 'Category', category)
            data.insert(1, 'Type', file_info)
            bdtopo.append(data)
            
        bdtopo = pd.concat(bdtopo).reset_index(drop=True)
        return bdtopo


    def isInMap(self):
        """
        Generate a function that checks if a given polygon is within the bounds of the image.
        
        This function returns a function that takes a polygon and returns True if the polygon's 
        centroid is within the image bounds, and False otherwise.
        
        Returns:
            Function: The generated function.
        """
        def my_function(polygon):
            x, y = polygon.centroid.x, polygon.centroid.y
            if self.xrange[0] < x and self.xrange[1] > x and self.yrange[0] < y and self.yrange[1] > y:
                return True
            else:
                return False
        return my_function
    

    def convert_centroid(self):
        """
        Generate a function that converts a polygon's centroid coordinates to image coordinates.
        
        This function returns a function that takes a polygon and calculates its centroid's 
        coordinates in the image coordinate system.
        
        Returns:
            Function: The generated function.
        """
        def my_function(polygon):
            x, y = polygon.centroid.x, polygon.centroid.y
            x_new = (x - self.xrange[0]) / (self.xrange[1] - self.xrange[0]) * self.map_size[0]
            y_new = self.map_size[0] - (y - self.yrange[0]) / (self.yrange[1] - self.yrange[0]) * self.map_size[0]
            return [x_new, y_new]
        return my_function


    def convert_polygon(self):
        """
        Generate a function that converts a polygon's coordinates to image coordinates.
        
        This function returns a function that takes a polygon and calculates its 
        coordinates in the image coordinate system.
        
        Returns:
            Function: The generated function.
        """
        def my_function(polygon):
            if polygon.wkt.lower()[:7] == "polygon":
                x, y = polygon.exterior.coords.xy
                x = x.tolist()
                y = y.tolist()
            elif polygon.wkt[:10].lower() == "linestring":
                x, y = polygon.coords.xy
                x = x.tolist()
                x += x[::-1]
                y = y.tolist()
                y += y[::-1]
            else:
                x = [1, 2]
                y = [1, 2]
            x = np.array(x)
            y = np.array(y)
            x_new = (x - self.xrange[0]) / (self.xrange[1] - self.xrange[0]) * self.map_size[0]
            y_new = self.map_size[0] - (y - self.yrange[0]) / (self.yrange[1] - self.yrange[0]) * self.map_size[0]
            
            return list(zip(x_new, y_new))
            
        return my_function

    def save_centroids(self, destination_path, rgb_image):
        """
        Save the centroids of the building polygons to a CSV file.

        Args:
            destination_path (str): The path to the directory where the CSV file should be saved.
            rgb_image (RGBImage): The RGBImage object that corresponds to the buildings.
        """
        print("save_centroids method called!")
        if not self.centroids.empty:
            centroids_df = pd.DataFrame({
                'X': self.centroids['xcentroid'],
                'Y': self.centroids['ycentroid']
            })
            centroids_path = os.path.join(destination_path, f"centroids_{rgb_image.rgb_id}.csv")

            centroids_df.to_csv(centroids_path, index=False)

            print('Checking if centroids_df is still filled after saving:')
            print(not centroids_df.empty)


    def preprocess_data(self, centroid_loaded_path, image):
        """
        Preprocess the loaded data.

        This function filters the GeoDataFrame to include only the polygons that are within 
        the image bounds and converts their coordinates and centroids to image coordinates. 
        It also saves the centroids to a CSV file if a path is provided.

        Args:
            centroid_loaded_path (str): The path to the directory where the centroids should be saved.
            image (RGBImage): The RGBImage object that corresponds to the buildings.
        """
        self.centroid_loaded_path = centroid_loaded_path
        self.image = image

        self.bdtopo = self.bdtopo[self.bdtopo['geometry'].apply(self.isInMap())].copy()
        self.bdtopo['centroid'] = self.bdtopo['geometry'].apply(self.convert_centroid())
        self.bdtopo['xcentroid'] = self.bdtopo['centroid'].apply(lambda x: x[0])
        self.bdtopo['ycentroid'] = self.bdtopo['centroid'].apply(lambda x: x[1])
        
        self.centroids = copy.deepcopy(self.bdtopo[['xcentroid', 'ycentroid']])
        if self.centroid_loaded_path is not None and self.image is not None:
            self.save_centroids(self.centroid_loaded_path, self.image)

        bdtopo_point_mask = self.bdtopo[self.bdtopo['geometry'].apply(lambda x: x.wkt.lower()[:5] == "point")]
        self.bdtopo = self.bdtopo[self.bdtopo['geometry'].apply(lambda x: x.wkt.lower()[:7] == "polygon" or x.wkt[:10].lower() == "linestring")]
        self.bdtopo['polygon'] = self.bdtopo['geometry'].apply(self.convert_polygon())
        self.bdtopo = self.bdtopo.groupby('Type').agg({'polygon': list})
        print(self.bdtopo.shape)


        # Concatenate all the polygon coordinates into one list
        self.polygon_coord = []
        for _, row in self.bdtopo.iterrows():
            self.polygon_coord.extend(row['polygon'])
    


    def create_mask(self):
        """
        Create a mask of the building polygons.

        This function creates a binary image of the same size as the original image, 
        where the pixels inside the building polygons are set to 1, and the rest to 0.
        
        Returns:
            np.array: The created mask.
        """
        width, height = self.map_size
        
        img = Image.new('L', (width, height), 0)
        for p in self.polygon_coord:
            ImageDraw.Draw(img).polygon(p, outline=1, fill=1)
        mask = np.array(img)
        return mask


    def save(self, rgb_image, mask, destination_path):
        """
        Save the building mask to a TIFF file.

        This function writes the given mask into a new TIFF file at destination_path, 
        using the metadata from the given RGBImage object.

        Args:
            rgb_image (RGBImage): The RGBImage object that corresponds to the buildings.
            mask (np.array): The mask to save.
            destination_path (str): The path to the directory where the TIFF file should be saved.
        """
        # Define the properties of the image
        profile = {
            'driver': 'GTiff',
            'height': mask.shape[0],
            'width': mask.shape[1],
            'count': 1,
            'dtype': mask.dtype,
            'crs': rgb_image.crs,
            'transform': rgb_image.transform
        }

        # Define the output file name including the department_number and the rgb_id
        mask_filename = f"mask_{rgb_image.rgb_id}.tif"
        
        # Save the mask as a TIFF file
        with rasterio.open(os.path.join(destination_path, mask_filename), 'w', **profile) as dst:
            dst.write(mask, 1)




In [2]:
# directories
jp2_directory = 'DATA/Selected/jp2'
shp_root_dir = 'DATA/Selected/shp'
rgb_loaded_path = "./processed_data/rgbs"
mask_loaded_path = "./processed_data/masks"
centroid_loaded_path = "./processed_data/centroids"  

if not os.path.exists(rgb_loaded_path):
    os.makedirs(rgb_loaded_path)
if not os.path.exists(mask_loaded_path): 
    os.makedirs(mask_loaded_path)
if not os.path.exists(centroid_loaded_path):
    os.makedirs(centroid_loaded_path)



# Load each jp2 file and the corresponding shapefile
for jp2_file in Path(jp2_directory).rglob("*.jp2"):
    image = RGBImage(str(jp2_file))
    image.load()
    
    building = Building(image.department_number, shp_root_dir, image.xrange, image.yrange, image.img_size, centroid_loaded_path, image)
    image.save(rgb_loaded_path)

    mask = building.create_mask()
    building.save(image, mask, mask_loaded_path)

save_centroids method called!
Checking if centroids_df is still filled after saving:
True
(1, 1)
save_centroids method called!
Checking if centroids_df is still filled after saving:
True
(1, 1)
save_centroids method called!
Checking if centroids_df is still filled after saving:
True
(1, 1)
save_centroids method called!
Checking if centroids_df is still filled after saving:
True
(1, 1)


In [5]:
import os
import glob
import rasterio
import numpy as np
import pandas as pd
import random
from collections import Counter
from typing import Tuple
from shapely.geometry import Point
from rasterio.windows import from_bounds

import pandas as pd

class ImageSplitter:
    def __init__(self, nbPxX: int, nbPxY: int, white_threshold: float = 0.15):
        """
        Initialize an ImageSplitter instance.

        Args:
            nbPxX (int): The number of pixels in the X direction for each subimage.
            nbPxY (int): The number of pixels in the Y direction for each subimage.
            white_threshold (float, optional): The fraction of white pixels allowed in a subimage. 
                                               Subimages with more white pixels are ignored. Defaults to 0.15.
        """

        self.nbPxX = nbPxX
        self.nbPxY = nbPxY
        self.white_threshold = white_threshold

    def split_img(self, img_path: str, dest_dir: str) -> dict:
        """
        Split the image into subimages and save them as .npy files.

        Args:
            img_path (str): The path to the original image file.
            dest_dir (str): The directory where the subimages will be saved.

        Returns:
            dict: A dictionary mapping the filepaths of the saved subimages to their category.
        """
        filename = os.path.splitext(os.path.basename(img_path))[0]
        img_type, image_id = filename.split("_", 1)
        
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)

        with rasterio.open(img_path) as src:
            width = src.width
            height = src.height

            num_subimages_x = width // self.nbPxX
            num_subimages_y = height // self.nbPxY

            tile_info = {}

            for i in range(num_subimages_y):
                for j in range(num_subimages_x):
                    if (i+1) * self.nbPxY > height or (j+1) * self.nbPxX > width:
                        continue

                    window = rasterio.windows.Window(j * self.nbPxX, i * self.nbPxY, self.nbPxX, self.nbPxY)
                    subimage = src.read(window=window)

                    white_fraction = np.mean(subimage == 255)
                    if white_fraction > self.white_threshold:
                        continue

                    subimage_filename = f"{img_type}_{image_id}_{i:03d}_{j:03d}.npy"
                    subimage_filepath = os.path.join(dest_dir, subimage_filename)

                    subimage = np.transpose(subimage, (1, 2, 0))
                    np.save(subimage_filepath, subimage)

                    tile_info[subimage_filepath] = self.classify_tile(j * self.nbPxX, i * self.nbPxY, self.nbPxX, self.nbPxY, image_id) 

            return tile_info

    # Method for classifying tiles
    def classify_tile(self, start_x, start_y, width, height, image_id) -> int:
        """
        Classify a tile based on the number of centroids it contains.

        Args:
            start_x (int): The X-coordinate of the top-left corner of the tile.
            start_y (int): The Y-coordinate of the top-left corner of the tile.
            width (int): The width of the tile.
            height (int): The height of the tile.
            image_id (str): The identifier of the image.

        Returns:
            int: The category of the tile (0 if it contains no centroids, 1 if it contains less than 6 centroids, 2 otherwise).
        """
        num_centroids = self.count_centroids(start_x, start_y, width, height, image_id)
        
        if num_centroids == 0:
            return 0
        elif num_centroids < 6:
            return 1
        else:
            return 2

    # Method for counting centroids
    def count_centroids(self, start_x, start_y, width, height, image_id) -> int:
        """
        Count the number of centroids within a tile.

        Args:
            start_x (int): The X-coordinate of the top-left corner of the tile.
            start_y (int): The Y-coordinate of the top-left corner of the tile.
            width (int): The width of the tile.
            height (int): The height of the tile.
            image_id (str): The identifier of the image.

        Returns:
            int: The number of centroids within the tile.
        """

        # Load the centroid data from CSV file
        centroid_data = pd.read_csv(f"processed_data/centroids/centroids_{image_id}.csv")

        # Calculate how many centroids are within the tile's boundaries
        num_centroids = ((centroid_data['X'] >= start_x) & 
                         (centroid_data['X'] < start_x + width) & 
                         (centroid_data['Y'] >= start_y) & 
                         (centroid_data['Y'] < start_y + height)).sum()

        return num_centroids

    # Modify balance_categories
    def balance_categories(self, all_tiles: dict):
        """
        Balance the number of tiles in each category by deleting tiles from overrepresented categories.

        Args:
            all_tiles (dict): A dictionary mapping the filepaths of all tiles to their category.

        Returns:
            dict: A dictionary mapping the filepaths of the remaining tiles to their category.
        """
        tile_categories = list(all_tiles.values())
        category_counts = Counter(tile_categories)
        min_count = min(category_counts.values())

        selected_tiles = {tile: category for tile, category in all_tiles.items() if category_counts[category] > min_count}
        unselected_tiles = {tile: category for tile, category in all_tiles.items() if category_counts[category] <= min_count}

        for tile in unselected_tiles.keys():
            os.remove(tile)

        print(f"Deleted {len(unselected_tiles)} tiles. {len(selected_tiles)} tiles remaining.")
        
        return selected_tiles



In [4]:
# Instantiate the class
image_splitter = ImageSplitter(512, 512)

# Define the input directories
rgb_dir = "processed_data/npy_tiles/rgbs"
mask_dir = "processed_data/npy_tiles/masks"

# Create output directories if they don't exist
if not os.path.exists(rgb_dir):
    os.makedirs(rgb_dir)
if not os.path.exists(mask_dir): 
    os.makedirs(mask_dir)

# Define the input directories
img_dirs = {
    "processed_data/rgbs": rgb_dir,
    "processed_data/masks": mask_dir,
}

# Split all images
print("Splitting images...")
all_tiles = {}
for img_input_dir, img_output_dir in img_dirs.items():
    img_paths = glob.glob(os.path.join(img_input_dir, "*.tif"))
    for img_path in img_paths:
        tiles = image_splitter.split_img(img_path, img_output_dir)
        all_tiles.update(tiles)  # add the tiles from the current image to the dictionary
print(f"Total tiles created: {len(all_tiles)}")

# Balance the categories
print("Balancing categories...")
balanced_tiles = image_splitter.balance_categories(all_tiles)
print(f"Total tiles after balancing: {len(balanced_tiles)}")

# Inform that the script has finished running
print("Done!")


Splitting images...
Total tiles created: 18432
Balancing categories...
Deleted 3468 tiles. 14964 tiles remaining.
Total tiles after balancing: 14964
Done!
