In [20]:
import geopandas as gpd
import numpy as np
import rasterio as rio
import pandas as pd
import matplotlib.pyplot as plt
from pprint import pp

from matplotlib.colors import LinearSegmentedColormap
from rasterio.io import MemoryFile
from scipy.signal import savgol_filter
from datetime import datetime

from os import listdir, path, walk
from os.path import expanduser, isfile, join, splitext
from re import findall, search, compile, IGNORECASE
from typing import List, Tuple, Union, Dict, Callable

from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler

import warnings
warnings.simplefilter('ignore')

In [2]:
BADBANDS = [[300, 400],
            [1310, 1470],
            [1750, 2000],
            [2450, 2600]]
# GOODBANDS = [[400, 1310],
#              [1470, 1750],
#              [2000, 2450]]

In [3]:
### Helper functions

@staticmethod
def get_images(raster_filepath : str,
               image_pattern : str = '.tif') -> list[str] :
    '''Gets a list of images from a directory.
    input:
        - image_dir: 
        - image_pattern: string
    output:
        - image_list: list of strings
    '''
    image_list = []
    for dp, dn, fn in walk(expanduser(raster_filepath)) :
        for f in fn :
            if search(image_pattern, f) :
                image_list.append(join(dp, f))
    return image_list


@staticmethod
def extract_timestamp(filepath : str,
                      regex_pattern : str = r'(\d{8})') -> datetime :
    '''Extracts the timestamp from a file path.
    input:
        - filepath: string
    output:
        - timestamp: datetime.datetime
    '''
    timestamp_pattern = regex_pattern
    match = search(timestamp_pattern, filepath)
    if match:
        timestamp_str = match.group(0)
        timestamp = datetime.strptime(timestamp_str, '%Y%m%d')
        return timestamp
    else:
        return None


@staticmethod
def extract_metric(filepath : str,
                   regex_pattern : str = r'_(PE|Fapar|E).tif') -> str :
    '''Extracts the vegetation index from a file path.
    input:
        - filepath: string
    output:
        - vegetation_index: string
    '''
    vi_pattern = regex_pattern 
    vegetation_index_pattern = compile(vi_pattern, IGNORECASE)
    match = search(vegetation_index_pattern, filepath)
    if match:
        vegetation_index = match.group(0)
        vegetation_index = vegetation_index[1:-4]
        return vegetation_index
    else:
        return None


@staticmethod
def extract_module(filepath : str,
                   regex_pattern : str = r'(\w+_\d{1,3})') -> str :
    '''Extracts the module id from a file path.
    input:
        - filepath: string
    output:
        - mod: string
    '''
    ### TODO:
    # 1. Change re.pattern to look for unknown count of letters_unknown count of numbers
    mod_pattern = regex_pattern
    mod_pattern = compile(mod_pattern, IGNORECASE)
    match = search(mod_pattern, filepath)
    if match:
        mod = match.group(0)
        return mod
    else:
        return None


@staticmethod
def image_collection(images: list[str], key_extractors: list[Callable]) -> dict:
    '''Returns a nested dictionary of images.
    input:
        - images: list of strings
        - key_extractors: list of functions to extract key(s) from the image
    output:
        - image_dict: dict
    '''
    def recursive_dict_creation(keys, image, current_dict):
        if len(keys) == 1:
            current_dict.setdefault(keys[0], []).append(image)
        else:
            # Recursive case: process the next key in the list
            key = keys[0]
            if key not in current_dict:
                current_dict[key] = {}
            recursive_dict_creation(keys[1:], image, current_dict[key])
        return current_dict

    image_dict = {}

    for image in images:
        keys = [key_func(image) for key_func in key_extractors]
        recursive_dict_creation(keys, image, image_dict)

    return image_dict


@staticmethod
def custom_cmap(colors : list, n = 100, reverse = False) -> LinearSegmentedColormap :
    '''Creates a custom colormap.
    input:
        - colors: list of tuples
    output:
        - cmap: matplotlib.colors.LinearSegmentedColormap
    '''
    if reverse:
        colors = colors[::-1]
    cmap = LinearSegmentedColormap.from_list("", colors, N=n)
    return cmap


@staticmethod
def get_wls(wavelengths : List[str]) -> np.array :
    ''' Function to extract wavelengths from the list of bands
    Input:
        - wavelengths_list : list, list of bands to extract wavelengths
    Output:
        - numpy array, list of wavelengths
    '''
    wls = []
    for value in wavelengths.values() :
        regex = r"\d+\.\d+"
        regex_list = findall(regex, value)
        if len(regex_list) > 0 :
            wls.append(float(regex_list[0]))
    return np.array(sorted(wls))


