In [None]:
#from tqdm import tqdm_notebook
from tqdm.notebook import tqdm, tnrange
import time

In [None]:
%matplotlib inline
from osgeo import gdal, ogr, osr, gdalconst
import os, sys
import glob
import simplekml
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import pandas as pd

sys.path.insert(0, os.path.dirname(os.getcwd()))
sys.path.insert(0, os.getcwd())


from utils.plots import plot_heatmap
import utils.gdal_processing as gp
from utils.read_geoTiff import readHR
from utils.data_reader import interpPatches


p2ha = lambda x: (x*10)**2 /100**2



def no_output(func):
    def wrapper(*args, **kwargs):
        sysout = sys.stdout
        sys.stdout = open(os.devnull, "w")
        func(*args, **kwargs)
        sys.stdout = sysout
    return wrapper



In [None]:
def zonal_stats_old(FID, input_zone_polygon, input_value_raster, fn, is_return_numpoints = False, refband=1):

    # Open data
    raster = gdal.Open(input_value_raster)
    shp = ogr.Open(input_zone_polygon)
    lyr = shp.GetLayer()

    # Get raster georeference info
    transform = raster.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    # Reproject vector geometry to same projection as raster
    sourceSR = lyr.GetSpatialRef()
    targetSR = osr.SpatialReference()
    targetSR.ImportFromWkt(raster.GetProjectionRef())
    coordTrans = osr.CoordinateTransformation(sourceSR,targetSR)
    feat = lyr.GetFeature(FID)
    geom = feat.GetGeometryRef()
    geom.Transform(coordTrans)

    # Get extent of feat
    geom = feat.GetGeometryRef()

    if geom.GetGeometryName() == 'MULTIPOLYGON' :
        count = 0
        pointsX = []; pointsY = []
        for polygon in geom:
            geomInner = geom.GetGeometryRef(count)
            ring = geomInner.GetGeometryRef(0)
            numpoints = ring.GetPointCount()
            for p in range(numpoints):
                    lon, lat, z = ring.GetPoint(p)
                    pointsX.append(lon)
                    pointsY.append(lat)
            count += 1
    elif geom.GetGeometryName() == 'POLYGON':
        ring = geom.GetGeometryRef(0)
        numpoints = ring.GetPointCount()
        pointsX = []; pointsY = []
        for p in range(numpoints):
                lon, lat, z = ring.GetPoint(p)
                pointsX.append(lon)
                pointsY.append(lat)
    elif (geom.GetGeometryName() == 'LINESTRING'):
        numpoints = geom.GetPointCount()
        pointsX = []
        pointsY = []
        for p in range(numpoints):
            lon, lat, z = geom.GetPoint(p)
            pointsX.append(lon)
            pointsY.append(lat)
    else:
        sys.exit("ERROR: Geometry needs to be either Polygon or Multipolygon")

    xmin = min(pointsX)
    xmax = max(pointsX)
    ymin = min(pointsY)
    ymax = max(pointsY)

    # Specify offset and rows and columns to read
    xoff = int((xmin - xOrigin)/pixelWidth)
    yoff = int((yOrigin - ymax)/pixelWidth)
    if xoff < 0 or yoff < 0:
        return np.nan
    xcount = int((xmax - xmin)/pixelWidth)+1
    ycount = int((ymax - ymin)/pixelWidth)+1

    if is_return_numpoints:
        # TODO check that all the points are inside the region of interest
        return geom.GetPointCount()

    # Create memory target raster
    target_ds = gdal.GetDriverByName('MEM').Create('', xcount, ycount, 1, gdal.GDT_Byte)
    target_ds.SetGeoTransform((
        xmin, pixelWidth, 0,
        ymax, 0, pixelHeight,
    ))

    # Create for target raster the same projection as for the value raster
    raster_srs = osr.SpatialReference()
    raster_srs.ImportFromWkt(raster.GetProjectionRef())
    target_ds.SetProjection(raster_srs.ExportToWkt())

    # Rasterize zone polygon to raster
    gdal.RasterizeLayer(target_ds, [1], lyr, burn_values=[1])

    # Read raster as arrays
    banddataraster = raster.GetRasterBand(refband)
    try:
        dataraster = banddataraster.ReadAsArray(xoff, yoff, xcount, ycount).astype(np.float)
    except AttributeError:
        return np.nan
    bandmask = target_ds.GetRasterBand(1)
    datamask = bandmask.ReadAsArray(0, 0, xcount, ycount).astype(np.float)
    print(datamask.mean())
    clip = True
    if clip:
        dataraster = np.clip(dataraster,0.01,1e9)
    if not np.any(datamask):
        print('datamask empty')
        return np.nan
    # Mask zone of raster
#     zoneraster = np.ma.masked_array(dataraster,  np.logical_not(datamask))
    dataraster[np.logical_not(datamask)] = np.nan

    # Calculate statistics of zonal raster
    # return numpy.average(zoneraster),numpy.mean(zoneraster),numpy.median(zoneraster),numpy.std(zoneraster),numpy.var(zoneraster)
    try:
        return fn(dataraster)
    except ValueError:
        print('fix')
        return np.nan
    
def loop_zonal_stats_update_old(input_zone_polygon, input_value_raster, fieldname, fn, is_update=True, refband=1, is_pos_only=False):

    shp = ogr.Open(input_zone_polygon, update=1)
    lyr = shp.GetLayer()
    lyrdf =lyr.GetLayerDefn()

    # TreeFieldName = 'TreePredAd1'
    if is_update:
        id_ = lyrdf.GetFieldIndex(fieldname)
        if id_ == -1:
            field_defn = ogr.FieldDefn(fieldname, ogr.OFTReal)
            lyr.CreateField(field_defn)
            id_ = lyrdf.GetFieldIndex(fieldname)
        else:
            print('Field {} already exists, may overwrite'.format(fieldname))
    outVals = []
    id_Name = lyrdf.GetFieldIndex('Name')
    for FID in range(lyr.GetFeatureCount()):
        feat = lyr.GetFeature(FID)
        if feat is not None:
            # compute sum
            name_ = feat.GetField(id_Name)
            if 'pos' in name_ or not is_pos_only:
                meanValue = zonal_stats(FID, input_zone_polygon, input_value_raster, fn, refband=refband)
                print(f' {meanValue:.2f} Trees in {name_}')

            else:
                meanValue = zonal_stats(FID, input_zone_polygon, input_value_raster, fn, is_return_numpoints=True, refband=refband)
                print(f' {meanValue:.2f} Ref points in {name_}')
            outVals.append(meanValue)
            if np.isnan(meanValue):
                print(meanValue,FID)
            if is_update:
                lyr.SetFeature(feat)
                feat.SetField(id_,meanValue)
                lyr.SetFeature(feat)
    return np.sum(outVals)


In [None]:
def loop_zonal_stats_update(input_zone_polygon, input_value_raster, fieldname, fn, is_update=True, refband=1, is_pos_only=False,bias=1, field_name = 'Name'):

    shp = ogr.Open(input_zone_polygon, update=1)
    lyr = shp.GetLayer()
    lyrdf =lyr.GetLayerDefn()

    
    id_ = lyrdf.GetFieldIndex(fieldname)
    if id_ == -1:
        field_defn = ogr.FieldDefn(fieldname, ogr.OFTReal)
        lyr.CreateField(field_defn)
        id_ = lyrdf.GetFieldIndex(fieldname)
    else:
        print('Field {} already exists, may overwrite'.format(fieldname))
    outVals = []
    id_Name = lyrdf.GetFieldIndex(field_name)
    for FID in range(lyr.GetFeatureCount()):
        feat = lyr.GetFeature(FID)
        if feat is not None:
            # compute sum
            name_ = feat.GetField(id_Name)
            meanValue = zonal_stats(FID, input_zone_polygon, input_value_raster, fn, refband=refband,bias=bias)
#             print(f' {meanValue:.2f} Trees in {name_}')
            outVals.append(meanValue)
#             if np.isnan(meanValue):
#                 print(name_,FID,'is all nan')
            lyr.SetFeature(feat)
            feat.SetField(id_,meanValue)
            lyr.SetFeature(feat)
    return np.sum(outVals)

