# Import libraries

In [None]:
import rasterio
import numpy as np
import matplotlib.pyplot as plt
from rasterio.warp import reproject, Resampling
from matplotlib.colors import ListedColormap, BoundaryNorm
import matplotlib.cm as cm
from sklearn.model_selection import train_test_split

# Define constants

In [None]:
# Path to the Sentinel multi-band file and label file
sentinel_file = "C:/Users/elise/Master/sentinel_reprojected.tif"
label_file = "C:/Users/elise/Master/sr16+background_cut.tif"

scaling_factor = 1/10000

# Define functions

In [None]:
# Function to read Sentinel bands
def read_sentinel_bands(sentinel_file):
    with rasterio.open(sentinel_file) as src:
        bands = src.read()  # Read all bands into a 3D array (bands, height, width)
        bands = bands * scaling_factor  # Scale the bands for reflectance values
        transform = src.transform  # Get the transform of the Sentinel image
    return bands, transform

# Function to reproject and resample labels to match Sentinel dimensions
def align_labels_to_sentinel(label_file, sentinel_transform, sentinel_crs, sentinel_shape):
    with rasterio.open(label_file) as src_labels:
        label_data = np.empty(sentinel_shape, dtype=src_labels.meta['dtype'])

        reproject(
            source=rasterio.band(src_labels, 1),
            destination=label_data,
            src_transform=src_labels.transform,
            src_crs=src_labels.crs,
            dst_transform=sentinel_transform,
            dst_crs=sentinel_crs,
            dst_width=sentinel_shape[1],
            dst_height=sentinel_shape[0],
            resampling=Resampling.nearest # Resolution: Sentinel 10 m, SR16 16 m --> upsampling
        )
    return label_data

# Read bands

In [None]:
# Read the Sentinel bands and their transform
sentinel_bands, sentinel_transform = read_sentinel_bands(sentinel_file)
print(f"Sentinel bands shape: {sentinel_bands.shape}")

In [None]:
print(f"Min = {sentinel_bands.min()}, Max = {sentinel_bands.max()}")
# Normalize the bands to [0, 1] range for reflectance
sentinel_bands = (sentinel_bands - sentinel_bands.min()) / (sentinel_bands.max() - sentinel_bands.min())
print(f"Min = {sentinel_bands.min()}, Max = {sentinel_bands.max()}")

# Plot bands

In [None]:
import matplotlib.pyplot as plt
import numpy as np

features = [
    "B02_winter_min_nomask", "B02_spring_min_nomask", "B02_autumn_min_nomask",
    "B03_winter_min_nomask", "B04_winter_min_nomask", "B04_spring_min_nomask",
    "B05_winter_min_nomask", "B06_winter_min_nomask", "B06_summer_median_mask",
    "B07_summer_median_nomask", "B07_summer_median_mask", "B08_summer_median_nomask",
    "B08_summer_median_mask", "B11_winter_min_nomask", "B12_winter_min_nomask",
    "B8A_summer_median_nomask", "B8A_summer_median_mask", "evi_winter_max_nomask",
    "ndvi_winter_max_nomask", "savi_winter_max_nomask"
]

# Function to generate a more descriptive name
def generate_descriptive_name(feature):
    parts = feature.split("_")
    band = parts[0]  # e.g., "B02"
    season = parts[1]  # e.g., "winter"
    stat = parts[2]  # e.g., "min"
    mask = parts[3] if len(parts) > 3 else "nomask"  # e.g., "nomask"
    return f"{band} ({season.capitalize()} {stat.capitalize()} {'Masked' if mask == 'mask' else 'Unmasked'})"

# Generate descriptive names
descriptive_names = [generate_descriptive_name(f) for f in features]

# Define the number of bands
num_bands = sentinel_bands.shape[0]

# Create a figure with subplots (adjust rows & cols based on num_bands)
rows = int(np.ceil(num_bands / 4))  # Adjust for 4 columns per row
cols = min(num_bands, 4)