@staticmethod
def get_bands(wls : List[str]) -> np.array :
    ''' Function to round the wavelengths to the nearest 5 and convert to string.
    Input:
        - wls : list, list of bands to extract bands
    Output:
        - np.array: list of bands as strings.
    '''
    wls = np.array(wls)
    wls = np.round(wls / 5) * 5
    # return [str(band) for band in wls]
    return wls


@staticmethod
def get_badbands(wls : List[str],
                 badbands : List[int] = BADBANDS) -> np.array :
    '''Function to remove bad bands from the list of bands
    Input:
        - wls : list, list of bands to remove bad bands, numpy array
        - badbands : list, list of bad bands
    Output:
        - numpy array, list of bands without bad bands
    '''
    good_bands = get_goodbands(wls, badbands)
    return wls[~np.isin(wls, good_bands)]


@staticmethod
def get_goodbands(wls : List[str],
                  badbands : List[int] = BADBANDS) -> np.array :
    '''Function to remove bad bands from the list of bands
    Input:
        - bands : list, list of bands to remove bad bands, numpy array
    Output:
        - numpy array, list of bands without bad bands
    '''
    for bb in badbands :
        wls = wls[(wls < bb[0]) | (wls > bb[1])]
    return wls


@staticmethod
def get_band_by_wl(rio_obj : rio.io.DatasetReader,
                   wl : int) -> np.array :
    '''Function to select a band from the image by wl
    Input:
        - wl : str, wavelength to select
    Output:
        - numpy array, band selected
    '''
    wls = get_wls(rio_obj.tags())
    band_idx = (np.abs(wls - wl)).argmin()
    return rio_obj.read(band_idx)


@staticmethod
def vector_normalize_raster(data: np.array) -> np.array :
    '''Function to normalize the nd-vector'''
    original_shape = data.shape

    data = data.reshape(data.shape[0], data.shape[1] * data.shape[2])
    norm = np.linalg.norm(data, axis = 0)
    data = np.divide(data,
                     norm,
                     where = norm != 0)

    data = data.reshape(original_shape)

    return data


@staticmethod
def memory_raster(raster : np.array,
                  meta_data : Dict,
                  tags : Dict = None,
                  band_names : Union[List[str], List[int]] = None) -> rio.io.DatasetReader :
    '''Creates a raster in memory.
    input:
        - raster: numpy.ndarray
        - tags: dict
        - meta_data: dict
        - band_names: list
    output:
        - raster: rasterio.io.DatasetReader
    '''
    with MemoryFile() as memfile:
        with memfile.open(**meta_data) as dataset :
            if band_names is not None :
                dataset.descriptions = band_names
            if tags is not None :
                dataset.update_tags(**tags)
            dataset.write(raster)
        return memfile.open()


@staticmethod
def clip_image(rio_obj : rio.io.DatasetReader,
               geometry : str) :
    '''Function to clip the image
    Input:
        - geometry : str, geometry to clip the image
    Output:
        - rasterio.io.DatasetReader
    '''
    from rasterio.mask import mask

    geom = gpd.read_file(geometry)
    geom = geom.to_crs(rio_obj.crs)
    geom = geom.geometry[0]

    clipped, transform = mask(rio_obj, [geom], crop=True)
    meta = rio_obj.meta
    descriptions = rio_obj.descriptions
    tags = rio_obj.tags()
    
    meta.update({"height": clipped.shape[1],
                 "width": clipped.shape[2],
                 "transform": transform})
    
    mem_raster = memory_raster(clipped, meta, band_names=descriptions, tags=tags)

    return mem_raster

@staticmethod
def AVIRIS_Spatial_Resample(rio_obj : rio.io.DatasetReader,
                            factor : int) -> rio.io.DatasetReader :
    '''Function to resample the image
    Input:
        - factor : int, factor to resample the image
    Output:
        - rasterio.io.DatasetReader
    '''
    from rasterio.enums import Resampling

    data = rio_obj.read(out_shape=(rio_obj.count,
                                   int(rio_obj.height / factor),
                                   int(rio_obj.width / factor)),
                        resampling=Resampling.bilinear)
    
    transform = rio_obj.transform * rio.Affine.scale(factor)
    
    meta = rio_obj.meta
    meta.update({"height": data.shape[1],
                 "width": data.shape[2],
                 "transform": transform})
    
    return memory_raster(data, meta, band_names=rio_obj.descriptions, tags=rio_obj.tags())

@staticmethod
def graph_spectra(ax : plt.Axes,
                  statistics : Dict,
                  styles : Dict,
                  wavelengths : np.array) -> None :

    ax.plot(wavelengths, statistics['mean'], **styles['mean'])
    ax.fill_between(wavelengths,
                    statistics['upper'], statistics['lower'], # update statistics['q25'] to parameter based filling
                    **styles['fill'])


