In [2]:
import numpy as np
import copy
import os
import pickle
import warnings
import matplotlib.pyplot as plt
from . import get_functions
from . import tools

# TODO: Decide the amount of checking and control in the class


class SarImage:
    """ Class to contain SAR image, relevant meta data and methods.
    Attributes:
        bands(list of numpy arrays): The measurements.
        mission(str): Mission name:
        time(datetime): start time of acquisition
        footprint(dict): dictionary with footprint of image
                        footprint = {'latitude': np.array
                                    'longitude': np.array}
        product_meta(dict): Dictionary with meta data.
        band_names(list of str): Names of the band. Normally the polarisation.
        calibration_tables(list of dict): Dictionary with calibration_tables information for each band.
        geo_tie_point(list of dict): Dictionary with geo tie point for each band.
        band_meta(list of dict): Dictionary with meta data for each band.
    """

    def __init__(self, bands, mission=None, time=None, footprint=None, product_meta=None,
                 band_names=None, calibration_tables=None, geo_tie_point=None, band_meta=None, unit=None):

        # assign values
        self.bands = bands
        self.mission = mission
        self.time = time
        self.footprint = footprint
        self.product_meta = product_meta
        self.band_names = band_names
        self.calibration_tables = calibration_tables
        self.geo_tie_point = geo_tie_point
        self.band_meta = band_meta
        self.unit = unit
        # Note that SlC is in strips. Maybe load as list of images

    def __repr__(self):
        return "Mission: %s \n Bands: %s" % (self.mission, str(self.band_names))

    def __getitem__(self, key):
        # Overload the get and slicing function a[2,4] a[:10,3:34]

        # Check i 2 dimension are given
        if len(key) != 2:
            raise ValueError('Need to slice both column and row like test_image[:,:]')

        # Get values as array
        if isinstance(key[0], int) & isinstance(key[1], int):
            return [band[key] for band in self.bands]

        if not isinstance(key[0], slice) & isinstance(key[1], slice):
            raise ValueError('Only get at slice is supported: a[2,4] a[:10,3:34]')

        # Else. Try to slice the image
        slice_row = key[0]
        slice_column = key[1]

        row_start = slice_row.start
        row_step = slice_row.step
        row_stop = slice_row.stop

        column_start = slice_column.start
        column_step = slice_column.step
        column_stop = slice_column.stop

        if row_start is None:
            row_start = 0
        if row_step is None:
            row_step = 1
        if row_stop is None:
            row_stop = self.bands[0].shape[0]
        if column_start is None:
            column_start = 0
        if column_step is None:
            column_step = 1
        if column_stop is None:
            column_stop = self.bands[0].shape[1]

        # Adjust footprint to window
        footprint_lat = np.zeros(4)
        footprint_long = np.zeros(4)

        window = ((row_start, row_stop), (column_start, column_stop))

        for i in range(2):
            for j in range(2):
                lat_i, long_i = self.get_coordinate(window[0][i], window[1][j])
                footprint_lat[2 * i + j] = lat_i
                footprint_long[2 * i + j] = long_i

        footprint = {'latitude': footprint_lat, 'longitude': footprint_long}

        # Adjust geo_tie_point, calibration_tables
        n_bands = len(self.bands)
        geo_tie_point = copy.deepcopy(self.geo_tie_point)
        calibration_tables = copy.deepcopy(self.calibration_tables)
        for i in range(n_bands):
            geo_tie_point[i]['row'] = (geo_tie_point[i]['row'] - row_start)/row_step
            geo_tie_point[i]['column'] = (geo_tie_point[i]['column'] - column_start)/column_step

            calibration_tables[i]['row'] = (calibration_tables[i]['row'] - row_start)/row_step
            calibration_tables[i]['column'] = (calibration_tables[i]['column'] - column_start)/column_step

        # slice the bands
        bands = [band[key] for band in self.bands]

        return SarImage(bands, mission=self.mission, time=self.time,
                        footprint=footprint, product_meta=self.product_meta,
                        band_names=self.band_names, calibration_tables=calibration_tables,
                        geo_tie_point=geo_tie_point, band_meta=self.band_meta)

    def get_index(self, lat, long):
        """Get index of a location by interpolating grid-points
        Args:
            lat(number): Latitude of the location
            long(number): Longitude of location
        Returns:
            row(int): The row index of the location
            column(int): The column index of the location
        Raises:
        """
        geo_tie_point = self.geo_tie_point
        row = np.zeros(len(geo_tie_point), dtype=int)
        column = np.zeros(len(geo_tie_point), dtype=int)

        # find index for each band
        for i in range(len(geo_tie_point)):
            lat_grid = geo_tie_point[i]['latitude']
            long_grid = geo_tie_point[i]['longitude']
            row_grid = geo_tie_point[i]['row']
            column_grid = geo_tie_point[i]['column']
            row[i], column[i] = get_functions.get_index_v2(lat, long, lat_grid, long_grid, row_grid, column_grid)

        # check that the results are the same
        if (abs(row.max() - row.min()) > 0.5) or (abs(column.max() - column.min()) > 0.5):
            warnings.warn('Warning different index found for each band. First index returned')

        return row[0], column[0]

    def get_coordinate(self, row, column):
        """Get coordinate from index by interpolating grid-points
            Args:
                row(number): index of the row of interest position
                column(number): index of the column of interest position
            Returns:
                lat(float): Latitude of the position
                long(float): longitude of the position
            Raises:
            """

        geo_tie_point = self.geo_tie_point
        lat = np.zeros(len(geo_tie_point), dtype=float)
        long = np.zeros(len(geo_tie_point), dtype=float)

        # find index for each band
        for i in range(len(geo_tie_point)):
            lat_grid = geo_tie_point[i]['latitude']
            long_grid = geo_tie_point[i]['longitude']
            row_grid = geo_tie_point[i]['row']
            column_grid = geo_tie_point[i]['column']
            lat[i], long[i] = get_functions.get_coordinate(row, column, lat_grid, long_grid, row_grid, column_grid)

        # check that the results are the same
        if (abs(lat.max() - lat.min()) > 0.001) or (abs(long.max() - long.min()) > 0.001):
            warnings.warn('Warning different coordinates found for each band. Mean returned')

        return lat.mean(), long.mean()

    def simple_plot(self, band_index=0, q_max=0.95, stride=1, **kwargs):
        """ Makes a simple image of band and a color bar.
            Args:
                band_index(int): index of the band to plot.
                q_max(number): q_max is the quantile used to set the max of the color range for example
                                q_max = 0.95 shows the lowest 95 percent of pixel values in the color range
                stride(int): Used to skip pixels when showing. Good for large images.
                **kwargs: Passed on to matplotlib imshow
            Returns:
            Raises:
            """
        v_max = np.quantile(self.bands[band_index].reshape(-1), q_max)

        plt.imshow(self.bands[band_index][::stride, ::stride], vmax=v_max, **kwargs)
        plt.colorbar()
        plt.show()

        return

    def calibrate(self, mode='gamma', tiles=4):
        """Get coordinate from index by interpolating grid-points
        Args:
            mode(string): 'sigma_0', 'beta' or 'gamma'
            tiles(int): number of tiles the image is divided into. This saves memory but reduce speed a bit
        Returns:
            Calibrated image as (SarImage)
        Raises:
        """
        if 'raw' not in self.unit:
            warnings.warn('Raw is not in units. The image have all ready been calibrated')
        
        calibrated_bands = []
        for i, band in enumerate(self.bands):
            row = self.calibration_tables[i]['row']
            column = self.calibration_tables[i]['column']
            calibration_values = self.calibration_tables[i][mode]
            calibrated_bands.append(tools.calibration(band, row, column, calibration_values, tiles=tiles))

        return SarImage(calibrated_bands, mission=self.mission, time=self.time,
                        footprint=self.footprint, product_meta=self.product_meta,
                        band_names=self.band_names, calibration_tables=self.calibration_tables,
                        geo_tie_point=self.geo_tie_point, band_meta=self.band_meta,
                        unit=mode)

    def to_db(self):
        """Convert  to decibel
                """
        db_bands = []
        for band in self.bands:
            if 'amplitude' in self.unit:
                db_bands.append(20*np.log(band))
            else:
                db_bands.append(10 * np.log(band))

        return SarImage(db_bands, mission=self.mission, time=self.time,
                        footprint=self.footprint, product_meta=self.product_meta,
                        band_names=self.band_names, calibration_tables=self.calibration_tables,
                        geo_tie_point=self.geo_tie_point, band_meta=self.band_meta,
                        unit=(self.unit+' dB'))

    def boxcar(self, kernel_size, **kwargs):
        """Simple (kernel_size x kernel_size) boxcar filter.
            Args:
                kernel_size(int): size of kernel
                **kwargs: Additional arguments passed to scipy.ndimage.convolve
            Returns:
                Filtered image
        """

        filter_bands = []
        for band in self.bands:
            filter_bands.append(tools.boxcar(band, kernel_size, **kwargs))

        return SarImage(filter_bands, mission=self.mission, time=self.time,
                        footprint=self.footprint, product_meta=self.product_meta,
                        band_names=self.band_names, calibration_tables=self.calibration_tables,
                        geo_tie_point=self.geo_tie_point, band_meta=self.band_meta,
                        unit=self.unit)

    def save(self, path):
        """Save the SarImage object in a folder at path.
            Args:
                path(str): Path of the folder where the the SarImage is saved.
                        Note that the folder is created and must not exist in advance
            Raises:
                ValueError: There already exist a folder at path
        """

        # Check if folder exists
        if os.path.exists(path):
            print('please give a path that is not used')
            raise ValueError

        # make folder
        os.makedirs(path)

        # save elements in separate files

        # product_meta
        file_path = os.path.join(path, 'product_meta.pkl')
        pickle.dump(self.product_meta, open(file_path, "wb"))

        # unit
        file_path = os.path.join(path, 'unit.pkl')
        pickle.dump(self.unit, open(file_path, "wb"))

        # footprint
        file_path = os.path.join(path, 'footprint.pkl')
        pickle.dump(self.footprint, open(file_path, "wb"))

        # geo_tie_point
        file_path = os.path.join(path, 'geo_tie_point.pkl')
        pickle.dump(self.geo_tie_point, open(file_path, "wb"))

        # band_names
        file_path = os.path.join(path, 'band_names.pkl')
        pickle.dump(self.band_names, open(file_path, "wb"))

        # band_meta
        file_path = os.path.join(path, 'band_meta.pkl')
        pickle.dump(self.band_meta, open(file_path, "wb"))

        # bands
        file_path = os.path.join(path, 'bands.pkl')
        pickle.dump(self.bands, open(file_path, "wb"))

        # reduce size of calibration_tables list
        reduced_calibration = []
        for i in range(len(self.bands)):
            cal = self.calibration_tables[i]

            # Get mask of rows in the image.
            index_row = (0 < cal['row']) & (cal['row'] < self.bands[i].shape[0])
            # Include one extra row on each side of the image to ensure interpolation
            index_row[1:] = index_row[1:] + index_row[:-1]
            index_row[:-1] = index_row[:-1] + index_row[1:]

            # Get mask of column in the image
            index_column = (0 < cal['column']) & (cal['column'] < self.bands[i].shape[1])
            # Include one extra column on each side of the image to ensure interpolation
            index_column[1:] = index_column[1:] + index_column[:-1]
            index_column[:-1] = index_column[:-1] + index_column[1:]

            # Get the relevant calibration_tables values
            reduced_cal_i = {
                "abs_calibration_const": cal["abs_calibration_const"],
                "row": cal["row"][index_row],
                "column": cal["column"][index_column],
                "azimuth_time": cal["azimuth_time"][index_row, :][:, index_column],
                "sigma_0": cal["sigma_0"][index_row, :][:, index_column],
                "beta_0": cal["beta_0"][index_row, :][:, index_column],
                "gamma": cal["gamma"][index_row, :][:, index_column],
                "dn": cal["dn"][index_row, :][:, index_column]
            }

            reduced_calibration.append(reduced_cal_i)

        # calibration_tables
        file_path = os.path.join(path, 'calibration_tables.pkl')
        pickle.dump(reduced_calibration, open(file_path, "wb"))

        return

    def pop(self, index=-1):
        """
        Remove and return band at index (default last).
        Raises IndexError if list is empty or index is out of range.
        """

        band = self.bands.pop(index)
        name = self.band_names.pop(index)
        calibration_tables = self.calibration_tables.pop(index)
        geo_tie_point = self.geo_tie_point.pop(index)
        band_meta = self.band_meta.pop()

        return SarImage([band], mission=self.mission, time=self.time,
                        footprint=self.footprint, product_meta=self.product_meta,
                        band_names=[name], calibration_tables=[calibration_tables],
                        geo_tie_point=[geo_tie_point], band_meta=[band_meta],
                        unit=self.unit)

    def get_band(self, index):
        """
        Return SarImage of band at index (default last).
        """

        band = self.bands[index]
        name = self.band_names[index]
        calibration_tables = self.calibration_tables[index]
        geo_tie_point = self.geo_tie_point[index]
        band_meta = self.band_meta[index]

        return SarImage([band], mission=self.mission, time=self.time,
                        footprint=self.footprint, product_meta=self.product_meta,
                        band_names=[name], calibration_tables=[calibration_tables],
                        geo_tie_point=[geo_tie_point], band_meta=[band_meta],
                        unit=self.unit)

