In [0]:
import pyspark.sql.functions as f


# This notebook contains all the function need for pyspark notebook 
def get_kernel_for_x_y(index_x,index_y, data):
        """

        Get a kernel with x,y as it's centre pixel.
        Be aware that the x,y coordinates have to be in the same coordinate system as the coordinate system in the .tif file.
        
        This can be used if you want to use a matrix of pixels to predict the content of a centre pixel, which seems to be the standard in computer vision.
        But this also leads to performance issue's since the number of pixel also increases with the matrix.

        @param index_x: the x coordinate.
        @param index_y: the y coordinate.
        @return a kernel with chosen size in the init parameters
        """
        
        if sum([band[index_x][index_y] for band in data]) == 0:
            return [0,0,0,0,0,0,0,0]
        else:
            spot_kernel = [band[index_x][index_y] for band in data]
            spot_kernel.append(index_x)
            spot_kernel.append(index_y)
            spot_kernel = np.array(spot_kernel)
            spot_kernel = spot_kernel.astype(int)
            return spot_kernel
          
class scaler_class_all:
    """
    This class is used to scale columns from a raste file stored in a pandas dataframe to 0 and 1.

    Scalers should have been made indepently!
    
    """
    def __init__(self, scaler_file_band1 = "", scaler_file_band2 = "", scaler_file_band3 = "", scaler_file_band4 = "", scaler_file_band5 = "", scaler_file_band6 = "") :
        """
        Init of this class.

        @param scaler_file_band1: Path to a file which contains the scaler for band 1.
        @param scaler_file_band2: Path to a file which contains the scaler for band 2.
        @param scaler_file_band3: Path to a file which contains the scaler for band 3.
        @param scaler_file_band4: Path to a file which contains the scaler for band 4.
        @param scaler_file_band5: Path to a file which contains the scaler for band 5.
        @param scaler_file_band6: Path to a file which contains the scaler for band 6.    
        """

        self.scaler_band1 = joblib.load(scaler_file_band1)
        self.scaler_band2 = joblib.load(scaler_file_band2)
        self.scaler_band3 = joblib.load(scaler_file_band3)
        self.scaler_band4 = joblib.load(scaler_file_band4)
        self.scaler_band5 = joblib.load(scaler_file_band5)
        self.scaler_band6 = joblib.load(scaler_file_band6)


    def transform(self,pixel_df, col_names = ['band1','band2','band3','band4','band5',"band6"]):
        """
        Transforms the bands of a pandas dataframe.

        @param pixel_df: dataframe in which bands column have to be scaled.
        @return: dataframe with scaled bands.
        
        """

        pixel_df[col_names[0]] = self.scaler_band1.transform(pixel_df[col_names[0]].values.reshape(-1,1))        
        pixel_df[col_names[1]] = self.scaler_band2.transform(pixel_df[col_names[1]].values.reshape(-1, 1))
        pixel_df[col_names[2]] = self.scaler_band3.transform(pixel_df[col_names[2]].values.reshape(-1,1))        
        pixel_df[col_names[3]] = self.scaler_band4.transform(pixel_df[col_names[3]].values.reshape(-1, 1))
        pixel_df[col_names[4]] = self.scaler_band5.transform(pixel_df[col_names[4]].values.reshape(-1,1))        
        pixel_df[col_names[5]] = self.scaler_band6.transform(pixel_df[col_names[5]].values.reshape(-1, 1))

        return pixel_df

def func_cor_square(input_x_y):
        """
        This function is used to make squares out of pixels for a inter connected output.
        @param input_x_y a pixel input variable to be made into a square.
        @return the the squared pixel.        
        """
        rect = [round(input_x_y[0]/2)*2, round(input_x_y[1]/2)*2, 0, 0]
        rect[2], rect[3] = rect[0] + 2, rect[1] + 2
        coords = Polygon([(rect[0], rect[1]), (rect[2], rect[1]), (rect[2], rect[3]), (rect[0], rect[3]), (rect[0], rect[1])])
        return coords

