In [1]:
import os
import rasterio
import numpy as np
import cv2

### Pansharpening

In [None]:
def check_band_coverage(directory):
    print(f"\n{'Band':<6} {'Resolution':<10} {'Shape':<15} {'Valid Pixels (%)':<18}")
    print("-" * 60)

    for fname in sorted(os.listdir(directory)):
        if not fname.endswith('.jp2') or '_TCI' in fname:
            continue

        band_code = fname.split('_')[-1].replace('.jp2', '')
        path = os.path.join(directory, fname)

        with rasterio.open(path) as src:
            img = src.read(1)
            total_pixels = img.size
            valid_pixels = np.count_nonzero(img)
            valid_pct = 100 * valid_pixels / total_pixels

            print(f"{band_code:<6} {src.res[0]:<10.1f} {img.shape!s:<15} {valid_pct:<18.2f}")


check_band_coverage("data/S2A_MSIL1C_20210925T092031_N0500_R093_T34SEJ_20230118T233535.SAFE/GRANULE/L1C_T34SEJ_A032694_20210925T092343/IMG_DATA")
check_band_coverage("data/S2A_MSIL1C_20210925T092031_N0500_R093_T34SFJ_20230118T233535.SAFE/GRANULE/L1C_T34SFJ_A032694_20210925T092343/IMG_DATA")
check_band_coverage("data/S2A_MSIL1C_20210925T092031_N0500_R093_T34TEK_20230118T233535.SAFE/GRANULE/L1C_T34TEK_A032694_20210925T092343/IMG_DATA")
check_band_coverage("data/S2A_MSIL1C_20210925T092031_N0500_R093_T34TFK_20230118T233535.SAFE/GRANULE/L1C_T34TFK_A032694_20210925T092343/IMG_DATA")


Band   Resolution Shape           Valid Pixels (%)  
------------------------------------------------------------
B01    60.0       (1830, 1830)    100.00            
B02    10.0       (10980, 10980)  100.00            


KeyboardInterrupt: 

In [2]:
BANDS_10M = ['B02', 'B03', 'B04', 'B08']
BANDS_20M = ['B05', 'B06', 'B07', 'B8A', 'B11', 'B12']
BANDS_60M = ['B01', 'B09', 'B10']
ALL_BANDS = BANDS_10M + BANDS_20M + BANDS_60M
TARGET_SHAPE = (10980, 10980)

def pansharpen_to_10m_and_save(directory, output_tiff="pansharpened.tif"):
    band_paths = {}
    for fname in os.listdir(directory):
        if fname.endswith('.jp2') and '_TCI' not in fname:
            band_code = fname.split('_')[-1].replace('.jp2', '')
            if band_code in ALL_BANDS:
                band_paths[band_code] = os.path.join(directory, fname)

    missing = [b for b in ALL_BANDS if b not in band_paths]
    if missing:
        raise ValueError(f"Missing bands: {missing}")

    # Use metadata from a 10m band as reference
    with rasterio.open(band_paths[BANDS_10M[0]]) as ref_src:
        reference_meta = ref_src.meta.copy()
        reference_meta.update({
            "count": len(ALL_BANDS),
            "height": TARGET_SHAPE[0],
            "width": TARGET_SHAPE[1],
            "driver": "GTiff"
        })

    # Order bands as requested
    ordered_bands = ordered_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']
    print(ordered_bands)

    # Create output TIFF and write one band at a time
    with rasterio.open(output_tiff, "w", **reference_meta) as dst:
        for idx, band in enumerate(ordered_bands, start=1):
            print(idx, band)
            with rasterio.open(band_paths[band]) as src:
                img = src.read(1)

                if band in BANDS_10M:
                    sharpened = img  # No resizing needed
                else:
                    sharpened = cv2.resize(
                        img,
                        (TARGET_SHAPE[1], TARGET_SHAPE[0]),
                        interpolation=cv2.INTER_CUBIC
                    )

                dst.write(sharpened, idx)

    print(f"\nSaved pansharpened image to: {output_tiff}")


