# Step 1 : Import Libraries and Set up Working path

In [15]:
# Set up working directory and project directories

# Define function to check if the code is run in Google Colab
def in_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

import os

# Set working directory to Google Drive or Local based on usage
if in_colab():
    ## GOOGLE COLAB USERS ONLY
    ## Mount Google Drive for data retrieval
    print("Running in Google Colab")
    from google.colab import drive
    drive.mount('/content/drive')

    project_path = '/content/drive/MyDrive/TheNavySeals/'

    !pip install torch rasterio torchvision tifffile segmentation-models-pytorch -q
else:
    ## LOCAL USERS ONLY
    ## Change the path to your project directory
    print("Running Locally")
    os.chdir('D:\E_2024_P6\SEAL')
    project_path = ''

## Import packages
import os
import rasterio
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image, ImageTk, ImageEnhance
import tifffile
from osgeo import gdal
import tkinter as tk
from tkinter import filedialog, Toplevel, Label
from torchvision.transforms import v2 as transforms
from scipy.ndimage import convolve
import random


## Define paths
source_image = os.path.join(project_path, 'data/part_2_11.TIF')
source_model = os.path.join(project_path, 'data/2_deep_learning/2x_output_model/Model_49')

reduced_image_path = os.path.join(project_path, 'data/3_postprocessing/3a_reduced_image')
os.makedirs(reduced_image_path, exist_ok=True)
reduced_image = os.path.join(reduced_image_path, '3a_reduced_image.tif')

tiled_images_path = os.path.join(project_path, 'data/3_postprocessing/3b_tiled_images')
os.makedirs(tiled_images_path, exist_ok=True)

predicted_masks_path = os.path.join(project_path, 'data/3_postprocessing/3c_predicted_masks')
os.makedirs(predicted_masks_path, exist_ok=True)

predicted_masks_georef_path = os.path.join(project_path, 'data/3_postprocessing/3d_predicted_masks_georef')
os.makedirs(predicted_masks_path, exist_ok=True)

predicted_mask_path = os.path.join(project_path, 'data/3_postprocessing/3e_predicted_mask_final')
os.makedirs(predicted_mask_path, exist_ok=True)
predicted_mask = os.path.join(predicted_mask_path, '3e_predicted_mask.tif')

mask_heatmap_path = os.path.join(project_path, 'data/3_postprocessing/3f_mask_heatmap')
os.makedirs(mask_heatmap_path, exist_ok=True)
mask_heatmap = os.path.join(mask_heatmap_path, '3f_mask_heatmap.tif')

Running Locally


# Step 2 :Definition all the fuctions

### Function 1 : Reduce Radiometric Resolution

In [16]:
# Reduce image for increased processing speed
def reduce_radiometric_resolution(input_path, output_path, input_res=11):
    '''
    Reduce the radiometric resolution of the input raster and save the output raster.

    Args:
    - input_path (string): Path to the input raster.
    - output_path (string): Path to the output raster.
    - input_res (int): Radiometric resolution of the input raster in bits.
    '''
    # Ensure the output directory exists
    output_dir = os.path.dirname(output_path)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with rasterio.open(input_path) as src:
        # Read the number of bands
        num_bands = src.count

        # Initialize an array to store the scaled bands
        scaled_arrays = []

        for band in range(1, num_bands + 1):
            # Read the image band as a numpy array
            image_array = src.read(band, masked=True)

            # Rescale the pixel values to fit within 8-bit range (0-255)
            scaled_array = (image_array / (2**input_res - 1) * 255).astype(np.uint8)

            # Append the scaled array to the list
            scaled_arrays.append(scaled_array)

        # Stack the scaled arrays along the first axis to create a 3D array
        scaled_arrays = np.stack(scaled_arrays, axis=0)

        # Create a new raster profile with 8-bit pixel depth
        profile = src.profile
        profile.update(dtype=rasterio.uint8, count=num_bands)

        # Write the scaled arrays to a new raster file
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(scaled_arrays)