ImportError: attempted relative import with no known parent package

In [3]:
import numpy as np
import xml.etree.ElementTree
import warnings
import datetime
import lxml.etree

def _load_calibration(path):
    """Load sentinel 1 calibration_table file as dictionary from PATH.
    The calibration_table file should be as included in .SAFE format
    retrieved from: https://scihub.copernicus.eu/
    Args:
        path: The path to the calibration_table file
    Returns:
        calibration_table: A dictionary with calibration_table constants
            {"abs_calibration_const": float(),
            "row": np.array(int),
            "column": np.array(int),
            "azimuth_time": np.array(datetime64[us]),
            "sigma_0": np.array(float),
            "beta_0": np.array(float),
            "gamma": np.array(float),
            "dn": np.array(float),}
        info: A dictionary with the meta data given in 'adsHeader'
            {child[0].tag: child[0].text,
             child[1].tag: child[1].text,
             ...}
    """
    # open xml file
    tree = xml.etree.ElementTree.parse(path)
    root = tree.getroot()

    # Find info
    info_xml = root.findall('adsHeader')
    if len(info_xml) == 1:
        info = {}
        for child in info_xml[0]:
            info[child.tag] = child.text
    else:
        warnings.warn('Warning adsHeader not found')
        info = None

    # Find calibration_table list
    cal_vectors = root.findall('calibrationVectorList')
    if len(cal_vectors) == 1:
        cal_vectors = cal_vectors[0]
    else:
        warnings.warn('Error loading calibration_table list')
        return None, info

    # get pixels from first vector
    pixel = np.array(list(map(int, cal_vectors[0][2].text.split())))
    # initialize arrays
    azimuth_time = np.empty([len(cal_vectors),len(pixel)], dtype='datetime64[us]')
    line = np.empty([len(cal_vectors)], dtype=int)
    sigma_0 = np.empty([len(cal_vectors),len(pixel)], dtype=float)
    beta_0 = np.empty([len(cal_vectors),len(pixel)], dtype=float)
    gamma = np.empty([len(cal_vectors),len(pixel)], dtype=float)
    dn = np.empty([len(cal_vectors),len(pixel)], dtype=float)

    # get data
    for i, cal_vec in enumerate(cal_vectors):
        pixel_i = np.array(list(map(int, cal_vec[2].text.split())))
        if not np.array_equal(pixel,pixel_i):
            warnings.warn('Warning in _load_calibration. The calibration_table data is not on a proper grid')
        azimuth_time[i,:] = np.datetime64(cal_vec[0].text)
        line[i] = int(cal_vec[1].text)
        sigma_0[i,:] = np.array(list(map(float, cal_vec[3].text.split())))
        beta_0[i,:] = np.array(list(map(float, cal_vec[4].text.split())))
        gamma[i,:] = np.array(list(map(float, cal_vec[5].text.split())))
        dn[i,:] = np.array(list(map(float, cal_vec[6].text.split())))

    # Combine calibration_table info
    calibration_table = {
        "abs_calibration_const": float(root[1][0].text),
        "row": line,
        "column": pixel,
        "azimuth_time": azimuth_time,
        "sigma_0": sigma_0,
        "beta_0": beta_0,
        "gamma": gamma,
        "dn": dn,
    }

    return calibration_table, info