In [3]:
pansharpen_to_10m_and_save("data/S2A_MSIL1C_20210925T092031_N0500_R093_T34SEJ_20230118T233535.SAFE/GRANULE/L1C_T34SEJ_A032694_20210925T092343/IMG_DATA", output_tiff="data/T34SEJ_pansharpened.tif")
pansharpen_to_10m_and_save("data/S2A_MSIL1C_20210925T092031_N0500_R093_T34SFJ_20230118T233535.SAFE/GRANULE/L1C_T34SFJ_A032694_20210925T092343/IMG_DATA", output_tiff="data/T34SFJ_pansharpened.tif")
pansharpen_to_10m_and_save("data/S2A_MSIL1C_20210925T092031_N0500_R093_T34TEK_20230118T233535.SAFE/GRANULE/L1C_T34TEK_A032694_20210925T092343/IMG_DATA", output_tiff="data/T34TEK_pansharpened.tif")
pansharpen_to_10m_and_save("data/S2A_MSIL1C_20210925T092031_N0500_R093_T34TFK_20230118T233535.SAFE/GRANULE/L1C_T34TFK_A032694_20210925T092343/IMG_DATA", output_tiff="data/T34TFK_pansharpened.tif")


['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']
1 B01
2 B02
3 B03
4 B04
5 B05
6 B06
7 B07
8 B08
9 B8A
10 B09
11 B10
12 B11
13 B12

Saved pansharpened image to: data/T34SEJ_pansharpened.tif
['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']
1 B01
2 B02
3 B03
4 B04
5 B05
6 B06
7 B07
8 B08
9 B8A
10 B09
11 B10
12 B11
13 B12

Saved pansharpened image to: data/T34SFJ_pansharpened.tif
['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']
1 B01
2 B02
3 B03
4 B04
5 B05
6 B06
7 B07
8 B08
9 B8A
10 B09
11 B10
12 B11
13 B12

Saved pansharpened image to: data/T34TEK_pansharpened.tif
['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12']
1 B01
2 B02
3 B03
4 B04
5 B05
6 B06
7 B07
8 B08
9 B8A
10 B09
11 B10
12 B11
13 B12

Saved pansharpened image to: data/T34TFK_pansharpened.tif


In [3]:
def read_pansharpened_tiff(tiff_path):
    with rasterio.open(tiff_path) as src:
        stack = src.read()  # shape will be (bands, height, width)
        print(f"Stack shape: {stack.shape}")
    return stack

_ = read_pansharpened_tiff("data/T34SEJ_pansharpened.tif")
_ = read_pansharpened_tiff("data/T34SFJ_pansharpened.tif")
_ = read_pansharpened_tiff("data/T34TEK_pansharpened.tif")
_ = read_pansharpened_tiff("data/T34TFK_pansharpened.tif")

Stack shape: (13, 10980, 10980)
Stack shape: (13, 10980, 10980)
Stack shape: (13, 10980, 10980)
Stack shape: (13, 10980, 10980)


### Alignment

In [6]:
import rasterio
from rasterio.plot import show
from rasterio.coords import BoundingBox

def get_bounds_and_crs(tif_path):
    with rasterio.open(tif_path) as src:
        return src.bounds, src.crs

# Paths to your pansharpened tiles
tile_paths = [
    "data/T34SEJ_pansharpened.tif",
    "data/T34SFJ_pansharpened.tif",
    "data/T34TEK_pansharpened.tif",
    "data/T34TFK_pansharpened.tif",
]

# Path to your ground truth .tif file
ground_truth_path = "data/GBDA24_ex2_ref_data.tif"

# Get bounds and CRS of the ground truth
gt_bounds, gt_crs = get_bounds_and_crs(ground_truth_path)

print(f"Ground truth bounds:\n{gt_bounds}")
print(f"Ground truth CRS: {gt_crs}\n")

# Check coverage for each tile
for path in tile_paths:
    tile_bounds, tile_crs = get_bounds_and_crs(path)
    
    # If CRS don't match, you'd need to reproject — for now, we assume same CRS
    if tile_crs != gt_crs:
        print(f"WARNING: CRS mismatch between {path} and ground truth!")
    
    # Check if tile is fully within ground truth
    covered = (
        tile_bounds.left   >= gt_bounds.left and
        tile_bounds.right  <= gt_bounds.right and
        tile_bounds.bottom >= gt_bounds.bottom and
        tile_bounds.top    <= gt_bounds.top
    )

    status = "✅ Covered" if covered else "❌ Not Covered"
    print(f"{path}: {status}")


Ground truth bounds:
BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075)
Ground truth CRS: EPSG:4326

data/T34SEJ_pansharpened.tif: ❌ Not Covered
data/T34SFJ_pansharpened.tif: ❌ Not Covered
data/T34TEK_pansharpened.tif: ❌ Not Covered
data/T34TFK_pansharpened.tif: ❌ Not Covered


In [4]:
import rasterio