def zonal_stats(FID, input_zone_polygon, input_value_raster, fn, is_return_numpoints = False, refband=1, bias = 1.0):

    # Open data
    raster = gdal.Open(input_value_raster)
    shp = ogr.Open(input_zone_polygon)
    lyr = shp.GetLayer()

    # Get raster georeference info
    transform = raster.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    # Reproject vector geometry to same projection as raster
    sourceSR = lyr.GetSpatialRef()
    targetSR = osr.SpatialReference()
    targetSR.ImportFromWkt(raster.GetProjectionRef())
    coordTrans = osr.CoordinateTransformation(sourceSR,targetSR)
    feat = lyr.GetFeature(FID)
    geom = feat.GetGeometryRef()
    geom.Transform(coordTrans)

    # Get extent of feat
    geom = feat.GetGeometryRef()
    if (geom.GetGeometryName() == 'MULTIPOLYGON'):
        count = 0
        pointsX = []; pointsY = []
        for polygon in geom:
            geomInner = geom.GetGeometryRef(count)
            ring = geomInner.GetGeometryRef(0)
            numpoints = ring.GetPointCount()
            for p in range(numpoints):
                    lon, lat, z = ring.GetPoint(p)
                    pointsX.append(lon)
                    pointsY.append(lat)
            count += 1
    elif geom.GetGeometryName() == 'POLYGON':
        ring = geom.GetGeometryRef(0)
        numpoints = ring.GetPointCount()
        pointsX = []; pointsY = []
        for p in range(numpoints):
                lon, lat, z = ring.GetPoint(p)
                pointsX.append(lon)
                pointsY.append(lat)
    else:
        sys.exit("ERROR: Geometry needs to be a Polygon")
    xmin = min(pointsX)
    xmax = max(pointsX)
    ymin = min(pointsY)
    ymax = max(pointsY)

    # Specify offset and rows and columns to read
    xoff = int((xmin - xOrigin)/pixelWidth)
    yoff = int((yOrigin - ymax)/pixelWidth)
   
    xcount = int((xmax - xmin)/pixelWidth)+1
    ycount = int((ymax - ymin)/pixelWidth)+1


    xoff = min(xoff,raster.RasterXSize -1)
    xoff = max(xoff,1)
    
    xcount = min(xcount,raster.RasterXSize -1 - xoff)
    ycount = min(ycount,raster.RasterYSize -1 - yoff)
      

    # Create memory target raster
    target_ds = gdal.GetDriverByName('MEM').Create('', xcount, ycount, 1, gdal.GDT_Byte)
    target_ds.SetGeoTransform((
        xmin, pixelWidth, 0,
        ymax, 0, pixelHeight,
    ))

    # Create for target raster the same projection as for the value raster
    raster_srs = osr.SpatialReference()
    raster_srs.ImportFromWkt(raster.GetProjectionRef())
    target_ds.SetProjection(raster_srs.ExportToWkt())

    # Rasterize zone polygon to raster
    gdal.RasterizeLayer(target_ds, [1], lyr, burn_values=[1])

    # Read raster as arrays
    banddataraster = raster.GetRasterBand(refband)
    try:
        dataraster = banddataraster.ReadAsArray(xoff, yoff, xcount, ycount).astype(np.float)
    except AttributeError:
        print('dataraster wrong')
#         print('geotransform',transform)
        print(xoff,yoff,xcount,ycount)
        print(raster.RasterXSize,raster.RasterYSize, 'xmax,ymax:',xoff+xcount,yoff+xcount)
        return np.nan
    
    bandmask = target_ds.GetRasterBand(1)
    datamask = bandmask.ReadAsArray(0, 0, xcount, ycount).astype(np.float)
#     print(datamask.mean())
    clip = True
    if clip:
#         dataraster = np.clip(dataraster,0.01,1e9)
        dataraster[dataraster < 0.01] = np.nan
    dataraster[dataraster == 99] = np.nan
    
    if not np.any(datamask):
        print('datamask empty')
        return np.nan
    # Mask zone of raster
#     zoneraster = np.ma.masked_array(dataraster,  np.logical_not(datamask))
    dataraster[np.logical_not(datamask)] = np.nan
    dataraster *=bias
    # Calculate statistics of zonal raster
    # return numpy.average(zoneraster),numpy.mean(zoneraster),numpy.median(zoneraster),numpy.std(zoneraster),numpy.var(zoneraster)
    try:
        return fn(dataraster)
    except ValueError:
        print('fix')
        return np.nan
    

In [None]:
# obj='palm'

object_dict= {'palm':0,'coco':1}

# ref_band = object_dict[obj]

# points ='/home/pf/pfstud/andresro/tree_annotationsAug2019/annotations/Jan/palm/49MCV/Palm_Jan_1.kml'


In [None]:
# for automatic GT
data_config = {'T47NQA':'101.45,0.53,101.62,0.55'} 


In [None]:
def get_median_sentinel2(path, ref_tile, roi_lon_lat=None, resolution=60):
    list1 = glob.glob(path+'/*{}*'.format(ref_tile))
    list2 = [glob.glob(f"{x}/GRANULE/*/IMG_DATA/R{resolution}m/*_TCI_{resolution}m.jp2") for x in list1]
    
    medians = []
    masks = []

    for count, id_ in enumerate(list2):
        if len(id_) > 0:
            
            #ds = gdal.Open(id_[0])
            #array = ds.ReadAsArray()
            #array = array.transpose().swapaxes(0, 1)
            array = readHR(roi_lon_lat,data_file=id_[0],scale=10,as_float=False, is_verbose=False, is_assert_blank=False)

#            print(count, id_[0].split('/')[-6])

            mask_ = (array == 0).all(axis=-1)

            medians.append(array)
            try:
                if resolution == 10:
                    resolution_ = 20
                else:
                    resolution_ = resolution
                id_scl = id_[0].replace(f'TCI_{resolution}m',f'SCL_{resolution_}m')
                id_scl = id_scl.replace(f'/R{resolution}m/',f'/R{resolution_}m/')
                
                #ds = gdal.Open(id_[0].replace('TCI_60m','SCL_60m'))
                arrayscl = readHR(roi_lon_lat,data_file=id_scl,scale=resolution_,as_float=False,is_assert_blank=False, is_verbose=False)
                arrayscl = interpPatches(arrayscl, array.shape[0:2], squeeze=True, mode='edge').squeeze()
                #arrayscl = ds.ReadAsArray()
                #arrayscl = arrayscl.transpose().swapaxes(0, 1)
                mask_ = np.logical_or(mask_, arrayscl==3)
                mask_ = np.logical_or(mask_, arrayscl == 11)
                mask_ = np.logical_or(mask_, arrayscl == 6)


                id_cld = glob.glob(id_[0].split('/IMG_DATA/')[0]+f'/QI_DATA/*CLD*{resolution_}m.jp2')[0]
                array_cld = readHR(roi_lon_lat,data_file=id_cld,scale=resolution_,as_float=False, is_verbose=False, is_assert_blank=False)
                array_cld = interpPatches(array_cld, array.shape[0:2], squeeze=True, mode='edge').squeeze()

                #array_cld = ds.ReadAsArray()
                #array_cld = array_cld.transpose().swapaxes(0, 1)
                mask_ = np.logical_or(mask_, array_cld > 5)
            except IndexError:
                print('error in cld, or SCL',count)

            mask_ = np.repeat(mask_[...,np.newaxis], 3, axis=-1)
            masks.append(mask_)



    mask_ = np.stack(masks)
    median_ = np.ma.masked_array(np.stack(medians), mask_)
    median_ = np.ma.median(median_,axis=0) /255.0
    
    return median_
    

#### plot preds

In [None]:
p2ha = lambda x: (x/10)**2


# @no_output
def get_rasters(folder_inference, tile, folder_annotations, group='group1', preds_axis=0, sq_kernel=2, scale=10, clip_min=0.2):
    ds_out = pd.DataFrame(columns=['GT','Pred','lon','lat'])
    
    ref_folder = f'{folder_annotations}/{tile}/{group}'
    if not os.path.isdir(ref_folder):
        print(ref_folder,'does not exist')
        return None


    group = ref_folder.split('/')[-1]
        
    ref_raster = glob.glob(f'{folder_inference}/{tile}*_preds_reg*.tif')
    
    if len(ref_raster) == 0:
        print(f' no files found in {folder_inference}/{tile}*_preds_reg*.tif skipping...')
        return None

    assert len(ref_raster) == 1,len(ref_raster)
    ref_raster = ref_raster[0]

    ds = gdal.Open(ref_raster)
    roi_ = gp.get_positive_area_folder(ref_folder)

    lims = gp.to_xy_box(roi_, ds, enlarge=10)

    gt = gp.rasterize_points_pos_neg_folder(folder=ref_folder,refDataset=ref_raster,lims=lims,lims_with_labels=lims,sq_kernel=sq_kernel)
    gt[gt == -1] = np.nan

    preds = readHR(data_file=ref_raster,roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False)
    #sys.stdout = sys.__stdout__
    if preds.shape[-1] == 2:
        preds = preds[...,preds_axis]
#             preds[preds <clip_min] = 0
#     preds[preds==99] = np.nan
#     preds = np.clip(preds*1.3,0,2.5)

    ref_raster_sem = ref_raster.replace('reg.tif','semA.tif')

#     if os.path.isfile(ref_raster_sem):
#         preds_sem = readHR(data_file=ref_raster.replace('reg.tif','semA.tif'),roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False)
#     else:
#         preds_sem = preds > 0.5

# #             ds, fig = plot_preds(scale,gt,preds, tile, group)
#             ds['lon'] = (roi_[0]+roi_[2])/2
#             ds['lat'] = (roi_[1]+roi_[3])/2
#             ds_out = ds_out.append(ds)
    return {'gt':gt,
           'preds':preds,
           'roi_':roi_}
    
    
