PREDICT, INFERENCE

In [None]:
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
import rasterio
from rasterio.merge import merge
import shutil

# Set device
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define input/output paths
UNLABELED_IMAGES_DIR = Path("/home/bioeos/Documents/project_hub/segment_upscaling/output/crooped_ortho_png")
CROPPED_ORTHO_IMG_DIR = Path("/home/bioeos/Documents/project_hub/segment_upscaling/output/cropped_ortho/")

PREDICTION_TIFF_OUTPUT_DIR = Path("/home/bioeos/Documents/project_hub/segment_upscaling/output/predictions_tiff")
if PREDICTION_TIFF_OUTPUT_DIR.exists():
    shutil.rmtree(PREDICTION_TIFF_OUTPUT_DIR)
PREDICTION_TIFF_OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

MERGED_PREDICTIONS_FOLDER = Path("/home/bioeos/Documents/project_hub/segment_upscaling/output/merged_predictions")
if MERGED_PREDICTIONS_FOLDER.exists():
    shutil.rmtree(MERGED_PREDICTIONS_FOLDER)
MERGED_PREDICTIONS_FOLDER.mkdir(exist_ok=True, parents=True)


# Load model and processor
MODEL_PATH = "./segmentation_model/checkpoint-3807/"
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH).to(DEVICE)
processor = AutoImageProcessor.from_pretrained("nvidia/mit-b0", do_reduce_labels=False)

# Function to preprocess images
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(image, return_tensors="pt").to(DEVICE)
    return image, inputs

# Function to perform inference
def predict_mask(image_path):
    image, inputs = preprocess_image(image_path)
    
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits  # Shape: (1, num_labels, height, width)
    mask_resized_bilinear = nn.functional.interpolate( # Segformer size is 1/4 need to resize to get mask on image
            logits,  
            size=image.size, 
            mode='bilinear',
            align_corners=False
        )
    mask_resized_bilinear = mask_resized_bilinear.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
    return mask_resized_bilinear + 1

# Process all images


def process_inference(session_images_path: Path, predictions_output_dir: Path, merged_predictions_dir: Path):

    session_images = sorted(list(session_images_path.iterdir()))
    predicted_rasters = []
    for img_path in tqdm(session_images, desc="Processing images"):
        mask = predict_mask(img_path)
        
        # Store predicted mask for TIFF conversion
        predicted_rasters.append((mask, img_path))

    # Convert predictions to GeoTIFF using corresponding spatial info
    for mask, img_path in tqdm(predicted_rasters):
        corresponding_tiff = Path(CROPPED_ORTHO_IMG_DIR, session_images_path.name ,f"{img_path.stem}.tif")
        if not corresponding_tiff.exists():
            print(f"Warning: No matching TIFF file found for {img_path.name}, skipping...")
            continue
        
        with rasterio.open(corresponding_tiff) as src:
            meta = src.meta.copy()
            meta.update({"dtype": 'uint8', "count": 1, "nodata": 255}) 
            
            output_tiff_path = Path(predictions_output_dir, f"{img_path.stem}_prediction.tif")
            with rasterio.open(output_tiff_path, 'w', **meta) as dst:
                mask = np.where(mask == 0, 255, mask)
                dst.write(mask, 1)

    # Merge all TIFF predictions into a single raster
    merged_tiff_path = Path(merged_predictions_dir, f"{session_images_path.name}_merged_predictions.tif")
    prediction_tiff_files = sorted(list(predictions_output_dir.iterdir()))

    # Open all raster tiles
    src_files_to_mosaic = [rasterio.open(f) for f in prediction_tiff_files]

    # Step 1: Find min and max values dynamically
    global_min, global_max = np.inf, -np.inf

    for src in tqdm(src_files_to_mosaic, desc="Analyzing raster values", unit="file"):
        tile_min, tile_max = src.read(1).min(), src.read(1).max()
        global_min = min(global_min, tile_min)
        global_max = max(global_max, tile_max)

    # Ensure valid class range
    num_classes = int(global_max - global_min + 1)
    print(f"✅ Detected class range: {global_min} to {global_max} ({num_classes} classes)")

    # Step 2: Merge rasters to determine full-size shape and transform
    mosaic, out_trans = merge(src_files_to_mosaic, method="first") # Method is callable

    # Step 3: Initialize a count array (num_classes layers, same size as mosaic)
    value_counts = np.zeros((num_classes, mosaic.shape[1], mosaic.shape[2]), dtype=np.ubyte)

    # Step 4: Process each tile and correctly map its values into the full raster
    for src in tqdm(src_files_to_mosaic, desc="Processing tiles", unit="file"):
        tile_data = src.read(1)  # Read first band
        tile_transform = src.transform  # Get tile transform

        # Compute tile window in the full raster
        window = rasterio.windows.from_bounds(*src.bounds, transform=out_trans)
        window = window.round_offsets().round_lengths()
        row_off, col_off = int(window.row_off), int(window.col_off)
        print(row_off, col_off)
        return

        # Place values into the full raster count array
        for v in range(global_min, global_max + 1):  # Iterate over detected classes
            value_counts[v - global_min, row_off:row_off + tile_data.shape[0], col_off:col_off + tile_data.shape[1]] += (tile_data == v)
    return

    # Step 5: Determine the most common value for each pixel
    most_common_values = np.argmax(value_counts, axis=0) + global_min  # Convert index back to actual value

    # Step 6: Handle NoData pixels
    valid_pixel_mask = value_counts.sum(axis=0) > 0  # Check if at least one value exists

    final_raster = np.where(valid_pixel_mask, most_common_values, 0)  # Set NoData pixels to 0

    # Step 7: Save the final raster at `merged_tiff_path`
    with rasterio.open(
        merged_tiff_path,  # ✅ Corrected to save in merged_tiff_path
        "w",
        driver="GTiff",
        height=final_raster.shape[0],
        width=final_raster.shape[1],
        count=1,  # Single-band output
        dtype=np.uint8,
        crs=src_files_to_mosaic[0].crs,
        transform=out_trans,
        nodata=0  # Explicitly set NoData to 0
    ) as dst:
        dst.write(final_raster, 1)

    # Step 8: Print summary
    print(f"✅ Merged raster saved at {merged_tiff_path}")


