# This notebook was created to pre-process the spot images for futher fine tuning the prithvi model 


## Learning Objectives

* The learner who want to pre-process the spot images for futher fine tuning the prithvi model to classify land use

## Tasks

1. Importing Necessary Libraries
2. Calculates vegetation and water indices (NDVI and NDWI) from a raster image
3. Processes a set of multiband images and stacks all the bands into a single raster file
4. Reproject and resample a land use raster image (landuse_image) to a specified resolution and spatial reference system (SRS)
5. Clip to reference extent
6. Creating patches from raster datasets
7.  Organizes the outputs for the specified region



## prithvi model

Prithvi-EO-1.0 is a first-of-its-kind temporal Vision transformer pre-trained by the IBM and NASA team on contiguous US Harmonised Landsat Sentinel 2 (HLS) data. The model adopts a self-supervised encoder developed with a ViT architecture and Masked AutoEncoder (MAE) learning strategy, with an MSE loss function. The model includes spatial attention across multiple patches and also temporal attention for each patch.

The Prithvi model is a pre-trained deep learning model designed specifically for geospatial and remote sensing applications. It is tailored to work with satellite imagery and supports tasks such as land use and land cover (LULC) classification, segmentation, and other image analysis objectives


**How the Prithvi Model Works in a Land Use Classification Workflow**

* Input Data: Multi-band satellite imagery (e.g., SPOT, Sentinel-2) is pre-processed (clipping, upscaling, stacking, patching).
* Model Fine-Tuning: The pre-trained Prithvi model is fine-tuned on the target dataset using labeled training samples for the specific classification task.
* Prediction: The fine-tuned model predicts land use classes for the input imagery, generating detailed maps.
* Output: Classified land use maps that can be visualized and analyzed for decision-making.



## Importing Necessary Libraries

#### rasterio:
- A Python library for reading and writing geospatial raster data.
- It provides tools for working with raster datasets in formats like GeoTIFF.
- Rasterio enables georeferencing and allows easy access to metadata and bands of raster files.

#### numpy:
- A library for numerical computing in Python.
- Often used for array manipulation, mathematical operations, and working with raster data as numerical arrays.

#### glob:
- A Python module for finding files and directories that match a specified pattern.
- Useful for searching for raster files in a directory.
#### os:
- A module that provides functions for interacting with the operating system.
- Used for file path manipulations, creating directories, or listing files.
- Create output directories for storing processed data.

In [2]:
import rasterio
import numpy as np
import glob
import os


In [3]:
!pip install geopandas
!pip install tqdm



## Thing need to use the preprocess 

In [18]:
place_name='khon-kaen' # Place or experiment name we are working
#output_folder=f'./outputs_spotb_22/{place_name}/' # Keep it as it is
output_folder= "/home/jovyan/shared/PCN/Prithvi/notebook/outputs_spotb_22"
spot_images_folder_path = "/home/jovyan/shared/2025_THA_LDD_training/SPOT/B/SPOT" # Folder path which contains all three spot image from different time
landuse_path='/home/jovyan/shared/PCN/Prithvi/01_preprocessing/kon-kaen_lu_raster/gt_LU2022B.tif' # Rasterized landuse path

## Calculate indicies and stack all three images with calculated indices

In [19]:
def calculate_indices(image_path):
    with rasterio.open(image_path) as src:
        red = src.read(1).astype(float)
        green = src.read(2).astype(float)
        blue = src.read(3).astype(float)
        nir = src.read(4).astype(float)

        meta = src.meta
    epsilon = 1e-10
    ndvi = (nir - red) / (nir + red + epsilon)
    ndwi = (green - nir) / (green + nir + epsilon)
    
    return ndvi, ndwi, meta

def process_and_stack_images(folder_path, output_path):
    # Get all multiband images
    image_files = sorted(glob.glob(os.path.join(folder_path, "*.tif")))
    
    if len(image_files) != 3:
        raise ValueError(f"Expected 3 images, found {len(image_files)}")
    
    print(f"Found {len(image_files)} images: {image_files}")
    all_bands = []
    for image_path in image_files:
        print(f"Processing {image_path}...")
        with rasterio.open(image_path) as src:
            bands = [src.read(i) for i in range(1, 5)]  
            meta = src.meta
        ndvi, ndwi, _ = calculate_indices(image_path)
        all_bands.extend(bands + [ndvi, ndwi])
    
    meta.update({
        'count': len(all_bands),  #18 bands total(6 bands × 3 images)
        'dtype': 'float32'
    })
    os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
    print(f"Saving stacked image to {output_path}...")
    with rasterio.open(output_path, 'w', **meta) as dst:
        for idx, band in enumerate(all_bands, start=1):
            dst.write(band.astype(np.float32), idx)
    print("Done with calculating and stacking all bands!")