@staticmethod
def hsi_custom_colorscale() -> LinearSegmentedColormap :
    '''Function to create a custom colorscale for hyperspectral images
    '''
    viswir = [[380, 850], # VIS: violet 380, blue 400, cyan 450, yellow 500, orange 550, red 700, dark-red 850
              [850, 900], # NIR: dark-red -> dark-red + black
              [900, 2500]] # SWIR: black -> black (full opacity)
    colors = [(380, 'violet'), (400, 'blue'), (450, 'cyan'), (500, 'yellow'),
              (550, 'orange'), (700, 'red'), (850, 'darkred'), (900, 'black')]

    color_transitions = []

    for wavelength, color in colors :
        if viswir[0][0] <= wavelength <= viswir[0][1]:
            normalized_wavelength = (wavelength - viswir[0][0]) / (viswir[2][1] - viswir[0][0])
            color_transitions.append((normalized_wavelength, color))

    normalized_nir_start = (viswir[1][0] - viswir[0][0]) / (viswir[2][1] - viswir[0][0])
    normalized_nir_end = (viswir[1][1] - viswir[0][0]) / (viswir[2][1] - viswir[0][0])
    color_transitions.append((normalized_nir_start, 'darkred'))
    color_transitions.append((normalized_nir_end, 'black'))

    normalized_swir_start = (viswir[2][0] - viswir[0][0]) / (viswir[2][1] - viswir[0][0])
    color_transitions.append((normalized_swir_start, 'black'))
    color_transitions.append((1.0, 'black'))

    color_transitions = sorted(color_transitions, key=lambda x: x[0])

    colormap = LinearSegmentedColormap.from_list('custom_spectrum', color_transitions)
    return colormap

In [37]:
def plot_savgol_results_with_pixels(results: dict, labels: list) -> None:
    """
    Plot the results of Savitzky-Golay filter optimization, including individual pixel lines and mean lines.
    
    Args:
        results (dict): Dictionary of results from calculate_savgol_results function.
        labels (list): Labels for the time axis.
    """
    ncols = 3
    nrows = len(results.keys())
    subplot_height = 6
    subplot_width = (4/3) * subplot_height  # Aspect ratio 4:3 for each subplot
    fig_width = ncols * subplot_width
    fig_height = nrows * subplot_height

    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
    plt.style.use('grayscale')

    for i, (window_length, polynomial_order) in enumerate(results.keys()):
        og = results[(window_length, polynomial_order)]['og']
        smoothed = results[(window_length, polynomial_order)]['smoothed']
        res = results[(window_length, polynomial_order)]['res']

        mean_og = np.nanmean(og, axis=1)
        mean_smoothed = np.nanmean(smoothed, axis=1)
        mean_res = np.nanmean(res, axis=1)

        for pixel in range(og.shape[1]):  # Plot each pixel (pixels are along axis 0)
            # ax[i, 0].plot(labels, og[:,pixel], alpha=0.1, color='grey')  # Original with transparency
            ax[i, 0].plot(labels, smoothed[:,pixel], alpha=0.05, color='black')  # Smoothed with transparency
        
        # Overlay mean lines
        ax[i, 0].plot(labels, mean_og, color='grey', label='Mean Original', linewidth=2)
        ax[i, 0].plot(labels, mean_smoothed, color='black', label='Mean Smoothed', linewidth=2)
        ax[i, 0].set_title(f'Window: {window_length}, Poly: {polynomial_order}, w/p Ratio: {(window_length / polynomial_order):.2f}, Mean Res: {mean_res.mean()}')

        ### Power Spectrum
        power_og = np.abs(np.fft.fftshift(np.fft.fft(mean_og)))**2
        power_smoothed = np.abs(np.fft.fftshift(np.fft.fft(mean_smoothed)))**2
        #### Define frequency axis
        fpix = np.arange(og.shape[0]) - og.shape[0] // 2

        ax[i, 1].semilogy(fpix, power_og, alpha=0.5, color='grey')
        ax[i, 1].semilogy(fpix, power_smoothed, color='black')
        ax[i, 1].set_title('Power Spectrum')
        ax[i, 1].set_xlabel('Pixel Frequency')

        ### Standard Errors by Predictions
        stnd_res = StandardScaler().fit_transform(mean_res.reshape(-1, 1))
        stnd_pred = StandardScaler().fit_transform(mean_smoothed.reshape(-1, 1))
        ax[i, 2].scatter(stnd_pred, stnd_res, label=f'{window_length}, {polynomial_order}', color='black')
        ax[i, 2].axhline(0, color='red', linestyle='--')
        ax[i, 2].set_title('Standard Errors by Predictions')
        ax[i, 2].set_xlabel('Standardized Predictions')
        ax[i, 2].set_ylabel('Standardized Residuals')

    # Add grid to each subplot
    for axes in ax.flat:
        axes.grid(True)

    plt.tight_layout()
    plt.show()