def get_tiff_info(filepath):
    with rasterio.open(filepath) as src:
        # Image size (width, height)
        width = src.width
        height = src.height
        # CRS (Coordinate Reference System)
        crs = src.crs
        # GSD (Ground Sample Distance)
        pixel_size_x = src.transform[0]  # width of a pixel
        pixel_size_y = -src.transform[4]  # height of a pixel
        # Get bounds
        bounds = src.bounds
        # Get transform
        transform = src.transform
        # Get number of bands
        band_count = src.count
        info = {
            'size': (width, height),
            'crs': crs,
            'gsd': (pixel_size_x, pixel_size_y),
            'bounds': bounds,
            'transform': transform,
            'band_count': band_count
        }
    return info

# Example usage:
ground_truth_path = "data/GBDA24_ex2_ref_data.tif"
sentinel_paths = [
    "data/T34SEJ_pansharpened.tif",
    "data/T34SFJ_pansharpened.tif",
    "data/T34TEK_pansharpened.tif",
    "data/T34TFK_pansharpened.tif"
]

# Get information for the ground truth and Sentinel-2 files
ground_truth_info = get_tiff_info(ground_truth_path)
sentinel_infos = [get_tiff_info(path) for path in sentinel_paths]

print("Ground Truth Info:", ground_truth_info)
for idx, info in enumerate(sentinel_infos):
    print(f"Sentinel-2 Info for file {sentinel_paths[idx]}:", info)


Ground Truth Info: {'size': (21424, 13572), 'crs': CRS.from_wkt('GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST],AUTHORITY["EPSG","4326"]]'), 'gsd': (8.333333333333333e-05, 8.333333333333333e-05), 'bounds': BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075), 'transform': Affine(8.333333333333333e-05, 0.0, 21.331833333333332,
       0.0, -8.333333333333333e-05, 40.26075), 'band_count': 1}
