# Land Cover Mapping with Prithvi EO Model
This notebook demonstrates the workflow for downloading satellite data, preprocessing, running inference with the Prithvi EO segmentation model, and merging results into a single GeoTIFF.

<a href="https://colab.research.google.com/github/easare377/Prithvi-EO-Segmentation/blob/main/create_landmap_prithvi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install terratorch

# Install Dependencies and Mount Google Drive
Install required packages and mount Google Drive for data access and storage.

In [None]:
import os
import numpy as np
import torch
from osgeo import gdal
from pathlib import Path

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:

import requests
from tqdm import tqdm
import pathlib


def download_file(url, dest_path, chunk_size=1024*1024):
    """
    Download a large file from a URL with a progress bar.
    Args:
        url (str): File URL.
        dest_path (str): Destination file path.
        chunk_size (int): Download chunk size in bytes.
    """
    response = requests.get(url, stream=True)
    total = int(response.headers.get('content-length', 0))
    with open(dest_path, 'wb') as file, tqdm(
        desc=f"Downloading {dest_path}",
        total=total,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in response.iter_content(chunk_size=chunk_size):
            size = file.write(data)
            bar.update(size)

# Download and Prepare Satellite Data
Download large satellite images and prepare them for processing and inference.

In [None]:
def zero_pad_array(input_array, new_shape):
    """
    Zero-pad the input_array to the specified new_shape.
    Args:
        input_array (numpy.ndarray): Input array of shape (height, width, ...).
        new_shape (tuple): Desired new shape (new_height, new_width, ...).
    Returns:
        numpy.ndarray: Zero-padded array of shape (new_height, new_width, ...).
    """
    h, w = input_array.shape[:2]
    new_h, new_w = new_shape[:2]
    pad_h = max(new_h - h, 0)
    pad_w = max(new_w - w, 0)
    pad_values = [(0, pad_h), (0, pad_w)]
    pad_values += [(0, 0)] * (input_array.ndim - 2)
    return np.pad(input_array, pad_values, mode='constant', constant_values=0)

def crop_np_array(arr, left, top, right, bottom):
    """
    Crop the input NumPy array to the specified bounding box.

    Args:
        arr (numpy.ndarray): Input array to be cropped.
        left (int): Left coordinate of the bounding box.
        top (int): Top coordinate of the bounding box.
        right (int): Right coordinate of the bounding box.
        bottom (int): Bottom coordinate of the bounding box.

    Returns:
        numpy.ndarray: Cropped array within the specified bounding box.

    Raises:
        ValueError: If the bounding box coordinates are invalid or exceed the array size.
    """
    # Validate the bounding box coordinates
    if left < 0 or top < 0 or right <= left or bottom <= top:
        raise ValueError("Invalid bounding box coordinates.")

    # Get the dimensions of the input array
    arr_height, arr_width = arr.shape[:2]

    # Check if the bounding box exceeds the array size
    if right > arr_width or bottom > arr_height:
        raise ValueError("Bounding box exceeds the array size.")

    # Crop the array to the specified bounding box
    cropped_arr = arr[top:bottom, left:right]

    return cropped_arr

def write_geoTiff_bands(output_raster, np_array, position):
    """
    Write a numpy array to a GeoTIFF file.

    Parameters:
    - output_raster (gdal.Dataset): Output GeoTIFF dataset.
    - np_array (numpy.ndarray): Numpy array to be written.
    - position (tuple): Top-left position (left, top) to write the array in the GeoTIFF.
    """
    array = np_array
    left, top = position
    height, width = array.shape[0], array.shape[1]
    bands = array.shape[2]
    for x in range(bands):
        output_raster.GetRasterBand(x + 1).WriteArray(
            array[:, :, x].reshape((height, width)), xoff=left, yoff=top)   # Writes my array to the raster
    output_raster.FlushCache()


def create_dir_if_not_exists(path):
    if os.path.exists(path):
        return
    os.makedirs(path)


def create_new_geoTiff(dest_path, size, band_count, projection, geo_transform, dtype=gdal.GDT_Byte, compression='DEFLATE', zlevel=9):
    """
    Create a new GeoTIFF file with the specified dimensions, bands, projection, and geotransformation.

    Parameters:
        - dest_path (str): The path where the new GeoTIFF file will be created.
        - size (tuple): A tuple specifying the width and height of the new GeoTIFF in pixels. (width, height)
        - band_count (int): The number of bands in the new GeoTIFF.
        - projection (str): The projection of the GeoTIFF file in Well-Known Text (WKT) format.
        - geo_transform (tuple): A tuple representing the geotransformation parameters for the GeoTIFF.
                               (originX, pixelWidth, 0, originY, 0, pixelHeight)
        - dtype (int, optional): The data type of the pixel values in the new GeoTIFF. Default is gdal.GDT_Byte (8-bit).
        - compression (str, optional): Compression method for the GeoTIFF. Default is 'DEFLATE'.
        - zlevel (int, optional): Compression level for the GeoTIFF. Default is 5.
    Returns:
        gdal.Dataset: The GDAL dataset representing the newly created GeoTIFF file.
    """

    # Create the GeoTIFF file with the specified dimensions
    width, height = size
    driver = gdal.GetDriverByName("GTiff")
    # Set compression options
    options = ['COMPRESS=' + compression, 'ZLEVEL=' + str(zlevel)]
    ds = driver.Create(dest_path, width, height,
                       band_count, dtype, options=options)
    # Define the projection of the file
    ds.SetProjection(projection)
    # Specify its coordinates
    ds.SetGeoTransform(geo_transform)
    return ds

def get_geoTiff_extent(geoTiff, image_bounds=None):
    """
    Get the geographic extent (bounding box) of a GeoTIFF.

    Parameters:
    - geoTiff (GDAL Dataset): Input GeoTIFF dataset object.
    - image_bounds (tuple, optional): Bounds of the specific area of interest within the GeoTIFF.
                                      Format: (xmin, xmax, ymin, ymax). Default is None,
                                      which represents the entire extent of the GeoTIFF.

    Returns:
    - tuple: Geographic extent (left, right, top, bottom) in the coordinate reference system of the GeoTIFF.

    Note:
    - The function assumes the input GeoTIFF is in a projected coordinate reference system (CRS).
    - If image_bounds is provided, the extent will be calculated based on the specified bounds.
      Otherwise, the entire extent of the GeoTIFF will be used.
    - The returned extent represents the geographic coordinates (left, right, top, bottom)
      within the CRS of the GeoTIFF.

    Exceptions:
    - ValueError: Raised if the image bounds are outside the valid range of the GeoTIFF.

    """
    xmin_i, xres_i, xskew_i, ymin_i, yskew_r, yres_i = geoTiff.GetGeoTransform()
    width, height = geoTiff.RasterXSize, geoTiff.RasterYSize
    if image_bounds is None:
        xmin, xmax, ymin, ymax = 0, geoTiff.RasterXSize, 0, geoTiff.RasterYSize
    else:
        xmin, xmax, ymin, ymax = image_bounds
        if xmin < 0 or xmax > width or ymin < 0 or ymax > height:
            raise ValueError(
                "Image bounds are outside the valid range of the GeoTIFF.")
    left = xmin_i + (xmin * xres_i)
    right = xmin_i + (xmax * xres_i)
    top = ymin_i + (ymin * yres_i)
    bottom = ymin_i + (ymax * yres_i)
    return (left, right, top, bottom)

def get_geoTiff_datatype(geoTiff):
    """
    Get the GDAL data type of a GeoTIFF dataset.

    Parameters:
    - geoTiff (gdal.Dataset): Input GeoTIFF dataset.

    Returns:
    - int: GDAL data type of the GeoTIFF dataset.

    """
    band = geoTiff.GetRasterBand(1)
    return band.DataType

def get_geoTiff_numpy_datatype(geoTiff):
    """
    Get the NumPy data type string of a GeoTIFF dataset.

    Parameters:
    - geoTiff (gdal.Dataset): Input GeoTIFF dataset.

    Returns:
    - str: NumPy data type string of the GeoTIFF dataset.

    """
    gt_dtype = get_geoTiff_datatype(geoTiff)
    gdal_to_numpy_datatype = {
        gdal.GDT_Byte: 'uint8',
        gdal.GDT_UInt16: 'uint16',
        gdal.GDT_Int16: 'int16',
        gdal.GDT_UInt32: 'uint32',
        gdal.GDT_Int32: 'int32',
        gdal.GDT_Float32: 'float32',
        gdal.GDT_Float64: 'float64'
    }
    numpy_datatype = gdal_to_numpy_datatype.get(gt_dtype, None)
    return numpy_datatype


def crop_geoTiff(geoTiff, left, top, right, bottom, dtype=None):
    """
    Crop a GeoTIFF array to the specified region.

    Parameters:
    - geoTiff (gdal.Dataset): Input GeoTIFF dataset.
    - left (int): Left coordinate of the crop region.
    - top (int): Top coordinate of the crop region.
    - right (int): Right coordinate of the crop region.
    - bottom (int): Bottom coordinate of the crop region.
    - dtype (str or None, optional): Desired data type of the output array.
    If set to None, the data type is inferred from the band. Defaults to None.

    Returns:
    - output (numpy.ndarray): Cropped array with dimensions (height, width, bands).

    Raises:
    - ValueError: If the crop dimensions exceed the size of the GeoTIFF.
    """
    if (int(right) > geoTiff.RasterXSize) or (int(bottom) > geoTiff.RasterYSize):
        # print(right, bottom)
        # print(geoTiff.RasterXSize, geoTiff.RasterYSize)
        raise ValueError('Crop dimensions exceed the size of the GeoTIFF.')
    if dtype is None:
        dtype = get_geoTiff_numpy_datatype(geoTiff)
    width = abs(right - left)
    height = abs(top - bottom)
    output = np.zeros(
        (int(height), int(width), geoTiff.RasterCount), dtype)
    # bands = [None] * geoTiff.RasterCount
    for x in range(geoTiff.RasterCount):
        band = geoTiff.GetRasterBand(x + 1).ReadAsArray(left, top,
                                                        int(width), int(height))
        output[..., x] = band
    return output

def get_geoTiff_part(geoTiff, left, top, right, bottom, dtype='uint16'):
    c_gt = crop_geoTiff(geoTiff, left, top, right, bottom, dtype)
    extent = get_geoTiff_extent(geoTiff, (left, right, top, bottom))
    return c_gt, extent

def crop_geoTiff_into_grids(geoTiff, max_grid_shape, yield_results=False, dtype = 'uint16'):
    """
    Split a GeoTIFF into grids of numpy arrays with a specified maximum shape.

    Parameters:
    - geoTiff (gdal.Dataset): Input GeoTIFF dataset.
    - max_grid_shape (tuple): Maximum shape of each grid (rows, columns).
    - yield_results (bool, optional): Whether to yield each grid individually or return all grids in a list.
                                      Defaults to True (yield results).

    Yields or Returns:
    - grid (numpy.ndarray): Numpy array representing a divided grid.
      (Yielded if yield_results=True, Returned as a list if yield_results=False)

    Raises:
    - ValueError: If the max_grid_shape is invalid or exceeds the size of the GeoTIFF.
    """
    if not isinstance(max_grid_shape, tuple) or len(max_grid_shape) != 2 or \
       max_grid_shape[0] <= 0 or max_grid_shape[1] <= 0:
        raise ValueError(
            "Invalid max_grid_shape. It should be a tuple of two positive integers.")

    if max_grid_shape[0] > geoTiff.RasterYSize or max_grid_shape[1] > geoTiff.RasterXSize:
        raise ValueError("max_grid_shape exceeds the size of the GeoTIFF.")

    rows, cols = geoTiff.RasterYSize, geoTiff.RasterXSize
    max_grid_height, max_grid_width = max_grid_shape
    total_grids = np.ceil(rows / max_grid_height) * np.ceil(cols / max_grid_width)
    if yield_results:
        progress = 0
        for r in range(0, rows, max_grid_height):
            for c in range(0, cols, max_grid_width):
                left, top = c, r
                right, bottom = min(
                    c + max_grid_width, cols), min(r + max_grid_height, rows)

                grid, extent = get_geoTiff_part(
                    geoTiff, left, top, right, bottom, dtype)
                progress += 1
                yield (grid, extent, (left, top, right, bottom), (progress, total_grids))
    else:
        grids = []
        for r in range(0, rows, max_grid_height):
            for c in range(0, cols, max_grid_width):
                left, top = c, r
                right, bottom = min(
                    c + max_grid_width, cols), min(r + max_grid_height, rows)

                grid, extent = get_geoTiff_part(
                    geoTiff, left, top, right, bottom)
                grids.append((grid, extent, (left, top, right, bottom)))
        return grids

def crop_and_process_geoTiff(geoTiff, process_func, max_grid_shape=None, dtype = 'uint16'):
    """
    Split a GeoTIFF into grids, process each grid, and write the processed grids to an output GeoTIFF.

    Parameters:
    - dataset (gdal.Dataset): Input GeoTIFF dataset.
    - max_grid_shape (tuple): Maximum shape of each grid (rows, columns).
    - output_tiff (str): Output GeoTIFF file path.
    Raises:
    - ValueError: If the max_grid_shape is invalid or exceeds the size of the GeoTIFF.
    """
    if max_grid_shape is None:
        max_grid_shape = geoTiff.RasterYSize, geoTiff.RasterXSize
    if not isinstance(max_grid_shape, tuple) or len(max_grid_shape) != 2 or \
       max_grid_shape[0] <= 0 or max_grid_shape[1] <= 0:
        raise ValueError(
            "Invalid max_grid_shape. It should be a tuple of two positive integers.")

    if max_grid_shape[0] > geoTiff.RasterYSize or max_grid_shape[1] > geoTiff.RasterXSize:
        raise ValueError("max_grid_shape exceeds the size of the GeoTIFF.")

    # Iterate over each grid individually
    for i, grid in enumerate(crop_geoTiff_into_grids(geoTiff, max_grid_shape, True, dtype)):
        # Process each grid (Replace with your processing code)
        np_gt_part, extent, image_coord, progress = grid
        processed_grid = process_func(np_gt_part, extent, image_coord, progress)

def get_file_or_foldername(path):
    return os.path.basename(path)

def real_image_preprocessing_func(np_gt):
    """
    1) Clip the raw array to [0, 6000].
    2) Normalize each band using the provided mean and std (computed from raw data).
    3) Clip the result to [0, 1].
    """

    # Ensure float32
    np_gt = np_gt.astype(np.float32)

    # Provided normalization statistics for raw values (0..6000 range):
    data_mean = np.array([1635.8452, 1584.4594, 1456.8425, 2926.6663, 2135.001, 1352.7313], dtype=np.float32)

    data_std = np.array([884.3994, 815.4016, 839.0293, 1055.6382, 751.4628, 628.5323], dtype=np.float32)


    # 1) Clip the raw values to [0..6000]
    #np_gt = np.clip(np_gt, 0, 6000)

    # 2) Bandwise normalization: (value - mean) / std
    #    Assuming np_gt has shape (H, W, 6), broadcasting will apply along the last dimension
    np_gt = (np_gt - data_mean) / data_std

    # 3) Finally, clip to [0..1] if desired
    #np_gt = np.clip(np_gt, 0, 1)

    return np_gt


# Preprocessing and Utility Functions
Define helper functions for padding, cropping, writing GeoTIFFs, and preprocessing images for inference.

In [None]:
dest_path = '/content/drive/MyDrive/SCO_training/sentinel_images'
geotiff_image_path = os.path.join(dest_path, '2023.tif')
#create_dir_if_not_exists(dest_path)
download_file('https://sco-training.s3.us-east-2.amazonaws.com/2024.tif', geotiff_image_path)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def predict_mask(unet, np_img, input_shape):
    # Zero-pad if needed
    height, width = np_img.shape[:2]
    target_height, target_width = input_shape
    # Preprocess
    np_img = real_image_preprocessing_func(np_img)
    if (height != target_height) or (width != target_width):
        np_img = zero_pad_array(np_img, (target_height, target_width))
    # Convert to torch tensor [B, C, H, W]
    img_tensor = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0)
    img_tensor = img_tensor.float().to(device)
    # Inference
    unet.eval()
    with torch.no_grad():
        logits = unet(img_tensor).output  # [1, n_classes, H, W]
    # Argmax -> [H, W]
    pred_mask = logits.argmax(dim=1).squeeze(0).cpu().numpy()
    # Crop back if zero-padded
    pred_mask = crop_np_array(pred_mask, 0, 0, width, height)
    #print(pred_mask.shape)
    pred_mask = pred_mask.reshape((pred_mask.shape[0], pred_mask.shape[1], 1))
    return pred_mask