def calculate_savgol_results(data : np.array, 
                             window_lengths : list, 
                             polynomial_orders : list, 
                             axis : int = 0) -> dict:
    """
    Apply Savitzky-Golay filter for different window lengths and polynomial orders.

    Args:
        data (np.array): Input data array of shape (days, rows, columns).
        window_lengths (list): List of window lengths to test.
        polynomial_orders (list): List of polynomial orders to test.
        axis (int): Axis along which to apply the filter (0 = time axis).

    Returns:
        dict: Dictionary with keys as (window_length, polynomial_order) and values as 
              a dictionary with original data, smoothed data, and residuals.
    """
    # Flatten data to (days, pixels)
    new_shape = (data.shape[0], data.shape[1] * data.shape[2])
    data_flat = data.reshape(new_shape)

    results = {}

    def optimal_savgol_params(window_length: int, polynomial_order: int) -> dict:
        savgol_smoothed = savgol_filter(data_flat, window_length, polynomial_order, axis=axis)
        return {
            'og': data_flat,
            'smoothed': savgol_smoothed,
            'res': data_flat - savgol_smoothed,
        }
    for window_length in window_lengths:
        for polynomial_order in polynomial_orders:
            if window_length > polynomial_order:
                results[(window_length, polynomial_order)] = optimal_savgol_params(window_length, polynomial_order)
    
    return results


def calculate_savgol(data : np.array,
                     window_length : int,
                     polynomial_order : int,
                     axis : int = 0) :
    
    new_shape = (data.shape[0], data.shape[1] * data.shape[2])
    data_flat = data.reshape(new_shape)
    savgol_smoothed = savgol_filter(data_flat, window_length, polynomial_order, axis=axis)
    return {
        'mask' : np.ma.masked_array(data_flat, mask=np.isnan(data_flat)),
        'og' : data_flat,
        'smoothed': savgol_smoothed,
        'res': data_flat - savgol_smoothed,
    }


def savgol_optimization(data: np.array,
                        labels: list,
                        window_lengths: list,
                        polynomial_orders: list,
                        axis: int = 0) -> None:
    """
    Apply Savitzky-Golay filter optimization on the input data and visualize the results.
    
    Args:
        data (np.array): Input data array of shape (days, rows, columns).
        labels (List[str]): Labels for each data point.
        window_lengths (List[int]): List of window lengths to test.
        polynomial_orders (List[int]): List of polynomial orders to test.
        axis (int): Axis along which to apply the filter (0 = time axis).
    """
    # Flatten data to (days, pixels)
    new_shape = (data.shape[0], data.shape[1] * data.shape[2])
    data_flat = data.reshape(new_shape)

    # Store results for different parameters
    results = {}
    
    def optimal_savgol_params(window_length: int, polynomial_order: int) -> dict:
        savgol_smoothed = savgol_filter(data_flat, window_length, polynomial_order, axis=0)
        return {
            'og': data_flat,
            'smoothed': savgol_smoothed,
            'res': data_flat - savgol_smoothed,
        }

    # Apply Savitzky-Golay filter for different window lengths and polynomial orders
    for window_length in window_lengths:
        for polynomial_order in polynomial_orders:
            if window_length > polynomial_order:
                results[(window_length, polynomial_order)] = optimal_savgol_params(window_length, polynomial_order)
    # pp(results)
    # Set up subplots
    ncols = 3
    nrows = len(results.keys())
    subplot_height = 6
    subplot_width = (4/3) * subplot_height # Aspect ratio 4:3 for each subplot
    fig_width = ncols * subplot_width
    fig_height = nrows * subplot_height

    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
    plt.style.use('grayscale')

    for i, (window_length, polynomial_order) in enumerate(results.keys()):
        og = results[(window_length, polynomial_order)]['og']
        smoothed = results[(window_length, polynomial_order)]['smoothed']
        res = results[(window_length, polynomial_order)]['res']

        mean_og = np.nanmean(og, axis=1)
        mean_smoothed = np.nanmean(smoothed, axis=1)
        mean_res = np.nanmean(res, axis=1)

        ### Time Series OG vs Smoothed
        ax[i, 0].plot(labels, mean_og, alpha=0.5, color='grey')
        ax[i, 0].plot(labels, mean_smoothed, color='black',)
        ax[i, 0].set_title(f'Window: {window_length}, Poly: {polynomial_order}, w/p Ratio: {(window_length / polynomial_order):.2f}, Mean Res: {mean_res.mean()}')

        ### Power Spectrum
        power_og = np.abs(np.fft.fftshift(np.fft.fft(mean_og)))**2
        power_smoothed = np.abs(np.fft.fftshift(np.fft.fft(mean_smoothed)))**2
        #### Define frequency axis
        fpix = np.arange(og.shape[0]) - og.shape[0]//2

        ax[i, 1].semilogy(fpix, power_og, alpha=0.5, color='grey')
        ax[i, 1].semilogy(fpix, power_smoothed, color='black')
        ### Limit to 50px
        # ax[i, 1].set_xlim(-50, 50)
        ax[i, 1].set_title('Power Spectrum')

        ### Standard Errors by Predictions
        stnd_res = StandardScaler().fit_transform(mean_res.reshape(-1, 1))
        stnd_pred = StandardScaler().fit_transform(mean_smoothed.reshape(-1, 1))
        ax[i, 2].scatter(stnd_pred, stnd_res, label=f'{window_length}, {polynomial_order}', color='black')
        ax[i, 2].axhline(0, color='red', linestyle='--')

    # Add grid to each subplot
    for axes in ax.flat:
        axes.grid(True)

    plt.tight_layout()
    plt.show()