## The script defines variables and uses the process_and_stack_images function to process raster images and create a stacked output.

In [20]:
## It may take upto 5 min to calculate indices and stack image
stacked_spot_image =os.path.join(output_folder,"stack/18_band_spot_b.tif" ) # give file name to stack image
#stacked_spot_image= "/home/jovyan/shared/PCN/Prithvi/notebook/outputs_spotb_22/khon-kaen/stack/18_band_spot_b.tif"
process_and_stack_images(spot_images_folder_path, stacked_spot_image )

## The code below will first clip the landuse image with the extent of the spot image and then it will create patch of 224*224 for both images

It is separated into three main functions and a main workflow.

#### Importing Necessary Libraries

In [None]:
import rasterio as rio
from rasterio.mask import mask
from rasterio.windows import Window
from rasterio.warp import reproject, Resampling
import geopandas as gpd
from shapely.geometry import box
import os
import numpy as np
from tqdm import tqdm

### Extract the bounding box (extent) from the reference raster file and Clip the source_path raster to match the reference extent.

In [None]:
def clip_to_reference_extent(source_path, reference_path, output_path):
    print(f"Clipping {os.path.basename(source_path)} to reference extent...")
    # Get reference image extent
    with rio.open(reference_path) as ref:
        ref_bounds = ref.bounds
        ref_bbox = box(ref_bounds.left, ref_bounds.bottom, ref_bounds.right, ref_bounds.top)
        ref_gdf = gpd.GeoDataFrame({'geometry': [ref_bbox]}, crs=ref.crs)
    
    # Clip source to reference extent
    with rio.open(source_path) as src:
        ref_gdf = ref_gdf.to_crs(src.crs)
        out_image, out_transform = mask(src, ref_gdf.geometry, crop=True)
        
        out_meta = src.meta.copy()
        out_meta.update({
            "height": out_image.shape[1],
            "width": out_image.shape[2],
            "transform": out_transform
        })
        
        with rio.open(output_path, "w", **out_meta) as dest:
            dest.write(out_image)
    print("Clipping completed!")

## Create training patches of size 224*224 for both spot image and landuse image

 **Calculate the number of patches required along the width (x) and height (y) of the raster.**
- Divide the raster dimensions by the patch size and round up to ensure full coverage.
- Calculate the total number of patches (total_patches).

**Fine-Tuning Pre-trained Models**:

Using 224x224 patches ensures that the satellite image patches can directly serve as input for models like the Prithvi model or other pre-trained architectures without additional resizing, preserving details.

In [21]:

def create_patches(image_path, output_dir, patch_size=224):
    with rio.open(image_path) as src:
        # Calculate number of patches
        num_patches_x = int(np.ceil(src.width / patch_size))
        num_patches_y = int(np.ceil(src.height / patch_size))
        total_patches = num_patches_y * num_patches_x
        
        print(f"Creating {total_patches} patches for {os.path.basename(image_path)}...")
        
        # Create progress bar
        pbar = tqdm(total=total_patches, desc="Creating patches")
        
        for y in range(num_patches_y):
            for x in range(num_patches_x):
                # Define window for current patch
                window = Window(x * patch_size, y * patch_size, patch_size, patch_size)
                transform = rio.windows.transform(window, src.transform)
                patch = src.read(window=window)
                profile = src.profile.copy()
                profile.update({
                    'height': patch.shape[1],
                    'width': patch.shape[2],
                    'transform': transform
                })
                
                # Use same naming pattern for both landuse and spot patches
                patch_name = f"patch_{x:03d}_{y:03d}.tif"
                output_path = os.path.join(output_dir, patch_name)
                
                with rio.open(output_path, 'w', **profile) as dest:
                    dest.write(patch)
                
                pbar.update(1)
        
        pbar.close()
    print(f"Patch creation completed for {os.path.basename(image_path)}!")

### Process images

- Create directories for storing land use and SPOT image patches.
- Align the land use raster to the extent of the SPOT imagery using the clip_to_reference_extent function.
- Create patches for both the clipped land use raster and SPOT imagery using the create_patches function.

In [None]:
def process_images(landuse_path, spot_path, output_base_dir):
    print("Starting image processing...")
    
    # Create output directories
    landuse_patches_dir = os.path.join(output_base_dir, 'landuse_patches_big_extent')
    spot_patches_dir = os.path.join(output_base_dir, 'spot_patches_big_extent')
    os.makedirs(landuse_patches_dir, exist_ok=True)
    os.makedirs(spot_patches_dir, exist_ok=True)
    
    # 1. First clip landuse to spot extent
    clipped_landuse_path = os.path.join(output_base_dir, 'clipped_landuse_big_extent.tif')
    clip_to_reference_extent(landuse_path, spot_path, clipped_landuse_path)
    
    # 2. Create patches for both images
    create_patches(clipped_landuse_path, landuse_patches_dir)
    create_patches(spot_path, spot_patches_dir)
    
    print("Image processing completed successfully!")