def pre_process_func(dest_gt, model, max_grid_shape):
    def predict_grid(np_gt, extent, image_coord, progress):
        mask_np_gt = predict_mask(model, np_gt, max_grid_shape)
        position = image_coord[0], image_coord[1]
        #print(mask_np_gt.shape)
        write_geoTiff_bands(dest_gt, mask_np_gt, position)
        # Optionally log progress
    return predict_grid

In [None]:
# Usage example
def run_inference(gt_paths, prediction_path, mc_unet):
    # Move model to GPU (if available) once
    mc_unet.to(device)
    mc_unet.eval()

    dest_path = prediction_path
    create_dir_if_not_exists(dest_path)

    for gt_path in gt_paths:
        dest_file = os.path.join(dest_path, get_file_or_foldername(gt_path))
        if os.path.exists(dest_file):
            print("skipping", dest_file)
            continue

        print("Processing", gt_path)
        gt = gdal.Open(gt_path)
        proj = gt.GetProjection()
        size = gt.RasterXSize, gt.RasterYSize
        geo_transform = gt.GetGeoTransform()

        dest_gt = create_new_geoTiff(dest_file, size, 1, proj, geo_transform, compression="None")
        tile_callback = pre_process_func(dest_gt, mc_unet, (224, 224))

        crop_and_process_geoTiff(gt, tile_callback, (224, 224), 'uint16')
        gt = None
        dest_gt = None
        print(dest_file, "saved")


