In [None]:
from osgeo import gdal, ogr
import glob
import numpy as np

In [None]:
import sys, os
sys.path.insert(0, os.path.dirname(os.getcwd()))
sys.path.insert(0, os.getcwd())


import utils.gdal_processing as gp


### Reproject map to grid without overlap

In [None]:
grid_file = '/home/pf/pfstaff/projects/andresro/barry_palm/palmoilmaps/descals2020/grid/grid_withOP.shp'

In [None]:
'gdal_retile.py -targetDir /scratch/andresro/leon_work/sparse/inference/palm4_act_simpleA9_soft_ens5/tiles '\
                '-of vrt -ps 10 10 -overlap 0 - levels 1 -v -r bilinear'

In [None]:
'gdal_retile.py -targetDir tiles -ps 10000 10000 -overlap 0 -levels 1 -v -r bilinear T51MTP*.tif T51MTQ*.tif T51MUP*.tif T51MUQ*.tif -tileIndex grid_10k.shp'

In [None]:
'gdal_retile.py -targetDir tiles -ps 10000 10000 -overlap 0 -levels 1 -v -tileIndex grid_10k.shp -pyramidOnly -r bilinear T*.tif'

## Create Grid

In [None]:
# folder_inference='/scratch/andresro/leon_work/sparse/inference/palm4_act_simpleA9_soft_ens5'
folder_inference='/scratch/andresro/leon_work/sparse/inference/palm2019_simpleA9_soft_ens5'

In [None]:
import math
EARTH_RADIUS = 6371000  # Radius in meters of Earth


# Compute the shortest path curved distance between 2 points (lat1,lon1) and (lat2,lon2) using the Haversine formula.
def haversine_distance(lon1, lat1, lon2, lat2):

    a = math.sin(math.radians((lat2 - lat1) / 2.0)) ** 2 + math.cos(math.radians(lat1)) * math.cos(
        math.radians(lat2)) * math.sin(math.radians((lon2 - lon1) / 2.0)) ** 2
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    return EARTH_RADIUS * c

def split_roi_to_rois(lon1_, lat1_, lon2_, lat2_, meters_split = 1500):

    lon1, lat1, lon2, lat2 = min(lon1_, lon2_), min(lat1_,lat2_), max(lon1_, lon2_), max(lat1_, lat2_)

    delta_lon_m = haversine_distance(lon1=lon1,lat1=lat1,lon2=lon2,lat2=lat1)
    delta_lat_m = haversine_distance(lon1=lon1,lat1=lat1,lon2=lon1,lat2=lat2)
    rois = []

    N_lon, N_lat = map(lambda x: int(math.ceil(x / meters_split)), [delta_lon_m,delta_lat_m])

    delta_lon, delta_lat = (lon2-lon1, lat2 - lat1)
    for i in range(N_lat):
        for j in range(N_lon):
            ind = i * N_lon + j
            rois.append({"roi": (
                                lat1 + (delta_lat) * i / N_lat,
                                lon1 + (delta_lon) * j / N_lon,
                                lat1 + (delta_lat) * (i + 1) / N_lat,
                                lon1 + (delta_lon) * (j + 1) / N_lon),
                        "name": "{}".format(ind)})

    return rois

def to_bbox(roi_lon_lat):
    if isinstance(roi_lon_lat, str):
        roi_lon1, roi_lat1, roi_lon2, roi_lat2 = map(float, re.split(',', roi_lon_lat))
    else:
        roi_lon1, roi_lat1, roi_lon2, roi_lat2 = roi_lon_lat

    geo_pts_ref = [(roi_lon1, roi_lat1), (roi_lon1, roi_lat2), (roi_lon2, roi_lat2), (roi_lon2, roi_lat1)]
    return geo_pts_ref



def convert_to_shp(points):
    if points.endswith('.kml'):
        new_points=points.replace('.kml','.shp')
        srcDS = gdal.OpenEx(points)
        ds = gdal.VectorTranslate(new_points, srcDS, format='ESRI Shapefile')
        ds = None
        points = new_points
    else:
        print('not converted')
    return points 


In [None]:
list_tif = glob.glob(folder_inference+'/T*.tif')
list_names = [x.split('/')[-1] for x in list_tif]
list_names = {x.split('_')[0][1:] for x in list_names}

In [None]:
print(len(list_tif))

In [None]:
geo_pts = []
for tif_ in list_tif:
    ds = gdal.Open(tif_)
    geo_pts.extend(gp.get_lonlat(ds))
geo_pts = np.array(geo_pts)
lon1_,lon2_ = geo_pts[:,0].min(),geo_pts[:,0].max()
lat1_,lat2_ = geo_pts[:,1].min(),geo_pts[:,1].max()

In [None]:
print(lon1_,lon2_)