### Function 2:  Split and Save Raster

In [17]:
# Tile the image into 224x224 px tiles
def split_and_save_raster(input_raster_path, part_width, part_height, output_folder):
    '''
    Split a raster into multiple tiles of length part_width and height part_height, and save them in output_folder.

    Args:
    - input_raster_path: path to the input raster.
    - part_width (int): Width of each tile.
    - part_height (int): Height of each tile.
    - output_folder (str): Directory to save the rasters.
    '''
    # Open the raster
    dataset = gdal.Open(input_raster_path)

    # Get raster dimensions
    width = dataset.RasterXSize
    height = dataset.RasterYSize

    # Calculate the number of parts
    num_parts_x = width // part_width
    num_parts_y = height // part_height

    # Get the number of bands
    bands = dataset.RasterCount

    # Split the raster and save
    for i in range(num_parts_x):
        for j in range(num_parts_y):
            x_offset = i * part_width
            y_offset = j * part_height

            # Read the split region
            part = dataset.ReadAsArray(x_offset, y_offset, part_width, part_height)

            # Expand dimensions if there's only one band
            if bands == 1:
               part = np.expand_dims(part, axis=0)

            # Create a new GDAL dataset to save the split part
            driver = gdal.GetDriverByName('GTiff')
            output_path = os.path.join(output_folder, f'part_{i}_{j}.tif')
            out_dataset = driver.Create(output_path, part_width, part_height, bands, gdal.GDT_UInt16)

            # Write data to the new dataset
            for band in range(bands):
                out_band = out_dataset.GetRasterBand(band + 1)
                out_band.WriteArray(part[band])

            # Set georeference and projection
            geo_transform = list(dataset.GetGeoTransform())
            geo_transform[0] += x_offset * geo_transform[1]
            geo_transform[3] += y_offset * geo_transform[5]
            out_dataset.SetGeoTransform(tuple(geo_transform))
            out_dataset.SetProjection(dataset.GetProjection())

            # Save and close
            out_dataset.FlushCache()
            del out_dataset

    # Close the original dataset
    del dataset

### Function 3:  Check Is multiband or not?

In [18]:
def is_multiband_pil(image_path):
    """
    Check if an image is multiband or singleband using Pillow.

    Args:
        image_path (str): Path to the image file.

    Returns:
        bool: True if the image is multiband (color), False if it is singleband (grayscale).
    """
    try:
        with Image.open(image_path) as img:
            return img.mode in ("RGB", "RGBA", "CMYK", "YCbCr")
    except Exception as e:
        print(f"Error opening image: {e}")
        return False

### Function 4 : Trans and Denormalize

In [19]:
def trans(image):
    image = torchvision.transforms.functional.to_tensor(image)
    image = image.float() / 255.0 #get pixel values between 0 and 1 for uint8
    print(image)

    if image.shape[0] != 1:
        image = image[:3, :, :] #Assumes that first three channels are RGB
        image = torchvision.transforms.functional.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    else:
        image = torchvision.transforms.functional.normalize(image, mean=0.445, std=0.269)
    return image