Sentinel-2 Info for file data/T34SEJ_pansharpened.tif: {'size': (10980, 10980), 'crs': CRS.from_wkt('PROJCS["WGS 84 / UTM zone 34N",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,

In [14]:
from rasterio.warp import calculate_default_transform, reproject, Resampling
import rasterio
from rasterio.crs import CRS

def reproject_raster(src_path, dst_path, dst_crs):
    """
    Reproject a raster to a new CRS.
    """
    with rasterio.open(src_path) as src:
        transform, width, height = calculate_default_transform(
            src.crs, dst_crs, src.width, src.height, *src.bounds
        )
        kwargs = src.meta.copy()
        kwargs.update({
            'crs': dst_crs,
            'transform': transform,
            'width': width,
            'height': height
        })
        with rasterio.open(dst_path, 'w', **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=dst_crs,
                    resampling=Resampling.bilinear
                )

# Reproject each Sentinel-2 file
for idx, path in enumerate(sentinel_paths):
    output_path = path.replace(".tif", "_reprojected_to_4326.tif")
    reproject_raster(path, output_path, CRS.from_epsg(4326))  # Reproject to EPSG:4326
    print(f"Reprojected Sentinel-2 file saved as {output_path}")


Reprojected Sentinel-2 file saved as data/T34SEJ_pansharpened_reprojected_to_4326.tif
Reprojected Sentinel-2 file saved as data/T34SFJ_pansharpened_reprojected_to_4326.tif
Reprojected Sentinel-2 file saved as data/T34TEK_pansharpened_reprojected_to_4326.tif
Reprojected Sentinel-2 file saved as data/T34TFK_pansharpened_reprojected_to_4326.tif


In [8]:
from rasterio.mask import mask

def clip_to_bounds(src_path, bounds, output_path):
    with rasterio.open(src_path) as src:
        # Create a bounding box for clipping (Polygon format)
        left, bottom, right, top = bounds
        geo = {
            'type': 'Polygon',
            'coordinates': [[
                (left, bottom),
                (right, bottom),
                (right, top),
                (left, top),
                (left, bottom)
            ]]
        }
        # Mask the raster to the bounds of the ground truth
        out_image, out_transform = mask(src, [geo], crop=True)  # Ensure this line is using the mask function correctly
        
        # Update metadata
        out_meta = src.meta.copy()
        out_meta.update({
            "driver": "GTiff",
            "count": out_image.shape[0],  # Set number of bands correctly
            "crs": src.crs,
            "transform": out_transform,
            "height": out_image.shape[1],
            "width": out_image.shape[2]
        })
        
        # Save the clipped image (multi-band image)
        with rasterio.open(output_path, 'w', **out_meta) as dest:
            for i in range(out_image.shape[0]):
                dest.write(out_image[i], i + 1)

# Clip each reprojected Sentinel-2 file
for idx, path in enumerate(sentinel_paths):
    reprojected_path = path.replace(".tif", "_reprojected_to_4326.tif")
    print(f"Processing {reprojected_path}")
    output_path = reprojected_path.replace(".tif", "_clipped.tif")
    clip_to_bounds(reprojected_path, ground_truth_info['bounds'], output_path)
    print(f"Clipped Sentinel-2 file saved as {output_path}")


Processing data/T34SEJ_pansharpened_reprojected_to_4326.tif
Clipped Sentinel-2 file saved as data/T34SEJ_pansharpened_reprojected_to_4326_clipped.tif
Processing data/T34SFJ_pansharpened_reprojected_to_4326.tif
Clipped Sentinel-2 file saved as data/T34SFJ_pansharpened_reprojected_to_4326_clipped.tif
Processing data/T34TEK_pansharpened_reprojected_to_4326.tif
Clipped Sentinel-2 file saved as data/T34TEK_pansharpened_reprojected_to_4326_clipped.tif
Processing data/T34TFK_pansharpened_reprojected_to_4326.tif
Clipped Sentinel-2 file saved as data/T34TFK_pansharpened_reprojected_to_4326_clipped.tif


In [None]:
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.mask import mask
from shapely.geometry import box
import geopandas as gpd

def clip_to_ground_truth(src_path, ground_truth_bounds, output_path):
    with rasterio.open(src_path) as src:
        # Ensure the bounds are slightly adjusted to match the ground truth
        min_x, min_y, max_x, max_y = ground_truth_bounds
        new_bounds = (min_x, min_y, max_x, max_y)

        # Clip the image
        geo = box(*new_bounds)
        out_image, out_transform = mask(src, [geo], crop=True)
        
        # Adjust metadata
        out_meta = src.meta.copy()
        out_meta.update({
            "driver": "GTiff",
            "count": 1,  # Assuming we want a single band output
            "crs": src.crs,
            "transform": out_transform,
            "width": out_image.shape[2],
            "height": out_image.shape[1],
        })

        # Save the clipped image
        with rasterio.open(output_path, 'w', **out_meta) as dest:
            dest.write(out_image)

clipped_tiles_paths = ["data/T34SEJ_pansharpened_reprojected_to_4326_clipped.tif",
"data/T34SFJ_pansharpened_reprojected_to_4326_clipped.tif",
"data/T34TEK_pansharpened_reprojected_to_4326_clipped.tif",
"data/T34TFK_pansharpened_reprojected_to_4326_clipped.tif"]

# Example usage for each tile
for tile_path in clipped_tiles_paths:
    output_path = tile_path.replace(".tif", "_clipped_to_ground_truth.tif")
    clip_to_ground_truth(tile_path, ground_truth_info['bounds'], output_path)
    print(f"Clipped Sentinel-2 file saved as {output_path}")


In [None]:
def check_alignment(ground_truth_bounds, clipped_image_path):
    with rasterio.open(clipped_image_path) as src:
        clipped_bounds = src.bounds  # Get bounds of the clipped Sentinel-2 image
    
    print(f"Ground Truth Bounds: {ground_truth_bounds}")
    print(f"Clipped Image Bounds: {clipped_bounds}")
    
    # Check if the bounding boxes align (allowing a small tolerance)
    tolerance = 0.0001  # Adjust tolerance as needed (e.g., 0.0001 degrees)
    aligned = all(abs(ground_truth_bounds[i] - clipped_bounds[i]) <= tolerance for i in range(4))
    
    if aligned:
        print("Ground truth and clipped image are aligned.")
    else:
        print("Ground truth and clipped image are NOT aligned.")


In [18]:
check_alignment(ground_truth_info['bounds'], "data/T34SEJ_pansharpened_reprojected_to_4326.tif")
check_alignment(ground_truth_info['bounds'], "data/T34SFJ_pansharpened_reprojected_to_4326.tif")
check_alignment(ground_truth_info['bounds'], "data/T34TEK_pansharpened_reprojected_to_4326.tif")
check_alignment(ground_truth_info['bounds'], "data/T34TFK_pansharpened_reprojected_to_4326.tif")


Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075)
Clipped Image Bounds: BoundingBox(left=20.99976654562274, bottom=38.75400437503255, right=22.281313387914107, top=39.75026792656398)
Ground truth and clipped image are NOT aligned.
Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075)
Clipped Image Bounds: BoundingBox(left=22.150872219612868, bottom=38.73595594721768, right=23.44792981139987, top=39.74439989605869)
Ground truth and clipped image are NOT aligned.
Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075)
Clipped Image Bounds: BoundingBox(left=20.99976343610849, bottom=39.65458947552921, right=22.29831573625469, top=40.650856515329274)
Ground truth and clipped image are NOT aligned.
Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.2607