# Run Inference and Save Predictions
Run the Prithvi EO segmentation model on satellite images and save predicted masks as GeoTIFFs.

In [None]:
from terratorch.tasks import SemanticSegmentationTask

model = SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args=dict(
        backbone="terratorch_prithvi_eo_v2_300_tl",
        backbone_pretrained=False,
        backbone_img_size=224,
        backbone_bands=["BLUE","GREEN","RED","NIR_NARROW","SWIR_1","SWIR_2"],
        necks=[{"name":"SelectIndices", "indices":[1, 5,11,17,23]},
               {"name":"ReshapeTokensToImage"}],
        decoder="FCNDecoder",
        decoder_channels=256,
        num_classes=3,
        head_dropout=0.1,
    ),
    freeze_backbone=False,
    freeze_decoder=False,
)

In [None]:
model_save_path  = Path("/content/drive/MyDrive/SCO_training/prithvi_state_dict.pt")
model.load_state_dict(torch.load(model_save_path))
model.eval()

In [None]:
total_elements = sum(p.numel() for p in model.model.encoder.parameters())
print(f"Total encoder parameter elements: {total_elements}")

In [None]:
def get_all_files(path, pattern='*', get_full_path=False):
    files = list(pathlib.Path(path).glob(pattern))
    if get_full_path:
        onlyfiles = [os.path.join(path, f.name) for f in files if f.is_file()]
    else:
        onlyfiles = [f.name for f in files if f.is_file()]
    return onlyfiles

