# Feed loop for VAE network

This notebook represents a development prototype for the feedloop providing input data to the VAE.

Given the data volume it is not possible to store the entire data set in memory. Instead the adopted approach splits the full data set (composed of tiles retrieved from Google Earth Engine at 100 km X 100 km) into subsets of size `sizeSubSet` (~30). Each subset tile is divided into cutouts of `sizeCutOut` X `sizeCutOut`. These are stacked and passed to the VAE as input.

This process repeats until all tiles have been exhausted, thus constituting one epoch. The trained network is saved and early stopping criteria are evaluted after epoch. 

If further training is needed, the next epoch employes cut outs shifted by `nShift` for data augementation. 

In [3]:
import rasterio
import numpy as np
from rasterio import windows
from itertools import product
import os
import yaml
import random
from pathlib import Path

Input can/should be provided as a yaml file. This will map to a dictionary. Specify by hand here for demo purposes

In [1]:
#TODO yaml input reader

In [None]:
configDict = {'inputDirectory':'', 
 'nBands':4,             
 'sizeSubSet':2,
 'sizeCutOut':20,
 'nEpochMax':2,
 'nShift':5}

In [43]:
class subSet(object):
    """ Create data subset for ingestion by VAE """
    
    def __init__(self, config=None, offset=None):
        self.nBands = config['nBands']
        self.maxSize = config['sizeSubset']
        self.sizeCutOut = config['sizeCutOut']
        self.offset = offset
        self.inputTiles = []
        self.window_list = []
        self.vaeInputArray = None
        
    def draw_input(self, inputTileList):
        """ 
        draw a randdom selection of input tiles form the list
        of all input tiles. Return input tile list with selected
        tiles removed 
        """
        if self.maxSize >= len(inputTileList):
            self.inputTiles = _draw(len(inputTileList), inputTileList)
        else :
            self.inputTiles = _draw(self.maxSize, inputTileList)
                
    def load_cutouts(self):
        """
        loop over input tiles creating windows and associated cutouts, 
        and stack to VAE inout array
        """
        window_list = []
        arrays = []
        for tile in self.inputTiles:
            wl, al = _window_and_load(tile, self.offset, self.nBands, self.sizeCutOut, self.sizeCutOut)
            window_list.append(wl)
            arrays.append(al)
            
        self.vaeInputArray = _create_VAE_input(arrays)
   
##############
    
    def _create_VAE_input(arrays,axis=0):
        stack = np.stack(arrays,axis=axis)
        return stack
    
    
    def _window_and_load(source, shift, nBands, width, height):
        
        with rasterio.open(source) as src:
            windows = [w for w in _get_windows(src, shift, width, height)]
            arrays = [src.read([range(nBands)+1],window = window) for window in windows]
            
        window_list = []
        window_list.extend([(source, w)] for w in windows)
        
        return window_list, arrays
        
            
    #adapted from Ou Ku's window_loading_list_shuffle.ipynb
    def _get_windows(src, shift, width, height, boundless = True):
        """
        windows should not extend beyond source image. Wrapping is sub-optimal as
        boundaries are non-periodic in principle. Determine elligible window range for
        given shift. BEWARE windows repeat for shift >= window size. This may form constraint
        maximum number of epochs.
    
        eligible windows can be determind as:
        
        eWitdth = (tilewidth - mod((tilewidth - shift),windowwidth))
        
        offsets are then given by range(shift,e_width,windowwidth)
    
        """
        modShiftWidth = np.mod((src.meta['width'] - shift),width)
        eWidth = src.meta['width'] - modShiftWidth
        modShiftHeight = np.mod((src.meta['height'] - shift),height)
        eHeight = src.meta['height'] - modShiftHeight
        
        offsets = product(range(shift, eWidth, width), range(shift, eHeight, height))
        for col_off, row_off in offsets:
            window = windows.Window(
                col_off=col_off,
                row_off=row_off,
                width=width,
                height=height)
        
            yield window
            
        
    def _draw(n, input):
        sample = [input.pop(random.randrange(len(input))) for _ in range(n)]
        return sample