def plot_preds(scale,raster,preds, tile, group,density = (0,3)):
    dens_min,dens_max = density
    mask_out = np.isnan(raster) | np.isnan(preds)
#         gt_count = raster[~mask_out].sum()
    preds1 = preds.copy()
    preds1[mask_out] = np.nan
    raster1 = raster.copy()
    raster1[mask_out] = np.nan

    print(p2ha(scale))
    r1 = gp.block_reduce(raster1,(scale,scale),np.nansum)
    p1 = gp.block_reduce(preds1,(scale,scale),np.nansum)
    diff = p1  - r1
    diff1 = (p1 - r1 )/ r1
    diff1[np.logical_and(r1 == 0, p1 == 0) ] = 0
    diff1[np.logical_and(r1 == 0, p1 > 0) ] = np.nan

    fig = plt.figure(figsize=(10,5))
    n_col = 3
    gs = gridspec.GridSpec(nrows=1,ncols=n_col,left=0.02, bottom=0.06, right=0.95, top=0.94, wspace=0.1, figure=fig)

    txt = f' {tile} {group} {p2ha(scale)}ha'
    
    # GT
    # ax = plt.subplot(gs[0])
    ax = fig.add_subplot(gs[0])

    im = ax.imshow(raster,vmin=dens_min,vmax=dens_max)
    
    cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Trees GT {np.nansum(raster1):.2f}') # ' ({np.nanmax(raster):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds1.shape[0]/100:.2f}km')
    ax.set_ylabel(f' {preds1.shape[1]/100:.2f}km')

    # PREDS
#    ax = plt.subplot(gs[1])
    ax = fig.add_subplot(gs[1])

    im = ax.imshow(preds,vmin=dens_min,vmax=dens_max)
    cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Trees Pred {np.nansum(preds1):.2f}') #' ({np.nanmax(preds):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds1.shape[0]/100:.2f}km')#  \n\n'+txt)
    ax.set_ylabel(f' {preds1.shape[1]/100:.2f}km')

    # DIFSS
#    ax = plt.subplot(gs[2])
    ax = fig.add_subplot(gs[2])

    lim_ = dens_max * ((scale/2)**2)
    im = ax.imshow(diff,cmap = 'bwr', vmin = -lim_,vmax=lim_ ) #,vmin=-dens_max*scale*3,vmax=dens_max*scale*3)
    cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    cbar.ax.set_xlabel(f'Trees/ {p2ha(scale)}ha blocks')
    
    
    
    ax.set_title(f'Error per {p2ha(scale)}ha blocks')
    ax.set_xticks([])
    ax.set_yticks([])


    
#     fig.suptitle(f' {tile} {group} {p2ha(scale)}ha', y=1.01)
#     fig.text(.5, .06, txt, ha='center')


    # Scatter
    zeros_ = np.logical_and(p1== 0,r1==0).ravel()
    ds = pd.DataFrame({'GT':r1.ravel(),'Pred':p1.ravel()})
    return ds,fig




In [None]:
p2ha = lambda x: (x/10)**2

def plot_preds(scale,raster,preds, tile, group,density = (0,3)):
    dens_min,dens_max = density
    mask_out = np.isnan(raster) | np.isnan(preds)
#         gt_count = raster[~mask_out].sum()
    preds1 = preds.copy()
    preds1[mask_out] = np.nan
    raster1 = raster.copy()
    raster1[mask_out] = np.nan

    print(p2ha(scale))
    r1 = gp.block_reduce(raster1,(scale,scale),np.nansum)
    p1 = gp.block_reduce(preds1,(scale,scale),np.nansum)
    diff = p1  - r1
    diff1 = (p1 - r1 )/ r1
    diff1[np.logical_and(r1 == 0, p1 == 0) ] = 0
    diff1[np.logical_and(r1 == 0, p1 > 0) ] = np.nan

    fig = plt.figure(figsize=(10,5))
    n_col = 3
    gs = gridspec.GridSpec(nrows=1,ncols=n_col,left=0.02, bottom=0.06, right=0.95, top=0.94, wspace=0.1, figure=fig)

    txt = f' {tile} {group} {p2ha(scale)}ha'
    
    # GT
    # ax = plt.subplot(gs[0])
    ax = fig.add_subplot(gs[0])

    im = ax.imshow(raster,vmin=dens_min,vmax=dens_max)
    
    cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Trees GT {np.nansum(raster1):.2f}') # ' ({np.nanmax(raster):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds1.shape[0]/100:.2f}km')
    ax.set_ylabel(f' {preds1.shape[1]/100:.2f}km')

    # PREDS
#    ax = plt.subplot(gs[1])
    ax = fig.add_subplot(gs[1])

    im = ax.imshow(preds,vmin=dens_min,vmax=dens_max)
    cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Trees Pred {np.nansum(preds1):.2f}') #' ({np.nanmax(preds):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds1.shape[0]/100:.2f}km')#  \n\n'+txt)
    ax.set_ylabel(f' {preds1.shape[1]/100:.2f}km')

    # DIFSS
#    ax = plt.subplot(gs[2])
    ax = fig.add_subplot(gs[2])

    lim_ = dens_max * ((scale/2)**2)
    im = ax.imshow(diff,cmap = 'bwr', vmin = -lim_,vmax=lim_ ) #,vmin=-dens_max*scale*3,vmax=dens_max*scale*3)
    cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    cbar.ax.set_xlabel(f'Trees/ {p2ha(scale)}ha blocks')
    
    
    
    ax.set_title(f'Error per {p2ha(scale)}ha blocks')
    ax.set_xticks([])
    ax.set_yticks([])


    
#     fig.suptitle(f' {tile} {group} {p2ha(scale)}ha', y=1.01)
#     fig.text(.5, .06, txt, ha='center')


    # Scatter
    zeros_ = np.logical_and(p1== 0,r1==0).ravel()
    ds = pd.DataFrame({'GT':r1.ravel(),'Pred':p1.ravel()})
    return ds,fig


# @no_output
def plot_counts(folder_inference, tile, folder_annotations, group='group1', preds_axis=0, sq_kernel=2, scale=10, clip_min=0.2):
    ds_out = pd.DataFrame(columns=['GT','Pred','lon','lat'])
    
    ref_folders = glob.glob(f'{folder_annotations}/{tile}/{group}')
    if not ref_folders:
        print(f'no folders in {folder_annotations}/{tile}/{group}')
        return ds_out

    is_aut_gt = False
    for ref_folder in ref_folders:
        group = ref_folder.split('/')[-1]
        print(ref_folder, group)
        ref_raster = glob.glob(f'{folder_inference}/{tile}*_preds_reg*.tif')
        if len(ref_raster) == 0:
            print(f' no files found in {folder_inference}/{tile}*_preds_reg*.tif skipping...')
            fig = None
        else:
            ref_raster = ref_raster[0]
            #sys.stdout = open(os.devnull, "w")

            ds = gdal.Open(ref_raster)
            if is_aut_gt:
                return None
                roi_ = data_config[tile]
            else:
                roi_ = gp.get_positive_area_folder(ref_folder)
            lims = gp.to_xy_box(roi_, ds, enlarge=10)

            if is_aut_gt:
                raster = gp.rasterize_points_constrained(Input=ref_folder,refDataset=ref_raster,lims=lims,lims_with_labels=lims,sq_kernel=sq_kernel)
            else:
                raster = gp.rasterize_points_pos_neg_folder(folder=ref_folder,refDataset=ref_raster,lims=lims,lims_with_labels=lims,sq_kernel=sq_kernel)
            raster[raster == -1] = np.nan

            preds = readHR(data_file=ref_raster,roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False)
            #sys.stdout = sys.__stdout__
            if preds.shape[-1] == 2:
                preds = preds[...,preds_axis]
            preds[preds <clip_min] = 0
            preds[preds==99] = np.nan
        #     preds = np.clip(preds*1.3,0,2.5)

            ref_raster_sem = ref_raster.replace('reg.tif','semA.tif')

            if os.path.isfile(ref_raster_sem):
                preds_sem = readHR(data_file=ref_raster.replace('reg.tif','semA.tif'),roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False)
            else:
                preds_sem = preds > 0.5

            ds, fig = plot_preds(scale,raster,preds, tile, group)
            ds['lon'] = (roi_[0]+roi_[2])/2
            ds['lat'] = (roi_[1]+roi_[3])/2
            ds_out = ds_out.append(ds)

        return ds_out, fig

In [None]:
save_path_figs = '/scratch2/Dropbox/Dropbox/Apps/Overleaf/activelearning_remotesensing/figures/'

## Palm4748a

In [None]:
import json
filename = '/home/pf/pfstaff/projects/andresro/barry_palm/data/2A/datasets/palm4748a_base.json'
with open(filename, 'rb') as fp:
    out_dict = json.load(fp)

In [None]:
out_dict['val'][-1:]

In [None]:
folder_inference = '/scratch/andresro/leon_work/sparse/inference/palm4748a_simpleA9_soft_ens5'
scale=20


In [None]:
ds_out

In [None]:
#folder_inference = '/home/pf/pfstaff/projects/andresro/sparse/inference_leon/borneo_simpleA9_mc10'
ds_out = pd.DataFrame(columns=['GT','Pred','lon','lat'])

fig_list = []
# tilenames = [x.split('/')[1] for x in out_dict['val']]
for x in out_dict['val']:
    print(x)
#for tile in tilenames:
    tile = x.split('/')[1]
    group= x.split('/')[2]
            
    ds_ = plot_counts(folder_inference=folder_inference,
                tile=tile,
                folder_annotations='/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations',
                group=group, scale=scale)
    if ds_ is not None:
        ds_out = ds_out.append(ds_[0],ignore_index=True)
        fig_list.append(ds_[1])
    else:
        print('error in tile',tile)

In [None]:
fig_list[5]

In [None]:
#fig_list[5].savefig(save_path_figs+'Density-validation-sumatra-5.pdf', bbox_inches='tight',dpi=300)
#fig_list[8].savefig(save_path_figs+'Density-validation-sumatra-8.pdf', bbox_inches='tight',dpi=300)

In [None]:
zeros_ = np.logical_and(ds_out.GT<= 0.1,ds_out.Pred==0)
ds = ds_out[~zeros_].copy()

# ds['Pred'] = ds.Pred*1.2

g = sns.jointplot(x='Pred',y='GT',data=ds, cmap="Reds",
#                   kind="hex",
                 ) #, clip=(dens_min,dens_max))

lims = (0,np.nanmax(ds.GT))
g.ax_marg_x.set_xlim(lims)
g.ax_marg_y.set_ylim(lims)

# lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(lims, lims, ':k')    
plt.title(f' MAE {np.nanmean(np.abs(ds.Pred -ds.GT))/p2ha(scale):.2f} Trees/ha in {p2ha(scale)}ha Blocks  \n ' \
          f'total trees GT:{np.nansum(ds.GT):.2f} Pred:{np.nansum(ds.Pred):.2f} ({100*(np.nansum(ds.Pred)-np.nansum(ds.GT))/np.nansum(ds.GT):.2f}%) \n',x=-0.1,y=0.5, fontsize = 12)


## Comparison to Descals 2019

In [None]:
folder_annotations='/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations'

In [None]:
out_dict['val']

In [None]:
def get_matching_id(pos_file, grid_file):

    pos_vector = ogr.Open(pos_file)
    pos_layer = pos_vector.GetLayer()
    pos_feature = pos_layer.GetFeature(0)
    positiveGeometry = pos_feature.geometry()

    grid_vector = ogr.Open(grid_file)
    grid_layer = grid_vector.GetLayer()

    for i in range(grid_layer.GetFeatureCount()):
        feature = grid_layer.GetFeature(i)
        vectorGeometry = feature.GetGeometryRef()
        if positiveGeometry.Intersects(vectorGeometry):
            return feature.GetField('ID')
    return -1
    

In [None]:
# pos_file = '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48MTD/palm_group2_Bischel/positiv_2.shp'


In [None]:
# @no_output
def get_rasters(ref_folder):

    pos_shp = glob.glob(ref_folder+'/*pos*.shp')
    assert len(pos_shp) == 1
    pos_file = pos_shp[0]
    
    
    grid_file = '/home/pf/pfstaff/projects/andresro/barry_palm/palmoilmaps/descals2020/grid/grid_withOP.shp'

    # Reading Predictions
    tile = ref_folder.split('/')[-2]

    pred_tif = glob.glob(f'{folder_inference}/{tile}*preds_reg*.tif')[0]

    ds = gdal.Open(pred_tif)
    ref_proj = ds.GetProjectionRef()

    roi_ = gp.get_positive_area_folder(ref_folder)

    id_ = get_matching_id(pos_file=pos_file,grid_file=grid_file)

    base_path = '/home/pf/pfstaff/projects/andresro/barry_palm/palmoilmaps/descals2020/oil_palm_map'
    pred_descals = f'{base_path}/L2_2019b_{id_}.tif'


    ds_descals = gdal.Open(pred_descals)

    warp_opts = gdal.WarpOptions(
        format="VRT",  # format='GTiff',
        dstSRS=ref_proj,
        resampleAlg=gdalconst.GRA_Bilinear,
        # srcNodata=99,
        dstNodata='nan')

    ds_descals_warped = gdal.Warp('', pred_descals, options=warp_opts)

    preds_descals = readHR(data_file=ds_descals_warped,roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False,is_exit=False)
    preds_ours = readHR(data_file=ds,roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False,is_exit=False)


    lims = gp.to_xy_box(roi_, ds, enlarge=10)

    raster_gt = gp.rasterize_points_pos_neg_folder(folder=ref_folder,refDataset=pred_tif,lims=lims,lims_with_labels=lims,sq_kernel=2)
    raster_gt[raster_gt == -1] = np.nan

    median_s2 = get_median_sentinel2(path='/home/pf/pfstaff/projects/andresro/barry_palm/data/2A/palmcountries_2017/',
    ref_tile=tile, roi_lon_lat=roi_, resolution=10)
    
    return {'gt': raster_gt,
            's2': median_s2,
            'desc': preds_descals,
            'ours':preds_ours,
            'roi':roi_}


In [None]:
all_rasters = {}
for ref_folder in tqdm(out_dict['val']):
    try:
        rasters_dict = get_rasters(folder_annotations+ref_folder)
        #fig = plot_rasters(rasters_dict)
        all_rasters[ref_folder] = rasters_dict        
        #fig_list.append(fig)
    except:
        print(f'error in {ref_folder}')

In [None]:
import matplotlib as mpl

In [None]:

def plot_rasters(rasters_dict, filename = None):
        
    fig = plt.figure(figsize=(10,5))
    n_col = 3
    gs = gridspec.GridSpec(nrows=2,ncols=n_col,left=0.02, bottom=0.06, right=0.95, top=0.94, wspace=0.1,hspace=0.4, figure=fig)

    dens_min = 0.2
    dens_max = 2.0

    ax = fig.add_subplot(gs[0])
    median_s2 = rasters_dict['s2']
    raster_gt = rasters_dict['gt']
    preds_ours = np.clip(rasters_dict['ours'],0.1,99)
    preds_descals = rasters_dict['desc']
    
    im = ax.imshow(median_s2)

    ax = fig.add_subplot(gs[1])
        
    cmap = mpl.colors.ListedColormap(['white', 'green', 'darkgreen'])
    vid = plt.get_cmap('viridis')
    vid.set_bad('gray')

    #cmap.set_over('0.25')
    #cmap.set_under('0.75')
    cmap.set_bad(color='gray')

    bounds = [0,0.9,2,3]
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)


    im = ax.imshow(raster_gt,vmin=dens_min,vmax=dens_max, cmap=vid)
    #im = ax.imshow(np.isnan(raster_gt))

    #cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    #cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Trees GT {np.nansum(raster_gt):.2f}') # ' ({np.nanmax(raster):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {raster_gt.shape[0]/100:.2f}km')
    ax.set_ylabel(f' {raster_gt.shape[1]/100:.2f}km')

    ax = fig.add_subplot(gs[2])

    im = ax.imshow(preds_ours,vmin=dens_min,vmax=dens_max)
    #cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    #cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Trees Pred {np.nansum(preds_ours):.2f}') #' ({np.nanmax(preds):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds_ours.shape[0]/100:.2f}km')#  \n\n'+txt)
    ax.set_ylabel(f' {preds_ours.shape[1]/100:.2f}km')


    ax = fig.add_subplot(gs[3])

    threshold_area = 0.5
    
    raster_gt_sem = (raster_gt > threshold_area)*1.
    raster_gt_sem[np.isnan(raster_gt)] = np.nan

    im = ax.imshow(raster_gt_sem,cmap =cmap,norm=norm) # ,vmin=dens_min,vmax=dens_max)
    #cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    #cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Area GT {np.nansum(raster_gt > threshold_area):.0f}') #' ({np.nanmax(preds):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds_ours.shape[0]/100:.2f}km')#  \n\n'+txt)
    ax.set_ylabel(f' {preds_ours.shape[1]/100:.2f}km')

    ax = fig.add_subplot(gs[4])

    im = ax.imshow(preds_ours > threshold_area, cmap =cmap, norm=norm) # ,vmin=dens_min,vmax=dens_max)
    #cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    #cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Area Pred {np.sum(preds_ours[~np.isnan(raster_gt)] >threshold_area):.0f}') #' ({np.nanmax(preds):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds_ours.shape[0]/100:.2f}km')#  \n\n'+txt)
    ax.set_ylabel(f' {preds_ours.shape[1]/100:.2f}km')

    ax = fig.add_subplot(gs[5])

    im = ax.imshow(3 - preds_descals ,cmap =cmap,norm=norm) # cmap=cmap) # ,vmin=dens_min,vmax=dens_max)
    #cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    #cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    print(preds_descals.shape, raster_gt.shape, preds_ours.shape)
    try:
    
        ax.set_title(f'Area Descals {np.sum(preds_descals[~np.isnan(raster_gt)] == 2):.0f}') #' ({np.nanmax(preds):.2f})')
    except IndexError:
        print(preds_descals.shape,preds_ours.shape, raster_gt.shape,)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds_ours.shape[0]/100:.2f}km')#  \n\n'+txt)
    ax.set_ylabel(f' {preds_ours.shape[1]/100:.2f}km')
    
    if filename is not None:
        fig.savefig(filename, bbox_inches='tight',dpi=300)
        print(filename,'saved!')
    #return fig
    return None



