# Import libraries

In [None]:
import rasterio
import numpy as np
import matplotlib.pyplot as plt
from rasterio.warp import reproject, Resampling, transform_bounds
from rasterio.mask import mask
from matplotlib.colors import ListedColormap, BoundaryNorm
import matplotlib.cm as cm
from sklearn.model_selection import train_test_split
import os

# Define constants

In [None]:
# Path to the files
landsat_dir = "C:/Users/elise/Master/Landsat/Bands/"
label_file = "C:/Users/elise/Master/sr16+background_cut.tif"
qa_band_file = "C:/Users/elise/Master/Landsat/LC08_L2SP_199016_20210810_20210819_02_T1_QA_PIXEL.tif"

scale_factor = 	0.0000275
additive_offset = -0.2

# Read multispectral data and labels

In [None]:
import os
import numpy as np
import rasterio
from rasterio.mask import mask
from rasterio.warp import reproject, Resampling, transform_bounds

def extract_landsat_and_labels(landsat_dir, label_file):
    # Step 1: Read label file (to define AOI)
    with rasterio.open(label_file) as src_labels:
        label_bounds = src_labels.bounds  # (minX, minY, maxX, maxY)
        label_crs = src_labels.crs  # CRS of the label file
        label_nodata = src_labels.nodata if src_labels.nodata is not None else -9999

        print(f"\n📌 Label File: {label_file}")
        print(f"Label CRS: {label_crs}")
        print(f"Label Bounds: {label_bounds}")
        print(f"Label NoData: {label_nodata}")

    # Step 2: Find all Landsat band files
    band_files = sorted([os.path.join(landsat_dir, f) for f in os.listdir(landsat_dir) if f.endswith('.TIF')])
    if not band_files:
        raise ValueError("No Landsat .TIF files found in the directory.")

    cropped_bands = []
    landsat_transform = None
    landsat_crs = None
    landsat_shape = None

    # Step 3: Process each Landsat band
    for band_path in band_files:
        with rasterio.open(band_path) as src:
            if landsat_crs is None:
                landsat_crs = src.crs
                print(f"\n📌 Landsat CRS: {landsat_crs}")

            # Reproject AOI bounds to Landsat CRS if needed
            if label_crs != landsat_crs:
                transformed_bounds = transform_bounds(label_crs, landsat_crs, *label_bounds)
                print(f"Transformed AOI Bounds in Landsat CRS: {transformed_bounds}")
            else:
                transformed_bounds = label_bounds

            # Check if AOI overlaps Landsat bounds
            landsat_bounds = src.bounds
            if (transformed_bounds[2] < landsat_bounds[0] or  # AOI maxX < Landsat minX
                transformed_bounds[0] > landsat_bounds[2] or  # AOI minX > Landsat maxX
                transformed_bounds[3] < landsat_bounds[1] or  # AOI maxY < Landsat minY
                transformed_bounds[1] > landsat_bounds[3]):   # AOI minY > Landsat maxY
                print("🚨 AOI does not intersect with this Landsat image. Skipping...")
                continue

            # Crop Landsat band to AOI
            cropped_band, cropped_transform = mask(src, [{
                "type": "Polygon",
                "coordinates": [[
                    (transformed_bounds[0], transformed_bounds[1]),  # Lower-left
                    (transformed_bounds[0], transformed_bounds[3]),  # Upper-left
                    (transformed_bounds[2], transformed_bounds[3]),  # Upper-right
                    (transformed_bounds[2], transformed_bounds[1]),  # Lower-right
                    (transformed_bounds[0], transformed_bounds[1])   # Close polygon
                ]]
            }], crop=True)

            cropped_bands.append(cropped_band[0])  # Remove extra dimension

            # Save Landsat metadata (only once)
            if landsat_transform is None:
                landsat_transform = cropped_transform
                landsat_shape = cropped_band.shape  # (height, width)

    # Ensure at least one valid band was found
    if not cropped_bands:
        raise ValueError("No Landsat bands were successfully cropped. Check AOI and CRS.")

    # Stack bands into 3D NumPy array (bands, height, width)
    landsat_array = np.stack(cropped_bands, axis=0)
    print(f"\n✅ Landsat Data Shape: {landsat_array.shape} (bands, height, width)")

    # Multiply all bands with scale factor and add offset
    print(f"Before scaling: Min = {landsat_array.min()}, Max = {landsat_array.max()}")
    # landsat_array = landsat_array * scale_factor + additive_offset
    landsat_array = (landsat_array - landsat_array.min()) / (landsat_array.max() - landsat_array.min())  # Normalize to [0, 1]
    print(f"After scaling: Min = {landsat_array.min()}, Max = {landsat_array.max()}")

    # Step 4: Reproject and resample labels to match Landsat dimensions
    with rasterio.open(label_file) as src_labels:
        label_data = np.full(landsat_shape, label_nodata, dtype=np.int32)

        reproject(
            source=rasterio.band(src_labels, 1),
            destination=label_data,
            src_transform=src_labels.transform,
            src_crs=src_labels.crs,
            dst_transform=landsat_transform,
            dst_crs=landsat_crs,
            dst_width=landsat_shape[1],
            dst_height=landsat_shape[0],
            resampling=Resampling.nearest,  # Preserve categorical values
            src_nodata=label_nodata,
            dst_nodata=label_nodata
        )

    print(f"✅ Label Data Shape: {label_data.shape} (height, width)")
    print(f"✅ Label Unique Values: {np.unique(label_data)}")
    print(f"✅ Label Value Counts: {np.unique(label_data, return_counts=True)}")

    return landsat_array, label_data, transformed_bounds