In [7]:
# import folium, io
# import plotly.express as px
# import numpy as np
# import helperfunctions as hf
# import rasterio as rio
# import pandas as pd
# import geopandas as gpd
# from shapely.geometry import box
# from branca.element import Figure
# from jinja2 import Template
# from matplotlib.colors import rgb2hex, LinearSegmentedColormap
# from matplotlib.cm import get_cmap
# from PIL import Image

class AutoMap:
    '''Class to generate maps from raster files.'''
    def __init__(self,
                 mapbox_api_key : str,
                 vi : str,
                 rio_obj : rio.DatasetReader) :
        
        self.mapbox_api_key = mapbox_api_key
        self.raster_array = rio_obj.read(masked = True)

        self.bbox = AutoMap.extract_boundingbox(rio_obj.bounds)
        self.bbox = AutoMap.reproject_boundingbox(self.bbox, rio_obj.crs, 'EPSG:4326')
        self.centroid = AutoMap.extract_bbox_centroid(self.bbox.values[0])

        self.vi = vi
        self.COLOR_RAMPS = {
                        'NDVI' : AutoMap.custom_cmap(['#D7bE69',
                                                '#A47E4f',
                                                '#5C6B28',
                                                '#424A26',
                                                '#2A3019'],
                                               n = 100),
                        'CVI' : 'viridis',
                        'NNI' : 'Greens',
                        'ET' : AutoMap.custom_cmap(['#B27116',
                                              '#FFB957',
                                              '#FFC370',
                                              '#1978B2',
                                              '#57BFFF'],
                                              n = 100)
                        }

    ### Color Stuff
    @staticmethod
    def custom_cmap(colors : list,
                    n : int = 100,
                    reverse : bool = False) -> LinearSegmentedColormap :
        '''Creates a custom colormap.
        input:
            - colors: list of tuples
        output:
            - cmap: matplotlib.colors.LinearSegmentedColormap
        '''
        if reverse:
            colors = colors[::-1]
        cmap = LinearSegmentedColormap.from_list("", colors, N=n)
        return cmap


    @staticmethod
    def normalize_minmax(array : np.ndarray) -> np.ndarray :
        '''Normalizes an array between 0 and 1.
        input:
            - array: np.array
        output:
            - normalized_array: np.array
        '''
        mask = np.ma.array(array, mask=np.isnan(array))
        min_val = np.nanmin(mask.data)
        max_val = np.nanmax(mask.data)
        # Perform normalization
        normalized_array = (mask.data - min_val) / (max_val - min_val)

        return normalized_array


    @staticmethod
    def get_color(x : float,
                  ramp : str = 'viridis') -> tuple[int]:
        '''Returns a color from a color ramp.
        input:
            - x: float
            - ramp: string
        output:
            - color: tuple
        '''
        decimals = 2
        x = np.around(x, decimals=decimals)
        ls = np.linspace(0,1,10**decimals+1)
        if 0 <= x <= 1:
            color = get_cmap(ramp)(x)
            return tuple((int(255 * color[0]), int(255 * color[1]), int(255 * color[2]), int(255 * color[3])))
        elif np.isnan(x):
            return (0, 0, 0, 0)
        else:
            raise ValueError()


    ### Bounding Box Stuff
    @staticmethod
    def reproject_boundingbox(bbox : list[float],
                              source_crs : 'str' ,
                              target_crs : str = 'EPSG:4326') -> tuple [float]:
        '''Reprojects a bounding box.
        Input:
            - bbox: list
            - source_crs: string
            - target_crs: string
        Output:
            - bbox: tuple
        '''
        bbox_df = pd.DataFrame({'geometry': [box(*bbox)]})
        bbox_gdf = gpd.GeoDataFrame(bbox_df, geometry='geometry')
        bbox_gdf.crs = {'init' : source_crs}
        bbox_gdf = bbox_gdf.to_crs(target_crs)
        bbox = bbox_gdf.geometry.bounds

        return bbox


    @staticmethod
    def extract_boundingbox(bounds : list[float]) -> tuple[float] :
        '''Extracts the bounding box from a raster file.
        Input:
            - bounds: list
            - source_crs: string
            - target_crs: string
            - reproject: boolean
        Output:
            - bbox: tuple
        '''
        minx, miny, maxx, maxy = bounds
        return minx, miny, maxx, maxy


    @staticmethod
    def extract_bbox_centroid(bbox : list[float]) -> tuple[float, float] :
        '''Extracts the centroid of a bounding box.
        Input:
            - bbox: list
        Output:
            - centroid: tuple'''
        return (bbox[0] + bbox[2])/2, (bbox[1] + bbox[3])/2


    def custom_colorbar(self,
                        colorbar_params : dict) -> str :
        '''Generates a custom colorbar for the map.
        Input:
            - colorbar_params: dict
        Output:
            - html: string
        '''
        html = """
            <head>
                <meta charset="UTF-8">
                <style>
                    .label-container {
                        width: {{ width }}px;
                        height: {{ height }}px;

                        background: linear-gradient(to right, 
                        {% for hex in hex_codes %}
                            {{ hex }}
                            {% if not loop.last %},{% endif %}
                        {% endfor %}
                        );
                        border-radius: 5px;
                        border: 0px solid #000;
                        justify-content: space-between;
                        align-items: center;
                        display: flex;
                        padding: 0px;
                        margin: 0 auto;
                        position: fixed;
                        z-index: 1000;
                        bottom: 0px;
                        left: 50%;
                        transform: translate(-50%, -50%);
                    }

                    .label {
                        color: {{ font_color }};
                        font-size: 22px;
                        font-weight: bold;
                        padding: 0px 10px;
                        margin: 0px;
                        border: none;
                    }
                </style>
            </head>
            <body>
                <div class="label-container">
                    <div class="label">{{ min }}</div>
                    <div class="label">{{ label }}</div>
                    <div class="label">{{ max }}</div>
                </div>
            </body>
        """
        my_templ = Template(html)
        with open('temp.html', 'w') as f:
            return my_templ.render(**colorbar_params)


    # TODO: Complete
    def make_histogram(self) -> px.histogram :
        pass


    # TODO: Complete
    def make_timeseries_graph(self) -> px.scatter_mapbox :
        pass


    def make_map(self) -> folium.Map :
        '''Generates a map from a raster file.
        Input:
            - raster_filepath: string, path to raster file
            - output_dir: string, path to output directory
        Output:
            - folium.Map
        '''

        palette = self.COLOR_RAMPS[self.vi.upper()]
        if not isinstance(palette, list) :
            cmap = get_cmap(palette, 100)

        px.set_mapbox_access_token(open(self.mapbox_api_key).read())
        mapboxTilesetId = 'mapbox.satellite'

        fig = Figure(width=800, height=500)

        m = folium.Map(
            location = [self.centroid[1], self.centroid[0]],
            zoom_start=18,
            zoom_control=False,
            tiles='https://api.tiles.mapbox.com/v4/' + mapboxTilesetId + '/{z}/{x}/{y}.png?access_token=' + open(".mapbox_token").read(),
            attr='mapbox.com'
        )

        fig.add_child(m)

        min, max = np.min(self.raster_array), np.max(self.raster_array)
        raster_arr = self.raster_array.reshape(-1, self.raster_array.shape[-1])
        norm_raster_arr = AutoMap.normalize_minmax(raster_arr)

        folium.raster_layers.ImageOverlay(norm_raster_arr,
                            bounds = [[self.bbox.miny.min(),
                                       self.bbox.minx.min()],
                                      [self.bbox.maxy.max(),
                                       self.bbox.maxx.max()]],
                            colormap = lambda x: AutoMap.get_color(x, ramp=palette),
                            ).add_to(m)

        colorbar_params = {
            'label' : self.vi,
            'max': "{0:.2f}".format(max),
            'min': "{0:.2f}".format(min),
            'hex_codes' : [rgb2hex(c) for c in cmap(np.linspace(0, 1, 100))],
            'width': 600,
            'height': 50,
            'font_color' : 'black'
        }

        colorbar_html = self.custom_colorbar(colorbar_params)
        m.get_root().html.add_child(folium.Element(colorbar_html))

        map_data = m._to_png()
        map = Image.open(io.BytesIO(map_data))
        map = map.convert('RGB')
        return map

