In [6]:
import os
from IPython.display import display, Image, clear_output
import ipywidgets as widgets
from PIL import Image as PILImage
from image_creator import ImageVisualizer
import matplotlib.pyplot as plt
import numpy as np
import re
import rasterio
import sys 
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import utils

In [7]:
class ImageVisualizer:
    def __init__(self):
        self.band_mappings = {
            "Landsat_4_5_7": {"RED": 3, "GREEN": 2, "BLUE": 1, "NIR": 4, "TIR": 6},
            "Landsat_8_9": {"RED": 4, "GREEN": 3, "BLUE": 2, "NIR": 5, "TIR": 10}
        }

    def rescale_to_8bit(self, band_array, min_value=None, max_value=None):
        """Rescale array values to 8-bit (0-255)."""
        valid_pixels = band_array[~np.isnan(band_array)]
        if len(valid_pixels) == 0:
            min_value, max_value = 0, 255
        else:
            if min_value is None:
                min_value = np.percentile(valid_pixels, 2)
            if max_value is None:
                max_value = np.percentile(valid_pixels, 98)

        band_array = np.clip(band_array, min_value, max_value)
        band_array = ((band_array - min_value) / (max_value - min_value) * 255).astype(np.uint8)
        return np.nan_to_num(band_array)

    def extract_landsat_version(self, filename):
        """Extract Landsat version from filename."""
        match = re.search(r"Landsat\d+", filename)
        if match:
            return match.group(0)
        else:
            raise ValueError(f"Landsat version could not be determined from filename: {filename}")

    def get_band_mapping(self, landsat_version):
        """Get band mapping for the Landsat version."""
        if landsat_version in ["Landsat4", "Landsat5", "Landsat7"]:
            return self.band_mappings["Landsat_4_5_7"]
        elif landsat_version in ["Landsat8", "Landsat9"]:
            return self.band_mappings["Landsat_8_9"]
        else:
            raise ValueError(f"Unsupported Landsat version: {landsat_version}")

    def generate_rgb_image(self, bands, band_mapping):
        """Generate RGB image from band data."""
        red = self.rescale_to_8bit(bands[band_mapping["RED"] - 1])
        green = self.rescale_to_8bit(bands[band_mapping["GREEN"] - 1])
        blue = self.rescale_to_8bit(bands[band_mapping["BLUE"] - 1])
        return np.stack((red, green, blue), axis=-1)

    def generate_combined_image(self, bands, band_mapping):
        """Generate combined RGB, NDVI, and LST visualization."""
        # RGB
        rgb_image = self.generate_rgb_image(bands, band_mapping)

        # NDVI
        nir = bands[band_mapping["NIR"] - 1]
        red = bands[band_mapping["RED"] - 1]
        ndvi = (nir - red) / (nir + red)
        ndvi_image = self.rescale_to_8bit(ndvi, -1, 1)

        # LST
        lst = bands[band_mapping["TIR"] - 1]
        lst_image = self.rescale_to_8bit(lst)

        return rgb_image, ndvi_image, lst_image



In [8]:
# Set directory paths
data_root = utils.get_data_root() 
input_dir = input_dir = os.path.join(data_root, 'test_jpeg_data/landsat_collection_2_request')
tif_files = [
    os.path.join(input_dir, file)
    for file in sorted(os.listdir(input_dir))
    if file.lower().endswith(".tif")
]

# Initialize ImageVisualizer
visualizer = ImageVisualizer()

# Helper function to load bands from a TIF file
def load_bands(tif_file):
    """Load bands from a TIF file."""
    with rasterio.open(tif_file) as src:
        bands = [src.read(i + 1, masked=True).filled(np.nan) for i in range(src.count)]
    return bands


In [9]:
current_index = widgets.IntText(value=0, description="Image Index:", disabled=True)

# Output area for visualization
output_area = widgets.Output()

# Buttons for interactivity
rgb_button = widgets.Button(description="Show RGB")
combined_button = widgets.Button(description="Show Combined")
next_button = widgets.Button(description="Next Image")

# Button handlers
def show_rgb(b):
    """Show the RGB visualization."""
    index = current_index.value
    if 0 <= index < len(tif_files):
        tif_file = tif_files[index]
        bands = load_bands(tif_file)
        landsat_version = visualizer.extract_landsat_version(os.path.basename(tif_file))
        band_mapping = visualizer.get_band_mapping(landsat_version)

        rgb_image = visualizer.generate_rgb_image(bands, band_mapping)
        with output_area:
            clear_output(wait=True)
            plt.imshow(rgb_image)
            plt.axis("off")
            plt.title("RGB Visualization")
            plt.show()

def show_combined(b):
    """Show the Combined visualization."""
    index = current_index.value
    if 0 <= index < len(tif_files):
        tif_file = tif_files[index]
        bands = load_bands(tif_file)
        landsat_version = visualizer.extract_landsat_version(os.path.basename(tif_file))
        band_mapping = visualizer.get_band_mapping(landsat_version)

        rgb_image, ndvi_image, lst_image = visualizer.generate_combined_image(bands, band_mapping)
        with output_area:
            clear_output(wait=True)
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))

            axes[0].imshow(rgb_image)
            axes[0].set_title("RGB")
            axes[0].axis("off")

            axes[1].imshow(ndvi_image, cmap="gray")
            axes[1].set_title("NDVI")
            axes[1].axis("off")

            axes[2].imshow(lst_image, cmap="gray")
            axes[2].set_title("LST")
            axes[2].axis("off")

            plt.show()

def show_next_image(b):
    """Advance to the next image."""
    current_index.value += 1
    if current_index.value < len(tif_files):
        clear_output(wait=True)
        print(f"Processing Image {current_index.value + 1} of {len(tif_files)}...")
    else:
        current_index.value -= 1
        clear_output(wait=True)
        print("No more images to display!")

# Assign button handlers
rgb_button.on_click(show_rgb)
combined_button.on_click(show_combined)
next_button.on_click(show_next_image)


Processing Image 4 of 497...


In [10]:
# Layout and display widgets
ui = widgets.VBox([
    widgets.HBox([rgb_button, combined_button, next_button]),
    widgets.HBox([current_index]),
    output_area
])

display(ui)


VBox(children=(HBox(children=(Button(description='Show RGB', style=ButtonStyle()), Button(description='Show Co…