In [None]:
out_dict['val']

In [None]:
all_rasters.keys()

In [None]:
# gt = all_rasters['/T48NUG/palm_group2']['gt']

# desc = all_rasters['/T48NUG/palm_group2']['desc']
# gt_sem = 1.*(gt > 0.5)
# gt_sem[np.isnan(gt)] = np.nan



# cmap = mpl.colors.ListedColormap(['white', 'darkgreen','red'])

# #cmap.set_over('0.25')
# #cmap.set_under('0.75')
# cmap.set_bad(color='gray')

# bounds = [0,0.9,2,3]
# norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

# plt.imshow(gt_sem,cmap=cmap, norm=norm)
# plt.colorbar()
# plt.show()
# plt.imshow(3- desc,cmap=cmap,norm=norm)
# plt.colorbar()

In [None]:

for key, val in all_rasters.items():
    plot_rasters(val)

In [None]:
all_rasters.keys()

In [None]:
plot_rasters(all_rasters['/T48NUG/palm_group2'], filename=save_path_figs+'/comparsion_descals.pdf')

### Validation areas from Descals

In [None]:
shp_val = '/home/pf/pfstaff/projects/andresro/barry_palm/palmoilmaps/descals2020/Validation_points_GlobalOilPalmLayer_2019/Validation_points_GlobalOilPalmLayer_2019.shp'

