In [None]:
%matplotlib inline
from osgeo import gdal, ogr, osr
import os, sys
import glob
import simplekml
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import pandas as pd

sys.path.insert(0, os.path.dirname(os.getcwd()))
sys.path.insert(0, os.getcwd())


def no_output(func):
    def wrapper(*args, **kwargs):
        sysout = sys.stdout
        sys.stdout = open(os.devnull, "w")
        func(*args, **kwargs)
        sys.stdout = sysout
    return wrapper

from utils.plots import plot_heatmap
import utils.gdal_processing as gp
from utils.gdal_processing import get_positive_area_folder, to_xy_box, rasterize_points_pos_neg_folder

from utils.read_geoTiff import readHR
# readHR = no_output(readHR)
# get_positive_area_folder = no_output(get_positive_area_folder)
# to_xy_box = no_output(to_xy_box)
# rasterize_points_pos_neg_folder = no_output(rasterize_points_pos_neg_folder)

            
p2ha = lambda x: (x*10)**2 /100**2

In [None]:
def zonal_stats_old(FID, input_zone_polygon, input_value_raster, fn, is_return_numpoints = False, refband=1):

    # Open data
    raster = gdal.Open(input_value_raster)
    shp = ogr.Open(input_zone_polygon)
    lyr = shp.GetLayer()

    # Get raster georeference info
    transform = raster.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    # Reproject vector geometry to same projection as raster
    sourceSR = lyr.GetSpatialRef()
    targetSR = osr.SpatialReference()
    targetSR.ImportFromWkt(raster.GetProjectionRef())
    coordTrans = osr.CoordinateTransformation(sourceSR,targetSR)
    feat = lyr.GetFeature(FID)
    geom = feat.GetGeometryRef()
    geom.Transform(coordTrans)

    # Get extent of feat
    geom = feat.GetGeometryRef()

    if geom.GetGeometryName() == 'MULTIPOLYGON' :
        count = 0
        pointsX = []; pointsY = []
        for polygon in geom:
            geomInner = geom.GetGeometryRef(count)
            ring = geomInner.GetGeometryRef(0)
            numpoints = ring.GetPointCount()
            for p in range(numpoints):
                    lon, lat, z = ring.GetPoint(p)
                    pointsX.append(lon)
                    pointsY.append(lat)
            count += 1
    elif geom.GetGeometryName() == 'POLYGON':
        ring = geom.GetGeometryRef(0)
        numpoints = ring.GetPointCount()
        pointsX = []; pointsY = []
        for p in range(numpoints):
                lon, lat, z = ring.GetPoint(p)
                pointsX.append(lon)
                pointsY.append(lat)
    elif (geom.GetGeometryName() == 'LINESTRING'):
        numpoints = geom.GetPointCount()
        pointsX = []
        pointsY = []
        for p in range(numpoints):
            lon, lat, z = geom.GetPoint(p)
            pointsX.append(lon)
            pointsY.append(lat)
    else:
        sys.exit("ERROR: Geometry needs to be either Polygon or Multipolygon")

    xmin = min(pointsX)
    xmax = max(pointsX)
    ymin = min(pointsY)
    ymax = max(pointsY)

    # Specify offset and rows and columns to read
    xoff = int((xmin - xOrigin)/pixelWidth)
    yoff = int((yOrigin - ymax)/pixelWidth)
    if xoff < 0 or yoff < 0:
        return np.nan
    xcount = int((xmax - xmin)/pixelWidth)+1
    ycount = int((ymax - ymin)/pixelWidth)+1

    if is_return_numpoints:
        # TODO check that all the points are inside the region of interest
        return geom.GetPointCount()

    # Create memory target raster
    target_ds = gdal.GetDriverByName('MEM').Create('', xcount, ycount, 1, gdal.GDT_Byte)
    target_ds.SetGeoTransform((
        xmin, pixelWidth, 0,
        ymax, 0, pixelHeight,
    ))

    # Create for target raster the same projection as for the value raster
    raster_srs = osr.SpatialReference()
    raster_srs.ImportFromWkt(raster.GetProjectionRef())
    target_ds.SetProjection(raster_srs.ExportToWkt())

    # Rasterize zone polygon to raster
    gdal.RasterizeLayer(target_ds, [1], lyr, burn_values=[1])

    # Read raster as arrays
    banddataraster = raster.GetRasterBand(refband)
    try:
        dataraster = banddataraster.ReadAsArray(xoff, yoff, xcount, ycount).astype(np.float)
    except AttributeError:
        return np.nan
    bandmask = target_ds.GetRasterBand(1)
    datamask = bandmask.ReadAsArray(0, 0, xcount, ycount).astype(np.float)
    print(datamask.mean())
    clip = True
    if clip:
        dataraster = np.clip(dataraster,0.01,1e9)
    if not np.any(datamask):
        print('datamask empty')
        return np.nan
    # Mask zone of raster