def func_cor_square_50cm(input_x_y):
        """
        This function is used to make squares out of pixels for a inter connected output.
        @param input_x_y a pixel input variable to be made into a square.
        @return the the squared pixel.        
        """
        rect = [input_x_y[0] - 0.5, input_x_y[1] - 0.5, 0, 0]
        rect[2], rect[3] = rect[0] + 0.5, rect[1] + 0.5
        coords = Polygon([(rect[0], rect[1]), (rect[2], rect[1]), (rect[2], rect[3]), (rect[0], rect[3]), (rect[0], rect[1])])
        return coords      
      
def dissolve_gpd_output(agpd, path_out):
    """
    
    This function is used to dissolve pixels to the same type of pixels thus decreasing the amount of space needed to save the file.
    
    @param agpd: a geopandas dataframe.
    @param path_out: Location of where to store the dissolved output.
    """

    dissolved = gpd.GeoDataFrame(columns=['label', 'geometry'], crs=agpd.crs)
    labels = agpd['label'].unique()
    #print("------")
    for label in labels:

      #  print(label)
        union_gpd = agpd[agpd['label'] == label].unary_union
        dissolved = dissolved.append([{"label":label,"geometry":union_gpd}])
    #print("------")

    if '.geojson' not in path_out:
        dissolved.to_file(path_out) 
         
    elif '.geojson' in path_out:
        dissolved.to_file(path_out, driver="GeoJSON")

def check_done_files():
  dates = []
  for file in glob.glob(path_to_output+"*"+model_path.split("/")[-1].split(".sav")[0].split("_")[-1].replace(".","_")+"*"):
    dates.append(file.split("/")[-1].split("_")[0])
  return dates

## pyspark udf functions

# Pandas udf functions are at the moment the most fastest to do distributed predictions with.
@f.pandas_udf('string')
def predict_pandas_udf(*cols):
      """
      
      Pyspark wrapper function for a sklearn model prediction.
      """
      # cols will be a tuple of pandas.Series here.
      X = pd.concat(cols, axis=1)
      return pd.Series(loaded_model.predict(X))
    
#udf function for calculating mode, used for 2m aggregationds
@f.udf
def mode(x):
    from collections import Counter
    return Counter(x).most_common(1)[0][0]
  