In [None]:
landsat_array, label_data, transformed_bounds = extract_landsat_and_labels(landsat_dir, label_file)

# Plot bands

In [None]:
features = [
    "B1 Costal/Aerosol", "B2 Blue", "B3 Green", "B4 Red",
    "B5 NIR", "B6 SWIR-1", "B7 SWIR-2"
]

# Define the number of bands
num_bands = landsat_array.shape[0]

# Create a figure with subplots
rows = int(np.ceil(num_bands / 4))
cols = min(num_bands, 4)
fig, axes = plt.subplots(rows, cols, figsize=(15, 3 * rows))
axes = axes.flatten()

# Plot each band with contrast stretching
for i in range(num_bands):
    ax = axes[i]
    band_data = landsat_array[i, :, :]

    # Contrast stretching using 2nd and 98th percentiles
    vmin, vmax = np.percentile(band_data, [2, 98])
    
    ax.imshow(band_data, cmap='gray', vmin=vmin, vmax=vmax, aspect='auto')
    ax.set_title(f"{features[i]}")
    ax.axis("off")

# Hide unused axes
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

red_band = landsat_array[3, :, :]   # B4
green_band = landsat_array[2, :, :] # B3
blue_band = landsat_array[1, :, :]  # B2

# Stack bands to form an RGB image
rgb_image = np.stack([red_band, green_band, blue_band], axis=-1)

# Normalize to 0-1 range for display
rgb_image = (rgb_image - np.min(rgb_image)) / (np.max(rgb_image) - np.min(rgb_image))

# Plot RGB image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image)
plt.title("Landsat RGB Composite")
plt.axis("off")
plt.show()


# Compare aligned labels and original labels

In [None]:
# Visualize the aligned labels and the original label image with unique value mapping
with rasterio.open(label_file) as src_labels:
    original_labels = src_labels.read(1)

# Get unique labels
unique_labels = np.unique(label_data)
print(f"Unique labels: {unique_labels}")

print(original_labels.shape)
print(label_data.shape)

#remove 1st dimension of label_data
aligned_labels = label_data[0]

print(original_labels.shape)
print(aligned_labels.shape)

# Create a custom colormap
# Use gray for -9999 and viridis for 0, 1, 2, 3
viridis_colors = cm.viridis([0.2, 0.4, 0.6, 0.8])  # Choose evenly spaced colors from viridis
colors = ["gray"] + list(viridis_colors)
cmap = ListedColormap(colors)

# Create normalization to map unique labels to colors
norm = BoundaryNorm(boundaries=[-9999.5, -0.5, 0.5, 1.5, 2.5, 3.5], ncolors=5)


# Plot and save the original labels
plt.figure(figsize=(8, 6))
plt.title("Original Labels")
im1 = plt.imshow(original_labels, cmap=cmap, norm=norm, interpolation="nearest")
cbar1 = plt.colorbar(im1, ticks=unique_labels)
cbar1.set_label("Label Value")
plt.savefig("original_labels.png")
plt.show()
plt.close()

# Plot and save the aligned labels
plt.figure(figsize=(8, 6))
plt.title("Aligned Labels")
im2 = plt.imshow(aligned_labels, cmap=cmap, norm=norm, interpolation="nearest")
cbar2 = plt.colorbar(im2, ticks=unique_labels)
cbar2.set_label("Label Value")
plt.savefig("aligned_labels_landsat.png")
plt.show()
plt.close()

print("Images saved as 'original_labels.png' and 'aligned_labels_landsat.png' in the current working folder.")