In [None]:
gt_paths = get_all_files(
    '/content/drive/MyDrive/SCO_training/sentinel_images', '*.tif', True)
len(gt_paths)

In [None]:
dest = '/content/drive/MyDrive/SCO_training/footprints'
create_dir_if_not_exists(dest)
run_inference(gt_paths, dest, model)

In [None]:
geoTiff_paths = iu.get_all_files(dest, '*.tif', True)
len(geoTiff_paths)

In [None]:
def merge_geoTiffs(geoTiff_paths, output_path, compression='NONE'):
    """
    Merge multiple GeoTIFF files into a single GeoTIFF file.
    """
    from osgeo import gdal

    # Open each input GeoTIFF file read-only
    geoTiffs = [gdal.Open(path, gdal.GA_ReadOnly) for path in geoTiff_paths]

    # Option 1: pass format to gdal.Warp, pass a WarpOptions object via warpOptions
    warp_options = gdal.WarpOptions(options=['COMPRESS=' + compression])
    g = gdal.Warp(output_path, geoTiffs, format='GTiff', warpOptions=warp_options)
    g = None  # Close dataset


# Merge Predicted GeoTIFFs
Combine multiple predicted GeoTIFF files into a single output for visualization or further analysis.

In [None]:
merge_geoTiffs(
    geoTiff_paths, 'C:\\Users\\emmanuelasare\\Downloads\\sentinel-aoi-2025(2025-7-22).tif',  'LZW')