In [5]:
    import tensorflow as tf
tf.config.list_physical_devices("GPU")

[]

In [3]:
from asf_hyp3 import API
import pandas as pd
from src.api_functions import hyp3_login, grab_subscription
from pprint import pprint
from datetime import date, datetime
import re
from src.model import load_model
from src.config import NETWORK_DEMS as dems

import dataclasses as dc
import numpy as np
from dataclasses import field, asdict
from shapely import wkt
from shapely.geometry import Polygon

In [24]:
@dc.dataclass()
class Product:
    name: str
    granule: str = None
    url: str = None
    shape: Polygon = None
    start: datetime = None
    end: datetime = None
        
    def __post_init__(self):
        self.product_time_regex = re.compile(
                r"S.*1SDV_(?P<start_year>\d{4})(?P<start_month>\d{2})(?P<start_day>\d{2})T(?P<start_hour>\d{2})("
                r"?P<start_minute>\d{2})(?P<start_second>\d{2})_(?P<end_year>\d{4})(?P<end_month>\d{2})(?P<end_day>\d{2})T("
                r"?P<end_hour>\d{2})(?P<end_minute>\d{2})(?P<end_second>\d{2})_*")
#         self.start = make_start(name)
#         self.end = make_end(name)
        self.start = self.make_start(self.granule)
        self.end = self.make_end(self.granule)
        
        

        
    def make_start(self,product_name) -> datetime:
        
        regex_match = re.match(self.product_time_regex, product_name)
        time_dict = regex_match.groupdict()
        for k, v in time_dict.items():
            time_dict[k] = int(v)

        return datetime(time_dict["start_year"], time_dict["start_month"], time_dict["start_day"],
                        time_dict["start_hour"], time_dict["start_minute"], time_dict["start_second"])



    def make_end(self, product_name) -> datetime:

        regex_match = re.match(self.product_time_regex, product_name)
        time_dict = regex_match.groupdict()
        for k, v in time_dict.items():
            time_dict[k] = int(v)

        return datetime(time_dict["end_year"], time_dict["end_month"], time_dict["end_day"],
                        time_dict["end_hour"], time_dict["end_minute"], time_dict["end_second"])
    
    
    def to_json(self):
        metadata = asdict(self)
        metadata['start'] = self.start.isoformat()
        metadata['end'] = self.end.isoformat()
        metadata['shape'] = str(self.shape)
        
#         for key in list(metadata):
#             if key is datetime:
#                 metadata[key] = metadata[key].isoformat()
#                 print(f"TEST: {key}= {metadata[key]}")
        return json.dumps(metadata)
        
        
    

In [18]:
def product_time(product_name):
    product_time_regex = re.compile(
        r"S.*1SDV_(?P<start_year>\d{4})(?P<start_month>\d{2})(?P<start_day>\d{2})T(?P<start_hour>\d{2})("
        r"?P<start_minute>\d{2})(?P<start_second>\d{2})_(?P<end_year>\d{4})(?P<end_month>\d{2})(?P<end_day>\d{2})T("
        r"?P<end_hour>\d{2})(?P<end_minute>\d{2})(?P<end_second>\d{2})_*")

    regex_match = re.match(product_time_regex, product_name)
    time_dict = regex_match.groupdict()

    # converts all dates/times values in dictionary from int to string
    for k, v in time_dict.items():
        time_dict[k] = int(v)

    start = datetime(time_dict["start_year"], time_dict["start_month"], time_dict["start_day"],
                     time_dict["start_hour"], time_dict["start_minute"], time_dict["start_second"])

    end = datetime(time_dict["end_year"], time_dict["end_month"], time_dict["end_day"],
                   time_dict["end_hour"], time_dict["end_minute"], time_dict["end_second"])


    return start, end

In [5]:
def product_in_time_bounds(product_name, start, end):
    prod_start, prod_end = product_time(product_name)
    
    return prod_start > start and prod_end < end

In [22]:
def get_sub_products(api, sub_id, start, end):
    
    response = api.get_products(sub_id=sub_id)
    
    products = []
    for product in response:
        if product_in_time_bounds(product['granule'], start, end):
            products.append(Product(product['name'],
                                    product['granule'],
                                    product['url']
                                   )
                           )
    return products
    


In [None]:
def stride_tile_image(

        image: np.ndarray, width: int = dems, height: int = dems
) -> np.ndarray:
    _nrows, _ncols = image.shape
    _strides = image.strides

    nrows, _m = divmod(_nrows, height)
    ncols, _n = divmod(_ncols, width)

    assert _m == 0, "Image must be evenly tileable. Please pad it first"
    assert _n == 0, "Image must be evenly tileable. Please pad it first"

    return np.lib.stride_tricks.as_strided(
        np.ravel(image),
        shape=(nrows, ncols, height, width),
        strides=(height * _strides[0], width * _strides[1], *_strides),
        writeable=False
    ).reshape(nrows * ncols, height, width)


def get_tile_dimensions(height: int, width: int, tile_size: int):
    return int(np.ceil(height / tile_size)), int(np.ceil(width / tile_size))


