In [1]:
import datacube
import rasterio.merge
import os
import lccs_l3
import numpy
import xarray
import scipy

from datacube.storage import masking
from rio_cogeo.cogeo import cog_translate
from rio_cogeo.profiles import cog_profiles

from datacube.model import Measurement
from datacube.helpers import write_geotiff
from fc.fractional_cover import fractional_cover
from shapely.geometry import box

import osr
import gdal


In [2]:
def fishnet(geometry, threshold):
    bounds = geometry.bounds
    xmin = int(bounds[0] // threshold)
    xmax = int(bounds[2] // threshold)
    ymin = int(bounds[1] // threshold)
    ymax = int(bounds[3] // threshold)
    ncols = int(xmax - xmin + 1)
    nrows = int(ymax - ymin + 1)
    result = []
    for i in range(xmin, xmax+1):
        for j in range(ymin, ymax+1):
            b = box(i*threshold, j*threshold, (i+1)*threshold, (j+1)*threshold)
            g = geometry.intersection(b)
            if g.is_empty:
                continue
            result.append(g)
    return result

In [3]:
def create_cog(input, output, overview_resampling, bidx):
    cogeo_profile = 'deflate'
    nodata = -1
    overview_level = 6
    overview_resampling = overview_resampling
    threads = 8
    output_profile = cog_profiles.get(cogeo_profile)
    output_profile.update(dict(BIGTIFF=os.environ.get("BIGTIFF", "IF_SAFER")))
    block_size = min(
        int(output_profile["blockxsize"]), int(output_profile["blockysize"])
    )

    config = dict(
        NUM_THREADS=threads,
        GDAL_TIFF_INTERNAL_MASK=os.environ.get("GDAL_TIFF_INTERNAL_MASK", True),
        GDAL_TIFF_OVR_BLOCKSIZE=os.environ.get("GDAL_TIFF_OVR_BLOCKSIZE", block_size),
    )

    cog_translate(
        src_path=input,
        dst_path=output,
        dst_kwargs=output_profile,
        indexes=bidx,
        nodata=nodata,
        web_optimized=False,
        add_mask=False,
        overview_level=overview_level,
        overview_resampling=overview_resampling,
        config=config,
        quiet=False
    )

In [4]:
def combine_tiles(files,filename):
    output_filename = filename+'_TEMP'+'.tif'
    sources = [rasterio.open(path) for path in files]
    dest, out_transform = rasterio.merge.merge(sources)

    with rasterio.open(files[0]) as src:
        profile = src.profile
    
    profile['transform'] = out_transform
    profile['width'] = len(dest[0][0])
    profile['height'] = len(dest[0])
    
    with rasterio.open(output_filename, 'w', **profile) as dst:
        dst.write(dest.astype(rasterio.int16))
    
    [os.remove(path) for path in files]
    return output_filename

In [5]:
def array_to_geotiff(fname, data, transform, crs, nodata_val=0, dtype=gdal.GDT_Int16):
    
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(crs)
    prj_wkt = srs.ExportToWkt()

    
    # Set up driver
    driver = gdal.GetDriverByName('GTiff')

    # Create raster of given size and projection
    rows, cols = data.shape
    dataset = driver.Create(fname, cols, rows, 1, dtype)
    dataset.SetGeoTransform(transform)
    dataset.SetProjection(prj_wkt)

    # Write data to array and set nodata values
    band = dataset.GetRasterBand(1)
    band.WriteArray(data)
    band.SetNoDataValue(nodata_val)

    # Close file
    dataset = None
    
    return fname

In [6]:
def create_level3(dc, query):
    fcp = dc.load(product='ls_usgs_fcp_fiji', measurements= ['PV_PC_50', 'NPV_PC_50', 'BS_PC_50'], **query)
    
    if fcp:
        fcp = masking.mask_invalid_data(fcp).squeeze()
    else:
        return None, None
    
    wofs = dc.load(product='ls_usgs_wofs_fiji', measurements= ['count_clear'], **query)
    
    if wofs:
        wofs = masking.mask_invalid_data(wofs).squeeze()
    else:
        return None, None
    
    veg = ((fcp.PV_PC_50 >= 55) | (fcp.NPV_PC_50 >= 55)).where(fcp.PV_PC_50.notnull())
    vegetat_veg_cat_ds = veg.to_dataset(name="vegetat_veg_cat")
    
    # Load data from datacube
    s1 = dc.load(product="s1_gamma0_scene", **query)
    if s1:
        s1 = masking.mask_invalid_data(s1)
    else:
        return None, None

    s1 = (10**(s1/10))

    water = ((s1.vv <= 0.07) & (s1.vh <= 0.01))
    aquatic_wat_cat_ds = ((water.sum(dim='time') / water.count(dim="time")) >.2).to_dataset(name="aquatic_wat_cat")

    dummy = xarray.DataArray(numpy.zeros(water[0,:,:].shape), coords=[water[0,:,:].y.data, water[0,:,:].x.data], dims=['y', 'x'] )
    #cultman_agr_cat_ds = dummy.to_dataset(name="cultman_agr_cat")
    
    time1 = ("2018-04-02", "2018-07-31")
    time2 = ("2018-08-01", "2018-12-31")
    time3 = ("2019-01-01", "2019-03-30")
    s1_mean_1 = s1.sel(time=slice(time1[0], time1[1])).mean(dim='time')
    s1_mean_2 = s1.sel(time=slice(time2[0], time2[1])).mean(dim='time')
    s1_mean_3 = s1.sel(time=slice(time3[0], time3[1])).mean(dim='time')
    
    crop = ((s1_mean_2.vv < 0.18) & (s1_mean_1.vv > 0.05))
    cultman_agr_cat_ds = crop.to_dataset(name="cultman_agr_cat")


    urban = ((s1.vv.median(dim='time') > .5) | (s1.vh.median(dim='time') > .1))
    artific_urb_cat_ds = urban.to_dataset(name="artific_urb_cat")

    artwatr_wat_cat_ds = dummy.to_dataset(name="artwatr_wat_cat")
 
    variables_xarray_list = []
    variables_xarray_list.append(artwatr_wat_cat_ds)
    variables_xarray_list.append(aquatic_wat_cat_ds)
    variables_xarray_list.append(vegetat_veg_cat_ds)
    variables_xarray_list.append(cultman_agr_cat_ds)
    variables_xarray_list.append(artific_urb_cat_ds)

    classification_data = xarray.merge(variables_xarray_list)

    # Apply Level 3 classification using separate function. Works through in three stages
    level1, level2, level3 = lccs_l3.classify_lccs_level3(classification_data)
    level3_clean = numpy.where(wofs.count_clear > 0,level3, -1)

    easting = float(s1.x[0])
    W_E_pixelRes = float(s1.y[0] - s1.y[1])
    rotation = 0.0 #(if image is 'north up')
    northing = float(s1.y[0])
    rotation1 = 0.0 #(if image is 'north up')
    N_S_pixelRes = float(s1.x[0] - s1.x[1])

    transform = (easting, W_E_pixelRes, rotation, northing, rotation1, N_S_pixelRes)
    
    return level3_clean, transform

In [7]:
footprint = box(444915.0,7959375.0, 706155.0, 8121225.0)
fishnet_size = 40000
crs=32760
crs_string='EPSG:'+str(crs)
out_filename='lccs/lccs_fiji_'

query = {'time': ('2018-03-01', '2019-03-01')}
query['crs'] = crs_string
query['resolution'] = (-30, 30)
query['output_crs'] = crs_string
overview_resampling = 'nearest'

dc = datacube.Datacube(app='lccs')


In [8]:
tile_no = 0
tile_locations = []
bounds_list = fishnet(footprint,fishnet_size)

for bb in bounds_list:
    query['x'] = (bb.bounds[0],bb.bounds[2])
    query['y'] = (bb.bounds[1],bb.bounds[3])
    
    lccs, transform = create_level3(dc, query)
    if lccs is not None:
        tile_locations.append(array_to_geotiff(out_filename+str(tile_no)+'_TEMP.tif',lccs,transform,crs))
    tile_no = tile_no + 1
    del lccs
    del transform

uncogged_output_file = combine_tiles(tile_locations,out_filename)
target_filename = out_filename+'.tif'
#as we have created this bands separately, band index (bidx) is always 0
bidx = 0
create_cog(uncogged_output_file, target_filename, overview_resampling, bidx)
os.remove(uncogged_output_file)

  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  return np.nanmean(a, axis=axis, dtype=dtype)
  r = func(a, **kwargs)
  retu