def _load_meta(SAFE_path):
    """Load manifest.safe as dictionary from SAFE_path.
    The manifest.safe file should be as included in .SAFE format
    retrieved from: https://scihub.copernicus.eu/
    Args:
        path: The path to the manifest.safe file
    Returns:
        metadata: A dictionary with meta_data
            example:
            {'mode': 'EW',
             'swath': ['EW'],
             'instrument_config': 1,
             'mission_data_ID': '110917',
             'polarisation': ['HH', 'HV'],
             'product_class': 'S',
             'product_composition': 'Slice',
             'product_type': 'GRD',
             'product_timeliness': 'Fast-24h',
             'slice_product_flag': 'true',
             'segment_start_time': datetime.datetime(2019, 1, 17, 19, 12, 32, 164986),
             'slice_number': 4,
             'total_slices': 4,
             'footprint': {'latitude': array([69.219566, 69.219566, 69.219566, 69.219566]),
                            'longitude': array([-35.149223, -35.149223, -35.149223, -35.149223])},
             'nssdc_identifier': '2016-025A',
             'mission': 'SENTINEL-1B',
             'orbit_number': array([14538, 14538]),
             'relative_orbit_number': array([162, 162]),
             'cycle_number': 89,
             'phase_identifier': 1,
             'start_time': datetime.datetime(2019, 1, 17, 19, 15, 36, 268585),
             'stop_time': datetime.datetime(2019, 1, 17, 19, 16, 25, 598196),
             'pass': 'ASCENDING',
             'ascending_node_time': datetime.datetime(2019, 1, 17, 18, 57, 16, 851007),
             'start_time_ANX': 1099418.0,
             'stop_time_ANX': 1148747.0}
            error: List of dictionary keys that was not found.
    """
    # Sorry the code look like shit but I do not like the file format
    # and I do not trust that ESA will keep the structure.
    # This is the reason for all the if statements and the error list

    # Open the xml like file
    with open(SAFE_path) as f:
        safe_test = f.read()
    safe_string = safe_test.encode(errors='ignore')
    safe_xml = lxml.etree.fromstring(safe_string)

    # Initialize results
    metadata = {}
    error = []

    # Prefixes used in the tag of the file. Do not ask me why the use them
    prefix1 = '{http://www.esa.int/safe/sentinel-1.0}'
    prefix2 = '{http://www.esa.int/safe/sentinel-1.0/sentinel-1}'
    prefix3 = '{http://www.esa.int/safe/sentinel-1.0/sentinel-1/sar/level-1}'

    # Put the data into the metadata

    # Get nssdc_identifier
    values = [elem for elem in safe_xml.iterfind(".//" + prefix1 + 'nssdcIdentifier')]
    if len(values) == 1:
        metadata['nssdc_identifier'] = values[0].text
    else:
        error.append('nssdcIdentifier')

    # Get mission
    values = [elem for elem in safe_xml.iterfind(".//" + prefix1 + 'familyName')]
    values2 = [elem for elem in safe_xml.iterfind(".//" + prefix1 + 'number')]
    if (len(values) > 0) & (len(values2) == 1):
        metadata['mission'] = values[0].text + values2[0].text
    else:
        error.append('mission')

    # get orbit_number
    values = [elem for elem in safe_xml.iterfind(".//" + prefix1 + 'orbitNumber')]
    if len(values) == 2:
        metadata['orbit_number'] = np.array([int(values[0].text), int(values[1].text)])
    else:
        error.append('orbit_number')

    # get relative_orbit_number
    values = [elem for elem in safe_xml.iterfind(".//" + prefix1 + 'relativeOrbitNumber')]
    if len(values) == 2:
        metadata['relative_orbit_number'] = np.array([int(values[0].text), int(values[1].text)])
    else:
        error.append('relative_orbit_number')

    # get cycle_number
    values = [elem for elem in safe_xml.iterfind(".//" + prefix1 + 'cycleNumber')]
    if len(values) == 1:
        metadata['cycle_number'] = int(values[0].text)
    else:
        error.append('cycle_number')

    # get phase_identifier
    values = [elem for elem in safe_xml.iterfind(".//" + prefix1 + 'phaseIdentifier')]
    if len(values) == 1:
        metadata['phase_identifier'] = int(values[0].text)
    else:
        error.append('phase_identifier')

    # get start_time
    values = [elem for elem in safe_xml.iterfind(".//" + prefix1 + 'startTime')]
    if len(values) == 1:
        t = values[0].text
        metadata['start_time'] = datetime.datetime(int(t[:4]), int(t[5:7]), int(t[8:10]), int(t[11:13]),
                                                   int(t[14:16]), int(t[17:19]), int(float(t[19:]) * 10 ** 6))
    else:
        error.append('start_time')

    # get stop_time
    values = [elem for elem in safe_xml.iterfind(".//" + prefix1 + 'stopTime')]
    if len(values) == 1:
        t = values[0].text
        metadata['stop_time'] = datetime.datetime(int(t[:4]), int(t[5:7]), int(t[8:10]), int(t[11:13]), int(t[14:16]),
                                                  int(t[17:19]), int(float(t[19:]) * 10 ** 6))
    else:
        error.append('stop_time')

    # get pass
    values = [elem for elem in safe_xml.iterfind(".//" + prefix2 + 'pass')]
    if len(values) == 1:
        metadata['pass'] = values[0].text
    else:
        error.append('pass')

    # get ascending_node_time
    values = [elem for elem in safe_xml.iterfind(".//" + prefix2 + 'ascendingNodeTime')]
    if len(values) == 1:
        t = values[0].text
        metadata['ascending_node_time'] = datetime.datetime(int(t[:4]), int(t[5:7]), int(t[8:10]), int(t[11:13]),
                                                            int(t[14:16]), int(t[17:19]), int(float(t[19:]) * 10 ** 6))
    else:
        error.append('ascending_node_time')

    # get start_time_ANX
    values = [elem for elem in safe_xml.iterfind(".//" + prefix2 + 'startTimeANX')]
    if len(values) == 1:
        metadata['start_time_ANX'] = float(values[0].text)
    else:
        error.append('start_time_ANX')

    # get stop_time_ANX
    values = [elem for elem in safe_xml.iterfind(".//" + prefix2 + 'stopTimeANX')]
    if len(values) == 1:
        metadata['stop_time_ANX'] = float(values[0].text)
    else:
        error.append('stop_time_ANX')

    # get mode
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'mode')]
    if len(values) == 1:
        metadata['mode'] = values[0].text
    else:
        error.append('mode')

    # get swath
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'swath')]
    if len(values) > 0:
        metadata['swath'] = [child.text for child in values]
    else:
        error.append('swath')

    # get instrument_config
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'instrumentConfigurationID')]
    if len(values) == 1:
        metadata['instrument_config'] = int(values[0].text)
    else:
        error.append('instrument_config')

    # get mission_data_ID
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'missionDataTakeID')]
    if len(values) == 1:
        metadata['mission_data_ID'] = values[0].text
    else:
        error.append('mission_data_ID')

    # get polarisation
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'transmitterReceiverPolarisation')]
    if len(values) > 0:
        metadata['polarisation'] = [child.text for child in values]
    else:
        error.append('polarisation')

    # get product_class
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'productClass')]
    if len(values) == 1:
        metadata['product_class'] = values[0].text
    else:
        error.append('product_class')

    # get product_composition
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'productComposition')]
    if len(values) == 1:
        metadata['product_composition'] = values[0].text
    else:
        error.append('product_composition')

    # get product_type
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'productType')]
    if len(values) == 1:
        metadata['product_type'] = values[0].text
    else:
        error.append('product_type')

    # get product_timeliness
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'productTimelinessCategory')]
    if len(values) == 1:
        metadata['product_timeliness'] = values[0].text
    else:
        error.append('product_timeliness')

    # get slice_product_flag
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'sliceProductFlag')]
    if len(values) == 1:
        metadata['slice_product_flag'] = values[0].text
    else:
        error.append('slice_product_flag')

    # get segment_start_time
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'segmentStartTime')]
    if len(values) == 1:
        t = values[0].text
        metadata['segment_start_time'] = datetime.datetime(int(t[:4]), int(t[5:7]), int(t[8:10]), int(t[11:13]),
                                                           int(t[14:16]), int(t[17:19]), int(float(t[19:]) * 10 ** 6))
    else:
        error.append('segment_start_time')

    # get slice_number
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'sliceNumber')]
    if len(values) == 1:
        metadata['slice_number'] = int(values[0].text)
    else:
        error.append('slice_number')

    # get total_slices
    values = [elem for elem in safe_xml.iterfind(".//" + prefix3 + 'totalSlices')]
    if len(values) == 1:
        metadata['total_slices'] = int(values[0].text)
    else:
        error.append('total_slices')

    # get footprint
    values = [elem for elem in safe_xml.iterfind(".//" + '{http://www.opengis.net/gml}coordinates')]
    if len(values) == 1:
        coordinates = values[0].text.split()
        lat = np.zeros(4)
        lon = np.zeros(4)
        for i in range(0, len(coordinates)):
            coord_i = coordinates[i].split(',')
            lat[i] = float(coord_i[0])
            lon[i] = float(coord_i[1])
        footprint = {'latitude': lat, 'longitude': lon}
        metadata['footprint'] = footprint
    else:
        error.append('footprint')

    return metadata, error