fig, axes = plt.subplots(rows, cols, figsize=(15, 3 * rows))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i in range(num_bands):
    ax = axes[i]
    band_data = sentinel_bands[i, :, :]

    # Mask zero values (e.g. no-data areas)
    masked_data = band_data[band_data > 0]

    if masked_data.size > 0:
        # Use percentile stretching only on valid data
        vmin, vmax = np.percentile(masked_data, [2, 98])
    else:
        vmin, vmax = 0, 1  # fallback if all values are zero

    # Plot original data with adjusted limits
    ax.imshow(band_data, cmap='gray', vmin=vmin, vmax=vmax, aspect='auto')
    ax.set_title(f"Band {i+1}: {descriptive_names[i]}")
    ax.axis("off")


# Hide unused subplots if bands < total subplots
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

red_band = sentinel_bands[5, :, :]   # B4
green_band = sentinel_bands[3, :, :] # B3
blue_band = sentinel_bands[1, :, :]  # B2

# Stack bands to form an RGB image
rgb_image = np.stack([red_band, green_band, blue_band], axis=-1)

# Normalize to 0-1 range for display
rgb_image = (rgb_image - np.min(rgb_image)) / (np.max(rgb_image) - np.min(rgb_image))

# Plot RGB image
plt.figure(figsize=(10, 10))
plt.imshow(rgb_image)
plt.title("Sentinel-2 RGB Composite (Bands 4-3-2)")
plt.axis("off")
plt.show()


# Read labels

In [None]:
# Read and align the labels to the Sentinel image
sentinel_crs = rasterio.open(sentinel_file).crs
aligned_labels = align_labels_to_sentinel(
    label_file, sentinel_transform, sentinel_crs, sentinel_bands.shape[1:]
)
print(f"Aligned labels shape: {aligned_labels.shape}")

In [None]:
with rasterio.open(label_file) as src_labels:
    print(src_labels.nodata)  # This will show if -9999 is set as NoData

# Compare aligned labels and original labels

In [None]:
# Visualize the aligned labels and the original label image with unique value mapping
with rasterio.open(label_file) as src_labels:
    original_labels = src_labels.read(1)

# Get unique labels
unique_labels = np.unique(aligned_labels)
print(f"Unique labels: {unique_labels}")

# Create a custom colormap
# Use gray for -9999 and viridis for 0, 1, 2, 3
viridis_colors = cm.viridis([0.2, 0.4, 0.6, 0.8])  # Choose evenly spaced colors from viridis
colors = ["gray"] + list(viridis_colors)
cmap = ListedColormap(colors)

# Create normalization to map unique labels to colors
norm = BoundaryNorm(boundaries=[-9999.5, -0.5, 0.5, 1.5, 2.5, 3.5], ncolors=5)


# Plot and save the original labels
plt.figure(figsize=(8, 6))
plt.title("Original Labels")
im1 = plt.imshow(original_labels, cmap=cmap, norm=norm, interpolation="nearest")
cbar1 = plt.colorbar(im1, ticks=unique_labels)
cbar1.set_label("Label Value")
plt.savefig("original_labels.png")
plt.show()
plt.close()

# Plot and save the aligned labels
plt.figure(figsize=(8, 6))
plt.title("Aligned Labels")
im2 = plt.imshow(aligned_labels, cmap=cmap, norm=norm, interpolation="nearest")
cbar2 = plt.colorbar(im2, ticks=unique_labels)
cbar2.set_label("Label Value")
plt.savefig("aligned_labels.png")
plt.show()
plt.close()

print("Images saved as 'original_labels.png' and 'aligned_labels.png' in the current working folder.")

# Function to print label information and count of each label
def print_label_info(label_array, label_name):
    print(f"--- {label_name} ---")
    print(f"Shape: {label_array.shape}")
    print(f"Data Type: {label_array.dtype}")
    print(f"Min Value: {np.min(label_array)}")
    print(f"Max Value: {np.max(label_array)}")
    unique, counts = np.unique(label_array, return_counts=True)
    print(f"Unique Values: {unique}")
    print(f"Counts: {dict(zip(unique, counts))}")
    print(f"Percentage of each label: {dict(zip(unique, counts/np.sum(counts)))}")
    print()