def denormalize(image_tensor):
    # This is for denormalization for visualization purposes
    if image_tensor.shape[0] != 1:
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(image_tensor.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(image_tensor.device)
    else:
        mean = torch.tensor([0.445]).view(1, 1, 1).to(image_tensor.device)
        std = torch.tensor([0.269]).view(1, 1, 1).to(image_tensor.device)
    image_tensor = image_tensor * std + mean
    return image_tensor


### Function 5 : Segmentation Dataset

In [20]:
def segmentation_dataset(data_path, transform=None):
   image_files = os.listdir(data_path)
   dataset = []

   for img_name in image_files:
         image = tifffile.imread(os.path.join(data_path, img_name)) #read .tif file

         if transform:
           image = trans(image.astype(float))

         dataset.append((image, img_name)) #append the image-mask pair in the dataset

   return dataset

### Function 6 : Predict and Save Masks

In [21]:
def predict_and_save_masks(dataset, model, device, predicted_masks_path):
    pred_masks = []
    
    for i in range(len(dataset)):
        image, img_name = dataset[i]
        image = image.to(device, dtype=torch.float).unsqueeze(0)  # Add batch dimension

        model.eval()
        with torch.no_grad():
            pred = model(image)

        pred = torch.sigmoid(pred).squeeze().cpu().numpy()  # Convert prediction to numpy array
        pred_print = (pred >= 0.5).astype(np.float32)

        pred_masks.append(pred_print)

        save_path = os.path.join(predicted_masks_path, img_name)
        tifffile.imwrite(save_path, pred_print)

        # Denormalize the image for future use if needed
        ##image_np = denormalize(image.squeeze()).cpu().numpy()

    return pred_masks

### Function 7 : Set Georeference for Masks

In [22]:
def georeference_masks(predicted_masks_path, tiled_images_path, output_path):
    """
    Georeferences mask images in predicted_masks_path to match corresponding images
    in tiled_images_path and saves them in output_path.
    Args:
    - predicted_masks_path (str): Path to directory containing mask images in TIFF format.
    - tiled_images_path (str): Path to directory containing georeferenced images in TIFF format.
    - output_path (str): Path to directory where georeferenced mask images will be saved.
    Returns:
    - None
    """
    # Ensure the output directory exists
    os.makedirs(output_path, exist_ok=True)

    # Function to georeference the mask image
    def georeference_mask(mask_path, image_path, output_path):
        with rasterio.open(image_path) as src:
            # Read metadata from the georeferenced image
            metadata = src.meta.copy()

        with rasterio.open(mask_path) as mask:
            mask_data = mask.read(1)

            # Update metadata for the mask
            metadata.update({
                'count': 1,
                'dtype': 'uint8',
                'nodata': 0
            })

            output_file = os.path.join(output_path, os.path.basename(mask_path))

            with rasterio.open(output_file, 'w', **metadata) as dst:
                dst.write(mask_data, 1)

    # Loop over all mask files and georeference them
    for mask_file in os.listdir(predicted_masks_path):
        if mask_file.endswith('.tif'):
            mask_path = os.path.join(predicted_masks_path, mask_file)
            image_path = os.path.join(tiled_images_path, mask_file)

            if os.path.exists(image_path):
                georeference_mask(mask_path, image_path, output_path)
            else:
                print(f"Warning: Corresponding image for {mask_file} not found.")

    print("Georeferencing completed.")


### Function 8 : Create the Heatmap

In [23]:
def mask_to_heatmap(input_raster_path, output_heatmap_path, window_size=5):
    """
    Converts a mask raster to a heatmap raster and saves it to a specified file.

    Parameters:
    - input_raster_path: str, path to the input mask raster file
    - output_heatmap_path: str, path to the output heatmap raster file
    - window_size: int, size of the window to calculate density (default is 5)
    """
    # Step 1: Load the mask raster
    with rasterio.open(input_raster_path) as src:
        mask_data = src.read(1)  # Assuming the mask is in the first band
        transform = src.transform
        crs = src.crs
        width = src.width
        height = src.height

    # Step 2: Identify regions with at least one '1'
    non_zero_mask = mask_data > 0

    # Step 3: Process the raster data to create a heatmap
    def calculate_density(data, window_size=5):
        # Create a window of ones
        window = np.ones((2 * window_size + 1, 2 * window_size + 1), dtype=np.float32)
        # Use convolution to calculate the density
        density = convolve(data, window, mode='constant', cval=0.0)
        return density

    # Apply the density calculation only to non-zero regions
    density_data = np.zeros_like(mask_data, dtype=np.float32)
    density_data[non_zero_mask] = calculate_density(mask_data, window_size)[non_zero_mask]

    # Normalize density data for better visualization
    if np.max(density_data) > 0:
        density_data = density_data / np.max(density_data)

    # Step 4: Save the heatmap to a new file
    out_meta = {
        'driver': 'GTiff',
        'height': height,
        'width': width,
        'count': 1,
        'dtype': 'float32',
        'crs': crs,
        'transform': transform
    }

    with rasterio.open(output_heatmap_path, 'w', **out_meta) as dst:
        dst.write(density_data, 1)

    print(f"Heatmap saved to {output_heatmap_path}")
    heatmap = Image.open(output_heatmap_path)
    print(type(heatmap))
    return heatmap

### Function 9 : Mosaic Rasters

In [24]:
def mosaic_rasters(input_folder, output_path):
    # List to hold the file paths of the rasters to be merged
    input_files = []

    # Loop through the folder and add all .tif files to the list
    for file_name in os.listdir(input_folder):
        print(file_name)
        if file_name.endswith('.tif'):
            input_files.append(os.path.join(input_folder, file_name))

    # Check if we have any input files
    if not input_files:
        raise FileNotFoundError("No .tif files found in the specified folder.")

    # Open the input files
    src_files_to_mosaic = []
    for file in input_files:
        src = gdal.Open(file)
        if src:
            src_files_to_mosaic.append(src)
        else:
            print(f"Failed to open {file}")

    # Create a virtual raster from the input files
    vrt = gdal.BuildVRT('temporary.vrt', src_files_to_mosaic)
    if vrt is None:
        raise ValueError("Failed to create virtual raster (VRT).")

    # Write the virtual raster to a new file
    gdal.Translate(output_path, vrt)

    # Cleanup
    vrt = None
    for src in src_files_to_mosaic:
        src = None

    print(f"Mosaic raster saved as {output_path}")

### Function 10 : Process and Predict

In [25]:
def process_and_predict(source_image, reduced_image, source_model, tiled_images_path, predicted_masks_path, predicted_masks_georef_path, predicted_mask, mask_heatmap, seed=0):
    # Step 1: Reduce radiometric resolution
    reduce_radiometric_resolution(source_image, reduced_image)
    
    # Step 2: Split and save raster
    split_and_save_raster(reduced_image, 224, 224, tiled_images_path)
    
    # Step 3: Set random seed for reproducibility
    torch.manual_seed(seed)
    
    # Step 4: Prepare dataset
    dataset = segmentation_dataset(tiled_images_path, transform=trans)
    
    # Step 5: Import and load the model
    model = torch.load(source_model)
    
    # Step 6: Transfer model to the appropriate device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Step 7: Predict and save masks
    predict_and_save_masks(dataset, model, device, predicted_masks_path)
    
    # Step 8: Georeference masks
    georeference_masks(predicted_masks_path, tiled_images_path, predicted_masks_georef_path)
    
    # Step 9: Mosaic the masks back together
    mosaic_rasters(predicted_masks_georef_path, predicted_mask)
    
    # Step 10: Convert the mask to a heatmap
    heatmap = mask_to_heatmap(predicted_mask, mask_heatmap, 5)
        
    return heatmap

# Step 3 : Call all functions

In [26]:
process_and_predict(
    source_image,
    reduced_image,
    source_model,
    tiled_images_path,
    predicted_masks_path,
    predicted_masks_georef_path,
    predicted_mask,
    mask_heatmap
)

tensor([[[0.2314, 0.3176, 0.3961,  ..., 0.1569, 0.1569, 0.1529],
         [0.4549, 0.3098, 0.1804,  ..., 0.1529, 0.1333, 0.1137],
         [0.4824, 0.4314, 0.2863,  ..., 0.1333, 0.1373, 0.1176],
         ...,
         [0.2549, 0.2431, 0.2588,  ..., 0.1843, 0.2667, 0.3176],
         [0.2314, 0.2627, 0.2667,  ..., 0.2902, 0.2235, 0.2039],
         [0.2667, 0.2784, 0.2588,  ..., 0.4588, 0.3529, 0.3255]]])
tensor([[[0.2863, 0.2824, 0.2902,  ..., 0.3255, 0.3176, 0.3294],
         [0.2863, 0.2824, 0.2863,  ..., 0.1961, 0.2196, 0.2118],
         [0.2471, 0.2706, 0.2706,  ..., 0.2118, 0.1843, 0.1725],
         ...,
         [0.2824, 0.3961, 0.3608,  ..., 0.3020, 0.3294, 0.3529],
         [0.2902, 0.2784, 0.2902,  ..., 0.3137, 0.3255, 0.3333],
         [0.2980, 0.2314, 0.2353,  ..., 0.3373, 0.3294, 0.3137]]])
tensor([[[0.3216, 0.2784, 0.2471,  ..., 0.3412, 0.3255, 0.3294],
         [0.2824, 0.2902, 0.2784,  ..., 0.3216, 0.3098, 0.3255],
         [0.2745, 0.3020, 0.3059,  ..., 0.3412, 0.3137, 0.

<PIL.TiffImagePlugin.TiffImageFile image mode=F size=2240x2240>

# Step 4 : UserInterface

In [27]:
class App(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("NavySeal")
        self.geometry("800x600")

        # Create a canvas
        self.canvas = tk.Canvas(self, width=800, height=600)
        self.canvas.pack(fill="both", expand=True)

        # Load the background image
        self.bg_image = Image.open("seals.jpg")
        self.bg_image = self.bg_image.resize((800, 600), Image.LANCZOS)

        # Adjust the transparency of the image
        self.bg_image = self.bg_image.convert("RGBA")
        alpha = self.bg_image.split()[3]
        alpha = ImageEnhance.Brightness(alpha).enhance(0.5)  # Adjust transparency to 50%
        self.bg_image.putalpha(alpha)

        # Convert the image with transparency to PhotoImage
        self.bg_photo = ImageTk.PhotoImage(self.bg_image)

        # Add the background image to the canvas
        self.canvas.create_image(0, 0, image=self.bg_photo, anchor="nw")

        # Keep a reference to the image to prevent it from being garbage collected
        self.canvas.image = self.bg_photo

        # Add other widgets on top of the canvas
        self.panel = tk.Label(self)
        self.canvas.create_window(400, 300, window=self.panel)

        self.btn_select = tk.Button(self, text="Select Image", command=self.select_image)
        self.canvas.create_window(400, 450, window=self.btn_select)

        self.btn_process = tk.Button(self, text="Process Data", command=self.process_data)
        self.canvas.create_window(400, 500, window=self.btn_process)

        self.result_label = tk.Label(self, text="File path: ")
        self.canvas.create_window(400, 550, window=self.result_label)

    def read_and_display_tif(self, file_path):
        global selected_image
        try:
            selected_image = Image.open(file_path)
            selected_image = selected_image.resize((500, 500))  # Resize to fit the Tkinter window
            img_tk = ImageTk.PhotoImage(selected_image)
            self.panel.config(image=img_tk)
            self.panel.image = img_tk

            self.result_label.config(text=f"File path: {file_path}")
        except Exception as e:
            print(f"Error reading TIF file: {e}")

    def select_image(self):
        file_path = filedialog.askopenfilename(filetypes=[("TIF Files", "*.tif")])
        if file_path:
            self.read_and_display_tif(file_path)

    def process_data(self):
        # Display message indicating the image is saved
        info_popup = Toplevel(self)
        info_popup.title("Image Saved")
        info_label = Label(info_popup, text="Your image is saved in the folder.")
        info_label.pack(padx=20, pady=20)

if __name__ == "__main__":
    app = App()
    app.mainloop()

In [ ]:
# # Global variable to store the selected image file path
# selected_image = None
# 
# class App(tk.Tk):
#     def __init__(self):
#         super().__init__()
#         self.title("NavySeal")
#         self.geometry("800x600")
# 
#         # Create a canvas
#         self.canvas = tk.Canvas(self, width=800, height=600)
#         self.canvas.pack(fill="both", expand=True)
# 
#         # Load the background image
#         self.bg_image = Image.open("seals.jpg")
#         self.bg_image = self.bg_image.resize((800, 600), Image.LANCZOS)
# 
#         # Adjust the transparency of the image
#         self.bg_image = self.bg_image.convert("RGBA")
#         alpha = self.bg_image.split()[3]
#         alpha = ImageEnhance.Brightness(alpha).enhance(0.5)  # Adjust transparency to 50%
#         self.bg_image.putalpha(alpha)
# 
#         # Convert the image with transparency to PhotoImage
#         self.bg_photo = ImageTk.PhotoImage(self.bg_image)
# 
#         # Add the background image to the canvas
#         self.canvas.create_image(0, 0, image=self.bg_photo, anchor="nw")
# 
#         # Keep a reference to the image to prevent it from being garbage collected
#         self.canvas.image = self.bg_photo
# 
#         # Add other widgets on top of the canvas
#         self.panel = tk.Label(self)
#         self.canvas.create_window(400, 300, window=self.panel)
# 
#         self.btn_select = tk.Button(self, text="Select Image", command=self.select_image)
#         self.canvas.create_window(400, 450, window=self.btn_select)
# 
#         self.btn_process = tk.Button(self, text="Process Data", command=self.process_data)
#         self.canvas.create_window(400, 500, window=self.btn_process)
# 
#         self.result_label = tk.Label(self, text="File path: ")
#         self.canvas.create_window(400, 550, window=self.result_label)
# 
#     def read_and_display_tif(self, file_path):
#         try:
#             img = Image.open(file_path)
#             img = img.resize((500, 500))  # Resize to fit the Tkinter window
#             img_tk = ImageTk.PhotoImage(img)
#             self.panel.config(image=img_tk)
#             self.panel.image = img_tk
# 
#             self.result_label.config(text=f"File path: {file_path}")
#         except Exception as e:
#             print(f"Error reading TIF file: {e}")
# 
#     def select_image(self):
#         global selected_image
#         file_path = filedialog.askopenfilename(filetypes=[("TIF Files", "*.tif")])
#         if file_path:
#             selected_image = file_path
#             self.read_and_display_tif(file_path)
# 
#     def process_data(self):
#         global selected_image
#         if selected_image is None:
#             self.result_label.config(text="No image selected")
#             return
# 
#         # Display "Please Wait" popup
#         please_wait_popup = Toplevel(self)
#         please_wait_popup.title("Processing")
#         please_wait_label = Label(please_wait_popup, text="Currently in Navy SEALs, please wait.")
#         please_wait_label.pack(padx=10, pady=10)
# 
#         # Simulate data processing
#         self.after(2000, lambda: self.finish_processing(please_wait_popup))
# 
#     def finish_processing(self, popup):
#         global selected_image
#         # Close the popup
#         popup.destroy()
# 
#         # Simulate processed data (for now just creating an inverted image)
#         # Replace this with actual data processing logi       
#         processed_data = process_and_predict(
#             selected_image,
#             reduced_image,
#             source_model,
#             tiled_images_path,
#             predicted_masks_path,
#             predicted_masks_georef_path,
#             predicted_mask,
#             mask_heatmap
#             )       
#         
#         img = processed_data
#         img = img.resize((500,500))  # Resize to fit the Tkinter window
#         img_tk = ImageTk.PhotoImage(img)
#         self.panel.config(image=img_tk)
#         self.panel.image = img_tk
# 
#         self.result_label.config(text="Processing complete")
# 
# if __name__ == "__main__":
#     app = App()
#     app.mainloop()