This script processes land use and SPOT imagery for the region specified by the place_name variable (khon-kaen in this case). It aligns the land use raster to the SPOT imagery's extent and resolution, and then divides both datasets into smaller patches for further analysis.

In [22]:
process_images(
    landuse_path=landuse_path, # File path of the landuse image
    spot_path=stacked_spot_image,
    output_base_dir=output_folder
)

Starting image processing...
Clipping gt_LU2022B.tif to reference extent...
Clipping completed!
Creating 2516 patches for clipped_landuse_big_extent.tif...


Creating patches: 100%|██████████| 2516/2516 [08:41<00:00,  4.83it/s] 


Patch creation completed for clipped_landuse_big_extent.tif!
Creating 2070 patches for 18_band_spot_b.tif...


Creating patches: 100%|██████████| 2070/2070 [23:46<00:00,  1.45it/s]

Patch creation completed for 18_band_spot_b.tif!
Image processing completed successfully!





## Filter the patches created above as some of them might not be clipped to 224*224 size

Ensure all images in a folder are of the required size (224x224 pixels).

In [23]:


def filter_patch_images(image_folder):
    desired_width, desired_height = 224, 224
    deleted_files = []
    image_files = [f for f in os.listdir(image_folder) if f.endswith('.tif')]
    
    for image_file in tqdm(image_files, desc="Processing Images", unit="image"):
        image_path = os.path.join(image_folder, image_file)
        try:
            with rasterio.open(image_path) as src:
                width, height = src.width, src.height
                if (width, height) != (desired_width, desired_height):
                    os.remove(image_path)
                    deleted_files.append(image_file)
        except Exception as e:
            print(f"Error processing {image_file}: {e}")
    
    print("\nSummary:")
    print(f"Total images processed: {len(image_files)}")
    print(f"Images deleted: {len(deleted_files)}")
    if deleted_files:
        print("Deleted files:")
        for file in deleted_files:
            print(f"- {file}")

In [24]:
landuse_patch_folder = os.path.join(output_folder,'landuse_patches_big_extent')
spot_patch_folder = os.path.join(output_folder,'spot_patches_big_extent')
filter_patch_images(spot_patch_folder)
filter_patch_images(landuse_patch_folder)

Processing Images: 100%|██████████| 2070/2070 [00:57<00:00, 35.80image/s] 



Summary:
Total images processed: 2070
Images deleted: 98
Deleted files:
- patch_068_007.tif
- patch_068_021.tif
- patch_009_029.tif
- patch_024_029.tif
- patch_029_029.tif
- patch_050_029.tif
- patch_056_029.tif
- patch_068_027.tif
- patch_003_029.tif
- patch_013_029.tif
- patch_019_029.tif
- patch_039_029.tif
- patch_047_029.tif
- patch_060_029.tif
- patch_064_029.tif
- patch_068_000.tif
- patch_017_029.tif
- patch_031_029.tif
- patch_068_028.tif
- patch_016_029.tif
- patch_033_029.tif
- patch_063_029.tif
- patch_068_002.tif
- patch_018_029.tif
- patch_045_029.tif
- patch_054_029.tif
- patch_068_004.tif
- patch_068_013.tif
- patch_022_029.tif
- patch_028_029.tif
- patch_055_029.tif
- patch_059_029.tif
- patch_066_029.tif
- patch_067_029.tif
- patch_002_029.tif
- patch_011_029.tif
- patch_027_029.tif
- patch_038_029.tif
- patch_040_029.tif
- patch_068_003.tif
- patch_068_014.tif
- patch_068_016.tif
- patch_068_018.tif
- patch_007_029.tif
- patch_043_029.tif
- patch_052_029.tif
- patch

Processing Images: 100%|██████████| 2516/2516 [00:56<00:00, 44.78image/s] 