def _load_annotation(path):
    """Load sentinel 1 annotation file as dictionary from PATH.
     The annotation file should be as included in .SAFE format
     retrieved from: https://scihub.copernicus.eu/
     Note that the file contains more information. Only the relevant have been chosen
     Args:
         path: The path to the annotation file
     Returns:
         geo_locations: A dictionary with geo location tie-points
             {'azimuth_time': np.array(datetime64[us]),
            'slant_range_time': np.array(float),
            'row': np.array(int),
            'column': np.array(int),
            'latitude': np.array(float),
            'longitude': np.array(float),
            'height': np.array(float),
            'incidence_angle': np.array(float),
            'elevation_angle': np.array(float)}
         info: A dictionary with the meta data given in 'adsHeader'
             {child[0].tag: child[0].text,
              child[1].tag: child[1].text,
              ...}
     """

    # open xml file
    tree = xml.etree.ElementTree.parse(path)
    root = tree.getroot()

    # Find info
    info_xml = root.findall('adsHeader')
    if len(info_xml) == 1:
        info = {}
        for child in info_xml[0]:
            info[child.tag] = child.text
    else:
        warnings.warn('Warning adsHeader not found')
        info = None

    # Find geo location list
    geo_points = root.findall('geolocationGrid')
    if len(geo_points) == 1:
        geo_points = geo_points[0][0]
    else:
        warnings.warn('Warning geolocationGrid not found')
        return None, None

    # initialize arrays
    n_points = len(geo_points)
    azimuth_time = np.empty(n_points, dtype='datetime64[us]')
    slant_range_time = np.zeros(n_points, dtype=float)
    line = np.zeros(n_points, dtype=int)
    pixel = np.zeros(n_points, dtype=int)
    latitude = np.zeros(n_points, dtype=float)
    longitude = np.zeros(n_points, dtype=float)
    height = np.zeros(n_points, dtype=float)
    incidence_angle = np.zeros(n_points, dtype=float)
    elevation_angle = np.zeros(n_points, dtype=float)

    # get the data
    for i in range(0, n_points):
        point = geo_points[i]

        azimuth_time[i] = np.datetime64(point[0].text)
        slant_range_time[i] = float(point[1].text)
        line[i] = int(point[2].text)
        pixel[i] = int(point[3].text)
        latitude[i] = float(point[4].text)
        longitude[i] = float(point[5].text)
        height[i] = float(point[6].text)
        incidence_angle[i] = float(point[7].text)
        elevation_angle[i] = float(point[8].text)

    # Combine geo_locations info
    geo_locations = {
        'azimuth_time': azimuth_time,
        'slant_range_time': slant_range_time,
        'row': line,
        'column': pixel,
        'latitude': latitude,
        'longitude': longitude,
        'height': height,
        'incidence_angle': incidence_angle,
        'elevation_angle': elevation_angle
    }

    return geo_locations, info