In [None]:
size = 100000
roi_ = split_roi_to_rois(lon1_, lat1_, lon2_, lat2_,size)
print(len(roi_))

In [None]:
save_folder = f'{folder_inference}/tiles{size//1000}km'
if not os.path.isdir(save_folder+'/vrt'):
    os.makedirs(save_folder+'/vrt')


In [None]:
save_folder

In [None]:
import simplekml
# kmlfile_name = '/scratch2/Dropbox/Dropbox/temp/grid_temp.kml'
kmlfile_name = f'{save_folder}/grid.kml'
kml = simplekml.Kml()
for roi in roi_:
    lat1, lon1, lat2, lon2 = roi["roi"]
    # print roi

    geo_pts_ref = to_bbox([lon1, lat1, lon2, lat2])
    geo_pts_ref.append(geo_pts_ref[0])
    pol = kml.newpolygon(name=roi['name'])
    # pol.innerboundaryis = geo_pts_ref
    pol.addfileouterboundaryis = geo_pts_ref
    pol.outerboundaryis = geo_pts_ref

    # pol.style.polystyle.color = simplekml.Color.changealphaint(100, simplekml.Color.white)

kml.save(kmlfile_name)
grid_file = convert_to_shp(kmlfile_name)
# print(kmlfile_name)

### Find the matching s2 tiles for each feature in the grid

In [None]:
grid_file = '/scratch/andresro/leon_work/sparse/inference/palm4_act_simpleA9_soft_ens5/tiles100km/grid.shp'

In [None]:
sentinel_2tiles = '/home/pf/pfstaff/projects/andresro/barry_palm/data/labels/sentinel2_tiles_world/sentinel2_tiles_world.shp'

In [None]:
vector_s2 = ogr.Open(sentinel_2tiles)
layer_s2 = vector_s2.GetLayer()


feat_list = []
for i in range(layer_s2.GetFeatureCount()):
    feature_s2 = layer_s2.GetFeature(i)
    if feature_s2.GetField('NAME') in list_names:
        feat_list.append(i)


def get_s2_matches(ref_geom):
    out_tiles = []
    for i_ in feat_list:
        feature_s2 = layer_s2.GetFeature(i_)
        geom_s2 = feature_s2.GetGeometryRef()
        if ref_geom.Intersects(geom_s2):
            out_tiles.append(feature_s2.GetField('NAME'))
    return out_tiles
    

In [None]:
len(feat_list)

In [None]:
# grid_vector = ogr.Open(kmlfile_name)
grid_vector = ogr.Open(grid_file)
grid_layer = grid_vector.GetLayer()
grid_layer.GetFeatureCount()

In [None]:
dict_features = dict()
for i in range(grid_layer.GetFeatureCount()):
    feature = grid_layer.GetFeature(i)
    vectorGeometry = feature.GetGeometryRef()
    tiles_area = get_s2_matches(vectorGeometry)
    if len(tiles_area) > 0:
        dict_features[i] = tiles_area
print(len(dict_features.keys()))

In [None]:
import os

def rename_(x, suffix= '_nan'):
    file = os.path.basename(x)
    file =  file.replace('.tif',f'{suffix}.vrt')
    return os.path.join(save_folder,'vrt',file)


In [None]:
from tqdm.notebook import tqdm

In [None]:
os.chdir(folder_inference)

In [None]:
os.getcwd()

In [None]:

# i = 310
# if True:
#     tiles = dict_features[i]
for i, tiles in tqdm(dict_features.items()):
    
    tiles = dict_features[i]
    names = [f'{folder_inference}/T{x}_5_preds_reg_12_12.tif' for x in tiles]
        
    feature = grid_layer.GetFeature(i)
    vectorGeometry = feature.GetGeometryRef()

    minX, maxX, minY, maxY = vectorGeometry.GetEnvelope()
    

    ref_proj = 'EPSG:4326'
    # Other projections do not work if the ref shp file is not in the same proj
#     ref_proj = 'EPSG:32751'
#     ds_ref = gdal.Open(names[0])
#     ref_proj = ds_ref.GetProjectionRef()

    warp_opts = gdal.WarpOptions(
        format="VRT",
        srcNodata=99,
        dstSRS=ref_proj,
        dstNodata='nan',
        outputBounds=[minX, minY, maxX, maxY],cutlineLayer=grid_layer, cropToCutline=True
        )
    
    new_names = [rename_(x,suffix=f'_tile{i}_{ref_proj}'.replace(':','_')) for x in names]
    ds_warped = [gdal.Warp(x1, x, options=warp_opts) for (x,x1) in zip(names,new_names)]
    assert ds_warped[0] is not None
    
    my_vrt = gdal.BuildVRT(f'{save_folder}/{i}.vrt',new_names)
    my_vrt = None
    # print(f'{folder_inference}/tiles/{i}.vrt','saved!')

    


In [None]:
new_names