#     zoneraster = np.ma.masked_array(dataraster,  np.logical_not(datamask))
    dataraster[np.logical_not(datamask)] = np.nan

    # Calculate statistics of zonal raster
    # return numpy.average(zoneraster),numpy.mean(zoneraster),numpy.median(zoneraster),numpy.std(zoneraster),numpy.var(zoneraster)
    try:
        return fn(dataraster)
    except ValueError:
        print('fix')
        return np.nan
    
def loop_zonal_stats_update_old(input_zone_polygon, input_value_raster, fieldname, fn, is_update=True, refband=1, is_pos_only=False):

    shp = ogr.Open(input_zone_polygon, update=1)
    lyr = shp.GetLayer()
    lyrdf =lyr.GetLayerDefn()

    # TreeFieldName = 'TreePredAd1'
    if is_update:
        id_ = lyrdf.GetFieldIndex(fieldname)
        if id_ == -1:
            field_defn = ogr.FieldDefn(fieldname, ogr.OFTReal)
            lyr.CreateField(field_defn)
            id_ = lyrdf.GetFieldIndex(fieldname)
        else:
            print('Field {} already exists, may overwrite'.format(fieldname))
    outVals = []
    id_Name = lyrdf.GetFieldIndex('Name')
    for FID in range(lyr.GetFeatureCount()):
        feat = lyr.GetFeature(FID)
        if feat is not None:
            # compute sum
            name_ = feat.GetField(id_Name)
            if 'pos' in name_ or not is_pos_only:
                meanValue = zonal_stats(FID, input_zone_polygon, input_value_raster, fn, refband=refband)
                print(f' {meanValue:.2f} Trees in {name_}')

            else:
                meanValue = zonal_stats(FID, input_zone_polygon, input_value_raster, fn, is_return_numpoints=True, refband=refband)
                print(f' {meanValue:.2f} Ref points in {name_}')
            outVals.append(meanValue)
            if np.isnan(meanValue):
                print(meanValue,FID)
            if is_update:
                lyr.SetFeature(feat)
                feat.SetField(id_,meanValue)
                lyr.SetFeature(feat)
    return np.sum(outVals)


In [None]:
def loop_zonal_stats_update(input_zone_polygon, input_value_raster, fieldname, fn, is_update=True, refband=1, is_pos_only=False,bias=1, field_name = 'Name'):

    shp = ogr.Open(input_zone_polygon, update=1)
    lyr = shp.GetLayer()
    lyrdf =lyr.GetLayerDefn()

    
    id_ = lyrdf.GetFieldIndex(fieldname)
    if id_ == -1:
        field_defn = ogr.FieldDefn(fieldname, ogr.OFTReal)
        lyr.CreateField(field_defn)
        id_ = lyrdf.GetFieldIndex(fieldname)
    else:
        print('Field {} already exists, may overwrite'.format(fieldname))
    outVals = []
    id_Name = lyrdf.GetFieldIndex(field_name)
    for FID in range(lyr.GetFeatureCount()):
        feat = lyr.GetFeature(FID)
        if feat is not None:
            # compute sum
            name_ = feat.GetField(id_Name)
            meanValue = zonal_stats(FID, input_zone_polygon, input_value_raster, fn, refband=refband,bias=bias)
#             print(f' {meanValue:.2f} Trees in {name_}')
            outVals.append(meanValue)
#             if np.isnan(meanValue):
#                 print(name_,FID,'is all nan')
            lyr.SetFeature(feat)
            feat.SetField(id_,meanValue)
            lyr.SetFeature(feat)
    return np.sum(outVals)