In [None]:

from sklearn.metrics import confusion_matrix, plot_confusion_matrix


In [None]:
confusion_matrix

In [None]:
sentinel_2tiles = '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/sentinel2_tiles_world/sentinel2_tiles_world.shp'

In [None]:
folder_inference

In [None]:
list_tif = glob.glob(folder_inference+'/T*.tif')

list_tif = [x.split('/')[-1] for x in list_tif]
list_tif = {x.split('_')[0][1:] for x in list_tif}


In [None]:
vector_s2 = ogr.Open(sentinel_2tiles)
layer_s2 = vector_s2.GetLayer()
nfeat_s2 = layer_s2.GetFeatureCount()

feat_list = []
for i in range(nfeat_s2):
    feature_s2 = layer_s2.GetFeature(i)
    if feature_s2.GetField('NAME') in list_tif:
        feat_list.append(i)


def get_s2_name(ref_geom):
    for i in feat_list:
        feature_s2 = layer_s2.GetFeature(i)
        geom_s2 = feature_s2.GetGeometryRef()
        if ref_geom.Intersects(geom_s2):
            return feature_s2.GetField('NAME')
    return None
    

In [None]:
vector = ogr.Open(shp_val)
layer = vector.GetLayer()
# feature = layer.GetFeature(0)
# positiveGeometry = pos_feature.geometry()

# grid_vector = ogr.Open(grid_file)
# grid_layer = grid_vector.GetLayer()
ids = []
gt = []
pred_desc = []
s2_names = []
for i in range(layer.GetFeatureCount()):
    feature = layer.GetFeature(i)
    vectorGeometry = feature.GetGeometryRef()
    s2_name = get_s2_name(vectorGeometry)
    if s2_name is not None:
        ids.append(i)
        gt.append(feature.GetField('Class'))
        pred_desc.append(feature.GetField('predClass'))
        s2_names.append(s2_name)


In [None]:
len(gt)

In [None]:
tn, fp, fn, tp = confusion_matrix(y_true=np.clip(gt,0,1),y_pred=np.clip(pred_desc,0,1)).ravel()

output_ = {}
output_['acc'] = (tp + tn) / (tp +tn + fn + fp)
output_['prec'] = (tp) / (tp + fp)
output_['rec'] = (tp) / (tp + fn)
output_

In [None]:
pred_ours = []
lonlats = []
for id_, s2_name in zip(ids,s2_names):
    ds = glob.glob(f'{folder_inference}/T{s2_name}*.tif')
    if len(ds) > 0:
        ds = gdal.Open(ds[0])
        feature = layer.GetFeature(id_)
        vectorGeometry = feature.GetGeometryRef()
        lon, lat, _ = vectorGeometry.GetPoint(0)
        x, y = gp.to_xy(lon,lat,ds =ds)
        lonlats.append((lon,lat))
        array = ds.ReadAsArray(xoff=x,yoff=y,xsize=1,ysize=1)
        pred_ours.append(array)
    else:
        pred_ours.append([np.nan])
        lonlats.append((np.nan,np.nan))

        

In [None]:
pred_ours = np.array(pred_ours).squeeze()

In [None]:
pred_sem = (pred_ours > 0.4)*1.
tn, fp, fn, tp = confusion_matrix(y_true=np.clip(gt,0,1),y_pred=pred_sem).ravel()

output_ = {}
output_['acc'] = (tp + tn) / (tp +tn + fn + fp)
output_['prec'] = (tp) / (tp + fp)
output_['rec'] = (tp) / (tp + fn)
output_

## PALM 4

In [None]:
import json
filename = '/home/pf/pfstaff/projects/andresro/barry_palm/data/2A/datasets/palm4_base.json'
with open(filename, 'rb') as fp:
    out_dict = json.load(fp)

In [None]:
out_dict['val'][-1:]

In [None]:
folder_inference = '/scratch/andresro/leon_work/sparse/inference/palm4_simpleA9_soft_ens5'
scale=20


In [None]:
all_rasters = {}
for ref_folder in tqdm(out_dict['val']):
    try:
        tile = ref_folder.split('/')[1]
        group= ref_folder.split('/')[2]
        rasters_dict = get_rasters(folder_inference=folder_inference,
                tile=tile,
                folder_annotations='/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations',
                group=group, scale=scale)
        #fig = plot_rasters(rasters_dict)
        all_rasters[ref_folder] = rasters_dict        
        #fig_list.append(fig)
    except:
        print(f'error in {ref_folder}')

In [None]:
scale = 10
dens_min = 0.3
ds_out = pd.DataFrame(columns=['GT','Pred','lon','lat'])


for key, val in all_rasters.items():
    raster = val['gt']
    preds = val['preds']
    roi_ = val['roi_']
    
    mask_out = np.isnan(raster) | np.isnan(preds) | (preds == 99)
    preds1 = preds.copy()
    preds1[preds1 < dens_min] = 0.0
    preds1[mask_out] = np.nan
    raster1 = raster.copy()
    raster1[mask_out] = np.nan

#     print(p2ha(scale))
    r1 = gp.block_reduce(raster1,(scale,scale),np.nansum)
    p1 = gp.block_reduce(preds1,(scale,scale),np.nansum)
    
    ds_ = pd.DataFrame({'GT':r1.ravel(),'Pred':p1.ravel(),
                       'lon':(roi_[0]+roi_[2])/2,
                       'lat':(roi_[1]+roi_[3])/2})
    
    #             ds['lon'] = (roi_[0]+roi_[2])/2
#             ds['lat'] = (roi_[1]+roi_[3])/2
    ds_out = ds_out.append(ds_,ignore_index=True)
    
    

In [None]:
min_trees = 1
zeros_ = np.logical_or(ds_out.GT < min_trees,ds_out.Pred < min_trees)
ds = ds_out.copy() # [~zeros_].copy()
ds = ds[~zeros_]

ds['Pred'] = ds.Pred*1.2

g = sns.jointplot(x='Pred',y='GT',data=ds,
#                   cmap="Reds",
                  kind='reg',scatter=False,
#                    kind="hex",
                 ) #, clip=(dens_min,dens_max))

g.ax_joint.scatter(x=ds.Pred,y=ds.GT, color='lightgray')
lims = (0,np.nanmax(ds.GT))
g.ax_marg_x.set_xlim(lims)
g.ax_marg_y.set_ylim(lims)

# lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(lims, lims, ':k')    
plt.title(f' MAE {np.nanmean(np.abs(ds.Pred -ds.GT))/p2ha(scale):.2f} Trees/ha in {p2ha(scale)}ha Blocks  \n ' \
#           f'total trees GT:{np.nansum(ds.GT):.2f} Pred:{np.nansum(ds.Pred):.2f} ({100*(np.nansum(ds.Pred)-np.nansum(ds.GT))/np.nansum(ds.GT):.2f}%) \n' \
          ' \n\n\n',
           x=0.7,
           fontsize = 14)