# Function to print label information and count of each label
def print_label_info(label_array, label_name, nodata_value=-9999):
    print(f"--- {label_name} ---")
    print(f"Shape: {label_array.shape}")
    print(f"Data Type: {label_array.dtype}")
    print(f"Min Value: {np.min(label_array)}")
    print(f"Max Value: {np.max(label_array)}")

    unique, counts = np.unique(label_array, return_counts=True)
    label_counts = dict(zip(unique, counts))

    # Filter out the NoData value for percentage calculations
    valid_mask = unique != nodata_value
    valid_unique = unique[valid_mask]
    valid_counts = counts[valid_mask]
    
    total_valid_pixels = np.sum(valid_counts)
    label_percentages = dict(zip(valid_unique, valid_counts / total_valid_pixels)) if total_valid_pixels > 0 else {}

    print(f"Unique Values: {unique}")
    print(f"Counts: {label_counts}")
    print(f"Percentage of each label (excluding {nodata_value}): {label_percentages}")
    print()


# Print information for original labels
print_label_info(original_labels, "Original Labels")

# Print information for aligned labels
print_label_info(aligned_labels, "Aligned Labels")




# Apply QA-mask

In [None]:
def apply_qa_mask(qa_band_file, landsat_array, label_data, transformed_bounds, quality_threshold=21850):
    """
    Loads the Landsat QA band, extracts unique values, creates a mask based on a quality threshold,
    and filters out low-quality pixels from the Landsat array and label data.

    Parameters:
        landsat_dir (str): Directory containing Landsat bands.
        landsat_array (np.array): 3D NumPy array (bands, height, width) of cropped Landsat data.
        label_data (np.array): 2D NumPy array (height, width) of labels.
        landsat_transform (Affine): Transform of the Landsat data.
        landsat_crs (CRS): CRS of the Landsat data.
        label_crs (CRS): CRS of the label data.
        label_bounds (tuple): Bounds of the label data in (minX, minY, maxX, maxY).
        quality_threshold (int): Maximum acceptable QA value (higher values indicate lower quality).

    Returns:
        np.array, np.array: Masked Landsat array and masked label data.
    """

    unique_labels, counts = np.unique(label_data, return_counts=True)
    print(f"Unique Labels before QA mask: {unique_labels}")
    print(f"Label Counts before QA mask: {dict(zip(unique_labels, counts))}")
    print(f"Nans in Landsat array: {np.sum(np.isnan(landsat_array))}")
  
    print(f"\n📌 Loading Landsat QA Band: {qa_band_file}")

    qa_landsat_array = landsat_array.copy()  # Create a copy to apply the QA mask
    qa_label_data = label_data.copy()  # Create a copy to apply the QA mask

    with rasterio.open(qa_band_file) as src_qa:
        cropped_qa, _ = mask(src_qa, [{
            "type": "Polygon",
            "coordinates": [[
                (transformed_bounds[0], transformed_bounds[1]),  # Lower-left
                (transformed_bounds[0], transformed_bounds[3]),  # Upper-left
                (transformed_bounds[2], transformed_bounds[3]),  # Upper-right
                (transformed_bounds[2], transformed_bounds[1]),  # Lower-right
                (transformed_bounds[0], transformed_bounds[1])   # Close polygon
            ]]
        }], crop=True)

        qa_band = cropped_qa[0]  # Remove extra dimension

    print(f"✅ QA Band Shape: {qa_band.shape} (height, width)")
    unique_labels, counts = np.unique(qa_band, return_counts=True)
    print(f"✅ Unique QA Values: {dict(zip(unique_labels, counts))}")

    # Step 3: Create a mask where QA values are below the quality threshold
    quality_mask = qa_band <= quality_threshold  # Keep only high-quality pixels
    print(f"✅ Masked {np.sum(~quality_mask)} low-quality pixels.")

    # Step 4: Apply the mask to Landsat and labels
    qa_landsat_array[:, ~quality_mask] = np.nan  # Set low-quality pixels to NaN
    qa_label_data[:, ~quality_mask] = -9999  # Set low-quality pixels in label data to NoData

    print(f"✅ Applied QA mask: Landsat shape {qa_landsat_array.shape}, Labels shape {qa_label_data.shape}")

    unique_labels, counts = np.unique(qa_label_data, return_counts=True)
    print(f"Unique Labels after QA mask: {unique_labels}")
    print(f"Label Counts after QA mask: {dict(zip(unique_labels, counts))}")
    print(f"Nans in Landsat array: {np.sum(np.isnan(qa_landsat_array))}")

    return qa_landsat_array, qa_label_data, quality_mask


In [None]:
qa_landsat_array, qa_label_data, quality_mask = apply_qa_mask(qa_band_file, landsat_array, label_data, transformed_bounds)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Extract RGB bands
red_band = qa_landsat_array[3, :, :]
green_band = qa_landsat_array[2, :, :]
blue_band = qa_landsat_array[1, :, :]