for session in sorted(list(UNLABELED_IMAGES_DIR.iterdir()))[1:2]:
    prediction_dir_output = Path(PREDICTION_TIFF_OUTPUT_DIR, session.name)
    prediction_dir_output.mkdir(exist_ok=True)


    process_inference(session, prediction_dir_output, MERGED_PREDICTIONS_FOLDER)




# TODO

Utiliser l'ortho de chateaux pour faire des tests :
- Refaire avec le workflow normale
- Refaire en sauvegardant values_count sans faire le argmax
- Deamnder à matteo c'est ou que ses rasters se mergent proprement

In [None]:
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
import rasterio
from rasterio.merge import merge
from rasterio.transform import Affine
import shutil
# Set device
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define input/output paths
UNLABELED_IMAGES_DIR = Path("/home/bioeos/Documents/project_hub/segment_upscaling/output/crooped_ortho_png")
CROPPED_ORTHO_IMG_DIR = Path("/home/bioeos/Documents/project_hub/segment_upscaling/output/cropped_ortho/")

PREDICTION_TIFF_OUTPUT_DIR = Path("/home/bioeos/Documents/project_hub/segment_upscaling/output/predictions_tiff")
if PREDICTION_TIFF_OUTPUT_DIR.exists():
    shutil.rmtree(PREDICTION_TIFF_OUTPUT_DIR)
PREDICTION_TIFF_OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

MERGED_PREDICTIONS_FOLDER = Path("/home/bioeos/Documents/project_hub/segment_upscaling/output/merged_predictions")
if MERGED_PREDICTIONS_FOLDER.exists():
    shutil.rmtree(MERGED_PREDICTIONS_FOLDER)
MERGED_PREDICTIONS_FOLDER.mkdir(exist_ok=True, parents=True)

# Load model and processor
MODEL_PATH = "./segmentation_model/checkpoint-3807/"
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH).to(DEVICE)
processor = AutoImageProcessor.from_pretrained("nvidia/mit-b0", do_reduce_labels=False)

# Function to preprocess images
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(image, return_tensors="pt").to(DEVICE)
    return image, inputs

# Function to perform inference
def predict_mask(image_path):
    image, inputs = preprocess_image(image_path)
    
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits  # Shape: (1, num_labels, height, width)
    mask_resized_bilinear = nn.functional.interpolate( # Segformer size is 1/4 need to resize to get mask on image
            logits,  
            size=image.size, 
            mode='bilinear',
            align_corners=False
        )
    mask_resized_bilinear = mask_resized_bilinear.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
    return mask_resized_bilinear + 1


