In [None]:
import rasterio
import numpy as np
import os
import glob

In [7]:
class Prob_Image:
    
    out_meta = {}
    mask_iniprob_slope = [] # image
    
    mask_slope = []         # image
    current_slope_threshold = None
    
    mask_endprob = []         # image
    current_endprob_threshold = None
    
    out_image = []
    current_out_image = 'Empty'
    
    def __init__(self, file_path):
        self.file_path = file_path
        self.src = rasterio.open(self.file_path)
        self.meta_data = self.src.meta
            
    # get the crop probability stack
    @property
    def crop_stack(self): 
        array = self.src.read()
        # select every odd 
        return array[1::2,:]/65532

    # get the initial (t=0) crop probability
    @property
    def init_crop_prob(self):
        return self.crop_stack[1,:,:]
    
    # get init crop layer use 0.8
    @property
    def init_crop_mask(self):
        mask = np.copy(self.crop_stack[1,:,:])
        mask[self.crop_stack[1,:,:]>0.8]=1
        mask[self.crop_stack[1,:,:]<=0.8]=0
        return mask.astype(np.uint16)
    
    # calculate a slope image from crop_stack
    @property
    def slope_image(self):
        row = self.crop_stack.shape[1]
        col = self.crop_stack.shape[2]
        band_num = self.crop_stack.shape[0]
        crop_stack_flat = self.crop_stack.transpose(2,1,0).reshape(row*col,band_num)
        
        time_steps = np.arange(band_num)
        
        X_mat=np.vstack((np.ones(len(time_steps)), time_steps)).T
        # cf formula : linear-regression-using-matrix-multiplication
        tmp = np.linalg.inv(X_mat.T.dot(X_mat)).dot(X_mat.T)
        return tmp.dot(crop_stack_flat.T)[1].reshape(col,row).transpose(1,0)
    
    # refine the GFSAD based on initial crop probability and the degree of change in crop probability over time
    def mask_ini_slope_image(self, init_thresh, slope_thresh):
        
        mask = np.copy(self.init_crop_prob)
        
        # self.init_crop_prob has to be greater than 0.8
        mask[self.init_crop_prob > init_thresh] = 1
        mask[self.init_crop_prob <= init_thresh] = 0
        # slope can be negative, but should be small ~ -0.2
        mask[self.slope_image<=slope_thresh] =0
        self.mask_iniprob_slope = mask.astype(np.uint16)


    def mask_of_slope(self, slope_thresh, usenan=0):
        # create a mask from a pair of thresholds
        # use nan = 1: set value 0 to nan in the mask
        #         = 0: keep value 0 as 0 in the mask
        
        mask = np.copy(self.slope_image) 
        
        if usenan == 0:
            #returns a 0-1 mask
            mask[mask <= slope_thresh] = -99
            mask[mask > slope_thresh] = 0
            mask[mask == -99] = 1
            
            self.mask_slope = mask.astype(np.uint16)
            self.current_slope_threshold = slope_thresh
            return mask.astype(np.uint16)
        elif usenan == 1:
            # returns a nan-1 mask
            #mask[mask <= slope_thresh] = -99
            #mask[mask > slope_thresh] = np.nan
            #mask[mask == -99] = 1
            mask = np.ma.masked_where(mask > slope_thresh, np.ones(mask.shape))
            self.mask_slope = mask
            self.current_slope_threshold = slope_thresh
            return mask
            
        else:
            print('wrong input, not use nan')
            mask[mask <= slope_thresh] = -99
            mask[mask > slope_thresh] = 0
            mask[mask == -99] = 1
            
            self.mask_slope = mask.astype(np.uint16)
            self.current_slope_threshold = slope_thresh
            return mask.astype(np.uint16)
            

    
    def mask_of_endprob(self, prob_thresh):
        # create a mask from a pair of thresholds
        mask = np.copy(self.crop_stack[-1,:,:]) 
        
        mask[mask <= prob_thresh] = -99
        mask[mask > prob_thresh] = 0
        mask[mask == -99] = 1
        self.mask_endprob = mask.astype(np.uint16)
        self.current_endprob_threshold = prob_thresh
        return mask.astype(np.uint16)
        
    
    def update_out_meta(self, update_dict):
        self.out_meta = self.meta_data.copy()
        self.out_meta.update(update_dict)
        
    def write_crop_stack(self, out_path, out_name):
        with rasterio.open(os.path.join(out_path, out_name), "w", **self.out_meta) as dest:
            dest.write(self.crop_stack)    
            
#     def write_refined_GFSAD(self, out_path, out_name):
#         with rasterio.open(os.path.join(out_path, out_name), "w", **self.out_meta) as dest:
#             out_image_expand_dim = np.expand_dims((self.out_image), axis=(0))
#             dest.write(out_image_expand_dim)
            
    def generic_export(self, image_type, out_path, out_name):
        if image_type == 'slope_mask':
            self.out_image = self.mask_slope
            self.current_out_image = image_type
            self.update_out_meta({'count':1}) # only 1 band in theslope mask
        elif image_type == 'shapelet':
            self.out_image = self.shapelet_image
            self.current_out_image = image_type
            self.update_out_meta({'count':1}) # only 1 band in theslope mask
        else:
            print('image type not recognized !')
        with rasterio.open(os.path.join(out_path, out_name), "w", **self.out_meta) as dest:
            out_image_expand_dim = np.expand_dims((self.out_image), axis=(0))
            dest.write(out_image_expand_dim)
    
    