NameError: name 'px' is not defined

In [8]:
### TODO:
# 1. repurpose class to be a general class for generating training data for grape variety discrimination
# 2. 

class HSITrainingDataGenerator() :
    '''Class to generate training data for hyperspectral images
    '''
    def __init__(self,
                 data : rio.io.DatasetReader,
                 vineyard_boundaries : gpd.GeoDataFrame,
                 label_col : str,
                 bounds = None,
                 badband_ranges : np.array = np.array([[300,400],
                                                       [1310,1470],
                                                       [1750,2000],
                                                       [2450,2600]])) :
        '''Class to generate
        Input:
            - data_path : str, path to the data file
            - geometry_path : str, path to the geometry file
            - labels : str, name of the column in the geometry file
            - reclassify : bool, reclassify the data based on the strategy
            - reclassify_strategy : str, strategy to reclassify the data
            - badbands : np.array, list of bad bands to remove
        '''
        self.vineyard_boundaries = vineyard_boundaries
        self.label_col = label_col
        self.bounds = bounds
        if bounds is None :
            self.bounds = self.vineyard_boundaries.total_bounds
        self.badbands = badband_ranges

        self.data = data
        self.meta = self.data.meta
        self.image_spatial_resolution = self.data.res[0]

        if self.data.descriptions is None :
            self.column_headers = [str(i) for i in range(self.data.count)]
        else :
            self.full_wls = get_wls(self.data.tags())
            self.goodbands = get_goodbands(self.full_wls)
            if len(self.full_wls) > 0 :
                self.column_headers = self.full_wls
            else :
                self.column_headers = self.data.descriptions

        # # set all bands within badbands to np.nan
        # for bb in self.badbands :

    def zonal_statistics(self) :
        return 1

        # if self.sampling_points.crs != self.data.crs :
        #     self.sampling_points = self.sampling_points.to_crs(self.data.crs)

        # if reclassify :
        #     self.grid = HSITrainingDataGenerator.generate_resampling_grid(self.sampling_points.crs,
        #                                                                   self.bounds,
        #                                                                   self.image_spatial_resolution)
        #     self.grid = HSITrainingDataGenerator.point_grid_count(self.grid,
        #                                                           self.sampling_points,
        #                                                           labels)
        #     self.grid_polygons = self.grid['geometry']
        #     self.grid.index = self.grid['geometry']
        #     self.grid = self.grid.drop(columns='geometry')

        #     if reclassify_strategy == 'majority' :
        #         self.sampling_points = HSITrainingDataGenerator.classify_grid_by_majority(self.grid)
        #         self.labels = 'majority'
        #     elif reclassify_strategy == 'threshold' :
        #         self.sampling_points = HSITrainingDataGenerator.classify_grid_by_threshold(self.grid)
        #     else :
        #         raise ValueError('Invalid reclassify strategy.')

        # self.training_data = HSITrainingDataGenerator.point_spatial_sample(self.sampling_points,
        #                                                                    self.data,
        #                                                                    column_headers=self.column_headers)

    def plot_spectra_by_label(self) :
        return 1

    @staticmethod
    def point_spatial_sample(sampling_points : gpd.GeoDataFrame,
                             data : rio.io.DatasetReader,
                             column_headers : List[Union[int, str]] = None) -> gpd.GeoDataFrame:
        """
        Extracts raster values at point locations from a vector file.
        Input
            - rio_obj : rasterio object, Raster file to extract values from.
            - geopandas_obj : geopandas object
            - column_headers : list, list of column headers for the output dataframe.
        Output:
            - geopandas dataframe, Dataframe with point locations and raster values.
        """
        points = sampling_points.geometry
        xy = [xy for xy in zip(points.x, points.y)]        
        sampler = rio.sample.sample_gen(data,
                                        xy,
                                        indexes=None,
                                        masked=True)
        spec_df = pd.DataFrame(sampler, columns=column_headers)

        spec_df.columns = spec_df.columns.map(str)
        spec_df = sampling_points.merge(spec_df, left_index=True, right_index=True)

        spec_df = spec_df[(spec_df[str(column_headers[0])] != "--") & (spec_df[str(column_headers[0])] != 0)]

        return spec_df