# Process all images
def process_inference(session_images_path: Path, predictions_output_dir: Path, merged_predictions_dir: Path):
    session_images = sorted(list(session_images_path.iterdir()))
    predicted_rasters = []
    
    for img_path in tqdm(session_images, desc="Processing images"):
        mask = predict_mask(img_path)
        predicted_rasters.append((mask, img_path))

    # Convert predictions to GeoTIFF
    for mask, img_path in tqdm(predicted_rasters):
        corresponding_tiff = Path(CROPPED_ORTHO_IMG_DIR, session_images_path.name, f"{img_path.stem}.tif")
        if not corresponding_tiff.exists():
            print(f"Warning: No matching TIFF file found for {img_path.name}, skipping...")
            continue

        with rasterio.open(corresponding_tiff) as src:
            meta = src.meta.copy()
            meta.update({"dtype": 'uint8', "count": 1, "nodata": 255}) 
            
            output_tiff_path = Path(predictions_output_dir, f"{img_path.stem}_prediction.tif")
            with rasterio.open(output_tiff_path, 'w', **meta) as dst:
                mask = np.where(mask == 0, 255, mask)
                dst.write(mask, 1)

    # Merge all TIFF predictions into a single raster
    merged_tiff_path = Path(merged_predictions_dir, f"{session_images_path.name}_merged_predictions.tif")
    prediction_tiff_files = sorted(list(predictions_output_dir.iterdir()))
    src_files_to_mosaic = [rasterio.open(f) for f in prediction_tiff_files]

    # Min-max detection in a single pass
    global_min, global_max = np.inf, -np.inf
    for src in tqdm(src_files_to_mosaic, desc="Analyzing raster values", unit="file"):
        tile = src.read(1)
        global_min = min(global_min, tile.min())
        global_max = max(global_max, tile.max())

    num_classes = int(global_max - global_min + 1)
    print(f"✅ Detected class range: {global_min} to {global_max} ({num_classes} classes)")

    # Merge rasters
    origin_mosaic, orig_transform = merge(src_files_to_mosaic, method="first")

    # Get the total size of the mosaic
    height, width = origin_mosaic.shape[1], origin_mosaic.shape[2]

    # Define the tile size
    tile_size = 20000  
    # Loop through and extract tiles
    mosaic_tiles = []
    for i in range(0, height, tile_size):
        # Define the window, making sure it doesn't exceed bounds
        win_height = min(tile_size, height - i)

        # Extract the tile
        tile = origin_mosaic[:, i:i+win_height, :]
        mosaic_tiles.append((tile, i))  # Store tile with position

    tiles_with_transforms = []
    for tile, i in mosaic_tiles:
        # Compute new transform for the tile
        new_transform = orig_transform * Affine.translation(1, i)
        tiles_with_transforms.append((tile, new_transform))

    tmp_rasters = []
    for i, (mosaic, out_trans) in enumerate(tiles_with_transforms):
        tmp_path = Path(merged_predictions_dir, f"{i}_{session_images_path.name}_merged_predictions.tif") 
        # Optimized Argmax Calculation
        most_common_values = np.full(mosaic.shape[1:], 0, dtype=np.uint8)  # Default to 0 (NoData)

        count_buffer = np.zeros((num_classes, *mosaic.shape[1:]), dtype=np.uint16)  # Avoid large int types

        for src in tqdm(src_files_to_mosaic, desc="Processing tiles", unit="file"):
            tile_data = src.read(1)
            window = rasterio.windows.from_bounds(*src.bounds, transform=out_trans).round_offsets().round_lengths()
            
             # Get offsets
            row_off, col_off = int(window.row_off), int(window.col_off)

            # **Handle Negative Offsets** (for overlapping images)
            row_start_tile = max(0, -row_off)  # How much to crop from tile (if it's above raster)
            col_start_tile = max(0, -col_off)

            row_off = max(0, row_off)  # Adjust offset to fit inside mosaic
            col_off = max(0, col_off)

            row_end = min(row_off + tile_data.shape[0] - row_start_tile, most_common_values.shape[0])
            col_end = min(col_off + tile_data.shape[1] - col_start_tile, most_common_values.shape[1])

            tile_height = row_end - row_off
            tile_width = col_end - col_off

            if tile_height <= 0 or tile_width <= 0:
                continue  # Skip tiles that are completely outside

            # **Crop tile_data properly for out-of-bounds cases**
            tile_data = tile_data[row_start_tile:row_start_tile + tile_height, col_start_tile:col_start_tile + tile_width]

            # Update class frequencies, ensuring bounds are correct
            for v in range(global_min, global_max + 1):
                mask = (tile_data == v)  
                count_buffer[v - global_min, row_off:row_end, col_off:col_end] += mask

        # Faster argmax using efficient NumPy operations
        valid_pixel_mask = count_buffer.sum(axis=0) > 0  # Avoid unnecessary computation
        most_common_values[valid_pixel_mask] = count_buffer[:, valid_pixel_mask].argmax(axis=0) + global_min

        # Save merged raster
        with rasterio.open(
            tmp_path,
            "w",
            driver="GTiff",
            height=most_common_values.shape[0],
            width=most_common_values.shape[1],
            count=1,
            dtype=np.uint8,
            crs=src_files_to_mosaic[0].crs,
            transform=out_trans,
            compress="LZW",
            nodata=0,
        ) as dst:
            dst.write(most_common_values, 1)

        tmp_rasters.append(tmp_path)

    # Merge all the small tiles into a single large raster using rasterio.merge
    mosaic, out_trans = merge([rasterio.open(tiff) for tiff in sorted(tmp_rasters)], method="first")

    # Save the final merged raster
    with rasterio.open(
        merged_tiff_path,
        "w",
        driver="GTiff",
        height=mosaic.shape[1],
        width=mosaic.shape[2],
        count=1,
        dtype=np.uint8,
        crs=src.crs,
        transform=out_trans,
        compress="LZW",
        nodata=0
    ) as dst:
        dst.write(mosaic[0, :], 1)

    # Clean up temporary files
    for temp_tiff in tmp_rasters:
        temp_tiff.unlink()


    print(f"✅ Merged raster saved at {merged_tiff_path}")

for session in sorted(list(UNLABELED_IMAGES_DIR.iterdir()))[:6]:
    prediction_dir_output = Path(PREDICTION_TIFF_OUTPUT_DIR, session.name)
    prediction_dir_output.mkdir(exist_ok=True)
    process_inference(session, prediction_dir_output, MERGED_PREDICTIONS_FOLDER)
