In [None]:
import numpy as np
import math
import scipy
import pandas as pd
import PIL
import gdal
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import sys, os
from pathlib import Path
import time
import random
import collections, functools, operator
import csv
import subprocess
import datetime

from osgeo import gdal,osr
from gdalconst import *
import subprocess
from osgeo.gdalconst import GA_Update

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.autograd import Variable
from torch.nn import Linear, ReLU, CrossEntropyLoss, MSELoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout, Sigmoid
from torch.optim import Adam, SGD
from torchvision import transforms, utils

import skimage
from skimage import io, transform
import sklearn
import sklearn.metrics
from sklearn.feature_extraction import image
from sklearn import svm

# Satellite dataset class

In [None]:
class SatelliteDataset(Dataset):
    """Satellite dataset."""

    def __init__(self, dirs, normalization_options = None, patches_options = None, readWhileRunning=False, readFromPatches=False, cropClassification=False):
        """
        Args:
            dirs (dict): Dictionary of lists of directories of scene "input" and "target" bands.
            normalization_options (obj): Object normalization_options with information about normalization technique to be applied.
            patches_options (obj): Object patches_options with information about the patching options (size, step, etc)
            readWhileRunning (bool): Whether you want to read online files or local files
            readFromPatches (bool): Whether you want to read from patches in a numpy array data structure.
            cropClassification (bool): Whether you want to perform crop classification. It will vary the reading workflow.
        """
        self.dirs = dirs
        self.normalization_options = normalization_options
        self.patches_options = patches_options
        self.readWhileRunning = readWhileRunning
        self.readFromPatches = readFromPatches
        self.data_augm = False
        self.cropClassification = cropClassification

        self.bandsToGet = ['band1', 'band2', 'band3', 'band4', 'band5', 'band7',
                           '_B02_', '_B03_', '_B04_', '_B05_', '_B06_', '_B07_', '_B08_', '_B09_', '_B10_',
                           
                           'B008', 'B009', 'B010', 'B011', 'B012', 'B013', 'B014', 'B015', 'B016', 'B017',
                           'B018', 'B019', 'B020', 'B021', 'B022', 'B023', 'B024', 'B025', 'B026', 'B027',
                           'B028', 'B029', 'B030', 'B031', 'B032', 'B033', 'B034', 'B035', 'B036', 'B037',
                           'B038', 'B039', 'B040', 'B041', 'B042', 'B043', 'B044', 'B045', 'B046', 'B047',
                           'B048', 'B049', 'B050', 'B051', 'B052', 'B053', 'B054', 'B055', 'B056', 'B057',
                           
                           'B079', 'B080', 'B081', 'B082', 'B083', 'B084', 'B085', 'B086', 'B087', 
                           'B088', 'B089', 'B090', 'B091', 'B092', 'B093', 'B094', 'B095', 'B096', 'B097', 
                           'B098', 'B099', 'B100', 'B101', 'B102', 'B103', 'B104', 'B105', 'B106', 'B107', 
                           'B108', 'B109', 'B110', 'B111', 'B112', 'B113', 'B114', 'B115', 'B116', 'B117', 
                           'B118', 'B119', 'B120',
                           
                           'B131', 'B132', 'B133', 'B134', 'B135', 'B136', 'B137', 'B138', 'B139', 'B140', 
                           'B141', 'B142', 'B143', 'B144', 'B145', 'B146', 'B147', 'B148', 'B149', 'B150', 
                           'B151', 'B152', 'B153', 'B154', 'B155', 'B156', 'B157', 'B158', 'B159', 'B160', 
                           'B161', 'B162', 'B163', 'B164', 'B165',
                           
                           'B181', 'B182', 'B183', 'B184', 'B185', 'B186', 'B187', 'B188', 'B189', 'B190', 
                           'B191', 'B192', 'B193', 'B194', 'B195', 'B196', 'B197', 'B198', 'B199', 'B200', 
                           'B201', 'B202', 'B203', 'B204', 'B205', 'B206', 'B207', 'B208', 'B209', 'B210', 
                           'B211', 'B212', 'B213', 'B214', 'B215', 'B216', 'B217', 'B218', 'B219', 'B220', 
                           'B221', 'B222', 'B223']

        if self.readWhileRunning == False:
            self.dataset = self.read()
        
    def readAllPatchesFromScene(self):
        scene = {}
        scene['input'] = []
        for fname in sorted(os.listdir(self.dirs['input'][0])):
            if any(b in fname for b in self.bandsToGet) and 'modified3' in fname:
                band = gdal.Open(os.path.join(self.dirs['input'][0], fname)).ReadAsArray()
                band = band.astype('float32')
                band /= 65535.0
                band = skimage.util.shape.view_as_windows(band, (64,64), step=64)
                scene['input'].append(band)
        scene['input'] = np.array(scene['input']).astype('float32')
        print(scene['input'].shape)
        scene['input'] = scene['input'].reshape(9,55*18,64,64).transpose(1,0,2,3)
        print(scene['input'].shape)
        
        scene['target'] = []
        for fname in sorted(os.listdir(self.dirs['target'][0])):
            if any(b in fname for b in self.bandsToGet):
                band = gdal.Open(os.path.join(self.dirs['target'][0], fname)).ReadAsArray()
                band = band.astype('float32')
                band /= 65535.0
                band = skimage.util.shape.view_as_windows(band, (64,64), step=64)
                scene['target'].append(band)
        scene['target'] = np.array(scene['target']).astype('float32')
        print(scene['target'].shape)
        scene['target'] = scene['target'].reshape(170,55*18,64,64).transpose(1,0,2,3)
        print(scene['target'].shape)

        return scene


    def __len__(self):
        if self.readWhileRunning == False and self.readFromPatches == False:
            return self.dataset['input'].shape[0]
        else:
            return len(os.listdir(self.dirs['input']))

    def __getitem__(self, idx): # Now it returns whole batch
        if self.readWhileRunning:
            dataset = self.getitem_online(idx)
        else:
            dataset = self.getitem_local(idx)

        dataset['input'] = torch.from_numpy(dataset['input'])
        dataset['target'] = torch.from_numpy(dataset['target'])
        
        if self.data_augm:
            if bool(random.getrandbits(1)):
                dataset['input'] = transforms.RandomHorizontalFlip(1).forward(dataset['input'])
                dataset['target'] = transforms.RandomHorizontalFlip(1).forward(dataset['target'])
            if bool(random.getrandbits(1)):
                dataset['input'] = transforms.RandomVerticalFlip(1).forward(dataset['input'])
                dataset['target'] = transforms.RandomVerticalFlip(1).forward(dataset['target'])
        return dataset
    
    def getitem_online(self, idx):
        fnameInput = sorted(os.listdir(self.dirs['input']))[idx]
        fnameTarget = sorted(os.listdir(self.dirs['target']))[idx]
        dataset = {'input': np.load(os.path.join(self.dirs['input'], fnameInput)),
                  'target': np.load(os.path.join(self.dirs['target'], fnameTarget))}
        if 'crop' in self.dirs.keys():
            dataset['crop'] = 'Not loaded'
        return dataset
    def getitem_local(self, idx):
        dataset = {'input': self.dataset['input'][idx],
                'target': self.dataset['target'][idx]}
        if 'crop' in self.dirs.keys():
            dataset['crop'] = self.dataset['crop'][idx]
        return dataset
    


    def read(self):
        if self.cropClassification == True:
            dataset = self.readAllPatchesFromScene()
        elif self.readFromPatches == True:
            dataset = self.read_patches()
        else:
            dataset = self.read_scenes()
        return dataset

    def read_patches(self):
        input=[]
        target=[]
        for fnameInput, fnameTarget in zip(sorted(os.listdir(self.dirs['input'])), sorted(os.listdir(self.dirs['target']))):
            input.append(np.load(os.path.join(self.dirs['input'],fnameInput)))
            target.append(np.load(os.path.join(self.dirs['target'],fnameTarget)))
        dataset = {'input': input, 'target': target}
        return dataset

    def read_scenes(self):
        dataset = {}
        dataset['input'] = np.array([])
        dataset['target'] = np.array([])
        dataset['crop'] = np.array([])

        for dirNameInput, dirNameTarget in zip(dirs['input'], dirs['target']):
            scene = {}
            scene['input'] = []
            scene['target'] = []
            scene['crop'] = []
            for fname in sorted(os.listdir(dirNameInput)):
                if any(b in fname for b in self.bandsToGet) and 'modified3' in fname:
                    band = gdal.Open(os.path.join(dirNameInput, fname)).ReadAsArray()
                    scene['input'].append(band)
            for fname in sorted(os.listdir(dirNameTarget)):
                if any(b in fname for b in self.bandsToGet):
                    band = gdal.Open(os.path.join(dirNameTarget, fname)).ReadAsArray()
                    scene['target'].append(band)
            sceneCropClass = None
            if 'crop' in self.dirs.keys():
                sceneCropClass = gdal.Open(self.dirs['crop']).ReadAsArray()

            scene['input'] = np.array(scene['input']).astype('float32')
            scene['target'] = np.array(scene['target']).astype('float32')

            cloudMask = None
            if self.patches_options and self.patches_options.evadeClouds:
                cloudMask = self.get_cloudMask(scene)
            if self.normalization_options:
                scene = self.normalize(scene)
            if self.patches_options:
                scene = self.extract_patches(scene, cloudMask, sceneCropClass)
                cloudMask = None
            
            if len(scene['input'].shape) >= 3 and len(scene['target'].shape) >= 3:
                dataset['input'] = np.concatenate((dataset['input'], scene['input'])) if dataset['input'].size else scene['input']
                dataset['target'] = np.concatenate((dataset['target'], scene['target'])) if dataset['target'].size else scene['target']
                if 'crop' in self.dirs.keys():
                    dataset['crop'] = np.concatenate((dataset['crop'], scene['crop'])) if dataset['crop'].size else scene['crop']


            print(dataset['input'].shape)
            print(dataset['target'].shape)
            if 'crop' in self.dirs.keys():
                print(dataset['crop'].shape)
            print('')

        return dataset





    def visualize(self, datasetName, rBand=None, gBand=None, bBand=None, patched=False):
        #if len(self.sample[datasetName].shape) == 4:
        #    raise TypeError('Dataset is in patches. Visualization not available.')

        if datasetName == 'all':
            fig, axs = plt.subplots(1,2)
            fig.set_figwidth(24)
            fig.set_figheight(15)
            for i, (nameset, scene) in enumerate(self.dataset.items()):
                if nameset == 'input':
                    rBand = 3-1
                    gBand = 2-1
                    bBand = 1-1
                elif nameset == 'target':
                    rBand = 29-8
                    gBand = 20-8
                    bBand = 12-8

                R = self.dataset[nameset][rBand]
                G = self.dataset[nameset][gBand]
                B = self.dataset[nameset][bBand]
                img = np.dstack((R, G, B))
                axs[i].imshow(img)
            plt.show()
        else:

            if datasetName == 'input':
                rBand -= 1
                gBand -= 1
                bBand -= 1
            elif datasetName == 'target':
                rBand -= 7+1
                gBand -= 7+1
                bBand -= 7+1

            if not patched:
                R = self.dataset[datasetName][rBand]
                G = self.dataset[datasetName][gBand]
                B = self.dataset[datasetName][bBand]
                img = np.dstack((R,G,B))
                plt.figure(num=None, figsize=(15, 12), dpi=80)
                plt.imshow(img)
                plt.show()

            elif patched:
                fig, axs = plt.subplots(1,10)
                fig.suptitle(datasetName)
                fig.set_figwidth(15)
                for i in range(10):
                    n = random.randint(0, self.dataset[datasetName].shape[0]-1)
                    R = self.dataset[datasetName][n][rBand]
                    G = self.dataset[datasetName][n][gBand]
                    B = self.dataset[datasetName][n][bBand]
                    
                    img = np.dstack((R, G, B))
                    axs[i].imshow(img)
                plt.show()

    def get_cloudMask(self, scene):
        cloudMask = {}
        cloudMask['input'] = scene['input'][3-1]
        cloudMask['target'] = scene['target'][29-8]
        cloudMask = self.bandAdaptive_norm(cloudMask)
        return cloudMask

    #===================================================#
    #============== EXTRACT PATCHES UTILS ==============#
    #===================================================#

    def extract_patches(self, scene, cloudMask=None, sceneCropClass=None):
        patchesInput = []
        patchesTarget = []
        patchesCrop = []

        window = self.patches_options.window
        step = self.patches_options.step

        stepNextRow = False
        i = 0
        while i < (scene['input'][4-2].shape[0] - window): # 3-1 for landsat7
            j=0
            while j < (scene['input'][4-2].shape[1] - window): # Rband - 2 because of indexing properly (indexing starts at 0 but first band is the 2nd as n1 is PAN)
                patchInput = scene['input'][:, i:i+window, j:j+window] 
                patchTarget = scene['target'][29-8][i:i+window, j:j+window] # Rband - 8 because of indexing properly (same as above and 8 or 7 initial bands from hyperion are removed)
                if self.meaningful_patch(patchInput, patchTarget, cloudMask):
                    if not self.patches_options.miniPatches:
                        patchesInput.append(patchInput)
                        patchesTarget.append(scene['target'][:, i:i+window, j:j+window])
                        if 'crop' in self.dirs.keys():
                            patchesCrop.append(sceneCropClass[i:i+window, j:j+window])
                    elif self.patches_options.miniPatches:
                        patchesInput.append(self.sliding_window(patchInput).transpose(1,2,0,3,4))
                        patchesTarget.append(scene['target'][:, i+2:i+window-2, j+2:j+window-2].transpose(1,2,0))
                        if 'crop' in self.dirs.keys():
                            patchesCrop.append(sceneCropClass[i+2:i+window-2, j+2:j+window-2])
                    j += step-1
                    stepNextRow = True
                j += 1
            if stepNextRow:
                i += step-1
                stepNextRow = False
            i += 1
        
        scene['input'] = np.array(patchesInput).astype('float32')
        scene['target'] = np.array(patchesTarget).astype('float32')
        if 'crop' in self.dirs.keys():
            scene['crop'] = np.array(patchesCrop)
        return scene
    
    def sliding_window(self, patch):
        miniPatch = []
        for bandPatch in patch:
            miniPatch.append(skimage.util.shape.view_as_windows(bandPatch, (5,5), step=1))
        return np.array(miniPatch)

    def meaningful_patch(self, patchInput, patchTarget, cloudMask=None):
        if self.patches_options.evadeClouds == False:
            if np.any(patchTarget == 0.0) or np.any(patchTarget == 1.0) or np.any(patchInput == 0.0) or np.any(patchInput == 1.0):
                return False
        elif self.patches_options.evadeClouds == True:
            if 'nan' in patchTarget or 0. in patchTarget or patchTarget.mean() < 0.3 or patchTarget.mean() > 0.7 or patchTarget.std() > 0.25:
                #print('target', patchTarget.mean(), patchTarget.std())
                return False
            if 'nan' in patchInput or patchInput.mean() > 0.8 or patchInput.mean() < 0.3 or patchInput.std() > 0.25:
                #print('input', patchInput.mean(), patchInput.std())
                return False
        
        return True


    #===================================================#
    #=============== NORMALIZATION UTILS ===============#
    #===================================================#
    def normalize(self, dataset):
        norm_type = self.normalization_options.norm_type
        if norm_type == 'bandAdaptive_norm':
            return self.bandAdaptive_norm(dataset)
        elif norm_type == 'sceneAdaptive_norm':
            return self.sceneAdaptive_norm(dataset)
        elif norm_type == 'sensorAdaptive_norm':
            return self.sensorAdaptive_norm(dataset)
        elif norm_type == 'fullRange_norm':
            return self.fullRange_norm(dataset)
        elif norm_type == 'none':
            return dataset
        else:
            return Exception('No norm_type set')

    def sensorAdaptive_norm(self, dataset):
        dataset['input'] = self.band_max_values(dataset['input'], 8)
        dataset['input'] /= 3000.0

        dataset['target'] = self.band_max_values(dataset['target'], 8)
        #dataset['target'] -= 100.0
        #dataset['target'] /= 2000.0 - 100.0
        dataset['target'] /= 2000.0
        dataset['target'][dataset['target'] < 0] = 0.0
        dataset['target'][dataset['target'] > 1] = 1.0
        
        return dataset


    def sceneAdaptive_norm(self, dataset):
        pass


    def fullRange_norm(self, dataset):
        dataset['input'] = self.band_max_values(dataset['input'], 8)
        dataset['input'] /= 65535.0
        dataset['target'] = self.band_max_values(dataset['target'], 8)
        dataset['target'] /= 65535.0
        '''
        for nameset, scenes in dataset.items():
            if isinstance(dataset[nameset], list): # Dataset with a list of scenes
                for scene_i, scene in enumerate(scenes):
                    dataset[nameset][scene_i] = self.band_max_values(scene, 8)
                    dataset[nameset][scene_i] = dataset[nameset][scene_i].astype('float') / 65535.0

            elif len(dataset[nameset].shape) == 3: # Dataset with only 1 scene (with some .tif bands)
                dataset[nameset] = self.band_max_values(scenes, 8)
                dataset[nameset] = dataset[nameset].astype('float') / 65535.0

            elif len(dataset[nameset].shape) == 2: # Dataset with only 1 band of 1 scene (one .tif)
                dataset[nameset] = self.stretch_hist(scenes)
                dataset[nameset] = dataset[nameset].astype('float') / 65535.0
        '''
        
        return dataset


    def bandAdaptive_norm(self, dataset):
        for nameset, scenes in dataset.items():
            if isinstance(scenes, list): # Dataset with a list of scenes
                for scene_i, scene in enumerate(scenes):
                    scene = self.band_max_values(scene, 8)
                    for band_i, band in enumerate(scene):
                        scene[band_i] = self.stretch_hist(band) # this should be the same for different scenes
                    dataset[nameset][scene_i] = scene

            elif len(dataset[nameset].shape) == 3: # Dataset with only 1 scene (with some .tif bands)
                scenes = self.band_max_values(scenes, 8)
                for i, band in enumerate(scenes):
                    scenes[i] = self.stretch_hist(band)
                dataset[nameset] = scenes

            elif len(dataset[nameset].shape) == 2: # Dataset with only 1 band of 1 scene (one .tif)
                dataset[nameset] = self.stretch_hist(scenes)
                
        return dataset


    def band_max_values(self, bandList, m = 8., verbose = False):
        '''
        Finds outliers value pixels from bands and limits the max possible value of that band to the median of its neighbours
        '''
        # Get max values from each band
        maxs = []
        for band in bandList:
            if isinstance(band, gdal.Dataset):
                band = band.ReadAsArray() # To numpy array

            maxs.append(np.max(band))
        
        # Fix outlier pixels
        if verbose:
            print('Input:')
            plt.figure(1)
            plt.plot(maxs)
            plt.show()
        
        maxs = np.array(maxs)
        d = np.abs(maxs - np.median(maxs))
        mdev = np.median(d)
        s = d/mdev if mdev else 0.
        outlierBands = np.where(s>m)[0]

        for band in outlierBands:
            neighbours = np.append(maxs[band-2:band], maxs[band+1:band+3])
            newMaxValueBand = np.median(neighbours)
            bandList[band][bandList[band]>newMaxValueBand] = newMaxValueBand
            maxs[band] = newMaxValueBand

        if verbose:
            print('Output:')
            plt.figure(2)
            plt.plot(maxs)
            plt.show()

        return bandList

    def stretch_hist(self, band, verbose = False):
        '''
        Removes the {percentage} least frequent pixel values from a band in order to decrease noise
        Admits a single band

        band = band.flatten()
        bs = math.ceil(band.max() - band.min())
        hist = plt.hist(band, bins=bs)[0]
        
        hist_s = np.array(sorted(hist))
        hist_sc = hist_s[hist_s>0]
        plt.figure()
        plt.hist(hist_sc)
        print(hist_sc)
        '''

        if isinstance(band, gdal.Dataset):
            band.ReadAsArray() # To numpy array

        band_no0 = band[band>0] # Mask will not be taken into account for the histogram stretching
        min_percent = 5   # Low percentile
        max_percent = 95  # High percentile

        # Find lower and upper percentile using percentile function, and "stretch" pixel range linearly between lower and upper percentiles
        lo, hi = np.percentile(band_no0, (min_percent, max_percent)) 

        # Apply linear "stretch" - lo goes to 0, and hi goes to 1
        res_img = (band.astype(float) - lo) / (hi-lo)

        #Multiply by 1, clamp range to [0, 1] and convert to float32
        res_img = np.maximum(np.minimum(res_img*1, 1), 0).astype(np.float32)

        if verbose:
            print('Input:')
            plt.figure(1)
            plt.imshow(band, cmap='gray', vmin=0, vmax=np.max(band))
            plt.show()
            print('Output (Normalized):')
            plt.figure(2)
            plt.imshow(res_img, cmap='gray', vmin=0, vmax=1)
            plt.show()

        return res_img


class NormalizeOptions(object):
    """
    Normalizes data
    Fixes pixel values outliers in bands comparing to others
    Applies normalization between values [0-1]
    Admits both a list of bands and a single band
    """
    def __init__(self, norm_type):
        self.norm_type = norm_type

class ToPatchesOptions(object):
    """Convert dataset to image patches with custom window size and step size"""

    def __init__(self, window, step, evadeClouds, miniPatches):
        self.window = window
        self.step = step
        #self.setForShape = setForShape # The name of the set to extract the patches from corresponding to its shape. I use Hyperion shape because it is the smallest
        #self.setsForPatching = setsForPatching # List of the name sets to divide into patches
        self.evadeClouds = evadeClouds
        self.miniPatches = miniPatches


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        # numpy image: H x W x C // Not using this format in any moment
        # torch image: C x H x W // Embraced this for my data
        
        sample = torch.from_numpy(sample)
        return sample