# Merge Code

## Import Dependencies

In [None]:
import glob
import os
import traceback
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import csv
import warnings
import math

from contextlib import contextmanager
from rasterio.transform import Affine
from rasterio.coords import disjoint_bounds
from rasterio.enums import Resampling
from rasterio import windows

## Configuration variables
Setting Up Config variables such as Dataset, ...

In [None]:
PATH_DATASET = '../dataset'
PATH_IMAGES = 'images'
PATH_RESULT = 'results'
SEARCH_CRITERIA = '*.tif'

#CSV Header
HEADER = ['Coordinate X', 'Coordinate Y', 'Clorophill - a']

#Final file names
MERGE_OUTNAME = 'merge'

#How many rows to write in a CSV file
MAX_ROWS_PER_CSV = None

## Helper Functions

### Sum method

In [None]:
def copy_nansum(merged_data, new_data):
    """Calculates sum between 2 arrays without NaNs

    Args:
        merged_data : destiny array
        new_data : source array
    """
    np.copyto(merged_data, np.nansum(np.array([merged_data, new_data]), axis = 0), casting = "unsafe")

### Raster to CSV

In [None]:
def write_value_to_csv(band, lons, lats, writer, width, offset = 0):
  """Calculates Lat index, Lon index and writes band data based on offset param

  Args:
      band : Array 1D
      lons : Longitudes
      lats : Latitudes
      writer : CSV writer
      width : Matrix width
      offset (int, optional): Offest in 1D array. Defaults to 0.
  """
  for i, value in enumerate(band):
    x, y = (i + offset) // width, (i + offset) % width
    writer.writerow([lons[y], lats[x], value])

def band_to_csv(band, transform, filename, max_rows = None):
  """Writes a band into a CSV file
  If max_rows is not None, creates more than one CSV file.

  Args:
      band : Band 2D 
      transform : Georeference
      filename : Destiny filename
      max_rows : Max values to write in a CSV file. Defaults to None.
  """
  height, width = band.shape[0], band.shape[1]
  cols, rows = np.meshgrid(np.arange(width), np.arange(height))
  xs, ys = rasterio.transform.xy(transform, rows, cols)
  lons, lats = np.array(xs)[0], np.array(ys)[:, 0]
  ext = '.csv'
  band = np.ravel(band)

  if max_rows is None:
    with open(filename + ext, 'w', encoding = 'UTF8') as f:
      writer = csv.writer(f)
      writer.writerow(HEADER)
      write_value_to_csv(band, lons, lats, writer, width)
  else:
    for i in range( math.ceil((width * height) / max_rows) ):
      offset = i * max_rows
      with open(filename + str(i) + ext, 'w', encoding = 'UTF8') as f:
        writer = csv.writer(f)
        writer.writerow(HEADER)
        write_value_to_csv(band = band[offset : offset + max_rows ], lons = lons, lats = lats, writer = writer, width = width, offset = offset)


### Merge method based on rasterio.merge.merge()

