In [1]:
# helper functions

def open_image(image_path,ncdf_layer='fsc'):
    """Opens an image and reads its metadata.
    
    Parameters
    ----------
    image_path : str
        path to an image
    ncdf_layer: optional , string of the name of wich layer of ncdf to open      
    Returns
    -------
    image : osgeo.gdal.Dataset
        the opened image
    information : dict
        dictionary containing image metadata    
    """
    
    ext = os.path.basename(image_path).split('.')[-1]
    
    if ext == 'nc':
        nc_data = netCDF4.Dataset(image_path,'r')
        vars_nc = list(nc_data.variables)
       # ncdf_layer="fsc_unc"
        scf_name = list(filter(lambda x: x.startswith(ncdf_layer), vars_nc))[0]        
        dataset = gdal.Open("NETCDF:{0}:{1}".format(image_path, scf_name))
        proj = dataset.GetProjection()        
        geotransform = dataset.GetGeoTransform()
        cols = dataset.RasterXSize
        rows = dataset.RasterYSize
        minx = geotransform[0]
        maxy = geotransform[3]
        maxx = minx + geotransform[1] * cols
        miny = maxy + geotransform[5] * rows        
        extent = [minx, miny, maxx, maxy]        
        X_Y_raster_size = [cols, rows]
        information = {}
        information['geotransform'] = geotransform
        information['extent'] = extent
        information['geotransform'] = tuple(map(lambda x: round(x, 4) or x, information['geotransform']))
        information['extent'] = tuple(map(lambda x: round(x, 4) or x, information['extent'])) 
        information['X_Y_raster_size'] = X_Y_raster_size
        information['projection'] = proj
        
        image_output = np.array(dataset.ReadAsArray(0, 0,cols, rows))            

    else:
        image = gdal.Open(image_path)
        cols = image.RasterXSize
        rows = image.RasterYSize
        geotransform = image.GetGeoTransform()
        proj = image.GetProjection()
        minx = geotransform[0]
        maxy = geotransform[3]
        maxx = minx + geotransform[1] * cols
        miny = maxy + geotransform[5] * rows
        X_Y_raster_size = [cols, rows]
        extent = [minx, miny, maxx, maxy]
        information = {}
        information['geotransform'] = geotransform
        information['extent'] = extent
        information['X_Y_raster_size'] = X_Y_raster_size
        information['projection'] = proj
        projection= osr.SpatialReference(wkt=image.GetProjection())
        with rasterio.open(image_path, 'r+') as rds:
            epsg_code = str(rds.crs).split(':')[1]
        information['EPSG'] = epsg_code
        #print(cols,rows )
        image_output = np.array(image.ReadAsArray(0, 0,cols, rows))
        
    if image is None:
        print('could not open ' + image_path)
        return
        
    return image_output, information



def get_sensor(acquisition_name):
    """Determines the satellite mission based on the acquisition name."""
    acquisition_name = os.path.basename(acquisition_name)
    if 'LT04' in acquisition_name:
        return 'L4'
    elif 'LT05' in acquisition_name or acquisition_name[:3] == 'LT5':
        return 'L5'
    elif 'LE07' in acquisition_name or acquisition_name[:3] == 'LE7':
        return 'L7'
    elif 'LC08' in acquisition_name or acquisition_name[:3] == 'LC8':
        return 'L8'
    elif 'LC09' in acquisition_name:
        return 'L8'
    elif 'S2' in acquisition_name:
        return 'S2'
    elif 'PRS' in acquisition_name:
        return 'PRISMA'
    else:
        raise ValueError(f"Invalid acquisition name: {acquisition_name}")


def plot_valid_pixels_percentage(ranges, percentage_per_angles_list, svm_folder_path):
    """
    Plots the percentage of valid pixels per angle range and saves the plot as a PNG file.

    Parameters:
    - ranges (tuple of tuples): Angle ranges for the x-axis.
    - percentage_per_angles_list (list): Percentage values corresponding to the ranges.
    - svm_folder_path (str): Directory to save the plot.
    """
    # Ensure ranges and percentage lists match
    if len(ranges) != len(percentage_per_angles_list):
        raise ValueError("Length of ranges and percentage_per_angles_list must match.")
    
    # Create the bar plot
    x_labels = [f"{r[0]}-{r[1]}" for r in ranges]
    plt.figure(figsize=(10, 6))
    plt.bar(x_labels, percentage_per_angles_list, color='skyblue')
    
    # Add title and labels
    plt.title("Percentage of Valid Pixels per Solar Incidence Angle Range", fontsize=14)
    plt.xlabel("Angle Ranges (degrees)", fontsize=12)
    plt.ylabel("Percentage (%)", fontsize=12)
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Save the plot
    output_path = os.path.join(svm_folder_path, 'valid_pixels_per_angle.png')
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.close()  # Close the plot to avoid display issues in non-interactive environments
    print(f"Plot saved to: {output_path}")