## main runner function  
def run_tif_model_implementer(a_path_to_tif_file, path_to_output, path_to_scalers, a_parts, a_model_path, aggregate_to_2m = True):
  """
  
  This function implements the actual loop which means that for every pixel it uses a model to predict a class.
  The raster file can be split into multiple parts in order to reduce memory load.
  
  Normalization scalers 
  
  @a_path_to_tif_file: Path to the raster file.
  @a_parts: The number of parts to divide the dataset into reduce memory issue's
  @a_model_path: Path to a model which needs to have a predict function with the input of the number of bands.
  """
  # Set up parameters
  path_to_tif_file = a_path_to_tif_file
  parts = a_parts
  model_path = a_model_path
  loaded_model = pickle.load(open(model_path, 'rb'))
  
  start_run = timer()

  # Bug still in databricks which does not let us directly write to azure blob
  output_location_local = "/home/"+path_to_tif_file.split("/")[-1].split(".")[0]+".shp"

  use_kernels = False


  output_location = path_to_output+path_to_tif_file.split("/")[-1].split(".")[0]+".shp"

  # Init the scaler to normalize values, we have to normalize values because of the differences between satellite images.
  dataset = rasterio.open(path_to_tif_file)
  meta = dataset.meta.copy()
  data = dataset.read()
  width, height = meta["width"], meta["height"]

  # Determine which version of the AHN can be used 
  if int(path_to_tif_file.split("/")[-1][0:4]) <= 2019:
                  ahn_type = "/dbfs/mnt/satellite-images-nso/SV_50cm/coepelduynen/scalers//ahn3.save"
  elif int(path_to_tif_file.split("/")[-1][0:4]) > 2019:
                  ahn_type = "/dbfs/mnt/satellite-images-nso/SV_50cm/coepelduynen/scalers//ahn4.save"
  
  # Load scalers for RGBI NDVI and Height, these scalers have to be premade.
  a_normalize_scaler_class_all = scaler_class_all(scaler_file_band1 = glob.glob(path_to_scalers+path_to_tif_file.split("/")[-1]+"*band1*")[0].replace("\\","/"), \
                                                      scaler_file_band2 = glob.glob(path_to_scalers+path_to_tif_file.split("/")[-1]+"*band2*")[0].replace("\\","/"), \
                                                      scaler_file_band3 = glob.glob(path_to_scalers+path_to_tif_file.split("/")[-1]+"*band3*")[0].replace("\\","/"), \
                                                      scaler_file_band4 = glob.glob(path_to_scalers+path_to_tif_file.split("/")[-1]+"*band4*")[0].replace("\\","/"), \
                                                      scaler_file_band5 = glob.glob(path_to_scalers+path_to_tif_file.split("/")[-1]+"*band5*")[0].replace("\\","/"), \
                                                      scaler_file_band6 = ahn_type)

  # Declare variables to make divide the .tif file in different parts.
  total_height = height

  height_parts = round(total_height/parts)
  begin_height = 0
  end_height = height_parts

  height_parts = total_height/parts

  # Engage a loop which divides the pixels in parts.
  for part in range(1,parts+1):

    # Check if a precious already had certain parts done.
    if os.path.isfile(output_location_local.replace(".","_part_"+str(part)+".")) is True:
      print("Part file already exists")
    else:
    
      permutations = list(itertools.product([x for x in range(begin_height, end_height)], [ y for y in range(0, width)]))
      print("Total permutations this step: "+str(len(permutations)))
      # Calculate the number of permutations for this part.
      start = timer()
      if use_kernels is True:   
          seg_df = [get_kernel_for_x_y(permutation[0],permutation[1],data) for permutation in permutations]
          seg_df = pd.DataFrame(seg_df, columns= ["r","g","b","i","ndvi","height","x","y"])
      else:
          seg_df = [band[begin_height:end_height].flatten() for band in data]
          seg_df = pd.DataFrame(np.array(seg_df).T, columns= ["r","g","b","i","ndvi","height"])
          seg_df['x'] = [permutation[0] for permutation in permutations]
          seg_df['y'] = [permutation[1] for permutation in permutations]
      print("Done with extracting dataframe in "+str(timer()-start)+" second(s)")

      # Filter out empty pixels.
      start = timer()
      seg_df = seg_df[(seg_df["r"] != 0) & (seg_df["g"] != 0) & (seg_df["b"] != 0) &  (seg_df["i"] != 0)].reset_index().drop(['index'],axis=1)
      seg_df['rd_x'],seg_df['rd_y'] = rasterio.transform.xy(dataset.transform, seg_df['x'], seg_df['y'])
      print("Filtering done in "+str(timer()-start)+" second(s)")
      print("Filtered length of dataframe: "+str(len(seg_df.index)))

      # Scale/Normalize the pixels of this part
      a_normalize_scaler_class_all.transform(seg_df, col_names=["r","g","b","i","ndvi","height"])
      print("Normalization done in "+str(timer()-start)+" second(s)")
      
      
      # Create a Spark DataFrame and start predicting
      start = timer()
      sdf = spark.createDataFrame(seg_df.values.tolist(), ["r","g","b","i","ndvi","height","x","y","rd_x","rd_y"])
      list_of_columns = ['r', 'g', 'b', 'i', 'ndvi', 'height']
      print("Finished making spark dataframe in "+str(timer()-start)+" second(s)")

      start = timer()
      sdf = sdf.withColumn('label', predict_pandas_udf(*list_of_columns))

      print("Predicting finished in: "+str(timer()-start)+" second(s)")

      # Aggregate to 2m for data reduction.
      if aggregate_to_2m is True:

        sdf = sdf.withColumn("group_x",f.round(f.col("x")/2)*2)
        sdf = sdf.withColumn("group_y",f.round(f.col("y")/2)*2)
        cols = ['label']
        agg_expr = [mode(f.collect_list(col)).alias(col) for col in cols]
        sdf.groupBy(['group_x','group_y']).agg(*agg_expr)

        seg_df = sdf.toPandas()
        print("Grouping to 2m finished in: "+str(timer()-start)+" second(s)")

      else:
        seg_df = sdf.toPandas()
        seg_df.columns = ['r', 'g', 'b', 'i', 'ndvi', 'height', 'x', 'y', 'rd_x', 'rd_y','label']
        seg_df = seg_df[["rd_x","rd_y","label"]]
      
      
      if aggregate_to_2m is False:
        start = timer()
        #seg_df = gpd.GeoDataFrame(seg_df.groupby("label").apply(lambda g: Polygon(gpd.points_from_xy(g['rd_x'],g['rd_y']))))
        #seg_df['geometry'] = seg_df[0]
        #seg_df = seg_df.drop([0],axis=1).reset_index()
        seg_df['geometry'] = [func_cor_square_50cm(permutation) for permutation in seg_df[["rd_x","rd_y"] ].to_numpy().tolist()]
        print("Grouping labels finished in: "+str(timer()-start)+" second(s)")
        
      else: 
        seg_df['geometry'] = [func_cor_square(permutation) for permutation in seg_df[["rd_x","rd_y"] ].to_numpy().tolist()]

      start = timer()
      seg_df= seg_df[["geometry","label"]]

      seg_df = gpd.GeoDataFrame(seg_df, geometry=seg_df.geometry)
      seg_df = seg_df.set_crs(epsg = 28992)

      dissolve_gpd_output(seg_df, output_location_local.replace(".","_part_"+str(part)+"."))
      print("Dissolving done in: "+str(timer()-start)+" second(s)")
      print(output_location_local.replace(".","_part_"+str(part)+"."))


    begin_height = int(round(end_height+1))
    end_height = int(round(begin_height+height_parts))

    if end_height > height - (1/2):
         end_height = round(height - (1/2))
        
  # Merge the different parts into one.
  first_check = 0
  start = timer()

  for file in glob.glob(output_location_local.replace(".","_part_*.")):
            print(file)

            if first_check == 0:
                 all_part = gpd.read_file(file)
                 first_check = 1
            else:
                 print("Append")
                 all_part = all_part.append(gpd.read_file(file))
                  
               
  all_part.dissolve(by='label').to_file(output_location_local)

    
  print("Done with merging files in: "+str(timer()-start)+" second(s)")
  for file in glob.glob(output_location_local.replace(".","_part_*.").split(".")[0]):
            os.remove(file)  

  # Move the file from databricks vm to azure blob storage.
  output_location_local = "/home/"+path_to_tif_file.split("/")[-1].split(".")[0]+".shp"
  output_location = path_to_output+path_to_tif_file.split("/")[-1].split(".")[0]+".shp"
  output_location = output_location.replace(".shp","_"+model_path.split("/")[-1]+".shp")
  print("Writing to:"+output_location)
  shutil.move(output_location_local,output_location)
  shutil.move(output_location_local.replace(".shp",".cpg"),output_location.replace(".shp",".cpg"))
  shutil.move(output_location_local.replace(".shp",".dbf"),output_location.replace(".shp",".dbf"))
  shutil.move(output_location_local.replace(".shp",".shx"),output_location.replace(".shp",".shx"))
  print("Done with whole run in: "+str(timer()-start_run)+" second(s)")