In [9]:
### Anaylsis of ET, PE, and FAPAR datasets for grapevine variety discrimination
data_repo = '/Users/fgalvan/Projects/GLRaV3_ET/data/stacked'
images = get_images(data_repo)
ic = image_collection(images, [extract_module, extract_metric])

In [38]:
### Plotting the results of the savgol optimization by pixel
data = ic['Canepa_239']['E'][0]

with rio.open(data) as src:
    testing_data = src.read(masked=True)
    testing_meta = src.meta
    testing_band_names = {i : extract_timestamp(desc) for i, desc in enumerate(src.descriptions)}

ssd = datetime(2020,4,1) # season start date
sed = datetime(2020,11,1) # season end date
total_days = (sed - ssd).days

nearest_odd_half = int(np.ceil(total_days/2) // 2 * 2 + 1)
window_lengths = [31, 51, nearest_odd_half, 150, 200]
polynomial_orders = [3]

growing_season = {k: v for k, v in testing_band_names.items() if ssd <= v <= sed}
growing_season_data = testing_data[list(growing_season.keys())]

# results = calculate_savgol_results(growing_season_data, window_lengths, polynomial_orders)
# plot_savgol_results_with_pixels(results, labels = growing_season.values())

best_savgol = calculate_savgol(growing_season_data, nearest_odd_half, 3)
display(best_savgol['smoothed'])

kmeans = KMeans(n_clusters=2, random_state=0)
cluster_labels = kmeans.fit(best_savgol['smoothed'])

# # Plot the clusters
# for cluster_id in np.unique(cluster_labels):
#     plt.plot(cluster_labels, best_savgol[cluster_labels == cluster_id].T, 
#              alpha=0.1 if cluster_id == 0 else 0.1, 
#              color='black' if cluster_id == 0 else 'grey')  # Different colors per cluster
# plt.title('K-Means Clustering of Smoothed Curves')
# plt.grid(True)
# plt.show()

array([[3.86258925, 4.00981909,        nan, ..., 3.67158875, 3.68531951,
        3.68249428],
       [3.83595729, 3.96696865,        nan, ..., 3.56133061, 3.57944511,
        3.58582795],
       [3.80895605, 3.92437096,        nan, ..., 3.45501221, 3.47693174,
        3.49172055],
       ...,
       [0.45635317, 0.43007599,        nan, ..., 0.58248301, 0.49760664,
        0.44604409],
       [0.41988432, 0.39353185,        nan, ..., 0.54560294, 0.46226934,
        0.41054305],
       [0.38185909, 0.3554214 ,        nan, ..., 0.50706344, 0.42538006,
        0.37354074]])

ValueError: Input X contains NaN.
KMeans does not accept missing values encoded as NaN natively. For supervised learning, you might want to consider sklearn.ensemble.HistGradientBoostingClassifier and Regressor which accept missing values encoded as NaNs natively. Alternatively, it is possible to preprocess the data, for instance by using an imputer transformer in a pipeline or drop samples with missing values. See https://scikit-learn.org/stable/modules/impute.html You can find a list of all estimators that handle NaN values at the following page: https://scikit-learn.org/stable/modules/impute.html#estimators-that-handle-nan-values

In [None]:
### Plotting multiple years same site
testing = ic['Canepa_239']['E'][0]

with rio.open(testing) as src:
    testing_data = src.read(masked=True)
    testing_meta = src.meta
    testing_band_names = {i : extract_timestamp(desc) for i, desc in enumerate(src.descriptions)}

s1_ssd = datetime(2020,4,1) # season start date
s1_sed = datetime(2020,11,1) # season end date
s2_ssd = datetime(2021,4,1) # season start date
s2_sed = datetime(2021,11,1) # season end date
s3_ssd = datetime(2022,4,1) # season end date
s3_sed = datetime(2022,11,1) # season end date

mean_total_days = ((s1_sed - s1_ssd).days + (s2_sed - s2_ssd).days + (s3_sed - s3_ssd).days) / 3
nearest_odd_half = int(np.ceil(mean_total_days/2) // 2 * 2 + 1)

window_lengths = [31, 51, 71, 91, nearest_odd_half, 150, 200]
polynomial_orders = [3]

growing_season = {k: v for k, v in testing_band_names.items() if ssd <= v <= sed}
growing_season_data = testing_data[list(growing_season.keys())]

savgol_optimization(data = growing_season_data,
                    labels = growing_season.values(),
                    window_lengths = window_lengths,
                    polynomial_orders = polynomial_orders)

In [None]:
### TODO:
# - Download data
# - Clip images to vineyard boundaries
# - Assign labels to the geometry

In [None]:
### TODO:
# Load polygons here
# Prepare the data for the model
# Explore the UMAP embeddings