In [None]:
ds_out

In [None]:
#folder_inference = '/home/pf/pfstaff/projects/andresro/sparse/inference_leon/borneo_simpleA9_mc10'
ds_out = pd.DataFrame(columns=['GT','Pred','lon','lat'])

fig_list = []
# tilenames = [x.split('/')[1] for x in out_dict['val']]
for x in out_dict['val']:
    print(x)
#for tile in tilenames:
    tile = x.split('/')[1]
    group= x.split('/')[2]
            
    ds_ = plot_counts(folder_inference=folder_inference,
                tile=tile,
                folder_annotations='/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations',
                group=group, scale=scale, clip_min=0.2)
    if ds_ is not None:
        ds_out = ds_out.append(ds_[0],ignore_index=True)
        fig_list.append(ds_[1])
    else:
        print('error in tile',tile)

In [None]:
#folder_inference = '/home/pf/pfstaff/projects/andresro/sparse/inference_leon/borneo_simpleA9_mc10'
ds_out = pd.DataFrame(columns=['GT','Pred','lon','lat'])

fig_list = []
# tilenames = [x.split('/')[1] for x in out_dict['val']]
for x in out_dict['val']:
    print(x)
#for tile in tilenames:
    tile = x.split('/')[1]
    group= x.split('/')[2]
            
    ds_ = plot_counts(folder_inference=folder_inference,
                tile=tile,
                folder_annotations='/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations',
                group=group, scale=scale, clip_min=0.4)
    if ds_ is not None:
        ds_out = ds_out.append(ds_[0],ignore_index=True)
        fig_list.append(ds_[1])
    else:
        print('error in tile',tile)

In [None]:
fig_list[2]

In [None]:
#fig_list[5].savefig(save_path_figs+'Density-validation-sumatra-5.pdf', bbox_inches='tight',dpi=300)
#fig_list[8].savefig(save_path_figs+'Density-validation-sumatra-8.pdf', bbox_inches='tight',dpi=300)

In [None]:
import pandas as pd
import geopandas
import matplotlib.pyplot as plt

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import cartopy.crs as ccrs

from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

In [None]:
d_ = ds_out.groupby(['lon','lat']).mean().reset_index()

In [None]:
d_

In [None]:

gdf = geopandas.GeoDataFrame(
    geometry=geopandas.points_from_xy(d_['lon'],  d_['lat']))

error_ = np.abs(d_.Pred -d_.GT)

world = geopandas.read_file(geopandas.datasets.get_path('naturalearth_lowres'))
countries = geopandas.read_file("/scratch/andresro/leon_work/barry_palm/data/labels/countries/3countries.shp")


plt.figure(figsize=(6,6))
ax = plt.axes(projection=ccrs.PlateCarree())


world.plot(ax=ax,
    color='lightgray', edgecolor='white')

#countries.plot(ax=ax,
#    color='lightgray', edgecolor='white')


# sm = gdf.plot(ax=ax, marker='.',
# # color='blue',
#          c=error_,
#                 label='Val'
#          )


sc = ax.scatter(d_['lon'], d_['lat'],
            c=np.abs(d_.Pred -d_.GT),
            # c=weight_sorted[:top_k],
            cmap='magma',
           marker='.')

plt.colorbar(sc,ax=ax,orientation="horizontal")


minx, miny, maxx, maxy = gdf.total_bounds
margin_x = 2
margin_y = 2


ax.set_xlim(minx-margin_x, maxx+margin_x)
ax.set_ylim(miny-margin_y, maxy+margin_y)




gl = ax.gridlines(
   crs=ccrs.PlateCarree(),
    draw_labels=True) #x_inline=False, y_inline=False)
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'rotation':'vertical'}


# ax.legend(bbox_to_anchor=(1, 0.75), loc='upper left', ncol=1,
#          title="Sample Type")
#ax.legend() # loc='lower left')

# plt.title('Active Learning Selected Samples\n\n')
# plt.savefig(save_path_figs + 'active_learning_samples_10opt.', bbox_inches='tight', dpi= 300)


In [None]:
ax

In [None]:
plt.scatter(ds_out['lon'], ds_out['lat'],
            c=np.abs(ds_out.Pred -ds_out.GT),
            # c=weight_sorted[:top_k],
            cmap='magma_r', marker='.')

In [None]:
(ds.GT == 0.1).mean()

In [None]:
# zeros_ = np.logical_and(ds_out.GT== 0.1,ds_out.Pred==0)
ds = ds_out # [~zeros_].copy()

ds['Pred'] = ds.Pred*1.2

g = sns.jointplot(x='Pred',y='GT',data=ds, cmap="Reds",
#                   kind="hex",
                 ) #, clip=(dens_min,dens_max))

lims = (0,np.nanmax(ds.GT))
g.ax_marg_x.set_xlim(lims)
g.ax_marg_y.set_ylim(lims)

# lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(lims, lims, ':k')    
plt.title(f' MAE {np.nanmean(np.abs(ds.Pred -ds.GT))/p2ha(scale):.2f} Trees/ha in {p2ha(scale)}ha Blocks  \n ' \
          f'total trees GT:{np.nansum(ds.GT):.2f} Pred:{np.nansum(ds.Pred):.2f} ({100*(np.nansum(ds.Pred)-np.nansum(ds.GT))/np.nansum(ds.GT):.2f}%) \n',x=-0.1,y=0.5, fontsize = 12)


## Comparison to Descals 2019

In [None]:
folder_annotations='/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations'

In [None]:
def get_matching_id(pos_file, grid_file):

    pos_vector = ogr.Open(pos_file)
    pos_layer = pos_vector.GetLayer()
    pos_feature = pos_layer.GetFeature(0)
    positiveGeometry = pos_feature.geometry()

    grid_vector = ogr.Open(grid_file)
    grid_layer = grid_vector.GetLayer()

    for i in range(grid_layer.GetFeatureCount()):
        feature = grid_layer.GetFeature(i)
        vectorGeometry = feature.GetGeometryRef()
        if positiveGeometry.Intersects(vectorGeometry):
            return feature.GetField('ID')
    return -1
    

In [None]:
# pos_file = '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48MTD/palm_group2_Bischel/positiv_2.shp'


In [None]:
# @no_output
def get_rasters(ref_folder):

    pos_shp = glob.glob(ref_folder+'/*pos*.shp')
    assert len(pos_shp) == 1
    pos_file = pos_shp[0]
    
    
    grid_file = '/home/pf/pfstaff/projects/andresro/barry_palm/palmoilmaps/descals2020/grid/grid_withOP.shp'

    # Reading Predictions
    tile = ref_folder.split('/')[-2]

    pred_tif = glob.glob(f'{folder_inference}/{tile}*preds_reg*.tif')[0]

    ds = gdal.Open(pred_tif)
    ref_proj = ds.GetProjectionRef()

    roi_ = gp.get_positive_area_folder(ref_folder)

    id_ = get_matching_id(pos_file=pos_file,grid_file=grid_file)

    base_path = '/home/pf/pfstaff/projects/andresro/barry_palm/palmoilmaps/descals2020/oil_palm_map'
    pred_descals = f'{base_path}/L2_2019b_{id_}.tif'


    ds_descals = gdal.Open(pred_descals)

    warp_opts = gdal.WarpOptions(
        format="VRT",  # format='GTiff',
        dstSRS=ref_proj,
        resampleAlg=gdalconst.GRA_Bilinear,
        # srcNodata=99,
        dstNodata='nan')

    ds_descals_warped = gdal.Warp('', pred_descals, options=warp_opts)

    preds_descals = readHR(data_file=ds_descals_warped,roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False,is_exit=False)
    preds_ours = readHR(data_file=ds,roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False,is_exit=False)


    lims = gp.to_xy_box(roi_, ds, enlarge=10)

    raster_gt = gp.rasterize_points_pos_neg_folder(folder=ref_folder,refDataset=pred_tif,lims=lims,lims_with_labels=lims,sq_kernel=2)
    raster_gt[raster_gt == -1] = np.nan

    median_s2 = get_median_sentinel2(path='/home/pf/pfstaff/projects/andresro/barry_palm/data/2A/palmcountries_2017/',
    ref_tile=tile, roi_lon_lat=roi_, resolution=10)
    
    return {'gt': raster_gt,
            's2': median_s2,
            'desc': preds_descals,
            'ours':preds_ours}


In [None]:
all_rasters = {}
for ref_folder in tqdm(out_dict['val']):
    try:
        rasters_dict = get_rasters(folder_annotations+ref_folder)
        #fig = plot_rasters(rasters_dict)
        all_rasters[ref_folder] = rasters_dict        
        #fig_list.append(fig)
    except:
        print(f'error in {ref_folder}')