# Stack bands to form an RGB image
rgb_image = np.stack([red_band, green_band, blue_band], axis=-1)

# Create a valid mask (True for non-NaN pixels)
valid_mask = ~np.isnan(rgb_image)

# Set NaN values to 0 temporarily to avoid issues in normalization
rgb_image[np.isnan(rgb_image)] = 0

# Normalize using min/max of valid pixels
min_val = np.min(rgb_image[valid_mask])
max_val = np.max(rgb_image[valid_mask])
if max_val > min_val:  # Avoid divide-by-zero errors
    rgb_image = (rgb_image - min_val) / (max_val - min_val)

# Create a grey RGB image (same shape as rgb_image)
grey_mask = np.full_like(rgb_image, 0.5)  # Grey = [0.5, 0.5, 0.5]

# Apply grey where NaN values were present
rgb_image[~valid_mask] = grey_mask[~valid_mask]

# Plot the image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image)
plt.title("Landsat RGB Composite with Grey NaN Mask")
plt.axis("off")
plt.show()


In [None]:
# Flatten the Landsat array and labels
flattened_landsat = qa_landsat_array.reshape(qa_landsat_array.shape[0], -1).T  # Shape: (num_pixels, num_bands)
flattened_labels = qa_label_data.flatten()  # Shape: (num_pixels,)

# Filter out samples with the -9999 label
valid_indices = flattened_labels > 0
filtered_landsat = flattened_landsat[valid_indices]
filtered_labels = flattened_labels[valid_indices]

print(f"Filtered Landsat shape: {filtered_landsat.shape}")
print(f"Filtered Labels shape: {filtered_labels.shape}")
print(f"Number of NaNs in filtered Landsat data: {np.isnan(filtered_landsat).sum()}")

# Save data

In [None]:
# Step 2: Stratified split into training, validation, and test sets
X_train, X_test, y_train, y_test, train_indices, test_indices = train_test_split(
    filtered_landsat, filtered_labels, np.arange(len(filtered_labels)), test_size=0.2, random_state=42, stratify=filtered_labels
)

valid_mask = qa_label_data > 0
print(F"Train_indices: {train_indices}")
print(F"Test_indices: {test_indices}")

# Step 3: Create masks for training, validation, and test sets
train_mask = np.zeros_like(label_data, dtype=bool)
test_mask = np.zeros_like(label_data, dtype=bool)
ignored_mask = qa_label_data <= 0 # Mask for ignored pixels

# Map valid indices back to the original label space
train_mask[valid_mask] = np.isin(np.arange(len(filtered_landsat)), train_indices)
test_mask[valid_mask] = np.isin(np.arange(len(filtered_landsat)), test_indices)

# Step 4: Create a single mask image to visualize all categories
# Assign integer values for each category
mask_image = np.zeros_like(label_data, dtype=np.int32)
mask_image[ignored_mask] = 0  # -9999
mask_image[train_mask] = 1    # Training 
mask_image[test_mask] = 2     # Testing

print(f"Mask Image Shape: {mask_image.shape}")
print(f"Unique Values in Mask Image: {np.unique(mask_image)}")

# Save training, validation, and test datasets
np.save('train_landsat_input.npy', X_train)
np.save('train_landsat_labels.npy', y_train)
np.save('test_landsat_input.npy', X_test)
np.save('test_landsat_labels.npy', y_test)
np.save('mask_image_landsat.npy', mask_image)

# Check if stratification is correct
train_labels, train_counts = np.unique(y_train, return_counts=True)
test_labels, test_counts = np.unique(y_test, return_counts=True)

print("Training set label distribution:", dict(zip(train_labels, train_counts)))
print("Testing set label distribution:", dict(zip(test_labels, test_counts)))

# Verify proportions
print("\nProportion in training set:")
print(train_counts / sum(train_counts))
print("\nProportion in testing set:")
print(test_counts / sum(test_counts))
print()


In [None]:
# Create a custom colormap
# Use gray for -9999 and viridis for 1, 2, 3
viridis_colors = cm.viridis([0.3, 0.6])  # Choose evenly spaced colors from viridis
colors = ["gray"] + list(viridis_colors)
cmap = ListedColormap(colors)

# Step 5: Plot the mask image
plt.figure(figsize=(8, 8))
plt.title("Dataset Splits")
plt.imshow(mask_image[0], cmap=cmap, interpolation="nearest")  # Remove the extra dimension
cbar = plt.colorbar()
cbar.set_ticks([0, 1, 2])
cbar.set_ticklabels(["Ignored", "Training", "Testing"])
plt.axis("off")
plt.tight_layout()
plt.savefig("dataset_splits_landsat.png")
plt.show()
