In [None]:
from utils import *
from utils import _get_band_suffixes
import boto3
import pickle
import random

aws_session = rasterio.session.AWSSession(boto3.Session(), requester_pays=True)

dir_suffix = ""

aoi_index = 6  # 13 and 6

keep_original_band_scenes = False
one_per_month = False
remove_duplicate_times = True
query_and_process = True

combined_path_rows = False
alternate_pairs = False

scan_big_shifts = True

if not combined_path_rows:
    alternate_pairs = False

platform = ["LANDSAT_4", "LANDSAT_5"]

reference_month = "01"
reference_month_1 = "01"
reference_month_2 = "01"

enhance_image = True

bands = ["red", "green", "blue"]  # "swir16"]

subdir = "rgb_enhanced"

id_filter = "T2"
if "_SR" in id_filter:
    collections = ["landsat-c2l2-sr"]
    dir_suffix = "L2" + dir_suffix
else:
    collections = ["landsat-c2l1"]
    dir_suffix = "L1" + dir_suffix
dir_suffix = dir_suffix + id_filter.replace("_SR", "").replace("_", "")


outputs_folder = f"outputs_coreg_spatial/outputs_RGB_enhanced"

force_reprocess = False

filename_suffix = "PROC"

inputs_dir = "inputs_coreg_spatial"

if (dir_suffix != "") and (not dir_suffix.endswith("/")):
    dir_suffix = dir_suffix + "/"

In [None]:
wa_bbox = resize_bbox(BoundingBox(*kml_to_poly("data/inputs_old/WA.kml").bounds), 0.1)
aoi_polys = kml_to_poly("data/inputs_old/aois.kml").geoms
white_island_bbox = read_kml_polygon("data/inputs_old/White_island.kml")[1]
inland_bbox = read_kml_polygon("data/inputs_old/inland3.kml")[1]

# AOI 5 and 6 L1T2 have visible shifts in their series. AOI 6 co_register with dist_thresh 30
bbox_list = [
    [67.45, -72.55, 67.55, -72.45],  # Amery bed rock
    [69.2, -68.1, 69.4, -67.9],  # Amery top, There are no T1 products for this AOI
    wa_bbox,  # WA sand dunes, L4-5 L1T2 co-reg only works for fist target with dist thresh = 50 and min dist thresh = 10 and directional filtering on
    *[list(p.bounds) for p in aoi_polys],  # AOI polygons
    list(
        resize_bbox(BoundingBox(*kml_to_poly("data/inputs_old/TAS.kml").bounds), 0.1)
    ),  # TAS
    [152.12, -28.37, 154.4, -26.48],  # QLD
    white_island_bbox,  # White Island
    inland_bbox,  # inland site
]

print("Using AOI index:", aoi_index)
print("Using AOI bbox:")
print([np.round(bb, 2).tolist() for bb in bbox_list[aoi_index]])

query_platform = platform
if type(platform) is list:
    platform = (
        platform[0].split("_")[0] + "_" + "_".join(platform).replace("LANDSAT_", "")
    )
    print("Using platform:", platform)

In [None]:
if query_and_process:
    query = get_search_query(
        bbox_list[aoi_index],
        # start_date="2008-12-31T00:00:00",
        start_date="1985-01-01T00:00:00",
        end_date="2012-12-31T00:00:00",
        platform=query_platform,
        collection_category=None,
        collections=collections,
        cloud_cover=25,
        extra_query=None,
    )

    print("Search query:", query)

    server_url = "https://landsatlook.usgs.gov/stac-server/search"
    features = query_stac_server(query, server_url, id_filter=id_filter)
    print(len(features), "features found")

    if len(features) < 12:
        print(f"Not enough features found: {len(features)}, skipping AOI {aoi_index}")
    # else:
    #     os.makedirs(f"data/{inputs_dir}/features/{platform}", exist_ok=True)
    #     with open(
    #         f"data/{inputs_dir}/features/{platform}/" + str(aoi_index) + ".pkl", "wb"
    #     ) as f:
    #         pickle.dump(features, f)

In [None]:
# if not query_and_process:
#     with open(f"data/inputs/features/{platform}/" + str(aoi_index) + ".pkl", "rb") as f:
#         features = pickle.load(f)

if query_and_process:
    scene_dict, scene_list = find_scenes_dict(
        features,
        one_per_month=one_per_month,
        # start_end_years=[2009, 2010],
        acceptance_list=bands + ["thumbnail"],
        remove_duplicate_times=remove_duplicate_times,
        duplicate_idx=1,
    )
    path_rows = list(scene_dict.keys())
    if len(path_rows) == 0:
        raise ValueError("No scenes found, cannot continue")
    print(path_rows)
    dates = [list(scene_dict[pr].keys()) for pr in path_rows]
    date_len = [len(d) for d in dates]

    path_row = path_rows[np.argmax(date_len)]
    diffs = [abs(int(pr) - int(path_row)) for pr in path_rows]
    up = path_rows[np.argmax(diffs)]
    pr_list = [path_row, up]

    print(path_row)
    print(pr_list)

In [None]:
path_row_list = [
    (path_row, len([s for s in scene_list if path_row in s["scene_name"]]))
    for path_row in path_rows
]
pd.DataFrame(path_row_list, columns=["path_row", "count"])

In [None]:
for pr in path_rows:
    output_dir = f"data/{inputs_dir}/{dir_suffix}{platform}_{pr}"
    data_dict = scene_dict[pr].copy()
    pr_date_list = flatten([scene_dict[pr][k] for k in scene_dict[pr].keys()])
    if query_and_process:
        process_dir = f"{output_dir}/{subdir}"
        process_ds_dir = f"{output_dir}/{subdir}_ds"
        bands_suffixes = _get_band_suffixes(data_dict, bands[0:3])
        download_and_process_series(
            pr_date_list,
            bands,
            bands_suffixes,
            output_dir,
            process_dir,
            process_ds_dir,
            aws_session,
            keep_original_band_scenes,
            gray_scale=True,
            averaging=True,
            stretch_contrast=enhance_image,
            force_reprocess=force_reprocess,
            filename_suffix=filename_suffix,
            preserve_depth=True,
        )