def zonal_stats(FID, input_zone_polygon, input_value_raster, fn, is_return_numpoints = False, refband=1, bias = 1.0):

    # Open data
    raster = gdal.Open(input_value_raster)
    shp = ogr.Open(input_zone_polygon)
    lyr = shp.GetLayer()

    # Get raster georeference info
    transform = raster.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    # Reproject vector geometry to same projection as raster
    sourceSR = lyr.GetSpatialRef()
    targetSR = osr.SpatialReference()
    targetSR.ImportFromWkt(raster.GetProjectionRef())
    coordTrans = osr.CoordinateTransformation(sourceSR,targetSR)
    feat = lyr.GetFeature(FID)
    geom = feat.GetGeometryRef()
    geom.Transform(coordTrans)

    # Get extent of feat
    geom = feat.GetGeometryRef()
    if (geom.GetGeometryName() == 'MULTIPOLYGON'):
        count = 0
        pointsX = []; pointsY = []
        for polygon in geom:
            geomInner = geom.GetGeometryRef(count)
            ring = geomInner.GetGeometryRef(0)
            numpoints = ring.GetPointCount()
            for p in range(numpoints):
                    lon, lat, z = ring.GetPoint(p)
                    pointsX.append(lon)
                    pointsY.append(lat)
            count += 1
    elif geom.GetGeometryName() == 'POLYGON':
        ring = geom.GetGeometryRef(0)
        numpoints = ring.GetPointCount()
        pointsX = []; pointsY = []
        for p in range(numpoints):
                lon, lat, z = ring.GetPoint(p)
                pointsX.append(lon)
                pointsY.append(lat)
    else:
        sys.exit("ERROR: Geometry needs to be a Polygon")
    xmin = min(pointsX)
    xmax = max(pointsX)
    ymin = min(pointsY)
    ymax = max(pointsY)

    # Specify offset and rows and columns to read
    xoff = int((xmin - xOrigin)/pixelWidth)
    yoff = int((yOrigin - ymax)/pixelWidth)
   
    xcount = int((xmax - xmin)/pixelWidth)+1
    ycount = int((ymax - ymin)/pixelWidth)+1


    xoff = min(xoff,raster.RasterXSize -1)
    xoff = max(xoff,1)
    
    xcount = min(xcount,raster.RasterXSize -1 - xoff)
    ycount = min(ycount,raster.RasterYSize -1 - yoff)
      

    # Create memory target raster
    target_ds = gdal.GetDriverByName('MEM').Create('', xcount, ycount, 1, gdal.GDT_Byte)
    target_ds.SetGeoTransform((
        xmin, pixelWidth, 0,
        ymax, 0, pixelHeight,
    ))

    # Create for target raster the same projection as for the value raster
    raster_srs = osr.SpatialReference()
    raster_srs.ImportFromWkt(raster.GetProjectionRef())
    target_ds.SetProjection(raster_srs.ExportToWkt())

    # Rasterize zone polygon to raster
    gdal.RasterizeLayer(target_ds, [1], lyr, burn_values=[1])

    # Read raster as arrays
    banddataraster = raster.GetRasterBand(refband)
    try:
        dataraster = banddataraster.ReadAsArray(xoff, yoff, xcount, ycount).astype(np.float)
    except AttributeError:
        print('dataraster wrong')
#         print('geotransform',transform)
        print(xoff,yoff,xcount,ycount)
        print(raster.RasterXSize,raster.RasterYSize, 'xmax,ymax:',xoff+xcount,yoff+xcount)
        return np.nan
    
    bandmask = target_ds.GetRasterBand(1)
    datamask = bandmask.ReadAsArray(0, 0, xcount, ycount).astype(np.float)
#     print(datamask.mean())
    clip = True
    if clip:
#         dataraster = np.clip(dataraster,0.01,1e9)
        dataraster[dataraster < 0.01] = np.nan
    dataraster[dataraster == 99] = np.nan
    
    if not np.any(datamask):
        print('datamask empty')
        return np.nan
    # Mask zone of raster
