In [None]:
from utils import *
import boto3
import pickle

# aws_session = rasterio.session.AWSSession(boto3.Session(), requester_pays=True)
aws_session = rasterio.session.AWSSession(boto3.Session())
dir_suffix = ""
if (dir_suffix != "") and (not dir_suffix.endswith("/")):
    dir_suffix = dir_suffix + "/"

# skip aoi numbers: 8, 9, 10, 18
# did 14
aoi_index = 0

keep_original_band_scenes = False
one_per_month = True
query_and_process = True

force_path_row = True
forced_pr = "48DXL"

if force_path_row:
    one_per_month = False

combined_path_rows = False
aletrnate_pairs = False

if not combined_path_rows:
    aletrnate_pairs = False

platform = "SENTINEL-2"
collections = ["SENTINEL-2"]
pystac_collections = ["sentinel-2-l2a"]

reference_month = "01"
reference_month_1 = "01"
reference_month_2 = "01"

enhance_image = False

use_pystac = True

bands = ["red", "green", "blue"]

subdir = "true_color"

forced_ids = None

force_reprocess = False

outputs_folder = "outputs_RGB"

reference_band_number = 3

filename_suffix = "TC"

In [None]:
wa_bbox = resize_bbox(
    read_kml_polygon("data/inputs_old/LANDSAT_8_127111/WA.kml")[1], 0.1
)
aoi_polys = kml_to_poly("data/inputs_old/aois.kml").geoms

bbox_list = [
    [67.45, -72.55, 67.55, -72.45],  # Amery bed rock
    [69.2, -68.1, 69.4, -67.9],  # Amery top
    [wa_bbox.left, wa_bbox.bottom, wa_bbox.right, wa_bbox.top],  # WA sand dunes
    *[list(p.bounds) for p in aoi_polys],  # AOI polygons
]

print([np.round(bb, 2).tolist() for bb in bbox_list[aoi_index]])

if type(platform) is list:
    platform = (
        platform[0].split("_")[0] + "_" + "_".join(platform).replace("SENTINEL_", "")
    )
    print("Using platform:", platform)

if force_path_row:
    if forced_ids is not None:
        extra_query = {"ids": forced_ids}
        print("Using extra query:", extra_query)
    else:
        scene_df_file = f"data/inputs/{dir_suffix}{platform}_{forced_pr}/pairs.csv"
        if os.path.exists(scene_df_file):
            scene_df = pd.read_csv(scene_df_file)
            scene_ids = [
                os.path.basename(s).replace("_TC.tif", "").replace("_PROC.tif", "")
                for s in scene_df.iloc[0, :].tolist()
            ]
            extra_query = {"ids": scene_ids}
            print("Using extra query:", extra_query)
        else:
            extra_query = None

In [None]:
if query_and_process:
    query = get_search_query(
        bbox_list[aoi_index],
        collections=collections,
        start_date="2016-01-01T00:00:00",
        end_date="2021-01-01T00:00:00",
        is_landsat=False,
        extra_query=extra_query if force_path_row else None,
    )

    if force_path_row:
        del query["bbox"]

    if use_pystac:
        query["collections"] = pystac_collections
        del query["page"]
        server_url = "https://earth-search.aws.element84.com/v1"
    else:
        server_url = "https://catalogue.dataspace.copernicus.eu/stac/search?"

    features = query_stac_server(query, server_url, pystac=use_pystac)
    print(len(features), "features found")

    if len(features) < 12 and not force_path_row:
        print(f"Not enough features found: {len(features)}, skipping AOI {aoi_index}")
    else:
        os.makedirs(f"data/inputs/features/{platform}", exist_ok=True)
        with open(
            f"data/inputs/features/{platform}/" + str(aoi_index) + ".pkl", "wb"
        ) as f:
            pickle.dump(features, f)

In [None]:
if not query_and_process:
    with open(f"data/inputs/features/{platform}/" + str(aoi_index) + ".pkl", "rb") as f:
        features = pickle.load(f)

if query_and_process:
    scene_dict, scene_list = find_scenes_dict(
        features,
        one_per_month=one_per_month,
        # start_end_years=[2009, 2010],
        acceptance_list=bands + ["thumbnail"],
        remove_duplicate_times=True,
        duplicate_idx=1,
    )
    path_rows = list(scene_dict.keys())
    dates = [list(scene_dict[pr].keys()) for pr in path_rows]
    date_len = [len(d) for d in dates]
    print([path_rows[i] for i in np.argsort(date_len)[::-1]])

    path_row = path_rows[np.argmax(date_len)]
    p_indexes = [path_rows.index(pr) for pr in path_rows]
    diffs = [abs(path_rows.index(path_row) - pi) for pi in p_indexes]
    up = path_rows[np.argmax(diffs)]
    pr_list = [path_row, up]

    print(path_row)
    print(pr_list)