Summary:
Total images processed: 2516
Images deleted: 107
Deleted files:
- patch_073_017.tif
- patch_073_021.tif
- patch_073_032.tif
- patch_012_033.tif
- patch_028_033.tif
- patch_030_033.tif
- patch_042_033.tif
- patch_060_033.tif
- patch_031_033.tif
- patch_071_033.tif
- patch_038_033.tif
- patch_073_008.tif
- patch_073_022.tif
- patch_073_031.tif
- patch_041_033.tif
- patch_043_033.tif
- patch_045_033.tif
- patch_062_033.tif
- patch_051_033.tif
- patch_073_018.tif
- patch_008_033.tif
- patch_017_033.tif
- patch_033_033.tif
- patch_055_033.tif
- patch_061_033.tif
- patch_073_027.tif
- patch_029_033.tif
- patch_037_033.tif
- patch_073_004.tif
- patch_073_003.tif
- patch_073_007.tif
- patch_073_002.tif
- patch_023_033.tif
- patch_032_033.tif
- patch_036_033.tif
- patch_070_033.tif
- patch_073_000.tif
- patch_002_033.tif
- patch_010_033.tif
- patch_035_033.tif
- patch_034_033.tif
- patch_069_033.tif
- patch_006_033.tif
- patch_013_033.tif
- patch_018_033.tif
- patch_040_033.tif
- patc




## Split the patches into train, test, and validation set

This process is essential for training machine learning models, where datasets are divided into different subsets for training, validating, and evaluating the model's performance.



In [25]:
import os
import shutil
import random
from pathlib import Path
import numpy as np

def create_train_val_test_splits(
    image_dir, 
    mask_dir, 
    output_dir,
    train_ratio=0.70,  # 70% of the data will be used for training.
    val_ratio=0.10,    # 10% of the data will be used for validation.
    test_ratio=0.20,   # 20% of the data will be used for testing.
    random_seed=42
):

    # Set random seed
    random.seed(random_seed)
    np.random.seed(random_seed)
    splits = ['train', 'val', 'test']
    for split in splits:
        for subdir in ['images', 'masks']:
            os.makedirs(os.path.join(output_dir, split, subdir), exist_ok=True)
    
    image_files = [f for f in os.listdir(image_dir) if f.endswith('.tif')]
    
    random.shuffle(image_files)
    
    total_size = len(image_files)
    train_size = int(total_size * train_ratio)
    val_size = int(total_size * val_ratio)

    train_files = image_files[:train_size]
    val_files = image_files[train_size:train_size + val_size]
    test_files = image_files[train_size + val_size:]
    
    def copy_files(file_list, split_name):
        print(f"\nCopying {split_name} files...")
        for filename in file_list:
            src_image = os.path.join(image_dir, filename)
            dst_image = os.path.join(output_dir, split_name, 'images', filename)
            shutil.copy2(src_image, dst_image)
            src_mask = os.path.join(mask_dir, filename)
            dst_mask = os.path.join(output_dir, split_name, 'masks', filename)
            shutil.copy2(src_mask, dst_mask)
    
    split_files = {
        'train': train_files,
        'val': val_files,
        'test': test_files
    }
    
    for split_name, files in split_files.items():
        copy_files(files, split_name)
        print(f"{split_name} split: {len(files)} images")

    return {
        'train_size': len(train_files),
        'val_size': len(val_files),
        'test_size': len(test_files)
    }



In [26]:
  
final_training_data_folder = os.path.join(output_folder,'final_training_data_big_extent')

stats = create_train_val_test_splits(
    image_dir=spot_patch_folder,
    mask_dir=landuse_patch_folder,
    output_dir=final_training_data_folder,
    train_ratio=0.70,
    val_ratio=0.10,
    test_ratio=0.20,
    random_seed=42
)

print("\nData split summary:")
print(f"Total images: {sum(stats.values())}")
for split_name, size in stats.items():
    print(f"{split_name}: {size} images")


Copying train files...
train split: 1380 images

Copying val files...
val split: 197 images

Copying test files...
test split: 395 images

Data split summary:
Total images: 1972
train_size: 1380 images
val_size: 197 images
test_size: 395 images


In [1]:
!pwd

/home/jovyan/shared/megharaj/prithvi_model/code_and_data


## Calculates the mean and standard deviation of pixel values for a dataset of raster images

Compute the mean and standard deviation for each channel (e.g., RGB or multispectral bands) across all images in the dataset.

In [None]:
import numpy as np
import rasterio
from glob import glob

def calculate_dataset_statistics(image_paths):
    means = []
    stds = []
    
    for path in image_paths:
        with rasterio.open(path) as src:
            img = src.read()  # Shape: (channels, height, width)
            means.append(img.mean(axis=(1, 2)))
            stds.append(img.std(axis=(1, 2)))
    
    channel_means = np.mean(means, axis=0)
    channel_stds = np.mean(stds, axis=0)
    
    return channel_means, channel_stds


# image_paths = glob(os.path.join(final_training_data_folder,'train','images','*.tif'))
image_paths = glob(os.path.join(final_training_data_folder,'train/images','*.tif'))

means, stds = calculate_dataset_statistics(image_paths)
means_list = ", ".join(f"{mean:.5f}" for mean in means)
stds_list = ", ".join(f"{std:.5f}" for std in stds)

print("Channel means:", means_list)
print("Channel stds:", stds_list)