#     zoneraster = np.ma.masked_array(dataraster,  np.logical_not(datamask))
    dataraster[np.logical_not(datamask)] = np.nan
    dataraster *=bias
    # Calculate statistics of zonal raster
    # return numpy.average(zoneraster),numpy.mean(zoneraster),numpy.median(zoneraster),numpy.std(zoneraster),numpy.var(zoneraster)
    try:
        return fn(dataraster)
    except ValueError:
        print('fix')
        return np.nan
    

In [None]:
# obj='palm'

object_dict= {'palm':0,'coco':1}

# ref_band = object_dict[obj]

# points ='/home/pf/pfstud/andresro/tree_annotationsAug2019/annotations/Jan/palm/49MCV/Palm_Jan_1.kml'


In [None]:
p2ha = lambda x: (x/10)**2
def plot_preds(scale,raster,preds, tile, group,density = (0,3), std = None):
    dens_min,dens_max = density
    mask_out = np.isnan(raster) | np.isnan(preds)
#         gt_count = raster[~mask_out].sum()
    preds1 = preds.copy()
    preds1[mask_out] = np.nan
    raster1 = raster.copy()
    raster1[mask_out] = np.nan
    if std is not None:
        std1 = std.copy()
        std1[mask_out] = np.nan

    print(p2ha(scale))
    r1 = gp.block_reduce(raster1,(scale,scale),np.nansum)
    p1 = gp.block_reduce(preds1,(scale,scale),np.nansum)
    if std is not None:
        s1 = gp.block_reduce(std1,(scale,scale),np.nanmean)
    diff = p1  - r1
    diff1 = (p1 - r1 )/ r1
    diff1[np.logical_and(r1 == 0, p1 == 0) ] = 0
    diff1[np.logical_and(r1 == 0, p1 > 0) ] = np.nan

    fig = plt.figure(figsize=(10,5))
    n_col = 3 if std is None else 4    
    gs = gridspec.GridSpec(nrows=1,ncols=n_col,left=0.02, bottom=0.06, right=0.95, top=0.94, wspace=0.1)

    txt = f' {tile} {group} {p2ha(scale)}ha'
    
    # GT
    ax = plt.subplot(gs[0])
    im = ax.imshow(raster,vmin=dens_min,vmax=dens_max)
    
    cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Trees GT {np.nansum(raster1):.2f} ({np.nanmax(raster):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds1.shape[0]/100:.2f}km')
    ax.set_ylabel(f' {preds1.shape[1]/100:.2f}km')

    # PREDS
    ax = plt.subplot(gs[1])
    im = ax.imshow(preds,vmin=dens_min,vmax=dens_max)
    cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
    ax.set_title(f'Trees Pred {np.nansum(preds1):.2f} ({np.nanmax(preds):.2f})')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel(f' {preds1.shape[0]/100:.2f}km \n\n'+txt)
    ax.set_ylabel(f' {preds1.shape[1]/100:.2f}km')

    # DIFSS
    ax = plt.subplot(gs[2])
    lim_ = dens_max * ((scale/2)**2)
    im = ax.imshow(diff,cmap = 'bwr', vmin = -lim_,vmax=lim_ ) #,vmin=-dens_max*scale*3,vmax=dens_max*scale*3)
    cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
    cbar.ax.set_xlabel(f'Trees/ {p2ha(scale)}ha blocks')
    
    
    
    ax.set_title(f'Error per {p2ha(scale)}ha blocks')
    ax.set_xticks([])
    ax.set_yticks([])

    

        # STD
    if std is not None:
        ax = plt.subplot(gs[3])
        im = ax.imshow(std ,vmin=0,vmax=1, cmap='inferno')
        cbar = fig.colorbar(im,ax=ax,orientation='horizontal')
        cbar.ax.set_xlabel(f'Trees/ {p2ha(1):.2f}ha blocks')
        ax.set_title(f'STD  mean (max) {np.nanmean(std):.2f} ({np.nanmax(std):.2f})')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel(f' {preds1.shape[0]/100:.2f}km \n\n'+txt)
        ax.set_ylabel(f' {preds1.shape[1]/100:.2f}km')

    # Scatter
    zeros_ = np.logical_and(p1== 0,r1==0).ravel()
    ds = pd.DataFrame({'GT':r1.ravel(),'Pred':p1.ravel()})
    if std is not None:
        ds['std'] = s1.ravel()
    return ds

#@no_output
def plot_counts(folder_inference, tile, folder_annotations, group='group1', preds_axis=0, sq_kernel=2, scale=10):
    ds_out = pd.DataFrame(columns=['GT','Pred'])
    
    if not isinstance(folder_annotations,list):
        ref_folders = glob.glob(f'{folder_annotations}/{tile}/{group}')
        if not ref_folders:
            print(f'no folders in {folder_annotations}/{tile}/{group}')
            return ds_out
    else:
        ref_folders = [x for x in folder_annotations if tile in x]

    is_aut_gt = False
    for ref_folder in ref_folders:
        group = ref_folder.split('/')[-1]
        print(ref_folder, group)
        ref_raster = glob.glob(f'{folder_inference}/{tile}*_preds_reg*.tif')
        if len(ref_raster) == 0:
            print(f' no files found in {folder_inference}/{tile}*_preds_reg*.tif skipping...')
        else:
            ref_raster = ref_raster[0]


            ds = gdal.Open(ref_raster)
            roi_ = get_positive_area_folder(ref_folder)
            lims = to_xy_box(roi_, ds, enlarge=10)

            raster = rasterize_points_pos_neg_folder(folder=ref_folder,refDataset=ref_raster,lims=lims,lims_with_labels=lims,sq_kernel=sq_kernel)
            raster[raster == -1] = np.nan

            preds = readHR(data_file=ref_raster,roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False)
            
            if preds.shape[-1] == 2:
                preds = preds[...,preds_axis]
            preds[preds <0] = 0
            preds[preds==99] = np.nan
        #     preds = np.clip(preds*1.3,0,2.5)

            ref_raster_sem = ref_raster.replace('reg.tif','semA.tif')

            if os.path.isfile(ref_raster_sem):
                preds_sem = readHR(data_file=ref_raster.replace('reg.tif','semA.tif'),roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False)
            else:
                preds_sem = preds > 0.5
                
            ref_raster_std = glob.glob(f'{folder_inference}/std/{tile}*preds_reg_0_std_0.tif')
            
            if len(ref_raster_std) > 0:
                ref_raster_std = ref_raster_std[0]
                print('adding std')
                preds_std = readHR(data_file=ref_raster_std,roi_lon_lat=roi_,scale=10, as_float=False, is_assert_blank=False)
            else:
                preds_std = None
                print(f'{folder_inference}/std/{tile}*preds_reg_0_std_0.tif')
             
            ds = plot_preds(scale,raster,preds, tile, group, std=preds_std)
            ds_out = ds_out.append(ds)

        return ds_out
        


## Palm 4748 - Dropout

In [None]:
#folder_inference = '/home/pf/pfstaff/projects/andresro/sparse/inference_leon/borneo_simpleA9_mc10'
#folder_inference = '/scratch/andresro/leon_work/sparse/inference/cocopreactive_simpleA9_soft_mc5'
folder_inference = '/scratch/andresro/leon_work/sparse/inference/palm4748a_simpleA9_soft_mc5'
folder_annotations = [
   '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48MTD/palm_group2_Bischel',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T47NNC/palm_group3_Brunner',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48MVV/palm_group1_julia',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T47NLE/palm_group2_hanlon',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T47NNA/palm_group3_Julia',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48NTJ/palm_group1',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T47NQE/palm_group3',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48MVB/palm_group2_Bischel',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48NUG/palm_group2',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T46NGL/palm_group3_janmathias'
]

scale=20
ds_out = pd.DataFrame(columns=['GT','Pred'])

tilenames = 'T48MVV T48NUG T47NNC T47NNA T47NQE T46NGL T47NLE T48MTD T48NTJ T48MVB'.split(' ')
#tilenames = ['T48MTD']
for tile in tilenames:
        
    ds_ = plot_counts(folder_inference=folder_inference,
                tile=tile,
                folder_annotations=folder_annotations,
                group=None, scale=scale)
    if ds_ is not None:
        ds_out = ds_out.append(ds_,ignore_index=True)
    else:
        print('error in tile',tile)

ds_out['resid'] = np.abs(ds_out['GT'] - ds_out['Pred'])
ds_drop = ds_out

In [None]:
ds_out

In [None]:
zeros_ = np.logical_and(ds_drop.GT== 0,ds_drop.Pred==0)
ds = ds_drop[~zeros_].copy()

# ds['Pred'] = ds.Pred*1.2

g = sns.jointplot(x='Pred',y='GT',data=ds, cmap="Reds",
#                   kind="hex",
                 ) #, clip=(dens_min,dens_max))

lims = (0,np.nanmax(ds.GT))
g.ax_marg_x.set_xlim(lims)
g.ax_marg_y.set_ylim(lims)

# lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(lims, lims, ':k')    
plt.title(f' MAE {np.nanmean(np.abs(ds.Pred -ds.GT))/p2ha(scale):.2f} Trees/ha in {p2ha(scale)}ha Blocks  \n ' \
          f'total trees GT:{np.nansum(ds.GT):.2f} Pred:{np.nansum(ds.Pred):.2f} ({100*(np.nansum(ds.Pred)-np.nansum(ds.GT))/np.nansum(ds.GT):.2f}%) \n',x=-0.1,y=0.5, fontsize = 12)


In [None]:
#bins = np.linspace(0,ds['resid'].max(), num=50)
#bins = np.percentile(ds['resid'],np.arange(0,100))
bins = np.percentile(ds['resid'],np.linspace(5,100,100))

In [None]:
avg_std = []
for b_ in bins:
    a = ds['std'][ds['resid']<=b_].mean()
    avg_std.append(a)
avg_std = np.array(avg_std)

In [None]:
plt.plot(bins/bins.max(), avg_std,'o-')

plt.ylabel('std')
plt.xlabel('recall based on resid')

In [None]:
#bins = np.linspace(0,ds['resid'].max(), num=50)
#bins = np.percentile(ds['resid'],np.arange(0,100))
bins = np.percentile(ds['std'],np.linspace(5,100,50))

In [None]:
avg_resid = []
for b_ in bins:
    a = ds['resid'][ds['std']<=b_].mean()
    avg_resid.append(a)
avg_resid = np.array(avg_resid)

In [None]:
recall_drop = bins/bins.max()
avg_resid_drop = avg_resid

In [None]:
plt.plot(bins/bins.max(), avg_resid,'o-')
# plt.plot(bins, avg_resid,'o-')

plt.ylabel('$|y - \hat{y}|$')
plt.xlabel('recall based on std')

In [None]:
ds['resid'].plot.hist(bins=50)

In [None]:
ds['std'].plot.hist(bins=50)


In [None]:
g = sns.jointplot('resid',y='std',data=ds, cmap="Reds",
           kind="hex", 
                  joint_kws={
                      'gridsize':30,
                  #           'bins':10
                  },
                 ) #, clip=(dens_min,dens_max))

## Palm 4748 - Ensemble

In [None]:
#folder_inference = '/home/pf/pfstaff/projects/andresro/sparse/inference_leon/borneo_simpleA9_mc10'
#folder_inference = '/scratch/andresro/leon_work/sparse/inference/cocopreactive_simpleA9_soft_mc5'
folder_inference = '/scratch/andresro/leon_work/sparse/inference/palm4748a_simpleA9_soft_ens5/'
folder_annotations = [
   '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48MTD/palm_group2_Bischel',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T47NNC/palm_group3_Brunner',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48MVV/palm_group1_julia',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T47NLE/palm_group2_hanlon',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T47NNA/palm_group3_Julia',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48NTJ/palm_group1',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T47NQE/palm_group3',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48MVB/palm_group2_Bischel',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T48NUG/palm_group2',
       '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/manual_annotations/T46NGL/palm_group3_janmathias'
]

scale=20
ds_out = pd.DataFrame(columns=['GT','Pred'])

tilenames = 'T48MVV T48NUG T47NNC T47NNA T47NQE T46NGL T47NLE T48MTD T48NTJ T48MVB'.split(' ')
#tilenames = ['T48MTD']
for tile in tilenames:
        
    ds_ = plot_counts(folder_inference=folder_inference,
                tile=tile,
                folder_annotations=folder_annotations,
                group=None, scale=scale)
    if ds_ is not None:
        ds_out = ds_out.append(ds_,ignore_index=True)
    else:
        print('error in tile',tile)
        
ds_out['resid'] = np.abs(ds_out['GT'] - ds_out['Pred'])
ds_ens = ds_out

In [None]:
zeros_ = np.logical_or(ds_ens.GT== 0,ds_ens.Pred==0)
ds = ds_ens[~zeros_].copy()

ds['Pred'] = ds.Pred

g = sns.jointplot(x='Pred',y='GT',data=ds, cmap="Reds",
#                   kind="hex",
                 ) #, clip=(dens_min,dens_max))

lims = (0,np.nanmax(ds.GT))
g.ax_marg_x.set_xlim(lims)
g.ax_marg_y.set_ylim(lims)

# lims = [max(x0, y0), min(x1, y1)]
g.ax_joint.plot(lims, lims, ':k')    
plt.title(f' MAE {np.nanmean(np.abs(ds.Pred -ds.GT))/p2ha(scale):.2f} Trees/ha in {p2ha(scale)}ha Blocks  \n ' \
          f'total trees GT:{np.nansum(ds.GT):.2f} Pred:{np.nansum(ds.Pred):.2f} ({100*(np.nansum(ds.Pred)-np.nansum(ds.GT))/np.nansum(ds.GT):.2f}%) \n',x=-0.1,y=0.5, fontsize = 12)


In [None]:

#bins = np.linspace(0,ds['resid'].max(), num=50)
#bins = np.percentile(ds['resid'],np.arange(0,100))
bins = np.percentile(ds['resid'],np.linspace(5,100,100))

In [None]:
avg_std = []
for b_ in bins:
    a = ds['std'][ds['resid']<=b_].mean()
    avg_std.append(a)
avg_std = np.array(avg_std)

In [None]:
plt.plot(bins/bins.max(), avg_std,'o-')

plt.ylabel('std')
plt.xlabel('recall based on resid')

In [None]:
#bins = np.linspace(0,ds['resid'].max(), num=50)
#bins = np.percentile(ds['resid'],np.arange(0,100))
bins = np.percentile(ds['std'],np.linspace(5,100,50))

In [None]:
avg_resid = []
for b_ in bins:
    a = ds['resid'][ds['std']<=b_].mean()
    avg_resid.append(a)
avg_resid = np.array(avg_resid)

In [None]:
recall_ens = bins/bins.max()
avg_resid_ens = avg_resid

In [None]:
plt.plot(bins/bins.max(), avg_resid,'o-')
# plt.plot(bins, avg_resid,'o-')

plt.ylabel('$|y - \hat{y}|$')
plt.xlabel('recall based on std')

In [None]:
ds['resid'].plot.hist(bins=50)

In [None]:
ds['std'].plot.hist(bins=50)


In [None]:

g = sns.jointplot('resid',y='std',data=ds, cmap="Reds",
#                   kind="hex",
                 ) #, clip=(dens_min,dens_max))

## Comparison

In [None]:
save_path = '/scratch2/Dropbox/Dropbox/Apps/Overleaf/activelearning_remotesensing/figures/'

In [None]:
fig = plt.figure(1)
plt.title('Uncertainty Calibration')
plt.plot(recall_drop, avg_resid_drop,'o-', label='MC-dropout')
plt.plot(recall_ens, avg_resid_ens,'o-', label='Ensemble')

plt.ylabel('$|y - \hat{y}|$')
plt.xlabel('recall')

plt.legend()

# fig.savefig(save_path+'uncertainty_calibration.pdf',) # bbox_extra_artists=(lgd,), bbox_inches='tight')


In [None]:
recall = np.linspace(0.05,1,50)

fig = plt.figure(1)
plt.title('Uncertainty Calibration')
plt.plot(recall, avg_resid_drop,'o-', label='MC-dropout')
plt.plot(recall, avg_resid_ens,'o-', label='Ensemble')

plt.ylabel('$|y - \hat{y}|$')
plt.xlabel('recall')

plt.legend()

#fig.savefig(save_path+'uncertainty_calibration.pdf',) # bbox_extra_artists=(lgd,), bbox_inches='tight')