In [None]:
import matplotlib as mpl
def plot_rasters(rasters_dict, filename = None):
        
    fig = plt.figure(figsize=(10,5))
    n_col = 3
    gs = gridspec.GridSpec(nrows=2,ncols=n_col,left=0.02, bottom=0.06, right=0.95, top=0.94, wspace=0.1,hspace=0.4, figure=fig)

    dens_min = 0.2
    dens_max = 2.0

    ax = fig.add_subplot(gs[0])
    median_s2 = rasters_dict['s2']
    raster_gt = rasters_dict['gt']
    preds_ours = np.clip(rasters_dict['ours'],0.1,99)
    preds_descals = rasters_dict['desc']
    
    im = ax.imshow(median_s2)

    ax = fig.add_subplot(gs[1])
        
    cmap = mpl.colors.ListedColormap(['white', 'green', 'darkgreen'])
    vid = plt.get_cmap('viridis')
    vid.set_bad('gray')

    #cmap.set_over('0.25')
    #cmap.set_under('0.75')
    cmap.set_bad(color='gray')

    bounds = [0,0.9,2,3]
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)


    im = ax.imshow(raster_gt,vmin=dens_min,vmax=dens_max, cmap=vid)
    #im = ax.imshow(np.isnan(raster_gt))

    #cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    #cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Trees GT {np.nansum(raster_gt):.2f}') # ' ({np.nanmax(raster):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {raster_gt.shape[0]/100:.2f}km')
    ax.set_ylabel(f' {raster_gt.shape[1]/100:.2f}km')

    ax = fig.add_subplot(gs[2])

    im = ax.imshow(preds_ours,vmin=dens_min,vmax=dens_max)
    #cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    #cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Trees Pred {np.nansum(preds_ours):.2f}') #' ({np.nanmax(preds):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds_ours.shape[0]/100:.2f}km')#  \n\n'+txt)
    ax.set_ylabel(f' {preds_ours.shape[1]/100:.2f}km')


    ax = fig.add_subplot(gs[3])

    threshold_area = 0.5
    
    raster_gt_sem = (raster_gt > threshold_area)*1.
    raster_gt_sem[np.isnan(raster_gt)] = np.nan

    im = ax.imshow(raster_gt_sem,cmap =cmap,norm=norm) # ,vmin=dens_min,vmax=dens_max)
    #cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    #cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Area GT {np.nansum(raster_gt > threshold_area):.0f}') #' ({np.nanmax(preds):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds_ours.shape[0]/100:.2f}km')#  \n\n'+txt)
    ax.set_ylabel(f' {preds_ours.shape[1]/100:.2f}km')

    ax = fig.add_subplot(gs[4])

    im = ax.imshow(preds_ours > threshold_area, cmap =cmap, norm=norm) # ,vmin=dens_min,vmax=dens_max)
    #cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    #cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Area Pred {np.sum(preds_ours[~np.isnan(raster_gt)] >threshold_area):.0f}') #' ({np.nanmax(preds):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds_ours.shape[0]/100:.2f}km')#  \n\n'+txt)
    ax.set_ylabel(f' {preds_ours.shape[1]/100:.2f}km')

    ax = fig.add_subplot(gs[5])

    im = ax.imshow(3 - preds_descals ,cmap =cmap,norm=norm) # cmap=cmap) # ,vmin=dens_min,vmax=dens_max)
    #cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    #cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    print(preds_descals.shape, raster_gt.shape, preds_ours.shape)
    try:
    
        ax.set_title(f'Area Descals {np.sum(preds_descals[~np.isnan(raster_gt)] == 2):.0f}') #' ({np.nanmax(preds):.2f})')
    except IndexError:
        print(preds_descals.shape,preds_ours.shape, raster_gt.shape,)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds_ours.shape[0]/100:.2f}km')#  \n\n'+txt)
    ax.set_ylabel(f' {preds_ours.shape[1]/100:.2f}km')
    
    if filename is not None:
        fig.savefig(filename, bbox_inches='tight',dpi=300)
        print(filename,'saved!')
    #return fig
    return None



In [None]:
out_dict['val'][:10]

In [None]:
# all_rasters.keys()

In [None]:
# gt = all_rasters['/T48NUG/palm_group2']['gt']

# desc = all_rasters['/T48NUG/palm_group2']['desc']
# gt_sem = 1.*(gt > 0.5)
# gt_sem[np.isnan(gt)] = np.nan



# cmap = mpl.colors.ListedColormap(['white', 'darkgreen','red'])

# #cmap.set_over('0.25')
# #cmap.set_under('0.75')
# cmap.set_bad(color='gray')

# bounds = [0,0.9,2,3]
# norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

# plt.imshow(gt_sem,cmap=cmap, norm=norm)
# plt.colorbar()
# plt.show()
# plt.imshow(3- desc,cmap=cmap,norm=norm)
# plt.colorbar()

In [None]:

for key, val in all_rasters.items():
    plot_rasters(val)

In [None]:
all_rasters.keys()

In [None]:
plot_rasters(all_rasters['/T48NUG/palm_group2'], filename=save_path_figs+'/comparsion_descals.pdf')

In [None]:
plot_rasters(all_rasters['/T48NUG/palm_group2'], filename=save_path_figs+'/comparsion_descals.pdf')

### Validation areas from Descals

In [None]:
shp_val = '/home/pf/pfstaff/projects/andresro/barry_palm/palmoilmaps/descals2020/Validation_points_GlobalOilPalmLayer_2019/Validation_points_GlobalOilPalmLayer_2019.shp'

In [None]:

from sklearn.metrics import confusion_matrix, plot_confusion_matrix


In [None]:
confusion_matrix

In [None]:
sentinel_2tiles = '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/sentinel2_tiles_world/sentinel2_tiles_world.shp'

In [None]:
folder_inference

In [None]:
list_tif = glob.glob(folder_inference+'/T*.tif')

list_tif = [x.split('/')[-1] for x in list_tif]
list_tif = {x.split('_')[0][1:] for x in list_tif}
len(list_tif)

In [None]:
vector_s2 = ogr.Open(sentinel_2tiles)
layer_s2 = vector_s2.GetLayer()
nfeat_s2 = layer_s2.GetFeatureCount()

feat_list = []
for i in range(nfeat_s2):
    feature_s2 = layer_s2.GetFeature(i)
    if feature_s2.GetField('NAME') in list_tif:
        feat_list.append(i)


def get_s2_name(ref_geom):
    for i in feat_list:
        feature_s2 = layer_s2.GetFeature(i)
        geom_s2 = feature_s2.GetGeometryRef()
        if ref_geom.Intersects(geom_s2):
            return feature_s2.GetField('NAME')
    return None
    

In [None]:
vector = ogr.Open(shp_val)
layer = vector.GetLayer()

ids = []
gt = []
pred_desc = []
s2_names = []
for i in range(layer.GetFeatureCount()):
    feature = layer.GetFeature(i)
    vectorGeometry = feature.GetGeometryRef()
    s2_name = get_s2_name(vectorGeometry)
    if s2_name is not None:
        ids.append(i)
        gt.append(feature.GetField('Class'))
        pred_desc.append(feature.GetField('predClass'))
        s2_names.append(s2_name)
print('matched points',len(ids))

In [None]:
tn, fp, fn, tp = confusion_matrix(y_true=np.clip(gt,0,1),y_pred=np.clip(pred_desc,0,1)).ravel()

output_ = {}
output_['acc'] = (tp + tn) / (tp +tn + fn + fp)
output_['prec'] = (tp) / (tp + fp)
output_['rec'] = (tp) / (tp + fn)
output_

Producer's accuracy = Recall ' the number of reference sites classified accurately divided by the total number of reference sites for that class. '

User's Accuracy = Precision The User's Accuracy is calculating by taking the total number of correct classifications for a particular class and dividing it by the row total.


In [None]:
pred_ours = []
lonlats = []
window_size = 3 # each side
patch_size = 2*window_size+1

for id_, s2_name in tqdm(zip(ids,s2_names)):
    ds = glob.glob(f'{folder_inference}/T{s2_name}*.tif')
    if len(ds) > 0:
        ds = gdal.Open(ds[0])
        feature = layer.GetFeature(id_)
        vectorGeometry = feature.GetGeometryRef()
        lon, lat, _ = vectorGeometry.GetPoint(0)
        x, y = gp.to_xy(lon,lat,ds =ds)
        x = np.clip(x-window_size, 0,ds.RasterXSize-patch_size)
        y = np.clip(y-window_size, 0,ds.RasterYSize-patch_size)
        
        lonlats.append((lon,lat))
#         array = ds.ReadAsArray(xoff=x,yoff=y,xsize=1,ysize=1)
        array = ds.ReadAsArray(xoff=int(x),yoff=int(y),xsize=patch_size,ysize=patch_size).mean()
        pred_ours.append(array)
    else:
        pred_ours.append([np.nan])
        lonlats.append((np.nan,np.nan))

        

In [None]:
pred_ours = np.array(pred_ours).squeeze()