In [None]:
# pr_list = ['106106', '102107']
# path_row = "48DXL"
# reference_month = "01"
# aletrnate_pairs = True

In [None]:
pr = None
if force_path_row:
    assert forced_pr is not None, "Forced path row must be provided"
    if query_and_process:
        if len(forced_pr.split("_")) > 1:
            sn = combine_scene_dicts([scene_dict[pr] for pr in forced_pr.split("_")])
            data_dict = sn.copy()
        else:
            data_dict = scene_dict[forced_pr].copy()
    pr = forced_pr
else:
    if combined_path_rows:
        if query_and_process:
            sn = combine_scene_dicts([scene_dict[pr] for pr in pr_list])
            data_dict = sn.copy()
        pr = "_".join(pr_list)
    else:
        if query_and_process:
            data_dict = scene_dict[path_row].copy()
        pr = path_row
print(pr)

In [None]:
if query_and_process:
    if aletrnate_pairs:
        closest_pair = get_pair_dict_alternate(
            scene_dict[pr_list[0]],
            scene_dict[pr_list[1]],
            "closest",
            reference_month_1=reference_month_1,
            reference_month_2=reference_month_2,
        )
        farthest_pair = get_pair_dict_alternate(
            scene_dict[pr_list[0]],
            scene_dict[pr_list[1]],
            "farthest",
            reference_month_1=reference_month_1,
            reference_month_2=reference_month_2,
        )
    else:
        closest_pair = get_pair_dict(
            data_dict, "closest", reference_month=reference_month
        )
        farthest_pair = get_pair_dict(
            data_dict, "farthest", reference_month=reference_month
        )

    print("Closest pair:")
    print(closest_pair[0])
    print(closest_pair[1])
    print("Farthest pair:")
    print(farthest_pair[0])
    print(farthest_pair[1])

    s3_list = [
        closest_pair[0]["thumbnail_alternate"],
        closest_pair[1]["thumbnail_alternate"],
        farthest_pair[1]["thumbnail_alternate"],
    ]
    outputs = []
    for id, url in enumerate(s3_list):
        outputs.append(
            f"data/inputs/thumbnails/{dir_suffix}{platform}_{path_row}/{url.split('/')[-2]}.jpg"
        )

    bucket = "sentinel-cogs"
    download_files(bucket, s3_list, outputs, -1, is_async_download=False)

In [None]:
if query_and_process:
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    im0 = plt.imread(outputs[0])
    im1 = plt.imread(outputs[1])
    im2 = plt.imread(outputs[2])
    axes[0].imshow(im0)
    axes[1].imshow(im1)
    axes[2].imshow(im2)
    axes[0].set_title(os.path.basename(outputs[0]).replace("_thumb_small.jpeg", ""))
    axes[1].set_title(os.path.basename(outputs[1]).replace("_thumb_small.jpeg", ""))
    axes[2].set_title(os.path.basename(outputs[2]).replace("_thumb_small.jpeg", ""))
    plt.tight_layout()

In [None]:
output_dir = f"data/inputs/{dir_suffix}{platform}_{pr}"
if query_and_process:
    download_and_process_pairs(
        (
            [scene_dict[pr_list[0]], scene_dict[pr_list[1]]]
            if aletrnate_pairs
            else data_dict
        ),
        bands,
        output_dir,
        aws_session,
        keep_original_band_scenes,
        reference_month=(
            [reference_month_1, reference_month_2]
            if aletrnate_pairs
            else reference_month
        ),
        gray_scale=True,
        averaging=True,
        subdir=subdir,
        stretch_contrast=enhance_image,
        force_reprocess=force_reprocess,
        reference_band_number=reference_band_number,
        filename_suffix=filename_suffix,
    )

In [None]:
scene_df = pd.read_csv(f"data/inputs/{dir_suffix}{platform}_{pr}/pairs.csv")
ref_image = scene_df["Reference"][0]
tgt_images = [
    scene_df["Closest_target"][0],
    scene_df["Farthest_target"][0],
]
print("Reference image:", ref_image)
print("Closest target image:", tgt_images[0])
print("Farthest target image:", tgt_images[1])

ref_time = datetime.strptime(os.path.basename(ref_image).split("_")[2], "%Y%m%d")
tgt_times = [
    datetime.strptime(os.path.basename(tgt).split("_")[2], "%Y%m%d")
    for tgt in tgt_images
]
print("Time differences:", [(tgt_time - ref_time).days for tgt_time in tgt_times])

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
im0 = rasterio.open(str(ref_image).replace(subdir, f"{subdir}_ds"))
im1 = rasterio.open(str(tgt_images[0]).replace(subdir, f"{subdir}_ds"))
im2 = rasterio.open(str(tgt_images[1]).replace(subdir, f"{subdir}_ds"))
show(im0, ax=axes[0], cmap="gray", title=os.path.basename(ref_image))
show(im1, ax=axes[1], cmap="gray", title=os.path.basename(tgt_images[0]))
show(im2, ax=axes[2], cmap="gray", title=os.path.basename(tgt_images[1]))
plt.tight_layout()