def get_representative_pixels(bands_data, valid_mask, sample_count = 50, k='auto', n_closest='auto'):
    """
    Selects representative "no snow" pixels by clustering and distance to cluster centroids.
    Saves the output as a raster.

    Parameters
    ----------
    bands_data : numpy.ndarray
        3D array (bands, height, width) containing spectral data for each band.
    valid_mask : numpy.ndarray
        2D mask of valid pixels for selection.
    k : int, optional
        Number of clusters for K-means, by default 5.
    n_closest : int, optional
        Number of closest pixels to each centroid to select, by default 5.

    Returns
    -------
    representative_pixels_mask : numpy.ndarray
        2D mask with representative pixels marked as 1.
    """
    # Extract "valid" pixels for clustering
    valid_pixels = bands_data[valid_mask, :]  # Shape (pixels, bands)

    # Normalize the valid pixels
    scaler = StandardScaler()
    normalized_pixels = scaler.fit_transform(valid_pixels)
    
    # find optimal K
    if k == 'auto':
        k = find_optimal_k(normalized_pixels, max_k=10, method="elbow")
    if n_closest == 'auto':
        n_closest = int(sample_count / k)

    # Perform K-means clustering on "no snow" pixels
    kmeans = KMeans(n_clusters=k, random_state=0)
    kmeans.fit(normalized_pixels)

    # Get cluster centroids and labels
    labels = kmeans.labels_
    centroids = kmeans.cluster_centers_

    # Initialize an empty mask for representative pixels
    representative_pixels_mask = np.zeros(valid_mask.shape, dtype='uint8')

    # Find the n_closest pixels to each centroid
    for cluster_idx in range(k):
        # Select pixels in the current cluster
        cluster_indices = np.where(labels == cluster_idx)[0]
        cluster_pixels = normalized_pixels[cluster_indices]

        # Compute distances to the centroid for these pixels
        distances = distance.cdist(cluster_pixels, [centroids[cluster_idx]], 'euclidean').flatten()

        # Get the indices of the n_closest pixels in the cluster
        closest_indices = np.argsort(distances)[:n_closest]

        # Map the closest indices back to the original image coordinates
        original_indices = np.argwhere(valid_mask)[cluster_indices]
        selected_pixels = original_indices[closest_indices]

        # Set these pixels in the representative mask
        representative_pixels_mask[selected_pixels] = 1

    return representative_pixels_mask



def read_masked_values(geotiff_path, mask, bands=None):
    """
    Reads the values of a multispectral GeoTIFF corresponding to a logical mask.

    Parameters
    ----------
    geotiff_path : str
        Path to the GeoTIFF file.
    mask : numpy.ndarray
        A 2D boolean mask (True where you want to keep values, False otherwise).
    bands : list of int, optional
        List of band indices to read (1-based index). If None, all bands are read.

    Returns
    -------
    masked_values : numpy.ndarray
        2D array of values where each row contains the pixel values across bands 
        for locations where the mask is True.
    """
    with rasterio.open(geotiff_path) as src:
        # If bands are not specified, read all bands
        if bands is None:
            bands = list(range(1, src.count + 1))
        
        # List to store masked values for each band
        masked_values_per_band = []

        for band in bands:
            data = src.read(band)  # Read each specified band
            masked_values_per_band.append(data[mask])  # Apply mask and store result

        # Stack the results to create a 2D array with shape (num_pixels, num_bands)
        masked_values = np.stack(masked_values_per_band, axis=-1)

    return masked_values