In [None]:
pred_sem = (pred_ours > 0.48)*1.
tn, fp, fn, tp = confusion_matrix(y_true=np.clip(gt,0,1),y_pred=pred_sem).ravel()

output_ = {}
output_['acc'] = (tp + tn) / (tp +tn + fn + fp)
output_['prec'] = (tp) / (tp + fp)
output_['rec'] = (tp) / (tp + fn)
output_

## OLD

In [None]:
def convert_to_shp(points):
    if points.endswith('.kml'):
        new_points=points.replace('.kml','.shp')
        srcDS = gdal.OpenEx(points)
        ds = gdal.VectorTranslate(new_points, srcDS, format='ESRI Shapefile')
        ds = None
        points = new_points
    return points 


def drop_all_butName(points):
    dataSource = ogr.Open(points, 1) 

    layer = dataSource.GetLayer()

    lyrdf = layer.GetLayerDefn()

    id_Name = lyrdf.GetFieldIndex('Name')
    attr_N = lyrdf.GetFieldCount()
    print(attr_N, id_Name)
    for i in range(attr_N):
        if not i == id_Name:
            layer.DeleteField(i)   
    attr_N = lyrdf.GetFieldCount()
    print(attr_N)
    dataSource = None


In [None]:
ref_raster = '/scratch/andresro/leon_work/sparse/inference/palmcoco_kalimA_simpleA5/R132_T49MCV_preds_reg.tif'
for i in range(3):
    points = f'/home/pf/pfstud/andresro/tree_annotationsAug2019/annotations/Jan/palm/49MCV/Palm_Jan_{i+1}.kml'
    points = convert_to_shp(points)
    print(points)
    loop_zonal_stats_update(input_zone_polygon=points,input_value_raster=ref_raster,fieldname='pred1',fn=np.ma.sum, is_update=False, is_pos_only=True)

In [None]:
ref_raster = '/scratch/andresro/leon_work/sparse/inference/palmcoco_kalimA_simpleA5/R018_T47NQA_preds_reg.tif'
points = '/home/pf/pfstud/andresro/tree_annotationsAug2019/annotations/Andres/palm/points_manual_2019.kml'
points = convert_to_shp(points)
print(points)
if os.path.isdir(points):
    pointsList = glob.glob(points+"/*.shp")
    print(f'processing {len(pointsList)} layers')
else:
    pointsList = [points]

for points in pointsList:
    
    out_val = loop_zonal_stats_update(input_zone_polygon=points,input_value_raster=ref_raster,fieldname='pred1',fn=np.ma.sum, is_update=False)
    print(f"TOTAL {out_val:.2f}", os.path.basename(points))

## Evaluate state-wide predictions

In [None]:
scale = 2
input_zone_pol = '/scratch2/Dropbox/Dropbox/0_phd/tree_annotationsAug2019/countries/malaysia/MYS_adm1.shp'
ref_raster = f'/scratch/andresro/leon_work/sparse/inference/palmsarawak_simpleA20_allsarawak/0_untiled_down{scale}.tif'

In [None]:

img = gp.rasterize_polygons(InputVector=input_zone_pol,refDataset=ref_raster,attribute='ID_1')

raster = readHR(roi_lon_lat = None,data_file=ref_raster,scale=1,as_float=False)

In [None]:
# plt.imshow(img==11)
raster[img!=11] = np.nan
raster *=(scale**2) # computing sum instead of average in down op 
raster*=1.4 # bias correction

plt.imshow(raster)
plt.colorbar()

In [None]:
p2ha(scale)

In [None]:
np.nanmin(raster),np.nanmax(raster)

#### Statistics

In [None]:
# Total Ha
ref_ = raster  > 30*p2ha(scale)
plt.imshow(ref_)
area = np.nansum(ref_)*p2ha(scale)
trees = np.nansum(raster[ref_])
f' Planted {area/1e6:.3f}mha , Trees {trees/1e6:.3f}m, Tree/Ha {trees/area} '

In [None]:


out_val = loop_zonal_stats_update(input_zone_polygon=input_zone_pol,input_value_raster=ref_raster,fieldname='pred1',fn=np.nansum, is_update=False,bias=14.0, field_name='NAME_1')



In [None]:
shp = ogr.Open(input_zone_pol, update = 1)

lyr = shp.GetLayer()

lyr

### Split raster into blocks and save it as shapefile

In [None]:
import math
EARTH_RADIUS = 6371000  # Radius in meters of Earth


# Compute the shortest path curved distance between 2 points (lat1,lon1) and (lat2,lon2) using the Haversine formula.
def haversine_distance(lon1, lat1, lon2, lat2):

    a = math.sin(math.radians((lat2 - lat1) / 2.0)) ** 2 + math.cos(math.radians(lat1)) * math.cos(
        math.radians(lat2)) * math.sin(math.radians((lon2 - lon1) / 2.0)) ** 2
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    return EARTH_RADIUS * c

def split_roi_to_rois(lon1_, lat1_, lon2_, lat2_, meters_split = 1500):

    lon1, lat1, lon2, lat2 = min(lon1_, lon2_), min(lat1_,lat2_), max(lon1_, lon2_), max(lat1_, lat2_)

    delta_lon_m = haversine_distance(lon1=lon1,lat1=lat1,lon2=lon2,lat2=lat1)
    delta_lat_m = haversine_distance(lon1=lon1,lat1=lat1,lon2=lon1,lat2=lat2)
    rois = []

    N_lon, N_lat = map(lambda x: int(math.ceil(x / meters_split)), [delta_lon_m,delta_lat_m])

    delta_lon, delta_lat = (lon2-lon1, lat2 - lat1)
    for i in range(N_lat):
        for j in range(N_lon):
            ind = i * N_lon + j
            rois.append({"roi": (
                                lat1 + (delta_lat) * i / N_lat,
                                lon1 + (delta_lon) * j / N_lon,
                                lat1 + (delta_lat) * (i + 1) / N_lat,
                                lon1 + (delta_lon) * (j + 1) / N_lon),
                        "name": "ROI{}".format(ind + 1)})

    return rois

def to_bbox(roi_lon_lat):
    if isinstance(roi_lon_lat, str):
        roi_lon1, roi_lat1, roi_lon2, roi_lat2 = map(float, re.split(',', roi_lon_lat))
    else:
        roi_lon1, roi_lat1, roi_lon2, roi_lat2 = roi_lon_lat

    geo_pts_ref = [(roi_lon1, roi_lat1), (roi_lon1, roi_lat2), (roi_lon2, roi_lat2), (roi_lon2, roi_lat1)]
    return geo_pts_ref



def convert_to_shp(points, is_overwrite=False):
    if points.endswith('.kml'):
        new_points=points.replace('.kml','.shp')
        if not os.path.exists(new_points) or is_overwrite:
            srcDS = gdal.OpenEx(points)
            ds = gdal.VectorTranslate(new_points, srcDS, format='ESRI Shapefile')
            ds = None
            points = new_points
    return points 


In [None]:
ref_raster = '/scratch/andresro/leon_work/sparse/inference/palmsarawak_simpleA20/T49MCV_preds_reg.tif'
# ref_raster = '/scratch/andresro/leon_work/sparse/inference/palmcoco_kalimA_simpleA5/T49MCV_preds_reg.tif'
# save_dir = '/scratch/andresro/leon_work/sparse/inference/palmcoco_kalimA_simpleA5'
save_dir = '/scratch/andresro/leon_work/sparse/inference/palmsarawak_simpleA20/'
ds = gdal.Open(ref_raster)

geo_pts_ref = gp.get_lonlat(ds)
lon1_, lat1_, lon2_, lat2_ = min([x[0] for x in geo_pts_ref]), min([x[1] for x in geo_pts_ref]),max([x[0] for x in geo_pts_ref]), max([x[1] for x in geo_pts_ref])

roi_ = split_roi_to_rois(lon1_, lat1_, lon2_, lat2_,2000)
print(len(roi_))

In [None]:
fname = os.path.basename(ref_raster).replace('.tif','')
kmlfile_name = f"{save_dir}/{fname}_rois_{len(roi_)}.kml"
kml = simplekml.Kml()
for roi in roi_:
    lat1, lon1, lat2, lon2 = roi["roi"]
    # print roi

    geo_pts_ref = to_bbox([lon1, lat1, lon2, lat2])
    pol = kml.newpolygon(name=roi['name'])
    pol.outerboundaryis = geo_pts_ref
    pol.style.polystyle.color = simplekml.Color.changealphaint(100, simplekml.Color.white)

kml.save(kmlfile_name)
print(kmlfile_name)

In [None]:
points = convert_to_shp(kmlfile_name, is_overwrite=True)
print(points)
# drop_all_butName(points)

In [None]:
loop_zonal_stats_update(input_zone_polygon=points,input_value_raster=ref_raster,fieldname='pred_palm',fn=np.nansum, is_update=True, refband=1,bias=1.5)

In [None]:
points

In [None]:
kmlfile_name