In [None]:
def merge(
    datasets, #Raster data
    method = copy_nansum, #Sum method
    bounds = None, #Limit if we want to cut the merge
    res = None, #Pixel size like 1m, 1km, ...
    nodata = None,
    dtype = None, #Int, float, ...
    precision = None,
    indexes = None,
    output_count = None,
    resampling = Resampling.nearest,
    target_aligned_pixels = False,
    dst_path = None,
    dst_kwds = None,
):

    # Create a dataset_opener object to use in several places in this function.
    if isinstance(datasets[0], (str, os.PathLike)):
        dataset_opener = rasterio.open
    else:

        @contextmanager
        def nullcontext(obj):
            try:
                yield obj
            finally:
                pass

        dataset_opener = nullcontext

    with dataset_opener(datasets[0]) as first:
        first_profile = first.profile
        first_res = first.res
        nodataval = first.nodatavals[0]
        dt = first.dtypes[0]

        if indexes is None:
            src_count = first.count
        elif isinstance(indexes, int):
            src_count = indexes
        else:
            src_count = len(indexes)

        try:
            first_colormap = first.colormap(1)
        except ValueError:
            first_colormap = None

    if not output_count:
        output_count = src_count

    # Extent from option or extent of all inputs
    if bounds:
        dst_w, dst_s, dst_e, dst_n = bounds
    else:
        # scan input files
        xs = []
        ys = []
        for dataset in datasets:
            with dataset_opener(dataset) as src:
                left, bottom, right, top = src.bounds
            xs.extend([left, right])
            ys.extend([bottom, top])
        dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys)

    # Resolution/pixel size
    if not res:
        res = first_res
    elif not np.iterable(res):
        res = (res, res)
    elif len(res) == 1:
        res = (res[0], res[0])

    if target_aligned_pixels:
        dst_w = math.floor(dst_w / res[0]) * res[0]
        dst_e = math.ceil(dst_e / res[0]) * res[0]
        dst_s = math.floor(dst_s / res[1]) * res[1]
        dst_n = math.ceil(dst_n / res[1]) * res[1]

    # Compute output array shape. We guarantee it will cover the output
    # bounds completely
    output_width = int(round((dst_e - dst_w) / res[0]))
    output_height = int(round((dst_n - dst_s) / res[1]))

    output_transform = Affine.translation(dst_w, dst_n) * Affine.scale(res[0], -res[1])

    if dtype is not None:
        dt = dtype

    out_profile = first_profile
    out_profile.update(**(dst_kwds or {}))

    out_profile["transform"] = output_transform
    out_profile["height"] = output_height
    out_profile["width"] = output_width
    out_profile["count"] = output_count
    out_profile["dtype"] = dt
    if nodata is not None:
        out_profile["nodata"] = nodata

    # create destination array
    dest = np.zeros((output_count, output_height, output_width), dtype=dt)
    dest_count = np.zeros((output_count, output_height, output_width), dtype=dt)

    if nodata is not None:
        nodataval = nodata

    if nodataval is not None:
        # Only fill if the nodataval is within dtype's range
        inrange = False
        if np.issubdtype(dt, np.integer):
            info = np.iinfo(dt)
            inrange = (info.min <= nodataval <= info.max)
        elif np.issubdtype(dt, np.floating):
            if math.isnan(nodataval):
                inrange = True
            else:
                info = np.finfo(dt)
                inrange = (info.min <= nodataval <= info.max)
        if inrange:
            dest.fill(nodataval)
        else:
            warnings.warn(
                "The nodata value, %s, is beyond the valid "
                "range of the chosen data type, %s. Consider overriding it "
                "using the --nodata option for better results." % (
                    nodataval, dt))
    else:
        nodataval = 0

    for idx, dataset in enumerate(datasets):
        with dataset_opener(dataset) as src:
            # Real World (tm) use of boundless reads.
            # This approach uses the maximum amount of memory to solve the
            # problem. Making it more efficient is a TODO.

            if disjoint_bounds((dst_w, dst_s, dst_e, dst_n), src.bounds):
                continue

            # 1. Compute spatial intersection of destination and source
            src_w, src_s, src_e, src_n = src.bounds

            int_w = src_w if src_w > dst_w else dst_w
            int_s = src_s if src_s > dst_s else dst_s
            int_e = src_e if src_e < dst_e else dst_e
            int_n = src_n if src_n < dst_n else dst_n

            # 2. Compute the source window
            src_window = windows.from_bounds(int_w, int_s, int_e, int_n, src.transform)

            # 3. Compute the destination window
            dst_window = windows.from_bounds(
                int_w, int_s, int_e, int_n, output_transform
            )

            # 4. Read data in source window into temp
            src_window_rnd_shp = src_window.round_lengths()
            dst_window_rnd_shp = dst_window.round_lengths()
            dst_window_rnd_off = dst_window_rnd_shp.round_offsets()

            temp_height, temp_width = (
                dst_window_rnd_off.height,
                dst_window_rnd_off.width,
            )
            temp_shape = (src_count, temp_height, temp_width)

            temp_src = src.read(
                out_shape=temp_shape,
                window=src_window_rnd_shp,
                boundless=False,
                masked=True,
                indexes=indexes,
                resampling=resampling,
            )

        # 5. Copy elements of temp into dest
        roff, coff = (max(0, dst_window_rnd_off.row_off), max(0, dst_window_rnd_off.col_off), )

        region = dest[:, roff : roff + temp_height, coff : coff + temp_width]
        region_count = dest_count[:, roff : roff + temp_height, coff : coff + temp_width]
        temp = temp_src[:, : region.shape[1], : region.shape[2]]

        method(region, temp)
        method(region_count, ~np.isnan(temp))

    np.divide(dest, dest_count, out = dest)

    if dst_path is None:
        return dest, output_transform

    else:
        with rasterio.open(dst_path, "w", **out_profile) as dst:
            dst.write(dest)
            if first_colormap:
                dst.write_colormap(1, first_colormap)


## Main
Notes: 
The georefence must be always North -> South and West -> East

In [None]:
#Search raster filenames
query = os.path.join(PATH_DATASET, PATH_IMAGES, SEARCH_CRITERIA)

#Open every raster
rasters = [rasterio.open(filename) for filename in glob.glob(query)]

try:
    #Get merged raster with georeference
    merge_data, merge_trans = merge(rasters)

    #Save raster in a CSV file
    band_to_csv(merge_data[0], merge_trans, os.path.join(PATH_DATASET, PATH_RESULT, MERGE_OUTNAME), MAX_ROWS_PER_CSV)

    # Get first raster metadata and modify the necesary data
    merge_meta = rasters[0].meta.copy()
    merge_meta.update({"driver": "GTiff", "height": merge_data.shape[1], "width": merge_data.shape[2], "transform": merge_trans,})

    #Save raster in a tif file
    with rasterio.open(os.path.join(PATH_DATASET, PATH_RESULT, MERGE_OUTNAME, '.tif'), 'w', **merge_meta) as dst:
        dst.write(merge_data[0], 1)
    
except:
    traceback.print_exc()