#Demonstration of DSen2 Superresolution 
*Johannes Mast*


*30.03.2020*





## Introduction
This script is an implementation of *Super-resolution of Sentinel-2 images: Learning a globally applicable deep neural network* [[Link]](https://arxiv.org/abs/1803.04271) by Charis Lanaras, José Bioucas-Dias, Silvano Galliani, Emmanuel Baltsavias and Konrad Schindler. The code was made available online on [github](https://github.com/lanha/DSen2). 

This implementation is designed to do three things:

1. **Demonstrate** the functionality of the code. The code published in the abovementioned Github is designed to be run from shell and -while documented- is fairly opaque. In this script, we want to execute the individual parts step by step to gain a better understanding of the algorithm. 

2. **Introduce** the basic building blocks for flexible **semi-automisation** of the procedure. We develop functionality for semi-automated data acquisition and batch-processing.

3. **Modify** the original code, in which some variables are hardcoded, to be more flexible in execution. This will also help achieving our other goals.

The implementation is designed to run in google colab, connected to the users google drive. Google colab allows for the connection to a moderately powerful cloud runtime - This allows the user more flexibility in using a comparatively demanding algorithm.







## Overview


Sentinel 2 imagery comes in three different resolutions: 10m, 20m, and 60m. To obtain a data cube in which all bands are at the highest resolution, the 20m and 60m bands need to be superresolved. In this script, we will see how deep neural networks can be applied to this purpose. We begin with an overview of the model, the training procedure, and the possible outputs.



---
###The Model

The model's network architecture consists mainly of convolutional layers,
ReLU non-linearities and skip connections packed into a series of residual blocks (ResBlocks) and notably a long skip connection from the input directly to the output.


![alt text](https://ars.els-cdn.com/content/image/1-s2.0-S0924271618302636-gr5.sml)

The model exists in a deep version **DSen2** and a very deep **VDSen2** version. Quoting the publication:

>*VDSen2 has a lot
higher capacity, and was designed with maximum accuracy
in mind. It is closer in terms of size and training time to
modern high-end CNNs for other image analysis tasks (Simonyan and Zisserman, 2015; He et al., 2016; Huang et al.,
2017), but is approximately two times slower and five times
slower in both training and prediction respectively, compared to its shallower counterpart (DSen2) [...] On the one hand, the very deep variant is
consistently a bit better, while training and applying it is not
more difficult, if adequate resources (i.e., high-end GPUs)
are available.*

The model comes with two different sets of weights, one which is trained on the superresolution from 60m to 10m (*030*) and one which is trained on the superresolution from 20m to 10m (*032*) .

The weights for the VDSen2 (very deep) models should be able to be downloaded from  [s2_033_lr_1e-04.hdf5](http://n.ethz.ch/~lanarasc/DSen2/s2_033_lr_1e-04.hdf5) (*033*) and [s2_034_lr_1e-04.hdf5](http://n.ethz.ch/~lanarasc/DSen2/s2_034_lr_1e-04.hdf5) (*034*) but are offline as of the creation of this notebook.

In this experiment, we will not use the VDSen2 model, as the much lighter DSen2 produces only slightly worse results.

For a detailed description of the models we refer to the publication. We will also look at the model code more closely in the Prediction section in Chapter 3.


---
### Training

The available model comes with already trained weights. Training is beneficial but not necessary. To quote the publication:
 

> *These shall enable out-of-the-box super-resolution of Sentinel-2 images world-wide, with minimal knowledge
of neural network tools. Of course, if a study is focussed
only in a specific geographic location, biome or land-cover
type, even better result can be expected by training the network only with images showing those specific conditions.
The literature suggests that in that case, it may be best
to start from our globally trained network and fine-tune it
through further training iterations on task-specific imagery*

For the training, no labeled data is necessary. This is possible because -the developers assume- the superresolution is scale invariant. Therefore, for the network can learn the superresolution from downsampled versions of image bands to their initial resolution. No labeled training data is needed, and potentially any Sentinel-2 tile can be used for training.

For a detailed description of the training procedure which produced the provided weights, we again refer to the publication. A list of the Sentinel-2 images which were used for training can be found [as a supplement to the code on the github](https://github.com/lanha/DSen2/blob/master/S2_tiles_training.txt). 

---
### Bands 

The principles of the algorithm can be applied to any data, however, the publication and the DSen2 model deal specifically with the superresolution of Sentinel-2 bands. Let us take a quick look at these bands.

**10m Resolution:**

These are visible and VNIR bands. They are already at the highest possible resolution of 10m and will not be predicted, although they can be copied to the output file. 

![alt text](https://earth.esa.int/image/image_gallery?uuid=c5fa6c3e-2978-4fb8-ac95-3be9c5171be2&groupId=247904&t=1345630320883)


**20m Resolution:**

These are VNIR bands 5,6,7, 8a, 11 and 12. They are prime candidates for the superresolution and can all be predicted at 10m.
![alt text](https://earth.esa.int/image/image_gallery?uuid=15dad96b-be6a-4b04-931d-d8c4db39e9e2&groupId=247904&t=1345630328076)


**60m Resolution:**

These are bands 1, 9 and 10 which are mainly used to measure water vapor, aerosols and clouds. Bands 1 and 9 can optionally be predicted to 10m. Band 10 is too noisy and cannot be superresolved.


![alt text](https://earth.esa.int/image/image_gallery?uuid=f6117fbe-1513-4a84-acc4-845e14e5c876&groupId=247904&t=1345630315020)



Note that due to python indexing, the band indices within the code are usually shifted by one position compared to the documentation except for gdal and rasterio, which start indexing at one. 


---
### Structure of this script

The process can be clearly structured into a number of tasks:

* **Setup**:  Firstly, we have to setup the environment. This must be done anytime the VM is started.
*   **Data Acquisition**: Secondly, we acquire imagery for our desired study area by using [sentinelsat](https://pypi.org/project/sentinelsat/) to query and download Sentinel-2 images from the [copernicus open access hub](https://pypi.org/project/sentinelsat/).
*  **Training**: The functionality for training the network is given by DSen2, it requires, however, a comparatively large amount of data and time. It is therefore **optional**.
*  **Superresolution**: We subsequently use the implementation of DSen2 to superresolve the desired bands of the images we have downloaded. The results will be saved on the drive.
* **Visualisation**: Of course we want take a quick look at the result.
* **Streamlining**: To finally apply the algorithm to many files, we combine its essential parts back into a reduced function. We can apply this function in a loop to process a list of files.





## The Script

### Chapter 1: Setup

Before we begin, we must set up the environment. The VMs come with some preinstalled packages, but other elements must be reestablished anytime we connect to a new VM.

#### Connecting to Google Drive

Every time the notebook runtime is started, it must be reconnected to the Google Drive.
Therefore, this must be done at the beginning of every script. An authorisation code is required as an input by the user.

In [0]:
%%capture
from google.colab import drive
from importlib import reload

drive.mount('/content/gdrive',force_remount=True)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········



We want to create all required directories, as well as import the DSen2 and the input data.

#### Project Directory

We create a project directory *DSen2*, which will contain everything which is related to this project. Alternatively we can navigate to an already existing directory on the drive, which has been created beforehand (for instance, the directory containing the present notebook)

We further create a subdirectory *Data* in which our data will be saved.

In [0]:
#Change to the google drive main
%cd "gdrive/My Drive"

#Create a new directory
!mkdir DSen2Main

#Enter the new directory
%cd DSen2Main

#make a dir for the data
!mkdir Data
!mkdir Data/Training
#make an output dir
!mkdir Outputs

/content/gdrive/My Drive
mkdir: cannot create directory ‘DSen2Main’: File exists
/content/gdrive/My Drive/DSen2Main
mkdir: cannot create directory ‘Data’: File exists
mkdir: cannot create directory ‘Data/Training’: File exists
mkdir: cannot create directory ‘Outputs’: File exists


#### Importing the DSen2

The DSen2 is available on github. It forms the backbone of the algorithm, and comes with pretrained weights. We clone it into our project directory.

In [0]:
!git clone https://github.com/lanha/DSen2

#### Installing required packages
Packages can be installed using the *!pip install* command. This needs to be done everytime a notebook is loaded which requires certain dependencies. Fortunately, many packages are already installed on the VMs. We install several common packages for handling geospatial data, which should only take a minute.





In [0]:
%%capture
!pip install rasterio
!pip install geopandas
!pip install descartes
!pip install gdal
!pip install sentinelsat
!pip install geojson
!pip install shapely
!pip install Pillow

#### Selecting  Tensorflow Version

Colab VMs offer different versions of tensorflow - here, we select tensorflow 1.15.0, the default as of 27.03.2020 is tensorflow 2.

In [0]:
%tensorflow_version 1.x
import tensorflow as tf
print("Tensorflow version " + tf.__version__)

### Chapter 2: Data Acquisition 

We use the [sentinelsat](https://pypi.org/project/sentinelsat/) package to connect to the api and then query and download images. For the spatial query we use a polygon string of our study area. We will be looking for three images in this example, but the number could be increased or reduced without issue.

In [0]:
n_dl_images=3

#### Defining the AOI

For querying spatially, we use an AOI. The author of this script is German, and thinks his country is very representative of the world, so we choose some German coordinates.

In [0]:
lon_min=9
lon_max=10
lat_min= 48
lat_max= 52

#### Connecting to the API
To query the images, we connect to the API via the sentinelsat package. For this, we need our username and a password.

In [0]:
from sentinelsat.sentinel import SentinelAPI, read_geojson, geojson_to_wkt
from datetime import date
import getpass
# connect to the API
api = SentinelAPI(getpass.getpass(prompt="Please enter Copernicus Open Access Hub username"),
                  getpass.getpass(prompt="Please enter Copernicus Open Access Hub password"),
                  'https://scihub.copernicus.eu/dhus')

We create an polygon from our AOI coordinates and visualize it on a map to make sure we got the coordinates right.

In [0]:
aoi = 'POLYGON((%s %s,%s %s,%s %s,%s %s,%s %s))'  %(lon_min,lat_min,lon_min,lat_max,lon_max,lat_max,lon_max,lat_min,lon_min,lat_min)
import plotly.graph_objects as go
fig = go.Figure(go.Scattermapbox(
    fill = "toself",text="Query Area",
    lon = [lon_min, lon_max, lon_max, lon_min],
    lat = [lat_max, lat_max, lat_min, lat_min],
    marker = { 'size': 10, 'color': "red" }))
fig.update_layout(
     title="AOI for the Query",font=dict(family="Arial",size=18,),
     mapbox = {
        'style': "stamen-terrain",
        'center': {'lon': ((lon_min+lon_max)/2), 'lat': ((lat_min+lat_max)/2) },
        'zoom': 4},
    showlegend = False)
fig.show()

Looks alright!

#### Querying the images

Now we are ready to query the API.
We use our polygon to query spatially and specify a date range to query temporally. We also only want scenes with low cloud cover. For our example run, we have that liberty.

In [0]:
# search by polygon, time, and SciHub query keywords
products = api.query(aoi,
                     date=('20190819', date(2019, 12, 29)),
                     platformname='Sentinel-2',
                     cloudcoverpercentage=(0, 30))

We convert the query result to a dataframe, sort it by cloud cover and take the best 3 results.

In [0]:
# convert to Pandas DataFrame
products_df = api.to_dataframe(products)

# sort and limit to first n_dl_images sorted products
products_df_sorted = products_df.sort_values(['cloudcoverpercentage', 'ingestiondate'], ascending=[True, True])
products_df_sorted = products_df_sorted.head(n_dl_images)
products_df_sorted

#### Download

Then we start the download into the data directory. This should be relatively quick, only taking a few minutes. Note: Uncomment `%%capture` to hide progress

In [0]:
# download sorted and reduced products
#%%capture 
api.download_all(products_df_sorted.index,"Data/")

Check the downloaded files - We should have one for each of the selected query results (3 in our example)


In [0]:
import glob
downloaded_files=(glob.glob("Data/*.zip"))
downloaded_files

['Data/S2B_MSIL1C_20190823T103029_N0208_R108_T32UNB_20190823T124349.zip',
 'Data/S2B_MSIL1C_20190823T103029_N0208_R108_T32UMB_20190823T124349.zip',
 'Data/S2B_MSIL1C_20190823T103029_N0208_R108_T32UNA_20190823T124349.zip']

Extracting the files not necessary, as the Gdal and therefore DSen2 can also accept zipped files.  We want to do it here regardless, because for the training procedure we need the extracted SAFE files. 

Since it also takes a while, and can occasionally break things, we do want to do it if we are not doing training. In that case, we can comment the following codeblock.

If we extract, we then list the MSI xml files which we use instead of the zip files- If it worked, we see one file for each previous zip file, again, 3 in this example.

In [0]:
import zipfile
for i in range(0, len(downloaded_files)):
  with zipfile.ZipFile(downloaded_files[i], 'r') as zip_ref:
      zip_ref.extractall("Data")
downloaded_files=(glob.glob("Data/**/*MSI*.xml",recursive=True))

###Chapter 3: Training

To perform the superresolution, our network needs some weights. Here, we have a number of options:

* We can use the **pretrained weights** [provided by the developers](https://github.com/lanha/DSen2/tree/f6b8790c28a136b0ee57f8d4aa801348efcb4b74/models) , which have been extensively trained on a large number of Sentinel-2 images from many different geographical areas. These are perfectly fine to use! And so this entire chapter can be skipped.


But maybe we are dealing with a particularly unusual region, for which the network was not trained. Then, it might make sense to train the network to also work well on this region. 
To fine-tune our model to our AOI, we can train it on the images we have just downloaded! 

* We can **train** from scratch, in which case the weights are randomly initialized using the [HeUniform](https://arxiv.org/abs/1502.01852) method.

* It is more advisable, however, to **transfer** from  the pretrained weights, as that will give us a very solid base to start from, and will therefore greatly reduce the required training time.

In this script, we will, as an example, train the 030 model which comes pretrained for the superresolution of 60m bands to 10m resolution, with two of our previously downloaded images.



>  At this point, a note: As of the time of the writing of this script, the code from the developers github implements rather strictly the procedure of the published study.
This creates reproducibility for the study, but also means that many variables, parameters, and especially path names are hardcoded. This conflicts with the goals of our script to flexibly execute the procedure flexibly and in a reasonable time. We therefore have to make some changes to the code. And for doing that, the simple way is to execute it in code blocks. So this chapter will be fairly bulky.

We need to accomplish a number of tasks:

*   **Creating the patches**: From our previously downloaded files, we create a number of smaller image-patches.
*   **Creating the Train/Test split**: We create an index file, which indexes the previously created patches. Using this file, we can split the data into a training and a testing partition.
*   **Training**: We train the 030 Network using the patches.
*   **Replacing the original file**: We replace the original weight file. We do this to avoid any unnecessary changes in file paths for the remainder of this script.


#### Creating the Patch Dataset

First, we sample a number of patches from our image, which we can then use for training.
To quote the publication:
>Sentinel-2 images are too big to fit them into GPU memory for training and testing, and in fact it is unlikely than
long-range context over distances of a kilometer or more
plays any significant role for super-resolution at the 10 m
level. With this in mind, we train the network on small
patches of w×h = (32×32) for T2×, respectively (96×96)
pixels for S6×. We note that this corresponds to a receptive field of several hundred metres on the ground, sufficient to capture the local low-level texture and potentially
also small semantic structures such as individual buildings
or small waterbodies, but not large-scale topographic features. 


First, we set a couple of parameters.

Essential are:

*   **data_file**: The Sentinel-2 file from which patches are created. This can be either the original ZIP file, or the S2A[...].xml file in a SAFE directory extracted from that ZIP. When running from command line, this is the only necessary argument. However, it only accepts one file at a time. We change it so that now it accepts a list of files (see **data_files** below)

No output file is specified, the code automatically saves into a test folder. A path can be specified with **save_prefix** below.

Optional are:

* **roi_x_y**: A string which sets the region of interest (x_1,y_1,x_2,y_2) to extract as pixels locations on the 10m bands.
* **test_data**: If `true` stores test patches in a separate dir? Not necessary for us.
* **save_prefix**: We can use this to add a prefix to the output file. This could also be a directory, as in `"Test/Outputs/"`. If we execute from command line, the default is to save into `"../data/"`, but we change this to` Data/Training/`.
* **write_images**: If `true`, write PNG images for the original and the superresolved bands, together with a composite rgb image (first three 10m bands), all with a quick and dirty clipping to 99%% of the original bands dynamic range and a quantization of the values to 256 levels. Since our space on the drive is limited, let's not do that for now.
* **run_60**: If `true`, creates patches also from the 60m channels. This is what we want to do, so we set this to True.
* **true_data**: If `true`, creates patches for S2 without GT. This option is not really useful here, please check the testing folder for predicting S2 images.

We added/changed the following parameters:

*   **data_files**: The Sentinel-2 files from which patches are created. We get the path to the downloaded and extracted SAFE files, but only take the first two.

* **NR_CROP**: In the provided implementation, the number of patches is given by the default of NR_CROP in `save_random_patches`, which is 8000. To control the number of patches, we add this variable, which is passed on towards `save_random_patches`.
The result will be a file containing the patches which is output to the specified folder. We use 1500 here - a comparatively low number, but we want to save time.








In [0]:
training_files=glob.glob("Data/*SAFE",recursive=True)[0:2]
training_files

['Data/S2B_MSIL1C_20190823T103029_N0208_R108_T32UNB_20190823T124349.SAFE',
 'Data/S2B_MSIL1C_20190823T103029_N0208_R108_T32UMB_20190823T124349.SAFE']

In [0]:
data_files=training_files
test_data=False
roi_x_y=""
save_prefix="Data/Training/"
write_images=False
run_60=True
true_data=False
NR_CROP=1500

The bulk of the work is done by the function below. We have modified it slightly compared to the source code, to include the `NR_CROP` argument.

In [0]:
#@title

from __future__ import division
import argparse
import numpy as np
from osgeo import gdal
import sys
from collections import defaultdict
import re
import os
import imageio
import json
sys.path.append('../')
from DSen2.utils.patches import downPixelAggr, save_test_patches, save_random_patches, save_random_patches60, save_test_patches60

data_filename = '/MTD_MSIL1C.xml'

# sleep(randint(0, 20))

def readS2fromFile(data_file,
                   test_data=False,
                   roi_x_y=None,
                   save_prefix="../data/",
                   write_images=False,
                   run_60=False,
                   true_data=False,
                   NR_CROP=8000):

    if run_60:
        select_bands = "B1,B2,B3,B4,B5,B6,B7,B8,B8A,B9,B11,B12"
    else:
        select_bands = "B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12"

    raster = gdal.Open(data_file + data_filename)

    datasets = raster.GetSubDatasets()
    tenMsets = []
    twentyMsets = []
    sixtyMsets = []
    unknownMsets = []
    for (dsname, dsdesc) in datasets:
        if '10m resolution' in dsdesc:
            tenMsets += [ (dsname, dsdesc) ]
        elif '20m resolution' in dsdesc:
            twentyMsets += [ (dsname, dsdesc) ]
        elif '60m resolution' in dsdesc:
            sixtyMsets += [ (dsname, dsdesc) ]
        else:
            unknownMsets += [ (dsname, dsdesc) ]

    if roi_x_y:
        roi_x1, roi_y1, roi_x2, roi_y2 = [float(x) for x in re.split(',', args.roi_x_y)]

    # case where we have several UTM in the data set
    # => select the one with maximal coverage of the study zone
    utm_idx = 0
    utm = ""
    all_utms = defaultdict(int)
    xmin, ymin, xmax, ymax = 0, 0, 0, 0
    largest_area = -1
    # process even if there is only one 10m set, in order to get roi -> pixels
    for (tmidx, (dsname, dsdesc)) in enumerate(tenMsets + unknownMsets):
        ds = gdal.Open(dsname)
        if roi_x_y:
            tmxmin = max(min(roi_x1, roi_x2, ds.RasterXSize - 1), 0)
            tmxmax = min(max(roi_x1, roi_x2, 0), ds.RasterXSize - 1)
            tmymin = max(min(roi_y1, roi_y2, ds.RasterYSize - 1), 0)
            tmymax = min(max(roi_y1, roi_y2, 0), ds.RasterYSize - 1)
            # enlarge to the nearest 60 pixel boundary for the super-resolution
            tmxmin = int(tmxmin / 36) * 36
            tmxmax = int((tmxmax + 1) / 36) * 36 - 1
            tmymin = int(tmymin / 36) * 36
            tmymax = int((tmymax + 1) / 36) * 36 - 1
        else:
            tmxmin = 0
            tmxmax = ds.RasterXSize - 1
            tmymin = 0
            tmymax = ds.RasterYSize - 1

        area = (tmxmax - tmxmin + 1) * (tmymax - tmymin + 1)
        current_utm = dsdesc[dsdesc.find("UTM"):]
        if area > all_utms[current_utm]:
            all_utms[current_utm] = area
        if area > largest_area:
            xmin, ymin, xmax, ymax = tmxmin, tmymin, tmxmax, tmymax
            largest_area = area
            utm_idx = tmidx
            utm = dsdesc[dsdesc.find("UTM"):]

    # convert comma separated band list into a list
    select_bands = [x for x in re.split(',',select_bands) ]

    print("Selected UTM Zone:".format(utm))
    print("Selected pixel region: xmin=%d, ymin=%d, xmax=%d, ymax=%d:" % (xmin, ymin, xmax, ymax))
    print("Selected pixel region: tmxmin=%d, tmymin=%d, tmxmax=%d, tmymax=%d:" % (tmxmin, tmymin, tmxmax, tmymax))
    print("Image size: width=%d x height=%d" % (xmax - xmin + 1, ymax - ymin + 1))

    if xmax < xmin or ymax < ymin:
        print("Invalid region of interest / UTM Zone combination")
        sys.exit(0)

    selected_10m_data_set = None
    if not tenMsets:
        selected_10m_data_set = unknownMsets[0]
    else:
        selected_10m_data_set = tenMsets[utm_idx]
    selected_20m_data_set = None
    for (dsname, dsdesc) in enumerate(twentyMsets):
        if utm in dsdesc:
            selected_20m_data_set = (dsname, dsdesc)
    # if not found, assume the listing is in the same order
    # => OK if only one set
    if not selected_20m_data_set: selected_20m_data_set = twentyMsets[utm_idx]
    selected_60m_data_set = None
    for (dsname, dsdesc) in enumerate(sixtyMsets):
        if utm in dsdesc:
            selected_60m_data_set = (dsname, dsdesc)
    if not selected_60m_data_set: selected_60m_data_set = sixtyMsets[utm_idx]

    ds10 = gdal.Open(selected_10m_data_set[0])
    ds20 = gdal.Open(selected_20m_data_set[0])
    ds60 = gdal.Open(selected_60m_data_set[0])

    def validate_description(description):
        m = re.match("(.*?), central wavelength (\d+) nm", description)
        if m:
            return m.group(1) + " (" + m.group(2) + " nm)"
        # Some HDR restrictions... ENVI band names should not include commas

        pos = description.find(',')
        return description[:pos] + description[(pos + 1):]

    def get_band_short_name(description):
        if ',' in description:
            return description[:description.find(',')]
        if ' ' in description:
            return description[:description.find(' ')]
        return description[:3]

    validated_10m_bands = []
    validated_10m_indices = []
    validated_20m_bands = []
    validated_20m_indices = []
    validated_60m_bands = []
    validated_60m_indices = []
    validated_descriptions = defaultdict(str)

    sys.stdout.write("Selected 10m bands:")
    for b in range(0, ds10.RasterCount):
        desc = validate_description(ds10.GetRasterBand(b + 1).GetDescription())
        shortname = get_band_short_name(desc)
        if shortname in select_bands:
            sys.stdout.write(" " + shortname)
            select_bands.remove(shortname)
            validated_10m_bands += [shortname]
            validated_10m_indices += [b]
            validated_descriptions[shortname] = desc
    sys.stdout.write("\nSelected 20m bands:")
    for b in range(0, ds20.RasterCount):
        desc = validate_description(ds20.GetRasterBand(b + 1).GetDescription())
        shortname = get_band_short_name(desc)
        if shortname in select_bands:
            sys.stdout.write(" " + shortname)
            select_bands.remove(shortname)
            validated_20m_bands += [shortname]
            validated_20m_indices += [b]
            validated_descriptions[shortname] = desc
    sys.stdout.write("\nSelected 60m bands:")
    for b in range(0, ds60.RasterCount):
        desc = validate_description(ds60.GetRasterBand(b + 1).GetDescription())
        shortname = get_band_short_name(desc)
        if shortname in select_bands:
            sys.stdout.write(" " + shortname)
            select_bands.remove(shortname)
            validated_60m_bands += [shortname]
            validated_60m_indices += [b]
            validated_descriptions[shortname] = desc
    sys.stdout.write("\n")

    if validated_10m_indices:
        print("Loading selected data from: %s" % selected_10m_data_set[1])
        data10 = np.rollaxis(
            ds10.ReadAsArray(xoff=xmin, yoff=ymin, xsize=xmax - xmin + 1, ysize=ymax - ymin + 1, buf_xsize=xmax - xmin + 1,
                             buf_ysize=ymax - ymin + 1), 0, 3)[:, :, validated_10m_indices]

    if validated_20m_indices:
        print("Loading selected data from: %s" % selected_20m_data_set[1])
        data20 = np.rollaxis(
            ds20.ReadAsArray(xoff=xmin // 2, yoff=ymin // 2, xsize=(xmax - xmin + 1) // 2, ysize=(ymax - ymin + 1) // 2,
                             buf_xsize=(xmax - xmin + 1) // 2, buf_ysize=(ymax - ymin + 1) // 2), 0, 3)[:, :,
                 validated_20m_indices]

    if validated_60m_indices:
        print("Loading selected data from: %s" % selected_60m_data_set[1])
        data60 = np.rollaxis(
            ds60.ReadAsArray(xoff=xmin // 6, yoff=ymin // 6, xsize=(xmax - xmin + 1) // 6, ysize=(ymax - ymin + 1) // 6,
                             buf_xsize=(xmax - xmin + 1) // 6, buf_ysize=(ymax - ymin + 1) // 6), 0, 3)[:, :,
                 validated_60m_indices]

    # The percentile_data argument is used to plot superresolved and original data
    # with a comparable black/white scale
    def save_band(data, name, percentile_data=None):
        if percentile_data is None:
            percentile_data = data
        mi, ma = np.percentile(percentile_data, (1, 99))
        band_data = np.clip(data, mi, ma)
        band_data = (band_data - mi) / (ma - mi)
        imageio.imsave(save_prefix + name + ".png", band_data)  # img_as_uint(band_data))

    chan3 = data10[:, :, 0]
    vis = (chan3 < 1).astype(np.int)
    if np.sum(vis) > 0:
        print('The selected image has some blank pixels')
        # sys.exit()

    scale20 = 2
    scale60 = 6

    data10_gt = data10
    data20_gt = data20

    if not true_data:
        if run_60:
            data60_gt = data60
            data10_lr = downPixelAggr(data10_gt, SCALE=scale60)
            data20_lr = downPixelAggr(data20_gt, SCALE=scale60)
            data60_lr = downPixelAggr(data60_gt, SCALE=scale60)
        else:
            data10_lr = downPixelAggr(data10_gt, SCALE=scale20)
            data20_lr = downPixelAggr(data20_gt, SCALE=scale20)
            if scale20 > 2:
                data20_lr = downPixelAggr(data20_gt, SCALE=scale20//2)

    if data_file.endswith('/'):
        tmp = os.path.split(data_file)[0]
        data_file = os.path.split(tmp)[1]
    else:
        data_file = os.path.split(data_file)[1]
    print(data_file)

    if test_data:
        if run_60:
            out_per_image0 = save_prefix + 'test60/'
            out_per_image = save_prefix + 'test60/' + data_file + '/'
        else:
            out_per_image0 = save_prefix + 'test/'
            out_per_image = save_prefix + 'test/' + data_file + '/'
        if not os.path.isdir(out_per_image0):
            os.mkdir(out_per_image0)
        if not os.path.isdir(out_per_image):
            os.mkdir(out_per_image)

        print('Writing files for testing to:{}'.format(out_per_image))
        if run_60:
            save_test_patches60(data10_lr, data20_lr, data60_lr, out_per_image)
            with open(out_per_image + 'roi.json', 'w') as f:
                json.dump([tmxmin // scale60, tmymin // scale60, (tmxmax + 1) // scale60, (tmymax + 1) // scale60], f)
        else:
            save_test_patches(data10_lr, data20_lr, out_per_image)
            with open(out_per_image + 'roi.json', 'w') as f:
                json.dump([tmxmin // scale20, tmymin // scale20, (tmxmax+1) // scale20, (tmymax+1) // scale20], f)

        if not os.path.isdir(out_per_image + 'no_tiling/'):
            os.mkdir(out_per_image + 'no_tiling/')

        print("Now saving the whole image without tiling...")
        if run_60:
            np.save(out_per_image + 'no_tiling/' + 'data60_gt', data60_gt.astype(np.float32))
            np.save(out_per_image + 'no_tiling/' + 'data60', data60_lr.astype(np.float32))
        else:
            np.save(out_per_image + 'no_tiling/' + 'data20_gt', data20_gt.astype(np.float32))
            save_band(data10_lr[:, :, 0:3], '/test/' + data_file + '/RGB')
        np.save(out_per_image + 'no_tiling/' + 'data10', data10_lr.astype(np.float32))
        np.save(out_per_image + 'no_tiling/' + 'data20', data20_lr.astype(np.float32))

    elif write_images:
        print('Creating RGB images...')
        save_band(data10_lr[:, :, 0:3], '/raw/rgbs/' + data_file + 'RGB')
        save_band(data20_lr[:, :, 0:3], '/raw/rgbs/' + data_file + 'RGB20')

    elif true_data:
        out_per_image0 = save_prefix + 'true/'
        out_per_image = save_prefix + 'true/' + data_file + '/'
        if not os.path.isdir(out_per_image0):
            os.mkdir(out_per_image0)
        if not os.path.isdir(out_per_image):
            os.mkdir(out_per_image)

        print('Writing files for testing to:{}'.format(out_per_image))
        save_test_patches60(data10_gt, data20_gt, data60_gt, out_per_image, patchSize=384, border=12)

        with open(out_per_image + 'roi.json', 'w') as f:
            json.dump([tmxmin, tmymin, tmxmax+1, tmymax+1], f)

        if not os.path.isdir(out_per_image + 'no_tiling/'):
            os.mkdir(out_per_image + 'no_tiling/')

        print("Now saving the whole image without tiling...")
        np.save(out_per_image + 'no_tiling/' + 'data10', data10_gt.astype(np.float32))
        np.save(out_per_image + 'no_tiling/' + 'data20', data20_gt.astype(np.float32))
        np.save(out_per_image + 'no_tiling/' + 'data60', data60_gt.astype(np.float32))

    else:
        if run_60:
            out_per_image0 = save_prefix + 'train60/'
            out_per_image = save_prefix + 'train60/' + data_file + '/'
        else:
            out_per_image0 = save_prefix + 'train/'
            out_per_image = save_prefix + 'train/' + data_file + '/'
        if not os.path.isdir(out_per_image0):
            os.mkdir(out_per_image0)
        if not os.path.isdir(out_per_image):
            os.mkdir(out_per_image)
        print('Writing files for training to:{}'.format(out_per_image))
        if run_60:
            save_random_patches60(data60_gt, data10_lr, data20_lr, data60_lr, out_per_image,NR_CROP)
        else:
            save_random_patches(data20_gt, data10_lr, data20_lr, out_per_image,NR_CROP)

    print("Success.")



We run the function in a loop over our files. Since we just have two here it is a bit superflous, but we can prove that it works in principle.

In [0]:
for i in range(0, len(data_files)):
  data_file=data_files[i]
  readS2fromFile(data_file,
                test_data,
                roi_x_y,
                save_prefix,
                write_images,
                run_60,
                true_data,
                NR_CROP)

Selected UTM Zone:
Selected pixel region: xmin=0, ymin=0, xmax=10979, ymax=10979:
Selected pixel region: tmxmin=0, tmymin=0, tmxmax=10979, tmymax=10979:
Image size: width=10980 x height=10980
Selected 10m bands: B4 B3 B2 B8
Selected 20m bands: B5 B6 B7 B8A B11 B12
Selected 60m bands: B1 B9
Loading selected data from: Bands B2, B3, B4, B8 with 10m resolution, UTM 32N
Loading selected data from: Bands B5, B6, B7, B8A, B11, B12 with 20m resolution, UTM 32N
Loading selected data from: Bands B1, B9, B10 with 60m resolution, UTM 32N
S2B_MSIL1C_20190823T103029_N0208_R108_T32UNB_20190823T124349.SAFE
Writing files for training to:Data/Training/train60/S2B_MSIL1C_20190823T103029_N0208_R108_T32UNB_20190823T124349.SAFE/
(1500, 2, 96, 96)
(1500, 4, 96, 96)
(1500, 6, 48, 48)
(1500, 2, 16, 16)
Done!
Success.
Selected UTM Zone:
Selected pixel region: xmin=0, ymin=0, xmax=10979, ymax=10979:
Selected pixel region: tmxmin=0, tmymin=0, tmxmax=10979, tmymax=10979:
Image size: width=10980 x height=10980
Sel

####Indexing and Splitting the Dataset

We create a simple index of the patches we have, and split it randomly.

Normally this is done by the `create_random.py ` which uses hardcoded paths. To control the paths, we execute it here as a codeblock, and change the path to where we saved previously.

Variables which we must use here are:


*   **size**: The number of files we processed (one, in our example) multiplied by the number of patches per tile (same as NR_CROP in the previous section)
*   **ratio**: The ratio of validation files. We set it to the suggested default of 0.1
* **path**: This is the path to the directory which contains the files we created in the previous block. So for us, it is the same as our previous `save_prefix`, which was `Data/Training/`



In [0]:
from random import randrange
# Size: number of S2 tiles (times) patches per tile
size = len(data_files)*NR_CROP
ratio = .1
nb = int(size * ratio)
path = save_prefix


index = np.zeros(size).astype(np.bool)
i = 0
while np.sum(index.astype(np.int)) < nb:
    x = randrange(0, size)
    index[x] = True
    i += 1
    
if run_60:
  path=path+"train60/"
else:
  path=path+"train/"
np.save(path + 'val_index.npy', index)

print('Full no of samples: {}'.format(size))
print('Validation samples: {}'.format(np.sum(index.astype(np.int))))

print("Number of iterations: {}".format(i))

Full no of samples: 3000
Validation samples: 300
Number of iterations: 320


We should now see how many samples we have that we can use for training. Two images with 450 patches each results in 900 patches sampled.

####Training
Now we have the resources we need for the training itself.
We can run from the `supres_train.py`, but again, we crack open the code to examine it and to reduce the number of epochs.

Setting variables

*   **predict_file**: If we provide a file here, we can do a single prediction. The purpose for this is presumably for testing.
*   **resume_file**:  Path to the previously created weights file to resume from. We want to train the 030 model for 60m data, which is why we use the `s2_030_lr_1e-05.hdf5` from the model folder - otherwise we would get the 032 model `s2_032_lr_1e-04.hdf5`
*   **true**: Use true scale data? No simulation or different resolutions.
*   **run_60**: Whether to run a 60->10m network. Default 20->10m.
*   **deep**: train a deep network. Takes too long, lets not do that here.

We added here one more variable:

* **n_epochs**: The number of epochs used for training. Fixed to 8192 in the provided implementation. For our purposes of testing, 6 is enough.

At this point, we could also adjust the following parameters:


*  **model_nr**: The name of the output model. Must be 7 characters. Here, we go with the default of "s2_038_"
*   **SCALE**: A factor by which the raw reflectance values are divided by for "numerical stability".
*   **lr**: The base learning rate.

We will leave those unchanged.

In [0]:
predict_file=None
if(run_60):
  resume_file="DSen2/models/s2_030_lr_1e-05.hdf5"
else:
  resume_file="DSen2/models/s2_032_lr_1e-04.hdf5"
true=True
run_60=run_60 #from previous code block
deep=False
path="Data/Training/"
n_epochs=5

In [0]:
from __future__ import division
import numpy as np
import datetime
import glob
import time
import argparse
import os
import sys
import matplotlib as mpl
mpl.use('Agg')
from keras.optimizers import Nadam
from keras.callbacks import ModelCheckpoint, Callback, ReduceLROnPlateau
from keras.utils import plot_model
import keras.backend as K

sys.path.append('../')
from DSen2.utils.patches import recompose_images, OpenDataFilesTest, OpenDataFiles
from DSen2.utils.DSen2Net import s2model

K.set_image_data_format('channels_first')

# Define file prefix for new training, must be 7 characters of this form:
model_nr = 's2_038_'
SCALE = 2000
lr = 1e-4


path = path
if not os.path.isdir(path):
    os.mkdir(path)
out_path = path+"network_data/"
if not os.path.isdir(out_path):
    os.mkdir(out_path)

We also create a callback which allows us to track how our losses are progressing - this will help us judge the performance of our network.

In [0]:
#@title
class PlotLosses(Callback):
    def __init__(self, model_nr, lr):
        self.model_nr = model_nr
        self.lr = lr

    def on_train_begin(self, logs=None):
        self.losses = []
        self.val_losses = []
        self.i = 0
        self.x = []
        self.filename = out_path + self.model_nr + '_lr_{:.1e}.txt'.format(self.lr)
        open(self.filename, 'w').close()

    def on_epoch_end(self, epoch, logs=None):
        import matplotlib.pyplot as plt
        plt.ioff()

        lr = float(K.get_value(self.model.optimizer.lr))
        # data = np.loadtxt("training.log", skiprows=1, delimiter=',')
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))
        self.x.append(self.i)
        self.i += 1
        try:
            with open(self.filename, 'a') as self.f:
                self.f.write('Finished epoch {:5d}: loss {:.3e}, valid: {:.3e}, lr: {:.1e}\n'
                             .format(epoch, logs.get('loss'), logs.get('val_loss'), lr))

            if epoch > 500:
                plt.clf()
                plt.plot(self.x[475:], self.losses[475:], label='loss')
                plt.plot(self.x[475:], self.val_losses[475:], label='val_loss')
                plt.legend()
                plt.xlabel('epochs')
                # plt.waitforbuttonpress(0)
                plt.savefig(out_path + self.model_nr + '_loss4.png')
            elif epoch > 250:
                plt.clf()
                plt.plot(self.x[240:], self.losses[240:], label='loss')
                plt.plot(self.x[240:], self.val_losses[240:], label='val_loss')
                plt.legend()
                plt.xlabel('epochs')
                # plt.waitforbuttonpress(0)
                plt.savefig(out_path + self.model_nr + '_loss3.png')
            elif epoch > 100:
                plt.clf()
                plt.plot(self.x[85:], self.losses[85:], label='loss')
                plt.plot(self.x[85:], self.val_losses[85:], label='val_loss')
                plt.legend()
                plt.xlabel('epochs')
                # plt.waitforbuttonpress(0)
                plt.savefig(out_path + self.model_nr + '_loss2.png')
            elif epoch > 50:
                plt.clf()
                plt.plot(self.x[50:], self.losses[50:], label='loss')
                plt.plot(self.x[50:], self.val_losses[50:], label='val_loss')
                plt.legend()
                plt.xlabel('epochs')
                # plt.waitforbuttonpress(0)
                plt.savefig(out_path + self.model_nr + '_loss1.png')
            else:
                plt.clf()
                plt.plot(self.x[0:], self.losses[0:], label='loss')
                plt.plot(self.x[0:], self.val_losses[0:], label='val_loss')
                plt.legend()
                plt.xlabel('epochs')
                # plt.waitforbuttonpress(0)
                plt.savefig(out_path + self.model_nr + '_loss0.png')
        except IOError:
            print('Network drive unavailable.')
            print(datetime.datetime.now().time())

First we create the model using the `s2model` function. 
We choose the [nadam](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Nadam) optimizer and compile our model.

We load the weights from the previously set `resume_file`, and then the data.

Then we can start the training! This should take about 3-5 minutes with our chosen settings, so now is a good time to take a break.


In [0]:
if path is not None:
    path = path

# input_shape = ((4,32,32),(6,16,16))
if run_60:
    input_shape = ((4, None, None), (6, None, None), (2, None, None))
else:
    input_shape = ((4, None, None), (6, None, None))
# create model
if deep:
    model = s2model(input_shape, num_layers=32, feature_size=256)
    batch_size = 8
else:
    model = s2model(input_shape, num_layers=6, feature_size=128)
    batch_size = 128
print('Symbolic Model Created.')

nadam = Nadam(lr=lr,
              beta_1=0.9,
              beta_2=0.999,
              epsilon=1e-8,
              schedule_decay=0.004)
              # clipvalue=0.000005)

model.compile(optimizer=nadam, loss='mean_absolute_error', metrics=['mean_squared_error'])
print('Model compiled.')
model.count_params()
# model.summary()

if predict_file:
    if true:
        folder = 'true/'
        border = 12
    elif run_60:
        folder = 'test60/'
        border = 12
    else:
        folder = 'test/'
        border = 4
    model_nr = predict_file[-20:-13]
    print('Changing the model number to: {}'.format(model_nr))
    model.load_weights(args.predict_file)
    print("Predicting using file: {}".format(predict_file))
    fileList = [os.path.basename(x) for x in sorted(glob.glob(path + folder + '*SAFE'))]
    print("Using patches from the files:")
    print(fileList)
    for dset in fileList:
        start = time.time()
        print("Timer started.")
        print("Predicting: {}.".format(dset))
        train, image_size = OpenDataFilesTest(path + folder + dset, run_60, SCALE, true)
        prediction = model.predict(train,
                                    batch_size=8,
                                    verbose=1)
        prediction_file = model_nr + '-predict'
        # np.save(path + 'test/' + dset + '/' + prediction_file + 'pat', prediction * SCALE)
        images = recompose_images(prediction, border=border, size=image_size)
        print('Writing to file...')
        np.save(path + folder + dset + '/' + prediction_file, images * SCALE)
        end = time.time()
        print('Elapsed time: {}.'.format(end - start))
    sys.exit(0)

if resume_file:
    print("Will resume from the weights {}".format(resume_file))
    model.load_weights(resume_file)
    model_nr = resume_file[-20:-13]
    print('Changing the model number to: {}'.format(model_nr))

else:
    print('Model number is {}'.format(model_nr))
    plot_model(model, to_file=out_path + model_nr+'model.png', show_shapes=True, show_layer_names=True)

    model_yaml = model.to_yaml()
    with open(out_path + model_nr + "model.yaml", 'w') as yaml_file:
        yaml_file.write(model_yaml)

filepath = out_path + model_nr + 'lr_{:.0e}.hdf5'.format(lr)
checkpoint = ModelCheckpoint(filepath,
                              monitor='val_loss',
                              verbose=1,
                              save_best_only=True,
                              save_weights_only=False,
                              mode='auto')
plot_losses = PlotLosses(model_nr, lr)
LRreducer = ReduceLROnPlateau(monitor='val_loss',
                              factor=0.5,
                              patience=5,
                              verbose=1,
                              epsilon=1e-6,
                              cooldown=20,
                              min_lr=1e-5)

callbacks_list = [checkpoint, plot_losses, LRreducer]

print('Loading the training data...from...')
print(path)
train, label, val_tr, val_lb = OpenDataFiles(path,  run_60, SCALE)
print('Training starts...')

model.fit(x=train,
          y=label,
          batch_size=batch_size,
          epochs=n_epochs,
          verbose=1,
          callbacks=callbacks_list,
          validation_split=0.,
          validation_data=(val_tr, val_lb),
          shuffle=True,
          class_weight=None,
          sample_weight=None,
          initial_epoch=0,
          validation_steps=None)

Via the callbacks, statistics are collected during training and finally plotted to an image. We can take a look at it to examing the progress.

In [0]:
%matplotlib inline
from IPython.display import Image
Image(out_path+'s2_030__loss0.png')

Hopefully, the validation loss is decreasing nicely, and then slowly flattening out. But since we use few epochs and little training data, it is not unlikely that some weirdness could also be happening.

####Replacing the original

If we are happy with the result, we replace the original weights file with the newly created one. Not very elegant - but it allows us to run the next section without any changes, regardless of whether we skipped the training or not.

In [0]:
from shutil import copyfile
if(run_60):
  copyfile(out_path+"s2_030_lr_1e-04.hdf5", "DSen2/models/s2_030_lr_1e-05.hdf5")
else:
  copyfile(out_path+"s2_030_lr_1e-04.hdf5", "DSen2/models/s2_032_lr_1e-04.hdf5")

Now we are done with training.

What we have achieved in this section is this: We have tuned our model to perform better in the 60->10 for our downloaded tiles, and, therefore, our AOI. At least we hope that this is the case. In any case, now we can move on towards the superresolution process itself!

### Chapter 4: Superresolution

Now we are ready to execute the superresolution algorithm.

We want to superresolve both the 60m and the 20m bands, as well as keep the original 10m bands, to get a full stack at 10m resolution. To keep the processing time short, we will restrict ourselves to a small subset of the data.

By default, DSen2 will be called via command line, with the command line arguments getting parsed to the function `s2_tiles_supres`.

This does not tell us much about the inner workings, so we will dissect that function and execute its components step-by-step.

#### Setting up DSen2

We import the DSen2 implementation.

In [0]:
import sys
sys.path.append("DSen2/testing")
sys.path.append("DSen2")

We import a couple of libraries which we need. DSen2-Supres itself depends on tensorflow, so this gets imported as well.

In [0]:
from __future__ import division
import argparse
import numpy as np
import os
import re
import sys
from osgeo import gdal, osr
from collections import defaultdict
from supres import DSen2_20, DSen2_60

#### Setting the parameters

Here, we set a couple of parameters.
Essential are:



*   **data_file**: The file on which will be predicted. For S2 imagery, this filename should point to the xml file in the highest directory.
*   **output_file**: The output filename. Tif is recommended, but other common types are supposedly implemented as well *(I did not test that so far)*. See **output_file_formats** below.

Optional are:

* **roi_lon_lat**: Sets the region of interest to extract, WGS84, decimal notation. lon_1,lat_1,lon_2,lat_2.
* **roi_x_y**: Sets the region of interest to extract as pixels locations on the 10m bands. Use this syntax: x_1,y_1,x_2,y_2.  To speed things up, we use this parameter to only work on a small subset of the data (1500 by 1500 pixels).
* **list_bands**: If `True`, lists bands in the input file subdata set matching the selected UTM zone. When run via command line, it then exits, but thats not what we do here.
* **run_60**: If `True`, super-resolve the 20m **and** 60m bands(B1,B2,B3,B4,B5,B6,B7,B8,B8A,B9,B11,B12). Otherwise, only super-resolve the 20m bands (B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12). Note: "Band B10 is too noisy and is not super-resolved".
* **list_UTM**: If `True`, list all UTM zones present in the input file, together with their coverage of the ROI in 10m x 10m pixels.
* **select_UTM**: Manually select a UTM zone to use - otherwise the algorithm will select the UTM zone with the best coverage.
* **list_output_file_formats**: If `true`, lists all output file formats supported by GDAL. 
* **output_file_format**: Specifies the name of a GDAL driver that supports file creation, like ENVI or GTiff. We use the proven GTiff here.

* **copy_original_bands**: Copy the original 10m bands into the new, predicted file? We set this to `true`, because then we have more bands to compare, but it obviously means the output files will be much larger. 
* **save_prefix**: We can use this to add a prefix to the output file. This could also be a directory, as in `"Test/Outputs/"`. Here, we use the previously created output folder.


In this example, we use mostly defaults to explore the most important aspects of the algorithm, first only for one file.

We also add one more switch that is not provided by the implementation:
* **deep**: If we have acquired the VDSen2 models, and have placed them in the `DSen2/models` directory we can use this switch to use them for prediction.

In [0]:
data_file=downloaded_files[0]
output_file="GuterOutput.Tif"
roi_lon_lat=""
roi_x_y="1,1,1500,1500"
list_bands=True
run_60=True
list_UTM=False
select_UTM=""
list_output_file_formats=True
output_file_format="GTiff"
copy_original_bands=True
save_prefix="Outputs/"

deep= False

#### Listing possible output file formats

Originally, this option will list the possible formats in which the result can be output and then quits. We can execute the codeblock to see what options we have. In this experiment we have chosen GTiff already, so this does not matter much, but it is good to know the alternatives. 

In [0]:
if list_output_file_formats:
    dcount = gdal.GetDriverCount()
    for didx in range(dcount):
        driver = gdal.GetDriver(didx)
        if driver:
            metadata = driver.GetMetadata()
        if (gdal.DCAP_CREATE in (driver and metadata) and metadata[gdal.DCAP_CREATE] == 'YES' and
        gdal.DCAP_RASTER in metadata and metadata[gdal.DCAP_RASTER] == 'YES'):
            name = driver.GetDescription()
            if "DMD_LONGNAME" in metadata:
                name += ": " + metadata["DMD_LONGNAME"]
            else:
                name = driver.GetDescription()
            if "DMD_EXTENSIONS" in metadata: name += " (" + metadata["DMD_EXTENSIONS"] + ")"
            print(name)

As we can see, there is quite a long list of alternatives.

#### Choosing the bands for predictions

By the parameter **run60** to True, we have chosen to superresolve also the 60m bands in addition to the 20m bands. We create a list of all those bands.



In [0]:
if run_60:
    select_bands = 'B1,B2,B3,B4,B5,B6,B7,B8,B8A,B9,B11,B12'
else:
    select_bands = 'B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12'

# convert comma separated band list into a list
select_bands = [x for x in re.split(',', select_bands)]
select_bands

In [0]:
if roi_lon_lat:
    roi_lon1, roi_lat1, roi_lon2, roi_lat2 = [float(x) for x in re.split(',', roi_lon_lat)]
else:
    roi_lon1, roi_lat1, roi_lon2, roi_lat2 = -180, -90, 180, 90

if roi_x_y:
    roi_x1, roi_y1, roi_x2, roi_y2 = [float(x) for x in re.split(',', roi_x_y)]

#### Loading the Data
We load in the data into a couple of lists, one for each of the resolutions

*   10m
*   20m
*  60m
* Unknown






In [0]:
raster = gdal.Open(data_file)
datasets = raster.GetSubDatasets();
tenMsets = []
twentyMsets = []
sixtyMsets = []
unknownMsets = []
for (dsname, dsdesc) in datasets:
    if '10m resolution' in dsdesc:
        tenMsets += [(dsname, dsdesc)]
    elif '20m resolution' in dsdesc:
        twentyMsets += [(dsname, dsdesc)]
    elif '60m resolution' in dsdesc:
        sixtyMsets += [(dsname, dsdesc)]
    else:
        unknownMsets += [(dsname, dsdesc)]

#### Choosing the subset correct UTM zone

Choosing the utm zone is more complex than it seems.
First, we must account for the possible restriction of the AOI with the **roi_lon_lat** and **roi_x_y** arguments.
Then, we choose the UTM zone with the best coverage.


In [0]:
# case where we have several UTM in the data set
# => select the one with maximal coverage of the study zone
utm_idx = 0
utm = select_UTM
all_utms = defaultdict(int)
xmin, ymin, xmax, ymax = 0, 0, 0, 0
largest_area = -1
# process even if there is only one 10m set, in order to get roi -> pixels
for (tmidx, (dsname, dsdesc)) in enumerate(tenMsets + unknownMsets):
    ds = gdal.Open(dsname)
    if roi_x_y:
        tmxmin = max(min(roi_x1, roi_x2, ds.RasterXSize - 1), 0)
        tmxmax = min(max(roi_x1, roi_x2, 0), ds.RasterXSize - 1)
        tmymin = max(min(roi_y1, roi_y2, ds.RasterYSize - 1), 0)
        tmymax = min(max(roi_y1, roi_y2, 0), ds.RasterYSize - 1)
        # enlarge to the nearest 60 pixel boundary for the super-resolution
        tmxmin = int(tmxmin / 6) * 6
        tmxmax = int((tmxmax + 1) / 6) * 6 - 1
        tmymin = int(tmymin / 6) * 6
        tmymax = int((tmymax + 1) / 6) * 6 - 1
    elif not roi_lon_lat:
        tmxmin = 0
        tmxmax = ds.RasterXSize - 1
        tmymin = 0
        tmymax = ds.RasterYSize - 1
    else:
        xoff, a, b, yoff, d, e = ds.GetGeoTransform()
        srs = osr.SpatialReference()
        srs.ImportFromWkt(ds.GetProjection())
        srsLatLon = osr.SpatialReference()
        srsLatLon.SetWellKnownGeogCS("WGS84");
        ct = osr.CoordinateTransformation(srsLatLon, srs)


        def to_xy(lon, lat):
            (xp, yp, h) = ct.TransformPoint(lon, lat, 0.)
            xp -= xoff
            yp -= yoff
            # matrix inversion
            det_inv = 1. / (a * e - d * b)
            x = (e * xp - b * yp) * det_inv
            y = (-d * xp + a * yp) * det_inv
            return (int(x), int(y))


        x1, y1 = to_xy(roi_lon1, roi_lat1)
        x2, y2 = to_xy(roi_lon2, roi_lat2)
        tmxmin = max(min(x1, x2, ds.RasterXSize - 1), 0)
        tmxmax = min(max(x1, x2, 0), ds.RasterXSize - 1)
        tmymin = max(min(y1, y2, ds.RasterYSize - 1), 0)
        tmymax = min(max(y1, y2, 0), ds.RasterYSize - 1)
        # enlarge to the nearest 60 pixel boundary for the super-resolution
        tmxmin = int(tmxmin / 6) * 6
        tmxmax = int((tmxmax + 1) / 6) * 6 - 1
        tmymin = int(tmymin / 6) * 6
        tmymax = int((tmymax + 1) / 6) * 6 - 1
    area = (tmxmax - tmxmin + 1) * (tmymax - tmymin + 1)
    current_utm = dsdesc[dsdesc.find("UTM"):]
    if area > all_utms[current_utm]:
        all_utms[current_utm] = area
    if current_utm == select_UTM:
        xmin, ymin, xmax, ymax = tmxmin, tmymin, tmxmax, tmymax
        utm_idx = tmidx
        utm = current_utm
        break
    if area > largest_area:
        xmin, ymin, xmax, ymax = tmxmin, tmymin, tmxmax, tmymax
        largest_area = area
        utm_idx = tmidx
        utm = dsdesc[dsdesc.find("UTM"):]

if list_UTM:
    print("List of UTM zones (with ROI coverage in pixels):")
    for u in all_utms:
        print("%s (%d)" % (u, all_utms[u]))



If successful, we now have chosen the correct subset of our AOI and, from that, best UTM zone for this prediction.

In [0]:
print("Selected UTM Zone:", utm)
print("Selected pixel region: xmin=%d, ymin=%d, xmax=%d, ymax=%d:" % (xmin, ymin, xmax, ymax))
print("Image size: width=%d x height=%d" % (xmax - xmin + 1, ymax - ymin + 1))

if xmax < xmin or ymax < ymin:
    print("Invalid region of interest / UTM Zone combination")
    sys.exit(0)
else:
    print("all good")

#### Opening the datasets bands

We now use gdal to open the file and just load in the bands into three gdal datasets, one for each resolution.

In [0]:
selected_10m_data_set = None
if not tenMsets:
    selected_10m_data_set = unknownMsets[0]
else:
    selected_10m_data_set = tenMsets[utm_idx]
selected_20m_data_set = None
for (dsname, dsdesc) in enumerate(twentyMsets):
    if utm in dsdesc:
        selected_20m_data_set = (dsname, dsdesc)
# if not found, assume the listing is in the same order
# => OK if only one set
if not selected_20m_data_set: selected_20m_data_set = twentyMsets[utm_idx]
selected_60m_data_set = None
for (dsname, dsdesc) in enumerate(sixtyMsets):
    if utm in dsdesc:
        selected_60m_data_set = (dsname, dsdesc)
if not selected_60m_data_set: selected_60m_data_set = sixtyMsets[utm_idx]

ds10 = gdal.Open(selected_10m_data_set[0])
ds20 = gdal.Open(selected_20m_data_set[0])
ds60 = gdal.Open(selected_60m_data_set[0])


Since we set the **list_bands** to True, we now can get an output of all the bands, not just the bands we selected, as well as a description with their wavelength. Note the order, which is different from the original structure. This is also the order in which the results will be written into the output file.



In [0]:
def validate_description(description):
    m = re.match("(.*?), central wavelength (\d+) nm", description)
    if m:
        return m.group(1) + " (" + m.group(2) + " nm)"
    # Some HDR restrictions... ENVI band names should not include commas
    if output_file_format == 'ENVI' and ',' in description:
        pos = description.find(',')
        return description[:pos] + description[(pos + 1):]
    return description


if list_bands:
    print("\n10m bands:")
    for b in range(0, ds10.RasterCount):
        print("- " + validate_description(ds10.GetRasterBand(b + 1).GetDescription()))
    print("\n20m bands:")
    for b in range(0, ds20.RasterCount):
        print("- " + validate_description(ds20.GetRasterBand(b + 1).GetDescription()))
    print("\n60m bands:")
    for b in range(0, ds60.RasterCount):
        print("- " + validate_description(ds60.GetRasterBand(b + 1).GetDescription()))
    print("")

#### Validating the bands

The function then "validates" the bands by selecting only the bands which we selected via the **select_bands** and **run_60** parameters.



In [0]:
def get_band_short_name(description):
    if ',' in description:
        return description[:description.find(',')]
    if ' ' in description:
        return description[:description.find(' ')]
    return description[:3]


validated_10m_bands = []
validated_10m_indices = []
validated_20m_bands = []
validated_20m_indices = []
validated_60m_bands = []
validated_60m_indices = []
validated_descriptions = defaultdict(str)

sys.stdout.write("Selected 10m bands:")
for b in range(0, ds10.RasterCount):
    desc = validate_description(ds10.GetRasterBand(b + 1).GetDescription())
    shortname = get_band_short_name(desc)
    if shortname in select_bands:
        sys.stdout.write(" " + shortname)
        select_bands.remove(shortname)
        validated_10m_bands += [shortname]
        validated_10m_indices += [b]
        validated_descriptions[shortname] = desc
sys.stdout.write("\nSelected 20m bands:")
for b in range(0, ds20.RasterCount):
    desc = validate_description(ds20.GetRasterBand(b + 1).GetDescription())
    shortname = get_band_short_name(desc)
    if shortname in select_bands:
        sys.stdout.write(" " + shortname)
        select_bands.remove(shortname)
        validated_20m_bands += [shortname]
        validated_20m_indices += [b]
        validated_descriptions[shortname] = desc
sys.stdout.write("\nSelected 60m bands:")
for b in range(0, ds60.RasterCount):
    desc = validate_description(ds60.GetRasterBand(b + 1).GetDescription())
    shortname = get_band_short_name(desc)
    if shortname in select_bands:
        sys.stdout.write(" " + shortname)
        select_bands.remove(shortname)
        validated_60m_bands += [shortname]
        validated_60m_indices += [b]
        validated_descriptions[shortname] = desc
sys.stdout.write("\n")


In this case, since we selected all the bands, the list of the bands is the same as before.

#### Setting the Output Filename

We now set the output filename. If we had not given one, it would now create one here, with the name identical to the input file, and the default extension being .tif.

At this point we also add the prefix, so that in our case, the full path points to the *Outputs* folder.

In [0]:
# All query options are processed, we now require an output file
if not output_file:
    print("Error: you must provide the name of an output file. I will set it identical to the input...")
    output_file = os.path.split(data_file)[1] + '.tif'
    # sys.exit(1)


output_file = save_prefix + output_file
# Some HDR restrictions... ENVI file name should be the .bin, not the .hdr
if output_file_format == 'ENVI' and (output_file[-4:] == '.hdr' or output_file[-4:] == '.HDR'):
    output_file = output_file[:-4] + '.bin'


#### Prediction
Any weights are loaded from the Dsen2/models folder. If we have 

Now we start the prediction for the desired bands - On the VM, this can take a few hours for a single image.

The core is in the `DSen2_20` and `DSen2_60` functions which can be found in DSen2/testing/supres.py . We will not execute them step-by-step here, but we can take a look at the most important functions.

We begin with `DSen2_20` which takes as inputs the 10m bands and 20m bands and can be used to construct both the deep and very deep versions of the network. 





**DSen2_20**
```
def DSen2_20(d10, d20, deep=False):
    # Input to the funcion must be of shape:
    #     d10: [x,y,4]      (B2, B3, B4, B8)
    #     d20: [x/2,y/4,6]  (B5, B6, B7, B8a, B11, B12)
    #     deep: specifies whether to use VDSen2 (True), or DSen2 (False)

    border = 8
    p10, p20 = get_test_patches(d10, d20, patchSize=128, border=border)
    p10 /= SCALE
    p20 /= SCALE
    test = [p10, p20]
    input_shape = ((4, None, None), (6, None, None))
    prediction = _predict(test, input_shape, deep=deep)
    images = recompose_images(prediction, border=border, size=d10.shape)
    images *= SCALE
    return images
```



We see that the `DSen2_20` function samples test patches using `get_test_patches` and passes them to the `_predict` function, which itself 



1.   Builds a `s2model` (see code cell below)
2.   Loads the appropriate model weights from the DSen2/models folder using `load_weights`
3.   Predicts the model using `predict` and returns the prediction


**_predict**
```

def _predict(test, input_shape, deep=False, run_60=False):
    # create model
    if deep:
        model = s2model(input_shape, num_layers=32, feature_size=256)
        predict_file = MDL_PATH+'s2_034_lr_1e-04.hdf5' if run_60 else MDL_PATH+'s2_033_lr_1e-04.hdf5'
    else:
        model = s2model(input_shape, num_layers=6, feature_size=128)
        predict_file = MDL_PATH+'s2_030_lr_1e-05.hdf5' if run_60 else MDL_PATH+'s2_032_lr_1e-04.hdf5'
    print('Symbolic Model Created.')

    model.load_weights(predict_file)
    print("Predicting using file: {}".format(predict_file))
    prediction = model.predict(test, verbose=1)
    return prediction
```

Unfortunately, because - as we can see here - the path to the model is hardcoded as relative path name to supres.py, we have to change the working directory to the testing folder to find the model weights.

The model itself is built by the s2model function from Dsen2/utils/DSen2Net.py :


**s2model**
```
def s2model(input_shape, num_layers=32, feature_size=256):

    input10 = Input(shape=input_shape[0])
    input20 = Input(shape=input_shape[1])
    if len(input_shape) == 3:
        input60 = Input(shape=input_shape[2])
        x = Concatenate(axis=1)([input10, input20, input60])
    else:
        x = Concatenate(axis=1)([input10, input20])

    # Treat the concatenation
    x = Conv2D(feature_size, (3, 3), kernel_initializer='he_uniform', activation='relu', padding='same')(x)

    for i in range(num_layers):
        x = resBlock(x, feature_size)

    # One more convolution, and then we add the output of our first conv layer
    x = Conv2D(input_shape[-1][0], (3, 3), kernel_initializer='he_uniform', padding='same')(x)
    # x = Dropout(0.3)(x)
    if len(input_shape) == 3:
        x = Add()([x, input60])
        model = Model(inputs=[input10, input20, input60], outputs=x)
    else:
        x = Add()([x, input20])
        model = Model(inputs=[input10, input20], outputs=x)
    return model
```
We can see that the bulk of the architecture consists of a number of `resBlock`. These residual blocks consist of


1.   Convolutional layer
2.   ReLU layer
3.   Convolutional layer
4.   Residual Scaling

![alt text](https://ars.els-cdn.com/content/image/1-s2.0-S0924271618302636-gr6.sml)

These layers are supplemented by a skip connection straight from the input to the Output.
```
def resBlock(x, channels, kernel_size=[3, 3], scale=0.1):
    tmp = Conv2D(channels, kernel_size, kernel_initializer='he_uniform', padding='same')(x)
    tmp = Activation('relu')(tmp)
    tmp = Conv2D(channels, kernel_size, kernel_initializer='he_uniform', padding='same')(tmp)
    tmp = Lambda(lambda x: x * scale)(tmp)

    return Add()([x, tmp])

```

Now that we have a basic idea of how the code works, we can move towards the prediction!

First, we read the different bands into arrays `data10` `data20` and possibly `data60`. We rearrange the dimensions of the arrays (probably to fit tensorflow syntax).



In [0]:
if validated_10m_indices:
    print("Loading selected data from: %s" % selected_10m_data_set[1])
    data10 = np.rollaxis(
        ds10.ReadAsArray(xoff=xmin, yoff=ymin, xsize=xmax - xmin + 1, ysize=ymax - ymin + 1, buf_xsize=xmax - xmin + 1,
                         buf_ysize=ymax - ymin + 1), 0, 3)[:, :, validated_10m_indices]

if validated_20m_indices:
    print("Loading selected data from: %s" % selected_20m_data_set[1])
    data20 = np.rollaxis(
        ds20.ReadAsArray(xoff=xmin // 2, yoff=ymin // 2, xsize=(xmax - xmin + 1) // 2, ysize=(ymax - ymin + 1) // 2,
                         buf_xsize=(xmax - xmin + 1) // 2, buf_ysize=(ymax - ymin + 1) // 2), 0, 3)[:, :,
             validated_20m_indices]

if validated_60m_indices:
    print("Loading selected data from: %s" % selected_60m_data_set[1])
    data60 = np.rollaxis(
        ds60.ReadAsArray(xoff=xmin // 6, yoff=ymin // 6, xsize=(xmax - xmin + 1) // 6, ysize=(ymax - ymin + 1) // 6,
                         buf_xsize=(xmax - xmin + 1) // 6, buf_ysize=(ymax - ymin + 1) // 6), 0, 3)[:, :,
             validated_60m_indices]



As mentioned before, we have to change the working directory just while predicting. No big deal.

In [0]:
%cd DSen2/testing/

In [0]:
if validated_60m_bands and validated_20m_bands and validated_10m_bands:
    print("Super-resolving the 60m data into 10m bands")
    sr60 = DSen2_60(data10, data20, data60, deep=deep)
else:
    sr60 = None

if validated_10m_bands and validated_20m_bands:
    print("Super-resolving the 20m data into 10m bands")
    sr20 = DSen2_20(data10, data20, deep=deep)
else:
    sr20 = None


In [0]:
%cd -

Note that at the end we change back to our real working directory, the DSen20 Folder

#### Writing the output

Now that the prediction is done, we can write the outputs. We have created the output folder and set the path to it earlier. 
If we use a geospatial output we use GDAL to create the output file and write the bands into that file.  If we have opted for the numpy-specific *npz* format, we use numpy to write the file.


The order in which the bands are written is as follows:


1.   Original 10m bands, if copied with **copy_original_bands**, in order `4 3 2 8`
2.   Superresolved 20m bands, in order `5 6 7 8A 11 12 `
3.   Superresolved 60m bands, in order `1 9`

In [0]:
if output_file_format != "npz":
    revert_to_npz = True
    driver = gdal.GetDriverByName(output_file_format)
    if driver:
        metadata = driver.GetMetadata()
        if gdal.DCAP_CREATE in metadata and metadata[gdal.DCAP_CREATE] == 'YES':
            revert_to_npz = False
    if revert_to_npz:
        print("Gdal doesn't support creating %s files" % output_file_format)
        print("Writing to npz as a fallback")
        output_file_format = "npz"
    bands = None
else:
    bands = dict()
    result_dataset = None

bidx = 0
all_descriptions = []
source_band = dict()

def write_band_data(data, description, shortname=None):
    global all_descriptions
    global bidx
    all_descriptions += [description]
    if output_file_format == "npz":
        bands[description] = data
    else:
        bidx += 1
        result_dataset.GetRasterBand(bidx).SetDescription(description)
        result_dataset.GetRasterBand(bidx).WriteArray(data)


if sr60 is not None:
    sr = np.concatenate((sr20, sr60), axis=2)
    validated_sr_bands = validated_20m_bands + validated_60m_bands
else:
    sr = sr20
    validated_sr_bands = validated_20m_bands

if copy_original_bands:
    out_dims = data10.shape[2] + sr.shape[2]
else:
    out_dims = sr.shape[2]

sys.stdout.write("Writing")
result_dataset = driver.Create(output_file, data10.shape[1], data10.shape[0], out_dims, gdal.GDT_Float64)
# Translate the image upper left corner. We multiply x10 to transform from pixel position in the 10m_band to meters.
geot = list(ds10.GetGeoTransform())
geot[0] += xmin * 10
geot[3] -= ymin * 10
result_dataset.SetGeoTransform(tuple(geot))
result_dataset.SetProjection(ds10.GetProjection())

if copy_original_bands:
    sys.stdout.write(" the original 10m bands and")
    # Write the original 10m bands
    for bi, bn in enumerate(validated_10m_bands):
        write_band_data(data10[:, :, bi], validated_descriptions[bn])
print(" the super-resolved bands in %s" % output_file)
for bi, bn in enumerate(validated_sr_bands):
    write_band_data(sr[:, :, bi], "SR" + validated_descriptions[bn], "SR" + bn)


for desc in all_descriptions:
    print(desc)

if output_file_format == "npz":
    np.savez(output_file, bands=bands)


Now we close the dataset and finish the writing process.

In [0]:
result_dataset=None

The result can now be downloaded from the drive. In this example, it is around 200MB in size. This is because although we predicted all possible bands, and also included the original 10m bands, we only worked on a subset of the data.

### Chapter 5: Visualisation

After doing all that hard work, we now take a look at the result. We use the rasterio package to read in the file we just created.


In [0]:
import rasterio
from rasterio.plot import show
from rasterio.plot import show_hist

#### Data

We check the file and see that it has 12 bands. If we had decided not to copy the original bands, or not to predict the 60m bands, we would have fewer.

In [0]:
src = rasterio.open(output_file)
src.count

We can also see that the file has the same dimension as the subset we specified at the beginning of Chapter 2. At least, it should have!

In [0]:
src.shape

####Histogram

We check a histogram of our create image and see that most DN for all bands are between 0 and 5000.

In [0]:
show_hist(src, bins=50, lw=0.0, stacked=False, alpha=0.3,
      histtype='stepfilled', title="Histogram")

#### True Color Images

First, we enjoy a true color composite of the original bands, to make sure that it worked out and to have an overview of the study area.

For doing this, we first use a little helper function to normalize the data into a range of 0-10

In [0]:
import matplotlib.pyplot as plt
b2_red = src.read(1)
b3_green = src.read(2)
b4_blue = src.read(3)

# Function to normalize the grid values
def normalize(array):
    """Normalizes numpy arrays into scale 0.0 - 1.0"""
    array_min, array_max = array.min(), array.max()
    return (10*(array - array_min)/(array_max - array_min))

    
# Normalize the bands
b2_redn = normalize(b2_red)
b3_greenn = normalize(b3_green)
b4_bluen = normalize(b4_blue)

print("Normalized bands: \n")

print(b2_redn.min(), '-', b2_redn.max(), 'mean:', b2_redn.mean())
print(b3_greenn.min(), '-', b3_greenn.max(), 'mean:', b3_greenn.mean())
print(b4_bluen.min(), '-', b4_bluen.max(), 'mean:', b4_bluen.mean())

# Create RGB natural color composite
rgb = np.dstack((b2_redn, b3_greenn, b4_bluen))
plt.title('True Color Composite of the entire study area')
print("\n\n")
# Let's see how our color composite looks like
plt.imshow(rgb)

#### Superresolved Bands

What we are really interested in are the superresolved bands. After applying a little histogram stretch and normalizing to 0-255, which aids visualisation, we plot a few bands side by side:

B8, B8a and B9 side by side.

*   **B8**, an original 10m band
*   **B8a**, originally 20m, now superresolved to 10m
*   **B9**, originally 60m, now superresolved to 10m

As all of these bands are in the NIR range of around 900nm wavelength, they should be roughly comparable.


In [0]:
b8_nir = src.read(4)
b8a_nir = src.read(8)
b9_wv = src.read(12)

# Normalize the bands

def stretch(a, lower_thresh, upper_thresh):
    r = 255.0/(upper_thresh-lower_thresh+2) # unit of stretching
    out = np.round(r*(a-lower_thresh+1)).astype(a.dtype) # stretched values
    out[a<lower_thresh] = 0
    out[a>upper_thresh] = 255
    return out

b8_nirn = stretch(b8_nir,np.quantile(b8_nir,0.1),np.quantile(b8_nir,0.9))
b8a_nirn = stretch(b8a_nir,np.quantile(b8a_nir,0.1),np.quantile(b8a_nir,0.9))
b9_wvn = stretch(b9_wv,np.quantile(b9_wv,0.1),np.quantile(b9_wv,0.9))

print("Stretched bands")
print(b8_nirn.min(), '-', b8_nirn.max(), 'mean:', b8_nirn.mean())
print(b8a_nirn.min(), '-', b8a_nirn.max(), 'mean:', b8a_nirn.mean())
print(b9_wvn.min(), '-', b9_wvn.max(), 'mean:', b9_wvn.mean())


from matplotlib import pyplot
fig, (axr, axg, axb) = pyplot.subplots(1,3, figsize=(21,7))
show(b8_nirn, ax=axr, cmap='Reds', title='B8 original: 833 nm (10m)')
show(b8a_nirn, ax=axg, cmap='Greens', title='B8a superresolved: 865 nm (20m -> 10m)')
show(b9_wvn, ax=axb, cmap='Blues', title='B9 superresolved: 945 nm (60m -> 10m)')
pyplot.show()

At this range, we see that the results are visually pleasing, but at this scale we see little difference. We now also try the same with a smaller subset, effectively zooming in.

In [0]:
b8_nirn_sel= b8_nirn[1000:1200,1000:1200]
b8a_nirn_sel = b8a_nirn[1000:1200,1000:1200]
b9_wvn_sel = b9_wvn[1000:1200,1000:1200]

fig, (axr, axg, axb) = pyplot.subplots(1,3, figsize=(21,7))
show(b8_nirn_sel, ax=axr, cmap='Oranges', title='B8 original: 833 nm (10m)')
show(b8a_nirn_sel, ax=axg, cmap='Oranges', title='B8a superresolved: 865 nm (20m -> 10m)')
show(b9_wvn_sel, ax=axb, cmap='Oranges', title='B9 superresolved: 945 nm (60m -> 10m)')
pyplot.show()

We see that the superresolved bands do not look too bad compared to the original 10m band! Nice!

### Chapter 6: Streamlining

So far we have executed the code step-by step. This made it easy to understand the code, but to perform superresolution on large numbers of files, it is not practical.


#### Wrapping the process into a function

 To speed things up for the future, we put the essential parts of the superresolution procedure into a function `DSen2_sl`. We do not include things like query options, and intermediate outputs.


In [0]:
#@title
import os
def DSen2_sl( data_file,output_file,roi_lon_lat="",roi_x_y="",run_60=True,copy_original_bands=True,output_file_format="GTiff"):
  "function_docstring"
  print("=========STARTING DSen2==========")
  print("data file:"+data_file)
  print("output file:"+output_file)
  print("roi_lon_lat:"+roi_lon_lat)
  print("roi_x_y:"+roi_x_y)
  print("run_60:"+str(run_60))
  print("output_file_format:"+str(output_file_format))
  print("copy_original_bands file:"+str(copy_original_bands))

  if run_60:
      select_bands = 'B1,B2,B3,B4,B5,B6,B7,B8,B8A,B9,B11,B12'
  else:
      select_bands = 'B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12'

  # convert comma separated band list into a list
  select_bands = [x for x in re.split(',', select_bands)]


  if roi_lon_lat:
      roi_lon1, roi_lat1, roi_lon2, roi_lat2 = [float(x) for x in re.split(',', roi_lon_lat)]
  else:
      roi_lon1, roi_lat1, roi_lon2, roi_lat2 = -180, -90, 180, 90

  if roi_x_y:
      roi_x1, roi_y1, roi_x2, roi_y2 = [float(x) for x in re.split(',', roi_x_y)]

  raster = gdal.Open(data_file)
  datasets = raster.GetSubDatasets();
  tenMsets = []
  twentyMsets = []
  sixtyMsets = []
  unknownMsets = []
  for (dsname, dsdesc) in datasets:
      if '10m resolution' in dsdesc:
          tenMsets += [(dsname, dsdesc)]
      elif '20m resolution' in dsdesc:
          twentyMsets += [(dsname, dsdesc)]
      elif '60m resolution' in dsdesc:
          sixtyMsets += [(dsname, dsdesc)]
      else:
          unknownMsets += [(dsname, dsdesc)]



            # case where we have several UTM in the data set
  # => select the one with maximal coverage of the study zone
  utm_idx = 0
  utm = select_UTM
  all_utms = defaultdict(int)
  xmin, ymin, xmax, ymax = 0, 0, 0, 0
  largest_area = -1
  # process even if there is only one 10m set, in order to get roi -> pixels
  for (tmidx, (dsname, dsdesc)) in enumerate(tenMsets + unknownMsets):
      ds = gdal.Open(dsname)
      if roi_x_y:
          tmxmin = max(min(roi_x1, roi_x2, ds.RasterXSize - 1), 0)
          tmxmax = min(max(roi_x1, roi_x2, 0), ds.RasterXSize - 1)
          tmymin = max(min(roi_y1, roi_y2, ds.RasterYSize - 1), 0)
          tmymax = min(max(roi_y1, roi_y2, 0), ds.RasterYSize - 1)
          # enlarge to the nearest 60 pixel boundary for the super-resolution
          tmxmin = int(tmxmin / 6) * 6
          tmxmax = int((tmxmax + 1) / 6) * 6 - 1
          tmymin = int(tmymin / 6) * 6
          tmymax = int((tmymax + 1) / 6) * 6 - 1
      elif not roi_lon_lat:
          tmxmin = 0
          tmxmax = ds.RasterXSize - 1
          tmymin = 0
          tmymax = ds.RasterYSize - 1
      else:
          xoff, a, b, yoff, d, e = ds.GetGeoTransform()
          srs = osr.SpatialReference()
          srs.ImportFromWkt(ds.GetProjection())
          srsLatLon = osr.SpatialReference()
          srsLatLon.SetWellKnownGeogCS("WGS84");
          ct = osr.CoordinateTransformation(srsLatLon, srs)


          def to_xy(lon, lat):
              (xp, yp, h) = ct.TransformPoint(lon, lat, 0.)
              xp -= xoff
              yp -= yoff
              # matrix inversion
              det_inv = 1. / (a * e - d * b)
              x = (e * xp - b * yp) * det_inv
              y = (-d * xp + a * yp) * det_inv
              return (int(x), int(y))


          x1, y1 = to_xy(roi_lon1, roi_lat1)
          x2, y2 = to_xy(roi_lon2, roi_lat2)
          tmxmin = max(min(x1, x2, ds.RasterXSize - 1), 0)
          tmxmax = min(max(x1, x2, 0), ds.RasterXSize - 1)
          tmymin = max(min(y1, y2, ds.RasterYSize - 1), 0)
          tmymax = min(max(y1, y2, 0), ds.RasterYSize - 1)
          # enlarge to the nearest 60 pixel boundary for the super-resolution
          tmxmin = int(tmxmin / 6) * 6
          tmxmax = int((tmxmax + 1) / 6) * 6 - 1
          tmymin = int(tmymin / 6) * 6
          tmymax = int((tmymax + 1) / 6) * 6 - 1
      area = (tmxmax - tmxmin + 1) * (tmymax - tmymin + 1)
      current_utm = dsdesc[dsdesc.find("UTM"):]
      if area > all_utms[current_utm]:
          all_utms[current_utm] = area
      if current_utm == select_UTM:
          xmin, ymin, xmax, ymax = tmxmin, tmymin, tmxmax, tmymax
          utm_idx = tmidx
          utm = current_utm
          break
      if area > largest_area:
          xmin, ymin, xmax, ymax = tmxmin, tmymin, tmxmax, tmymax
          largest_area = area
          utm_idx = tmidx
          utm = dsdesc[dsdesc.find("UTM"):]

  selected_10m_data_set = None
  if not tenMsets:
      selected_10m_data_set = unknownMsets[0]
  else:
      selected_10m_data_set = tenMsets[utm_idx]
  selected_20m_data_set = None
  for (dsname, dsdesc) in enumerate(twentyMsets):
      if utm in dsdesc:
          selected_20m_data_set = (dsname, dsdesc)
  # if not found, assume the listing is in the same order
  # => OK if only one set
  if not selected_20m_data_set: selected_20m_data_set = twentyMsets[utm_idx]
  selected_60m_data_set = None
  for (dsname, dsdesc) in enumerate(sixtyMsets):
      if utm in dsdesc:
          selected_60m_data_set = (dsname, dsdesc)
  if not selected_60m_data_set: selected_60m_data_set = sixtyMsets[utm_idx]

  ds10 = gdal.Open(selected_10m_data_set[0])
  ds20 = gdal.Open(selected_20m_data_set[0])
  ds60 = gdal.Open(selected_60m_data_set[0])



  def validate_description(description):
      m = re.match("(.*?), central wavelength (\d+) nm", description)
      if m:
          return m.group(1) + " (" + m.group(2) + " nm)"
      # Some HDR restrictions... ENVI band names should not include commas
      if output_file_format == 'ENVI' and ',' in description:
          pos = description.find(',')
          return description[:pos] + description[(pos + 1):]
      return description

  def get_band_short_name(description):
      if ',' in description:
          return description[:description.find(',')]
      if ' ' in description:
          return description[:description.find(' ')]
      return description[:3]


  validated_10m_bands = []
  validated_10m_indices = []
  validated_20m_bands = []
  validated_20m_indices = []
  validated_60m_bands = []
  validated_60m_indices = []
  validated_descriptions = defaultdict(str)

  for b in range(0, ds10.RasterCount):
      desc = validate_description(ds10.GetRasterBand(b + 1).GetDescription())
      shortname = get_band_short_name(desc)
      if shortname in select_bands:
          sys.stdout.write(" " + shortname)
          select_bands.remove(shortname)
          validated_10m_bands += [shortname]
          validated_10m_indices += [b]
          validated_descriptions[shortname] = desc
  for b in range(0, ds20.RasterCount):
      desc = validate_description(ds20.GetRasterBand(b + 1).GetDescription())
      shortname = get_band_short_name(desc)
      if shortname in select_bands:
          sys.stdout.write(" " + shortname)
          select_bands.remove(shortname)
          validated_20m_bands += [shortname]
          validated_20m_indices += [b]
          validated_descriptions[shortname] = desc
  for b in range(0, ds60.RasterCount):
      desc = validate_description(ds60.GetRasterBand(b + 1).GetDescription())
      shortname = get_band_short_name(desc)
      if shortname in select_bands:
          sys.stdout.write(" " + shortname)
          select_bands.remove(shortname)
          validated_60m_bands += [shortname]
          validated_60m_indices += [b]
          validated_descriptions[shortname] = desc
  sys.stdout.write("\n")

  # Some HDR restrictions... ENVI file name should be the .bin, not the .hdr
  if output_file_format == 'ENVI' and (output_file[-4:] == '.hdr' or output_file[-4:] == '.HDR'):
      output_file = output_file[:-4] + '.bin'


  if validated_10m_indices:
      data10 = np.rollaxis(
          ds10.ReadAsArray(xoff=xmin, yoff=ymin, xsize=xmax - xmin + 1, ysize=ymax - ymin + 1, buf_xsize=xmax - xmin + 1,
                          buf_ysize=ymax - ymin + 1), 0, 3)[:, :, validated_10m_indices]

  if validated_20m_indices:
      data20 = np.rollaxis(
          ds20.ReadAsArray(xoff=xmin // 2, yoff=ymin // 2, xsize=(xmax - xmin + 1) // 2, ysize=(ymax - ymin + 1) // 2,
                          buf_xsize=(xmax - xmin + 1) // 2, buf_ysize=(ymax - ymin + 1) // 2), 0, 3)[:, :,
              validated_20m_indices]

  if validated_60m_indices:
      data60 = np.rollaxis(
          ds60.ReadAsArray(xoff=xmin // 6, yoff=ymin // 6, xsize=(xmax - xmin + 1) // 6, ysize=(ymax - ymin + 1) // 6,
                          buf_xsize=(xmax - xmin + 1) // 6, buf_ysize=(ymax - ymin + 1) // 6), 0, 3)[:, :,
              validated_60m_indices]

  maindir = os.getcwd()
  subdir = maindir+"/DSen2/testing/"
  os.chdir(subdir)
  ##CODE START
  print("Predicting.")
  if validated_60m_bands and validated_20m_bands and validated_10m_bands:
      sr60 = DSen2_60(data10, data20, data60, deep=False)
  else:
      sr60 = None

  if validated_10m_bands and validated_20m_bands:
      sr20 = DSen2_20(data10, data20, deep=False)
  else:
      sr20 = None
  ##CODE END
  os.chdir(maindir)



  if output_file_format != "npz":
      revert_to_npz = True
      driver = gdal.GetDriverByName(output_file_format)
      if driver:
          metadata = driver.GetMetadata()
          if gdal.DCAP_CREATE in metadata and metadata[gdal.DCAP_CREATE] == 'YES':
              revert_to_npz = False
      if revert_to_npz:
          print("Gdal doesn't support creating %s files" % output_file_format)
          print("Writing to npz as a fallback")
          output_file_format = "npz"
      bands = None
  else:
      bands = dict()
      result_dataset = None
  global bidx
  bidx = 0
  all_descriptions = []
  source_band = dict()
  def write_band_data(data, description, shortname=None):
      global all_descriptions
      global bidx
      all_descriptions += [description]
      if output_file_format == "npz":
          bands[description] = data
      else:
          bidx += 1
          result_dataset.GetRasterBand(bidx).SetDescription(description)
          result_dataset.GetRasterBand(bidx).WriteArray(data)


  if sr60 is not None:
      sr = np.concatenate((sr20, sr60), axis=2)
      validated_sr_bands = validated_20m_bands + validated_60m_bands
  else:
      sr = sr20
      validated_sr_bands = validated_20m_bands

  if copy_original_bands:
      out_dims = data10.shape[2] + sr.shape[2]
  else:
      out_dims = sr.shape[2]

  print("Writing to:"+str(output_file))
  result_dataset = driver.Create(output_file, data10.shape[1], data10.shape[0], out_dims, gdal.GDT_Float64)
  # Translate the image upper left corner. We multiply x10 to transform from pixel position in the 10m_band to meters.
  geot = list(ds10.GetGeoTransform())
  geot[0] += xmin * 10
  geot[3] -= ymin * 10
  result_dataset.SetGeoTransform(tuple(geot))
  result_dataset.SetProjection(ds10.GetProjection())
  if copy_original_bands:
      # Write the original 10m bands
      for bi, bn in enumerate(validated_10m_bands):
        write_band_data(data10[:, :, bi], validated_descriptions[bn])
  for bi, bn in enumerate(validated_sr_bands):
      write_band_data(sr[:, :, bi], "SR" + validated_descriptions[bn], "SR" + bn)


  if output_file_format == "npz":
      np.savez(output_file, bands=bands)

  result_dataset=None
  return(output_file)


 We still repeat some steps like the creation of a new model, which is not optimal. Still, it allows us to semi-automate the superresolution.

#### Processing multiple files

Now with this function, it is easy to loop over a list of previously downloaded files, applying the superresolution algorithm to all of them and writing their results to the drive.

In [0]:
downloaded_files

In [0]:
for i in range(0, len(downloaded_files)):
    DSen2_sl(downloaded_files[i],"Outputs/DSen2_Output"+str(i)+".Tif",roi_x_y="1,1,300,300")

That's it! We can now do this for as many files as we can get our hands on. When they are done, we just download them from our drive!

## Final Note: Emulating the publication

In this script, we have made a good number of changes to the source code. Nevertheless, it is possible to closely follow the procedure as described in the publication. We just have to make sure a few settings match:

In the Training section:

* **NR_CROP** should be set to `8000`.
* **SCALE** should be set to `2000`
* **n_epochs** should be set to `8192`.

In the Superresolution section:

* **roi_lon_lat** should be an empty string `""` to predict on whole images.
* **roi_x_y** should be an empty string `""` to predict on whole images.
* **select_UTM** should also be an empty string to let the algorithm choose the optimal UTM zone.
* **run_60** and **copy_original_bands** should be set to `True` to get a full stack of 10m resolution.

The publication compares the performances of the deep DSen2 model and the very deep VDSen2 model. If we want to use the latter, we  have to do two things:

* Acquire the weights for the VCSen2 model and place them in the `DSen2/models` directory.
* Set the switch **deep** to `True`.

Using these settings, the algorithm can be executed as described in the publication.

To recreate the study itself, we can train from scratch and on the same Sentinel-2 images as the study, a [list](https://github.com/lanha/DSen2/blob/master/S2_tiles_training.txt) of which is available on the Github. Another [list](https://github.com/lanha/DSen2/blob/master/S2_tiles_testing.txt) is available for the images which were used for testing. Using these images it should be possible to closely recreate the study.