In [None]:
def collect_trainings(curr_acquisition, curr_aux_folder, auxiliary_folder_path, SVM_folder_name, no_data_mask, bands, PCA=False, total_samples = 500):
    
    scf_folder = os.path.join(curr_acquisition, SVM_folder_name)
    if not os.path.exists(scf_folder):
        os.makedirs(scf_folder)
        
    sensor = get_sensor(os.path.basename(curr_acquisition))
    
    path_cloud_mask = glob.glob(os.path.join(curr_aux_folder, '*cloud_Mask.tif'))[0]
    path_water_mask = glob.glob(os.path.join(auxiliary_folder_path, '*Water_Mask.tif'))[0]
    solar_incidence_angle_path = glob.glob(os.path.join(curr_aux_folder, '*solar_incidence_angle.tif'))[0]
    NDSI_path = glob.glob(os.path.join(curr_aux_folder, '*NDSI.tif'))[0]
    NDVI_path = glob.glob(os.path.join(curr_aux_folder, '*NDVI.tif'))[0]
    diff_B_NIR_path = glob.glob(os.path.join(curr_aux_folder, '*diffBNIR.tif'))[0]
    shad_idx_path = glob.glob(os.path.join(curr_aux_folder, '*shad_idx.tif'))[0]
    distance_index_path = glob.glob(os.path.join(curr_aux_folder, '*distance.tif'))[0]
        
    
    bands_path = glob.glob(os.path.join(curr_acquisition, '*scf.vrt'))
    
    if bands_path == []:
        bands_path = [f for f in glob.glob(curr_acquisition + os.sep + "PRS*.tif") if 'PCA' not in f][0]
    else:
        bands_path = bands_path[0]
        
    valid_mask = np.logical_not(no_data_mask)
    
    # Load masks and other necessary data
    cloud_mask = open_image(path_cloud_mask)[0]
    water_mask = open_image(path_water_mask)[0]
    solar_incidence_angle = open_image(solar_incidence_angle_path)[0]
    curr_scene_valid = np.logical_not(np.logical_or.reduce((cloud_mask == 2, water_mask == 1, no_data_mask)))
    
    ranges = ((0,20), (20, 45), (45, 70), (70, 90), (90, 180))
    range_samples = calculate_training_samples(solar_incidence_angle, ranges, total_samples)
    #ranges = ((70, 90))
    
    #ranges = ((0,20), (20, 30), (30, 40), (40, 50), (50, 60), (60, 70), (70, 80), (80, 90), (90, 180))
    empty = np.zeros(curr_scene_valid.shape, dtype='uint8')
    
    percentage_per_angles_list = []
    for curr_range, sample_count in range_samples.items():
        print(curr_range)
        print(sample_count)
        
        # Initialize as empty arrays
        representative_pixels_mask_snow = np.array([])
        representative_pixels_mask_noSnow = np.array([])
    
        curr_angle_valid = np.logical_and(curr_scene_valid, np.logical_and(solar_incidence_angle >= curr_range[0], solar_incidence_angle < curr_range[1]))
        
        percentage_of_scene_valid =  np.sum(curr_angle_valid) / np.sum(curr_scene_valid)
        
        percentage_per_angles_list.append(percentage_of_scene_valid)
    
        curr_NDSI = read_masked_values(NDSI_path, curr_angle_valid)
        curr_NDVI = read_masked_values(NDVI_path, curr_angle_valid)
        curr_green = read_masked_values(bands_path, curr_angle_valid, bands=[2])
        curr_bands = read_masked_values(bands_path, curr_angle_valid)
        curr_diff_B_NIR = read_masked_values(diff_B_NIR_path, curr_angle_valid)
        curr_shad_idx = read_masked_values(shad_idx_path, curr_angle_valid)
        curr_distance_idx = read_masked_values(distance_index_path, curr_angle_valid)
    
        # SNOW TRAINING
        if curr_range[0] >= 90:
            # Normalize indices and compute shadow metric
            diff_B_NIR_low_perc, diff_B_NIR_high_perc = np.percentile(curr_diff_B_NIR, [2, 95])
            shad_idx_low_perc, shad_idx_high_perc = np.percentile(curr_shad_idx, [2, 95])
            curr_diff_B_NIR_norm = np.clip((curr_diff_B_NIR - diff_B_NIR_low_perc) / (diff_B_NIR_high_perc - diff_B_NIR_low_perc), 0, 1)
            curr_shad_idx_norm = np.clip((curr_shad_idx - shad_idx_low_perc) / (shad_idx_high_perc - shad_idx_low_perc), 0, 1)
            curr_score_snow_shadow = curr_diff_B_NIR_norm - curr_shad_idx_norm
            threshold_shadow = np.percentile(curr_score_snow_shadow, 95)
            curr_valid_snow_mask_shadow = np.logical_and.reduce((curr_score_snow_shadow >= threshold_shadow, curr_NDSI > 0.7, curr_distance_idx != 255)).flatten()
            if np.sum(curr_valid_snow_mask_shadow) > 10:
                representative_pixels_mask_snow = get_representative_pixels(curr_bands, curr_valid_snow_mask_shadow, sample_count = int(sample_count/2), k=5, n_closest='auto')
        else:
            # Normalize indices and compute sun metric
            NDSI_low_perc, NDSI_high_perc = np.percentile(curr_NDSI[np.logical_not(np.isnan(curr_NDSI))], [1, 99])
            NDVI_low_perc, NDVI_high_perc = np.percentile(curr_NDVI[np.logical_not(np.isnan(curr_NDVI))], [1, 99])
            green_low_perc, green_high_perc = np.percentile(curr_green, [1, 99])
            curr_NDSI_norm = np.clip((curr_NDSI - NDSI_low_perc) / (NDSI_high_perc - NDVI_low_perc), 0, 1)
            curr_NDVI_norm = np.clip((curr_NDVI - NDVI_low_perc) / (NDVI_high_perc - NDVI_low_perc), 0, 1)
            curr_green_norm = np.clip((curr_green - green_low_perc) / (green_high_perc - green_low_perc), 0, 1)
            curr_score_snow_sun = curr_NDSI_norm - curr_NDVI_norm + curr_green_norm
            threshold = np.percentile(curr_score_snow_sun, 95)
            curr_valid_snow_mask = np.logical_and.reduce((curr_score_snow_sun >= threshold, curr_NDSI > 0.7, curr_distance_idx != 255)).flatten()
            
            if np.sum(curr_valid_snow_mask) > 10:
                representative_pixels_mask_snow = get_representative_pixels(curr_bands, curr_valid_snow_mask, sample_count = int(sample_count/2), k=5, n_closest='auto')
    
        ## NO snow TRAINING
        if curr_range[0] >= 90:
            threshold_shadow_no_snow = np.percentile(curr_score_snow_shadow, 5)
            curr_valid_no_snow_mask_shadow = (curr_score_snow_shadow <= threshold_shadow_no_snow).flatten()
            
            if np.sum(curr_valid_no_snow_mask_shadow) > 10:
                representative_pixels_mask_noSnow = get_representative_pixels(curr_bands, curr_valid_no_snow_mask_shadow, sample_count = int(sample_count/2), k=5, n_closest='auto') * 2
        else:
            curr_valid_no_snow_mask = (curr_NDSI < 0).flatten()
            
            if np.sum(curr_valid_no_snow_mask) > 10:
                representative_pixels_mask_noSnow = get_representative_pixels(curr_bands, curr_valid_no_snow_mask, sample_count = int(sample_count/2), k=10, n_closest='auto') * 2
    
        # Check if masks have been assigned; if not, set as zeros
        if representative_pixels_mask_snow.size == 0:
            representative_pixels_mask_snow = np.zeros(curr_angle_valid.sum(), dtype='uint8')
        if representative_pixels_mask_noSnow.size == 0:
            representative_pixels_mask_noSnow = np.zeros(curr_angle_valid.sum(), dtype='uint8')
            
        representative_pixels_mask = representative_pixels_mask_noSnow + representative_pixels_mask_snow
        empty[curr_angle_valid] = representative_pixels_mask
        
        print(str(np.sum(representative_pixels_mask_snow.flatten())) + ' SNOW PIXELS')
        print(str(np.sum(representative_pixels_mask_noSnow.flatten() / 2)) + ' NO SNOW PIXELS')

    
    # Convert points where result == 1 or 2 to a shapefile
    points = []
    values = []
    with rasterio.open(NDSI_path) as src:
        for row, col in zip(*np.where((empty == 1) | (empty == 2))):
            x, y = src.xy(row, col)
            points.append(Point(x, y))
            values.append(empty[row, col])

    gdf = gpd.GeoDataFrame({"value": values}, geometry=points, crs=src.crs)
    svm_folder_path = os.path.join(curr_acquisition, SVM_folder_name)
    
    plot_valid_pixels_percentage(ranges, percentage_per_angles_list, svm_folder_path)
    
    shapefile_path = os.path.join(svm_folder_path, 'representative_pixels_for_training_samples.shp')
    gdf.to_file(shapefile_path, driver="ESRI Shapefile")
    
    training_mask_path = os.path.join(svm_folder_path, 'representative_pixels_for_training_samples.tif')
    
    # Update the profile and save the representative mask
    with rasterio.open(NDSI_path) as src:
        profile = src.profile
    profile.update(dtype='uint8', count=1, compress='lzw', nodata=0)
    
    with rasterio.open(training_mask_path, 'w', **profile) as dst:
        dst.write(empty, 1)

    return shapefile_path , training_mask_path   