In [4]:
import numpy as np
import os
import pickle
import warnings
from itertools import compress
import rasterio

from .sarpy_class import SarImage
from . import s1
from .get_functions import get_index_v2, get_coordinate


def s1_load(path, polarisation='all', location=None, size=None):
    """Function to load SAR image into SarImage python object.
        Currently supports: unzipped Sentinel 1 GRDH products
        Args:
            path(number): Path to the folder containing the SAR image
                    as retrieved from: https://scihub.copernicus.eu/
            polarisation(list of str): List of polarisations to load.
            location(array/list): [latitude,longitude] Location to center the image.
                                    If None the entire Image is loaded
            size(array/list): [width, height] Extend of image to load.
                                    If None the entire Image is loaded
        Returns:
            SarImage: object with the SAR measurements and meta data from path. Meta data index
                    and foot print are adjusted to the window
        Raises:
            ValueError: Location not in image
        """
    # manifest.safe
    path_safe = os.path.join(path, 'manifest.safe')
    meta, error = s1._load_meta(path_safe)

    # annotation
    ls_annotation = os.listdir(os.path.join(path, 'annotation'))
    xml_files = [file[-3:] == 'xml' for file in ls_annotation]
    xml_files = list(compress(ls_annotation, xml_files))
    annotation_temp = [s1._load_annotation(os.path.join(path, 'annotation', file)) for file in xml_files]

    # calibration_tables
    path_cal = os.path.join(path, 'annotation', 'calibration')
    ls_cal = os.listdir(path_cal)
    cal_files = [file[:11] == 'calibration' for file in ls_cal]
    cal_files = list(compress(ls_cal, cal_files))
    calibration_temp = [s1._load_calibration(os.path.join(path_cal, file)) for file in cal_files]

    # measurement
    measurement_path = os.path.join(path, 'measurement')
    ls_meas = os.listdir(measurement_path)
    tiff_files = [file[-4:] == 'tiff' for file in ls_meas]
    tiff_files = list(compress(ls_meas, tiff_files))
    with warnings.catch_warnings(): # Ignore the "NotGeoreferencedWarning" when opening the tiff
        warnings.simplefilter("ignore")
        measurement_temp = [rasterio.open(os.path.join(measurement_path, file)) for file in tiff_files]

    # Check if polarisation is given
    if polarisation == 'all':
        polarisation = meta['polarisation']
    else:
        polarisation = [elem.upper() for elem in polarisation]

    # only take bands of interest and sort
    n_bands = len(polarisation)
    calibration_tables = [None] * n_bands
    geo_tie_point = [None] * n_bands
    band_meta = [None] * n_bands
    measurement = [None] * n_bands

    for i in range(n_bands):

        for idx, file in enumerate(tiff_files):
            if file.split('-')[3].upper() == polarisation[i]:
                measurement[i] = measurement_temp[idx]

        for band in calibration_temp:
            if band[1]['polarisation'] == polarisation[i]:
                calibration_tables[i] = band[0]

        for band in annotation_temp:
            if band[1]['polarisation'] == polarisation[i]:
                geo_tie_point[i] = band[0]
                band_meta[i] = band[1]

    # Check that there is one band in each tiff
    for i in range(n_bands):
        if measurement[i].count != 1:
            warnings.warn('Warning tiff file contains several bands. First band read from each tiff file')

    if (location is None) or (size is None):
        bands = [image.read(1) for image in measurement]
    else:
        # Check location is in foot print
        maxlat = meta['footprint']['latitude'].max()
        minlat = meta['footprint']['latitude'].min()
        maxlong = meta['footprint']['longitude'].max()
        minlong = meta['footprint']['longitude'].min()

        if not (minlat < location[0] < maxlat) & (minlong < location[1] < maxlong):
            raise ValueError('Location not inside the footprint')

        # get the index
        row = np.zeros(len(geo_tie_point), dtype=int)
        column = np.zeros(len(geo_tie_point), dtype=int)
        for i in range(len(geo_tie_point)):
            lat_grid = geo_tie_point[i]['latitude']
            long_grid = geo_tie_point[i]['longitude']
            row_grid = geo_tie_point[i]['row']
            column_grid = geo_tie_point[i]['column']
            row[i], column[i] = get_index_v2(location[0], location[1], lat_grid, long_grid, row_grid, column_grid)
        # check if index are the same for all bands
        if (abs(row.max() - row.min()) > 0.5) or (abs(column.max() - column.min()) > 0.5):
            warnings.warn('Warning different index found for each band. First index returned')

        # Find the window
        row_index_min = row[0] - int(size[0]/2)
        row_index_max = row[0] + int(size[0]/2)

        column_index_min = column[0] - int(size[1]/2)
        column_index_max = column[0] + int(size[1]/2)

        # Check if window is in image
        if row_index_max < 0 or column_index_max < 0:
            raise ValueError('Error window not in image ')

        if row_index_min < 0:
            warnings.warn('Extend out of image. Window constrained ')
            row_index_min = 0

        if column_index_min < 0:
            warnings.warn('Extend out of image. Window constrained ')
            column_index_min = 0

        for image in measurement:
            if row_index_min > image.height or column_index_min > image.width:
                raise ValueError('Error window not in image')

            if row_index_max > image.height:
                warnings.warn('Extend out of image. Window constrained ')
                row_index_max = image.height

            if column_index_max > image.width:
                warnings.warn('Extend out of image. Window constrained ')
                column_index_max = image.width

        # Adjust footprint to window
        footprint_lat = np.zeros(4)
        footprint_long = np.zeros(4)
        window = ((row_index_min, row_index_max), (column_index_min, column_index_max))

        for i in range(2):
            for j in range(2):
                lat_i, long_i = get_coordinate(window[0][i], window[1][j], geo_tie_point[0]['latitude'],
                                               geo_tie_point[0]['longitude'], geo_tie_point[0]['row'],
                                               geo_tie_point[0]['column'])
                footprint_lat[2 * i + j] = lat_i
                footprint_long[2 * i + j] = long_i

        meta['footprint']['latitude'] = footprint_lat
        meta['footprint']['longitude'] = footprint_long

        # Adjust geo_tie_point, calibration_tables
        for i in range(n_bands):
            geo_tie_point[i]['row'] = geo_tie_point[i]['row'] - row_index_min
            geo_tie_point[i]['column'] = geo_tie_point[i]['column'] - column_index_min

            calibration_tables[i]['row'] = calibration_tables[i]['row'] - row_index_min
            calibration_tables[i]['column'] = calibration_tables[i]['column'] - column_index_min

        # load the data window
        bands = [image.read(1, window=window) for image in measurement]

    return SarImage(bands, mission=meta['mission'], time=meta['start_time'],
                    footprint=meta['footprint'], product_meta=meta,
                    band_names=polarisation, calibration_tables=calibration_tables,
                    geo_tie_point=geo_tie_point, band_meta=band_meta, unit='raw amplitude')


