In [None]:
from . import utils

class Predictor:
    '''
    Class for predicting land use maps based on historical data and driver maps

    Attributes:
        land_cover_map (xarray): xarray datasets with pixel values 0, 1, and 2, where 0 is masked pixels
        beginning time (int): year of begin
        ending_time (int): year of end
        time_step (int): number of years between time steps 1 and 2
        suitability_map (xarray): suitability map with float values from 0-100, leaving as None triggers caculation of suitability map from drivers
        driver_maps (xarray): stack of categorical driver maps that the module will use to calculate a suitabilty map if suitability_map = None
        validation_map (xarray): xarray datasets with pixel values 0, 1, and 2, where 0 is masked pixels, for ending time step
        pixel_quantities (tuple, int): projected quantity of landuse states 1 and 2 at ending time, if None, calculates based on validation_map
        mask_image (xarray): binary mask of 0 and 1 to define study area
        strata_map (xarray): integer image dividing study area into regions
        constrain_to_neighborhood (int, tuple): tuple of values to constrain change to pixel neighborhood, None if no constraint
        driver_map_weights (dict): set of weights representing the amount to weight each driver map {map_name: weight, map_name2 : weight2}, if None, equal weights
    
    Methods:
    '''
    def __init__(
        self, 
        land_cover_map, 
        beginning_time,
        ending_time,
        time_step,
        suitability_map=None, 
        driver_maps=None, 
        validation_map=None,
        pixel_quantities=None,
        mask_image=None,
        strata_map=None,
        constrain_to_neighborhood=None,
        driver_map_weights=None
        ):
        # define all the attributes and stuff
        self.land_cover_map = land_cover_map
        self.driver_maps = driver_maps
        self.suitability_map = suitability_map
        self.validation_map = validation_map
        self.pixel_quantities = pixel_quantities
        self.mask_image = mask_image
        self.strata_map = strata_map
        self.constrain_to_neighborhood = constrain_to_neighborhood
        self.driver_map_weights = driver_map_weights

        if self.suitability_map is None:
            self.suitability_map = utils.create_suitability_map(self.land_cover_map, self.driver_maps, self.driver_map_weights, self.strata_map)

    def predict(self):
        '''
        Predict land use maps based on the provided parameters.

        Parameters:
            None

        Returns:
            xarray: Predicted classification map.
        '''
        # Step 1: Mask pixels already developed in the land cover map out of the suitability map
        masked_suitability = utils.mask_developed_pixels(self.suitability_map, self.land_cover_map)

        # Step 2: Reclassify the suitability map
        if self.constrain_to_neighborhood is None:
            reclassified_suitability = utils.reclassify_suitability(masked_suitability, pixel_quantity=self.pixel_quantities)
        else:
            edge_pixels = utils.get_edges(masked_suitability, self.constrain_to_neighborhood)
            reclassified_suitability = utils.reclassify_suitability(edge_pixels, pixel_quantity=self.pixel_quantities)

        # Step 3: Combine reclassified suitability map with the original land cover map
        predicted_map = utils.combine_maps(reclassified_suitability, self.land_cover_map)

        return predicted_map

    def validate(self):
        '''
        Validate the predicted land use maps.

        Returns:
            dict: Table of performance metrics.
            xarray: Validation map.
        '''
        
        performance_table = utils.CROSSTAB(self.land_cover_map, self.predicted_map, self.validation_map)

        validation_map = utils.validate(self.predicted_map, self.validation_map)

        return performance_table, validation_map