In [None]:
output_dir = f"data/inputs/{dir_suffix}{platform}_{pr}_edge_detection"
process_existing_outputs(
    [ref_image] + tgt_images,
    output_dir,
    edge_detection=True,
    edge_detection_mode="canny",
    subdir=subdir,
    force_reprocess=force_reprocess,
)

scene_df = pd.read_csv(
    f"data/inputs/{dir_suffix}{platform}_{pr}_edge_detection/pairs.csv"
)
ref_image_edge = scene_df["Reference"][0]
tgt_images_edge = [
    scene_df["Closest_target"][0],
    scene_df["Farthest_target"][0],
]

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
im0 = rasterio.open(str(ref_image_edge).replace(subdir, f"{subdir}_ds"))
im1 = rasterio.open(str(tgt_images_edge[0]).replace(subdir, f"{subdir}_ds"))
im2 = rasterio.open(str(tgt_images_edge[1]).replace(subdir, f"{subdir}_ds"))
show(im0, ax=axes[0], cmap="gray", title=os.path.basename(ref_image_edge))
show(im1, ax=axes[1], cmap="gray", title=os.path.basename(tgt_images_edge[0]))
show(im2, ax=axes[2], cmap="gray", title=os.path.basename(tgt_images_edge[1]))
plt.tight_layout()

#### Co_Register

In [None]:
use_overlap = not (
    rasterio.open(ref_image).transform
    == rasterio.open(tgt_images[0]).transform
    == rasterio.open(tgt_images[1]).transform
)
print("Using overlap:", use_overlap)
print()
output_path = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}/Co_Register"

_, shifts, target_ids = co_register(
    ref_image,
    tgt_images,
    output_path=output_path,
    return_shifted_images=True,
    use_overlap=use_overlap,
    # phase_corr_filter=False,
    # phase_corr_valid_num_points=1,
    # of_dist_thresh=10,
    # band_number=2,
)

failed_tagets = [i for i in range(len(tgt_images)) if i not in target_ids]
default_params = [True] * len(tgt_images)
for i in failed_tagets:
    default_params[i] = False

if len(failed_tagets) > 0:
    print("Failed to co-register targets:", failed_tagets)
    print("Re-running co-registration with Laplacian filter for failed targets")
    print(end="\r")

    shutil.rmtree(output_path, ignore_errors=True)
    laplacian_filter = True
    laplacian_for_targets_ids = failed_tagets

    output_path += "_lpc"

    _, shifts, target_ids = co_register(
        ref_image,
        tgt_images,
        output_path=output_path,
        return_shifted_images=True,
        use_overlap=use_overlap,
        # phase_corr_filter=False,
        # phase_corr_valid_num_points=1,
        # of_dist_thresh=10,
        # band_number=2,
        laplacian_kernel_size=5,
        laplacian_for_targets_ids=laplacian_for_targets_ids,
    )

print("\nCo-register shifts:")
for i, shift in enumerate(shifts):
    print(
        f"Target {target_ids[i]}: {tuple([np.round(el.tolist(), 3).tolist() for el in shift])} pixels"
    )

#### Karios

In [None]:
output_dir = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}/Karios"
karios_executable = "/home/ubuntu/Coreg/karios/karios/karios.py"
shift_dict, target_ids = karios(
    ref_image,
    tgt_images,
    output_dir,
    karios_executable,
)
print("\nKarios shifts:")
for i, shift in enumerate(shift_dict):
    print(
        f"Target {target_ids[i]}: {tuple([np.round(el, 3).tolist() for el in shift_dict[shift]])} pixels"
    )

#### AROSICS

In [None]:
output_dir = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}/AROSICS"
shifts, target_ids = arosics(
    ref_image,
    tgt_images,
    output_dir,
)
print("\nAROSICS shifts:")
for i, shift in enumerate(shifts):
    print(
        f"Target {target_ids[i]}: {tuple([np.round(el.tolist(), 3).tolist() for el in shift])} pixels"
    )

#### AROSICS EDGE

In [None]:
output_dir = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}/AROSICS_edge"
shifts, target_ids = arosics(
    ref_image_edge,
    tgt_images_edge,
    output_dir,
    existing_ref_image=ref_image,
    existing_tgt_images=tgt_images,
)
print("\nAROSICS shifts:")
for i, shift in enumerate(shifts):
    print(
        f"Target {target_ids[i]}: {tuple([np.round(el.tolist(), 3).tolist() for el in shift])} pixels"
    )

In [None]:
root_output = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}"

combine_comparison_results(root_output, default_params)