def load(path):
    """ Load SarImage saved with the SarImage save method (img.save(path)).
        Args:
            path(str): Path to the folder where SarImage is saved.
        Returns:
            SarImage
        """

    # product_meta
    file_path = os.path.join(path,'product_meta.pkl')
    product_meta = pickle.load( open( file_path, "rb" ) )

    # unit
    file_path = os.path.join(path,'unit.pkl')
    unit = pickle.load( open( file_path, "rb" ) )

    # footprint
    file_path = os.path.join(path,'footprint.pkl')
    footprint = pickle.load(open( file_path, "rb" ) )

    # geo_tie_point
    file_path = os.path.join(path,'geo_tie_point.pkl')
    geo_tie_point= pickle.load( open( file_path, "rb" ) )

    # band_names
    file_path = os.path.join(path,'band_names.pkl')
    band_names = pickle.load(open( file_path, "rb" ) )

    # band_meta
    file_path = os.path.join(path,'band_meta.pkl')
    band_meta = pickle.load(open( file_path, "rb" ) )

    # bands
    file_path = os.path.join(path,'bands.pkl')
    bands = pickle.load(open( file_path, "rb" ) )

    # calibration_tables
    file_path = os.path.join(path,'calibration_tables.pkl')
    calibration_tables = pickle.load(open( file_path, "rb" ) )
    
    return SarImage(bands, mission=product_meta['mission'], time=product_meta['start_time'],
                    footprint=footprint, product_meta=product_meta,
                    band_names=band_names, calibration_tables=calibration_tables,
                    geo_tie_point=geo_tie_point, band_meta=band_meta, unit=unit)