# Print information for original labels
print_label_info(original_labels, "Original Labels")

# Print information for aligned labels
print_label_info(aligned_labels, "Aligned Labels")




# Save data

In [None]:
# Flatten and prepare for machine learning
input = sentinel_bands.reshape(sentinel_bands.shape[0], -1).T  # Shape: (n_samples, n_features)
labels = aligned_labels.flatten()  # Shape: (n_samples,)

print(f"Prepared input shape: {input.shape}, labels shape: {labels.shape}")

In [None]:
# Step 1: Remove -9999 labels and corresponding inputs
valid_mask = labels != -9999  # Mask for valid pixels
valid_mask = valid_mask & (labels != 0)  # Mask for valid labels (0 is ignored)
X_valid = input[valid_mask]
y_valid = labels[valid_mask]

# Step 2: Stratified split into training, validation, and test sets
X_train, X_test, y_train, y_test, train_indices, test_indices = train_test_split(
    X_valid, y_valid, np.arange(len(y_valid)), test_size=0.2, random_state=42, stratify=y_valid
)

# Step 3: Create masks for training, validation, and test sets
train_mask = np.zeros_like(labels, dtype=bool)
test_mask = np.zeros_like(labels, dtype=bool)
ignored_mask = labels == -9999  # Mask for ignored pixels

# Map valid indices back to the original label space
train_mask[valid_mask] = np.isin(np.arange(len(X_valid)), train_indices)
test_mask[valid_mask] = np.isin(np.arange(len(X_valid)), test_indices)

# Step 4: Create a single mask image to visualize all categories
# Assign integer values for each category
mask_image = np.zeros_like(labels, dtype=np.int32)
mask_image[ignored_mask] = 0  # -9999
mask_image[train_mask] = 1    # Training 
mask_image[test_mask] = 2     # Testing

# Reshape the mask image to the original image dimensions
mask_image = mask_image.reshape(2221, 2260)

# Save training, validation, and test datasets
np.save('train_sentinel_input.npy', X_train)
np.save('train_sentinel_labels.npy', y_train)
np.save('test_sentinel_input.npy', X_test)
np.save('test_sentinel_labels.npy', y_test)
np.save('mask_image.npy', mask_image)


In [None]:
# Check if stratification is correct
unique_labels = np.unique(y_valid, return_counts=False)
train_counts = np.unique(y_train, return_counts=True)[1]
test_counts = np.unique(y_test, return_counts=True)[1]

print("Training set label distribution:", dict(zip(unique_labels, train_counts)))
print("Testing set label distribution:", dict(zip(unique_labels, test_counts)))

# Verify proportions
print("\nProportion in training set:")
print(train_counts / sum(train_counts))
print("\nProportion in testing set:")
print(test_counts / sum(test_counts))
print()

# Overview of masks
print(f"Total samples: {train_mask.sum()}, proportion of total: {train_mask.sum() / len(y_valid)}")
print(f"Total samples: {test_mask.sum()}, proportion of total: {test_mask.sum() / len(y_valid)}")

In [None]:
# Create a custom colormap
# Use gray for -9999 and viridis for 0, 1, 2, 3
viridis_colors = cm.viridis([0.3, 0.6])  # Choose evenly spaced colors from viridis
colors = ["gray"] + list(viridis_colors)
cmap = ListedColormap(colors)

# Step 5: Plot the mask image
plt.figure(figsize=(8, 8))
plt.title("Dataset Splits")
plt.imshow(mask_image, cmap=cmap, interpolation="nearest")
cbar = plt.colorbar()
cbar.set_ticks([0, 1, 2])
cbar.set_ticklabels(["Ignored (-9999)", "Training", "Testing"])
plt.axis("off")
plt.tight_layout()
plt.savefig("dataset_splits.png")
plt.show()