def write_mask_to_file(
        mask: np.ndarray, file_name: str, projection: str, geo_transform: str
) -> None:
    (width, height) = mask.shape
    out_image = gdal.GetDriverByName('GTiff').Create(
        file_name, height, width, bands=1
    )
    out_image.SetProjection(projection)
    out_image.SetGeoTransform(geo_transform)
    out_image.GetRasterBand(1).WriteArray(mask)
    out_image.GetRasterBand(1).SetNoDataValue(0)
    out_image.FlushCache()


def pad_image(image: np.ndarray, to: int) -> np.ndarray:
    height, width = image.shape

    n_rows, n_cols = get_tile_dimensions(height, width, to)
    new_height = n_rows * to
    new_width = n_cols * to

    padded = np.zeros((new_height, new_width))
    padded[:image.shape[0], :image.shape[1]] = image
    return padded


# TODO: Cut edge fill on final mask (make it more pretty!
# TODO: FIX VV/VH ISSUE. ONLY WORKS WITH VV RIGHT NOW!
# TODO: Split get vv/vh tiles into functions
# TODO: Try differnt tiling method (not strided)
def create_water_mask(
        model_path: str, vv_path: str, vh_path: str, outfile: str, verbose: int = 0
):
    if not os.path.isfile(vv_path):
        raise FileNotFoundError(f"Tiff '{vv_path}' does not exist")

    if not os.path.isfile(vh_path):
        raise FileNotFoundError(f"Tiff '{vh_path}' does not exist")

    def get_tiles(img_path):
        f = gdal.Open(img_path)
        img_array = f.ReadAsArray()
        original_shape = img_array.shape
        n_rows, n_cols = get_tile_dimensions(*original_shape, tile_size=dems)
        padded_img_array = pad_image(img_array, dems)
        invalid_pixels = np.nonzero(padded_img_array == 0.0)
        img_tiles = stride_tile_image(padded_img_array)
        return img_tiles, n_rows, n_cols, invalid_pixels, f.GetProjection(), f.GetGeoTransform()

    # Get vv tiles
    vv_tiles, vv_rows, vv_cols, vv_pixels, vv_projection, vv_transform = get_tiles(vv_path)

    # Get vh tiles
    vh_tiles, vh_rows, vh_cols, vh_pixels, vh_projection, vh_transform = get_tiles(vh_path)

    model = load_model(model_path)

    # Predict masks
    masks = model.predict(
        np.stack((vv_tiles, vh_tiles), axis=3), batch_size=1, verbose=verbose
    )

    masks.round(decimals=0, out=masks)

    # Stitch masks together
    mask = masks.reshape((vv_rows, vv_cols, dems, dems)) \
        .swapaxes(1, 2) \
        .reshape(vv_rows * dems, vv_cols * dems)  # yapf: disable

    mask[vv_pixels] = 0
    write_mask_to_file(mask, outfile, vv_projection, vv_transform)

    # Needed?
    f = None


### extract_from_product

In [1]:
def extract_from_product(product_path, output_dir):
    """Extract vv and vh tifs from product"""
    product_name = Path(product_path).stem
    sar_regex = re.compile(r"(S1[A|B])_(.{2})_(.*)_(VV|VH)(.tif)")

    with zipfile.ZipFile(product_path, "r") as zip_ref:
        with TemporaryDirectory() as tmpdir_name:
            for file_info in zip_ref.infolist():
                if re.fullmatch(sar_regex,file_info.filename):
                    zip_ref.extract(file_info,path=tmpdir_name)
                    shutil.move(f"{tmpdir_name}/{file_info.filename}", output_dir)

In [None]:
def create_water_masks(model, products: list, name: str):
    for product in products:
        with TemporaryDirectory() as tmpdir_name:
            extract_from_product(product)
        

In [25]:
start = datetime(year=2019,month=12,day=3,hour=0,minute=0,second=0)
end   = datetime(year=2019,month=12,day=4,hour=0,minute=0,second=0)

api = hyp3_login()
# subscription = grab_subscription(api)
# sub_id = subscription['id']
id = 2810


products = get_sub_products(api, id, start, end)
pprint(products)

# response = api.get_products(sub_id = id)
# pprint(response)
# products_list = [product['granule'] for product in response]

# products = [p for p in products_list if product_in_time_bounds(p, start, end)]

# # products = get_sub_granules(api, sub_id, start, end)

# print(f"product count is {len(products)}")







 login successful!
 Welcome jmherning
[Product(name='S1A_IW_20191203T224518_DVP_RTC10_G_gpuned_54B9.zip', granule='S1A_IW_GRDH_1SDV_20191203T224518_20191203T224543_030190_037348_297D', url='https://hyp3-download.asf.alaska.edu/asf/data/S1A_IW_20191203T224518_DVP_RTC10_G_gpuned_54B9.zip', shape=None, start=datetime.datetime(2019, 12, 3, 22, 45, 18), end=datetime.datetime(2019, 12, 3, 22, 45, 43)),
 Product(name='S1A_IW_20191203T224543_DVP_RTC10_G_gpuned_3A53.zip', granule='S1A_IW_GRDH_1SDV_20191203T224543_20191203T224608_030190_037348_8C67', url='https://hyp3-download.asf.alaska.edu/asf/data/S1A_IW_20191203T224543_DVP_RTC10_G_gpuned_3A53.zip', shape=None, start=datetime.datetime(2019, 12, 3, 22, 45, 43), end=datetime.datetime(2019, 12, 3, 22, 46, 8))]