ModuleNotFoundError: No module named 'rasterio'

In [5]:
import numpy as np
from scipy.interpolate import RegularGridInterpolator
from scipy import ndimage


def calibration(band, rows, columns, calibration_values, tiles=4):
    """Calibrates image using linear interpolation.
    See https://sentinel.esa.int/documents/247904/685163/S1-Radiometric-Calibration-V1.0.pdf
    Args:
        band(2d numpy array): The non calibrated image
        rows(number): rows of calibration point
        columns(number): columns of calibration point
        calibration_values(2d numpy array): grid of calibration values
        tiles(int): number of tiles the image is divided into. This saves memory but reduce speed a bit
    Returns:
        calibrated image (2d numpy array)
    Raises:
    """

    # Create interpolation function
    f = RegularGridInterpolator((rows, columns), calibration_values)

    result = np.zeros(band.shape)
    # Calibrate one tile at the time
    column_start = 0
    column_max = band.shape[1]
    for i in range(tiles):
        column_end = int(column_max / tiles * (i + 1))
        # Create array of point where calibration is needed
        column_mesh, row_mesh = np.meshgrid(np.array(range(column_start, column_end)), np.array(range(band.shape[0])))
        points = np.array([row_mesh.reshape(-1), column_mesh.reshape(-1)]).T
        # Get the image tile and the calibration values for it
        img_tile = band[:, column_start:column_end]
        img_cal = f(points).reshape(img_tile.shape)
        # Set in result
        result[:, column_start:column_end] = (img_tile / img_cal)

        column_start = column_end
    return result ** 2