In [17]:
check_alignment(ground_truth_info['bounds'], "data/T34SEJ_pansharpened_reprojected_to_4326_clipped.tif")
check_alignment(ground_truth_info['bounds'], "data/T34SFJ_pansharpened_reprojected_to_4326_clipped.tif")
check_alignment(ground_truth_info['bounds'], "data/T34TEK_pansharpened_reprojected_to_4326_clipped.tif")
check_alignment(ground_truth_info['bounds'], "data/T34TFK_pansharpened_reprojected_to_4326_clipped.tif")


Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075)
Clipped Image Bounds: BoundingBox(left=21.331785311912185, bottom=39.129649826904966, right=22.28131338791411, top=39.75026792656398)
Ground truth and clipped image are NOT aligned.
Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075)
Clipped Image Bounds: BoundingBox(left=22.150872219612868, bottom=39.129688977827016, right=23.117185804832417, top=39.74439989605869)
Ground truth and clipped image are NOT aligned.
Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075)
Clipped Image Bounds: BoundingBox(left=21.331817631906816, bottom=39.6545894755292, right=22.29831573625469, top=40.26083123470299)
Ground truth and clipped image are NOT aligned.
Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26

In [19]:
check_alignment(ground_truth_info['bounds'], "data/T34SEJ_pansharpened_reprojected_to_4326_clipped_clipped_to_ground_truth.tif")
check_alignment(ground_truth_info['bounds'], "data/T34SFJ_pansharpened_reprojected_to_4326_clipped_clipped_to_ground_truth.tif")
check_alignment(ground_truth_info['bounds'], "data/T34TEK_pansharpened_reprojected_to_4326_clipped_clipped_to_ground_truth.tif")
check_alignment(ground_truth_info['bounds'], "data/T34TFK_pansharpened_reprojected_to_4326_clipped_clipped_to_ground_truth.tif")


Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075)
Clipped Image Bounds: BoundingBox(left=21.331785311912185, bottom=39.129649826904966, right=22.28131338791411, top=39.75026792656398)
Ground truth and clipped image are NOT aligned.
Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075)
Clipped Image Bounds: BoundingBox(left=22.150872219612868, bottom=39.129688977827016, right=23.117185804832417, top=39.74439989605869)
Ground truth and clipped image are NOT aligned.
Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26075)
Clipped Image Bounds: BoundingBox(left=21.331817631906816, bottom=39.6545894755292, right=22.29831573625469, top=40.26083123470299)
Ground truth and clipped image are NOT aligned.
Ground Truth Bounds: BoundingBox(left=21.331833333333332, bottom=39.12975, right=23.117166666666666, top=40.26

In [20]:
import rasterio
import numpy as np
import os

# Function to load a TIFF file and convert it to numpy array
def load_tiff_as_array(tiff_path):
    with rasterio.open(tiff_path) as src:
        return src.read()

# Function to generate the dataset
def generate_dataset(tile_paths, ground_truth_path, output_dir):
    # Load the ground truth image
    ground_truth = load_tiff_as_array(ground_truth_path)
    
    # Initialize lists to store pairs of satellite images and ground truth images
    satellite_images = []
    ground_truth_images = []
    
    # Loop through each satellite tile
    for tile_path in tile_paths:
        print(f"Processing {tile_path}...")
        
        # Load the satellite image tile as numpy array
        satellite_image = load_tiff_as_array(tile_path)
        
        # Ensure that the satellite image and ground truth have the same shape (dimension alignment)
        if satellite_image.shape != ground_truth.shape:
            raise ValueError(f"Shape mismatch: Satellite image shape {satellite_image.shape} "
                             f"does not match ground truth shape {ground_truth.shape}")
        
        # Append the images to their respective lists
        satellite_images.append(satellite_image)
        ground_truth_images.append(ground_truth)
    
    # Convert the lists of images to numpy arrays
    satellite_images_np = np.array(satellite_images)
    ground_truth_images_np = np.array(ground_truth_images)
    
    # Save the dataset as numpy arrays
    np.save(os.path.join(output_dir, 'satellite_images.npy'), satellite_images_np)
    np.save(os.path.join(output_dir, 'ground_truth_images.npy'), ground_truth_images_np)
    
    print(f"Dataset saved: {output_dir}/satellite_images.npy and {output_dir}/ground_truth_images.npy")

# Example usage
tile_paths = [
    "data/T34SEJ_pansharpened_reprojected_to_4326_clipped_clipped_to_ground_truth.tif",
    "data/T34SFJ_pansharpened_reprojected_to_4326_clipped_clipped_to_ground_truth.tif",
    "data/T34TEK_pansharpened_reprojected_to_4326_clipped_clipped_to_ground_truth.tif",
    "data/T34TFK_pansharpened_reprojected_to_4326_clipped_clipped_to_ground_truth.tif"]
ground_truth_path = "data/GBDA24_ex2_ref_data.tif"
output_dir = "generated_dataset"

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Generate and save the dataset
generate_dataset(tile_paths, ground_truth_path, output_dir)


Processing data/T34SEJ_pansharpened_reprojected_to_4326_clipped_clipped_to_ground_truth.tif...


ValueError: Shape mismatch: Satellite image shape (13, 5989, 9163) does not match ground truth shape (1, 13572, 21424)

In [None]:
import os
import rasterio
import numpy as np
from rasterio.windows import from_bounds
from glob import glob
import uuid

tile_folder = "data/aligned/"           # Folder with your aligned tiles
reference_path = "data/reference_repoj.tif"         # Ground truth raster
output_dataset = "image_label_dataset.npz"  # Output file
temp_dir = "temp_batches"

os.makedirs(temp_dir, exist_ok=True)

tile_paths = glob(os.path.join(tile_folder, "*.tif"))
sample_idx = 0

batch_paths_X = []
batch_paths_y = []

with rasterio.open(reference_path) as ref_ds:
    ref_data = ref_ds.read(1)
    ref_nodata = ref_ds.nodata
    ref_bounds = ref_ds.bounds

    for tile_path in tile_paths:
        print(tile_path)
        with rasterio.open(tile_path) as tile_ds:
            tile_nodata = tile_ds.nodata

            # Get intersection bounds
            intersection_bounds = (
                max(tile_ds.bounds.left, ref_bounds.left),
                max(tile_ds.bounds.bottom, ref_bounds.bottom),
                min(tile_ds.bounds.right, ref_bounds.right),
                min(tile_ds.bounds.top, ref_bounds.top)
            )

            if intersection_bounds[0] >= intersection_bounds[2] or intersection_bounds[1] >= intersection_bounds[3]:
                continue

            # Read overlapping window
            tile_window = from_bounds(*intersection_bounds, transform=tile_ds.transform)
            ref_window = from_bounds(*intersection_bounds, transform=ref_ds.transform)

            tile_data = tile_ds.read(window=tile_window)
            ref_crop = ref_ds.read(1, window=ref_window)

            # Resize to common shape
            H = min(tile_data.shape[1], ref_crop.shape[0])
            W = min(tile_data.shape[2], ref_crop.shape[1])
            tile_data = tile_data[:, :H, :W]
            ref_crop = ref_crop[:H, :W]

            tile_mask = np.all(tile_data != tile_nodata, axis=0) if tile_nodata is not None else np.ones((H, W), dtype=bool)
            ref_mask = (ref_crop != ref_nodata) if ref_nodata is not None else np.ones((H, W), dtype=bool)
            valid_mask = tile_mask & ref_mask

            if np.count_nonzero(valid_mask) == 0:
                continue

            pixels = tile_data[:, valid_mask].T  # (N, bands)
            labels = ref_crop[valid_mask]        # (N,)

            # Save batch
            batch_id = uuid.uuid4().hex
            X_path = os.path.join(temp_dir, f"X_{batch_id}.npy")
            y_path = os.path.join(temp_dir, f"y_{batch_id}.npy")
            np.save(X_path, pixels)
            np.save(y_path, labels)

            batch_paths_X.append(X_path)
            batch_paths_y.append(y_path)

            sample_idx += len(labels)
            print(f"Processed {tile_path}, saved {len(labels)} samples")

# Combine batches
print("Merging all batches into final dataset...")
X_all = np.concatenate([np.load(p) for p in batch_paths_X], axis=0)
y_all = np.concatenate([np.load(p) for p in batch_paths_y], axis=0)

np.savez_compressed(output_dataset, X=X_all, y=y_all)
print(f"✅ Done. Saved final dataset: {output_dataset}")
print(f"Total samples: {X_all.shape[0]}")

# Clean up temp
import shutil
shutil.rmtree(temp_dir)


data/aligned/T34TEK_pansharpened_aligned.tif
Processed data/aligned/T34TEK_pansharpened_aligned.tif, saved 89328181 samples
data/aligned/T34SFJ_pansharpened_aligned.tif


In [None]:
import os
import rasterio
import numpy as np
from rasterio.windows import from_bounds
from glob import glob
import uuid
import shutil

# CONFIG
tile_folder = "data/aligned/"           # Folder with your aligned tiles
reference_path = "data/reference_repoj.tif"         # Ground truth raster
output_npz = "image_label_dataset.npz"
temp_dir = "temp_patch_batches"
patch_size = 64
stride = 32  # overlapping windows

os.makedirs(temp_dir, exist_ok=True)

tile_paths = glob(os.path.join(tile_folder, "*.tif"))
X_paths, y_paths = [], []

with rasterio.open(reference_path) as ref_ds:
    ref_bounds = ref_ds.bounds
    ref_nodata = ref_ds.nodata
    ref_transform = ref_ds.transform

    for tile_path in tile_paths:
        print(tile_path)
        with rasterio.open(tile_path) as tile_ds:
            tile_nodata = tile_ds.nodata

            # Find overlap region
            intersection_bounds = (
                max(tile_ds.bounds.left, ref_bounds.left),
                max(tile_ds.bounds.bottom, ref_bounds.bottom),
                min(tile_ds.bounds.right, ref_bounds.right),
                min(tile_ds.bounds.top, ref_bounds.top)
            )

            if intersection_bounds[0] >= intersection_bounds[2] or intersection_bounds[1] >= intersection_bounds[3]:
                continue

            # Read overlapping region
            tile_window = from_bounds(*intersection_bounds, transform=tile_ds.transform)
            ref_window = from_bounds(*intersection_bounds, transform=ref_ds.transform)

            tile_data = tile_ds.read(window=tile_window)
            ref_crop = ref_ds.read(1, window=ref_window)

            # Resize to match shape
            H = min(tile_data.shape[1], ref_crop.shape[0])
            W = min(tile_data.shape[2], ref_crop.shape[1])
            tile_data = tile_data[:, :H, :W]
            ref_crop = ref_crop[:H, :W]

            # Slide patch window
            for row in range(0, H - patch_size + 1, stride):
                for col in range(0, W - patch_size + 1, stride):
                    image_patch = tile_data[:, row:row + patch_size, col:col + patch_size]
                    label_patch = ref_crop[row:row + patch_size, col:col + patch_size]

                    if tile_nodata is not None:
                        valid_image = np.all(image_patch != tile_nodata, axis=0)
                    else:
                        valid_image = np.ones((patch_size, patch_size), dtype=bool)

                    if ref_nodata is not None:
                        valid_label = (label_patch != ref_nodata)
                    else:
                        valid_label = np.ones((patch_size, patch_size), dtype=bool)

                    if np.all(valid_image & valid_label):
                        # Save patch
                        patch_id = uuid.uuid4().hex
                        X_path = os.path.join(temp_dir, f"X_{patch_id}.npy")
                        y_path = os.path.join(temp_dir, f"y_{patch_id}.npy")
                        np.save(X_path, image_patch.astype(np.float32))
                        np.save(y_path, label_patch.astype(np.int16))
                        X_paths.append(X_path)
                        y_paths.append(y_path)

                        #print(f"✅ Patch saved: {patch_id}")

# Merge all patches
print("\n🔁 Merging patches into final dataset...")
X_all = np.stack([np.load(p) for p in X_paths])
y_all = np.stack([np.load(p) for p in y_paths])
np.savez_compressed(output_npz, X=X_all, y=y_all)
print(f"✅ Done. Saved {X_all.shape[0]} patches of size {patch_size}x{patch_size}")

# Clean up temp files
shutil.rmtree(temp_dir)


data/aligned/T34TEK_pansharpened_aligned.tif
data/aligned/T34SFJ_pansharpened_aligned.tif
data/aligned/T34SEJ_pansharpened_aligned.tif
data/aligned/T34TFK_pansharpened_aligned.tif

🔁 Merging patches into final dataset...


In [1]:
import os
import rasterio
import numpy as np
from rasterio.windows import from_bounds
from glob import glob
import uuid
import shutil
import random

# CONFIG
tile_folder = "data/aligned/"           # Folder with your aligned tiles
reference_path = "data/reference_repoj.tif"  # Ground truth raster
output_npz = "image_label_dataset.npz"
temp_dir = "temp_patch_batches"
patch_size = 128
stride = 32  # overlapping windows
MAX_PATCHES = 1000  # Cap on number of patches

os.makedirs(temp_dir, exist_ok=True)

tile_paths = glob(os.path.join(tile_folder, "*.tif"))
random.shuffle(tile_paths)  # optional: randomize tile order
X_paths, y_paths = [], []
patch_counter = 0

with rasterio.open(reference_path) as ref_ds:
    ref_bounds = ref_ds.bounds
    ref_nodata = ref_ds.nodata
    ref_transform = ref_ds.transform

    for tile_path in tile_paths:
        print(tile_path)
        with rasterio.open(tile_path) as tile_ds:
            tile_nodata = tile_ds.nodata

            # Find overlap region
            intersection_bounds = (
                max(tile_ds.bounds.left, ref_bounds.left),
                max(tile_ds.bounds.bottom, ref_bounds.bottom),
                min(tile_ds.bounds.right, ref_bounds.right),
                min(tile_ds.bounds.top, ref_bounds.top)
            )

            if intersection_bounds[0] >= intersection_bounds[2] or intersection_bounds[1] >= intersection_bounds[3]:
                continue

            # Read overlapping region
            tile_window = from_bounds(*intersection_bounds, transform=tile_ds.transform)
            ref_window = from_bounds(*intersection_bounds, transform=ref_ds.transform)

            tile_data = tile_ds.read(window=tile_window)
            ref_crop = ref_ds.read(1, window=ref_window)

            # Resize to match shape
            H = min(tile_data.shape[1], ref_crop.shape[0])
            W = min(tile_data.shape[2], ref_crop.shape[1])
            tile_data = tile_data[:, :H, :W]
            ref_crop = ref_crop[:H, :W]

            # Slide patch window
            for row in range(0, H - patch_size + 1, stride):
                for col in range(0, W - patch_size + 1, stride):
                    if patch_counter >= MAX_PATCHES:
                        break

                    image_patch = tile_data[:, row:row + patch_size, col:col + patch_size]
                    label_patch = ref_crop[row:row + patch_size, col:col + patch_size]

                    if tile_nodata is not None:
                        valid_image = np.all(image_patch != tile_nodata, axis=0)
                    else:
                        valid_image = np.ones((patch_size, patch_size), dtype=bool)

                    if ref_nodata is not None:
                        valid_label = (label_patch != ref_nodata)
                    else:
                        valid_label = np.ones((patch_size, patch_size), dtype=bool)

                    if np.all(valid_image & valid_label):
                        # Save patch
                        patch_id = uuid.uuid4().hex
                        X_path = os.path.join(temp_dir, f"X_{patch_id}.npy")
                        y_path = os.path.join(temp_dir, f"y_{patch_id}.npy")
                        np.save(X_path, image_patch.astype(np.float32))
                        np.save(y_path, label_patch.astype(np.int16))
                        X_paths.append(X_path)
                        y_paths.append(y_path)
                        patch_counter += 1

# Merge all patches
print(f"\n🔁 Merging {patch_counter} patches into final dataset...")
X_all = np.stack([np.load(p) for p in X_paths])
y_all = np.stack([np.load(p) for p in y_paths])

np.savez_compressed(output_npz, X=X_all, y=y_all)
print(f"✅ Done. Saved {X_all.shape[0]} patches of size {patch_size}x{patch_size}")

# Clean up temp files
shutil.rmtree(temp_dir)


data/aligned/T34TEK_pansharpened_aligned.tif
data/aligned/T34TFK_pansharpened_aligned.tif
data/aligned/T34SEJ_pansharpened_aligned.tif
data/aligned/T34SFJ_pansharpened_aligned.tif

🔁 Merging 1000 patches into final dataset...
✅ Done. Saved 1000 patches of size 128x128


In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt

In [2]:
class LandUseDataset(Dataset):
    def __init__(self, npz_path, transform=None):
        data = np.load(npz_path)
        self.X = data['X']  # Shape: (N, 13, H, W)
        self.y = data['y']  # Shape: (N, H, W)
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.y[idx]
        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.long)
        return x, y


In [3]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)


class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.down1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(128, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.upconv2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.upconv1 = DoubleConv(128, 64)

        self.final = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(self.pool1(d1))
        b = self.bottleneck(self.pool2(d2))

        u2 = self.up2(b)
        u2 = self.upconv2(torch.cat([u2, d2], dim=1))
        u1 = self.up1(u2)
        u1 = self.upconv1(torch.cat([u1, d1], dim=1))
        return self.final(u1)


In [4]:
# Load dataset
dataset = LandUseDataset("image_label_dataset.npz")
train_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

# Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=13, num_classes=dataset.y.max()+1).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()


In [6]:
# Training loop
for epoch in range(10):  # or more
    model.train()
    total_loss = 0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")


Epoch 1, Loss: 37.7594
Epoch 2, Loss: 32.4856


KeyboardInterrupt: 