In [None]:
ref_list = []
tgt_list = []
pr_list = []
for pr in path_rows:
    input_dir = f"data/{inputs_dir}/{dir_suffix}{platform}_{pr}"
    pr_date_list = sorted(glob.glob(input_dir + f"/{subdir}/**"))
    ref_image = pr_date_list[0]
    tgt_images = pr_date_list[1:]

    if len(tgt_images) < 1:
        print(f"Not enough target images for path_row {pr}, skipping")
        continue
    ref_list.append(ref_image)
    tgt_list.append(tgt_images)
    pr_list.append(pr)

In [None]:
ref_list_copy = ref_list.copy()
overlaps_per_ref = [ref_list_copy[0]]
ref = ref_list_copy[0]
ref_list_copy.pop(0)
while len(ref_list_copy) > 0:
    overlap_areas = []
    for tgt in ref_list_copy:
        try:
            area = box(*find_overlap(ref, tgt)[0]).area / 1000000
            overlap_areas.append(area)
        except Exception:
            overlap_areas.append(0.0)
    ref = ref_list_copy[np.argmax(overlap_areas)]
    ref_list_copy.pop(np.argmax(overlap_areas))
    overlaps_per_ref.append(ref)

overlap_dir_name = f"data/{outputs_folder}/{dir_suffix}refs_{platform}/" + "_".join(
    [ref.split("/")[3].split("_")[3] for ref in overlaps_per_ref]
)
os.makedirs(overlap_dir_name, exist_ok=True)

ref_sortperm = [
    [os.path.basename(ref) for ref in ref_list].index(ref)
    for ref in [os.path.basename(ref) for ref in overlaps_per_ref]
]
tgt_list = [tgt_list[i] for i in ref_sortperm]
pr_list = [pr_list[i] for i in ref_sortperm]

In [None]:
# methods = ["Co_Register", "KARIOS", "AROSICS"]
methods = ["AROSICS"]
for method in methods:
    print(f"\n=== Using method: {method} ===\n")
    print("Processing reference-target chain.")
    print("Co-registering references. Thef first reference image is fixed.")
    ordered_refs = overlaps_per_ref.copy()
    for i in range(len(ordered_refs) - 1):
        ref_id = i
        tgt_id = i + 1
        output_dir = f"data/{outputs_folder}/{dir_suffix}refs_{platform}_{overlaps_per_ref[ref_id].split('/')[3].split('_')[3]}_{overlaps_per_ref[tgt_id].split('/')[3].split('_')[3]}/{method}{'_bigShifts' if scan_big_shifts else ''}"
        shifts, target_ids = coreg(
            ordered_refs[ref_id],
            [ordered_refs[tgt_id]],
            output_dir,
            method=method,
        )
        ordered_refs[tgt_id] = (
            f"{output_dir}/Aligned/{os.path.basename(overlaps_per_ref[tgt_id])}"
        )
        print(f"\n{method} shifts:")
        if method == "KARIOS":
            for i, shift in enumerate(shifts):
                print(
                    f"Target {target_ids[i]}: {tuple([np.round(el, 3).tolist() for el in shifts[shift]])} pixels"
                )
        else:
            for i, shift in enumerate(shifts):
                print(
                    f"Target {target_ids[i]}: {tuple([np.round(el.tolist(), 3).tolist() for el in shift])} pixels"
                )
    for i, file in enumerate(ordered_refs):
        if not os.path.exists(file):
            print(f"{i}: MISSING FILE {file}")
            ordered_refs[i] = overlaps_per_ref[i]

    make_difference_gif(
        ordered_refs, overlap_dir_name + f"/{method}.gif", mosaic_scenes=True
    )

    print(
        "\nCo-registering targets for each pathrow to their corresponding references.\n"
    )
    for i, pr in enumerate(pr_list):
        output_dir = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}/{method}{'_bigShifts' if scan_big_shifts else ''}"
        kwargs = {"method": method}
        if method == "AROSICS":
            kwargs["max_shift"] = 200 if scan_big_shifts else 5
        elif method == "KARIOS":
            kwargs["scan_big_shifts"] = scan_big_shifts
        elif method == "Co_Register":
            kwargs["big_shifts_mode"] = scan_big_shifts
        print(
            f"Co-registering targets for pathrow {pr} against reference {ordered_refs[i]}"
        )
        shifts, target_ids = coreg(
            ordered_refs[i],
            tgt_list[i],
            output_dir,
            **kwargs,
        )
        print(f"\n{method} shifts:")
        if method == "KARIOS":
            for i, shift in enumerate(shifts):
                print(
                    f"Target {target_ids[i]}: {tuple([np.round(el, 3).tolist() for el in shifts[shift]])} pixels"
                )
        else:
            for i, shift in enumerate(shifts):
                print(
                    f"Target {target_ids[i]}: {tuple([np.round(el.tolist(), 3).tolist() for el in shift])} pixels"
                )
    print(f"\n=== Finished using method: {method} ===\n")

In [None]:
make_difference_gif(overlaps_per_ref, overlap_dir_name + "/raw.gif", mosaic_scenes=True)

In [None]:
for pr in pr_list:
    root_output = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}"

    combine_comparison_results(
        root_output,
        "bigShifts" if scan_big_shifts else None,
        coreg_methods=methods,
    )