def boxcar(img, kernel_size, **kwargs):
    """Simple (kernel_size x kernel_size) boxcar filter.
    Args:
        img(2d numpy array): image
        kernel_size(int): size of kernel
        **kwargs: Additional arguments passed to scipy.ndimage.convolve
    Returns:
        Filtered image
    Raises:
    """
    # For small kernels simple convolution
    if kernel_size < 8:
        kernel = np.ones([kernel_size,kernel_size])
        box_img = ndimage.convolve(img, kernel, **kwargs)/kernel_size**2

    # For large kernels use Separable Filters. (https://www.youtube.com/watch?v=SiJpkucGa1o)
    else:
        kernel1 = np.ones([kernel_size, 1])
        kernel2 = np.ones([1, kernel_size])
        box_img = ndimage.convolve(img, kernel1, **kwargs) / kernel_size
        box_img = ndimage.convolve(box_img, kernel2, **kwargs) / kernel_size

    return box_img

In [6]:
import numpy as np
from scipy import interpolate
from scipy.optimize import minimize


def get_coordinate(row, column, lat_gridpoints, long_gridpoints, row_gridpoints, column_gridpoints):
    """Get coordinate from index by interpolating grid-points
    Args:
        row(number): index of the row of interest position
        column(number): index of the column of interest position
        lat_gridpoints(numpy array of length n): Latitude of grid-points
        long_gridpoints(numpy array of length n): Longitude of grid-points
        row_gridpoints(numpy array of length n): row of grid-points
        column_gridpoints(numpy array of length n): column of grid-points
    Returns:
        lat(float): Latitude of the position
        long(float): longitude of the position
    Raises:
    """

    # Create interpolate functions
    points = np.vstack([row_gridpoints, column_gridpoints]).transpose()
    lat = float(interpolate.griddata(points, lat_gridpoints, (row, column)))
    long = float(interpolate.griddata(points, long_gridpoints, (row, column)))
    return lat, long


def get_index_v1(lat, long, lat_gridpoints, long_gridpoints, row_gridpoints, column_gridpoints):
    """Get index of a location by interpolating grid-points
    Args:
        lat(number): Latitude of the location
        long(number): Longitude of location
        lat_gridpoints(numpy array of length n): Latitude of grid-points
        long_gridpoints(numpy array of length n): Longitude of grid-points
        row_gridpoints(numpy array of length n): row of grid-points
        column_gridpoints(numpy array of length n): column of grid-points
    Returns:
        row(int): The row index of the location
        column(int): The column index of the location
    Raises:
    """

    points = np.vstack([lat_gridpoints, long_gridpoints]).transpose()
    row = int(np.round(interpolate.griddata(points, row_gridpoints, (lat, long))))
    column = int(np.round(interpolate.griddata(points, column_gridpoints, (lat, long))))
    return row, column


def get_index_v2(lat, long, lat_gridpoints, long_gridpoints, row_gridpoints, column_gridpoints):
    """
    Same as "get_index_v1" but consistent with "get_coordinate". Drawback is that it is slower
    """

    # Get an initial guess
    row_i, column_i = get_index_v1(lat, long, lat_gridpoints, long_gridpoints, row_gridpoints, column_gridpoints)

    # Define a loss function
    def loss_function(index):
        lat_res, long_res = get_coordinate(index[0], index[1], lat_gridpoints, long_gridpoints, row_gridpoints,
                                           column_gridpoints)
        return ((lat - lat_res) * 100) ** 2 + ((long - long_res) * 100) ** 2

    # Find the index where "get_coordinate" gives the closest coordinates
    res = minimize(loss_function, [row_i, column_i])

    return int(round(res.x[0])), int(round(res.x[1]))