In [None]:
from utils import *
import boto3
import pickle
import random

aws_session = rasterio.session.AWSSession(boto3.Session(), requester_pays=True)

dir_suffix = ""

aoi_index = -1

keep_original_band_scenes = False
one_per_month = True
remove_duplicate_times = True
query_and_process = True

force_path_row = True
forced_pr = "089080"

fill_nodata = False

if force_path_row:
    one_per_month = False

combined_path_rows = False
alternate_pairs = False

if not combined_path_rows:
    alternate_pairs = False

# platform = ["LANDSAT_4", "LANDSAT_5"]
platform = "LANDSAT_5"
# platform = "LANDSAT_7"
# platform = "LANDSAT_8"

reference_month = "01"
reference_month_1 = "01"
reference_month_2 = "01"

enhance_image = True

bands = ["red", "green", "blue"]  # "swir16"]

subdir = "rgb_enhanced"

forced_ids = None
forced_ids = [
    "LT05_L2SP_089080_20040504_20200903_02_T1_SR",
    "LT05_L2SP_089080_20110711_20200822_02_T2_SR",
    "LT05_L2SP_089080_20111015_20200820_02_T1_SR",
]

id_filter = ""
if platform in ["LANDSAT_4", "LANDSAT_5"]:
    if "_SR" in id_filter:
        l_4_5_collections = ["landsat-c2l2-sr"]
    else:
        l_4_5_collections = ["landsat-c2l1"]
    dir_suffix = id_filter.replace("_", "") + dir_suffix

outputs_folder = (
    f"outputs_coreg/outputs_RGB_enhanced{'_infilled' if fill_nodata else ''}"
)

force_reprocess = False

filename_suffix = "PROC"

inputs_dir = "inputs_coreg"

if (dir_suffix != "") and (not dir_suffix.endswith("/")):
    dir_suffix = dir_suffix + "/"

In [None]:
# pr_dir = [
#     dir
#     for dir in glob.glob(
#         "data/inputs_old_coreg_outputs/outputs_coreg/outputs_RGB_enhanced/*"
#     )
#     if forced_pr in dir
# ][0]
# landsat_dirnames = glob.glob(f"{pr_dir}/Karios/LE*")
# ref_name = landsat_dirnames[0].split("PROC")[1][1:-1]
# tgt_0_name = os.path.basename(landsat_dirnames[0].split("PROC")[0][:-1])
# tgt_1_name = os.path.basename(landsat_dirnames[1].split("PROC")[0][:-1])
# forced_ids = [ref_name, tgt_0_name, tgt_1_name]
# print("Forced IDs:", forced_ids)

In [None]:
wa_bbox = resize_bbox(BoundingBox(*kml_to_poly("data/inputs_old/WA.kml").bounds), 0.1)
aoi_polys = kml_to_poly("data/inputs_old/aois.kml").geoms

bbox_list = [
    [67.45, -72.55, 67.55, -72.45],  # Amery bed rock
    [69.2, -68.1, 69.4, -67.9],  # Amery top
    [wa_bbox.left, wa_bbox.bottom, wa_bbox.right, wa_bbox.top],  # WA sand dunes
    *[list(p.bounds) for p in aoi_polys],  # AOI polygons
    list(
        resize_bbox(BoundingBox(*kml_to_poly("data/inputs_old/TAS.kml").bounds), 0.1)
    ),  # TAS
    [152.12, -28.37, 154.4, -26.48],  # QLD
]

print("Using AOI index:", aoi_index)
print("Using AOI bbox:")
print([np.round(bb, 2).tolist() for bb in bbox_list[aoi_index]])

query_platform = platform
if type(platform) is list:
    platform = (
        platform[0].split("_")[0] + "_" + "_".join(platform).replace("LANDSAT_", "")
    )
    print("Using platform:", platform)

if force_path_row:
    if forced_ids is not None:
        extra_query = {"ids": forced_ids}
        print("Using extra query:", extra_query)
    else:
        scene_df_file = (
            f"data/{inputs_dir}/{dir_suffix}{platform}_{forced_pr}/pairs.csv"
        )
        if os.path.exists(scene_df_file):
            scene_df = pd.read_csv(scene_df_file)
            scene_ids = [
                os.path.basename(s).replace("_TC.TIF", "").replace("_PROC.TIF", "")
                for s in scene_df.iloc[0, :].tolist()
            ]
            extra_query = {"ids": scene_ids}
            print("Using extra query:", extra_query)
        else:
            extra_query = None

In [None]:
if query_and_process:
    if platform == "LANDSAT_8":
        query = get_search_query(
            bbox_list[aoi_index] if forced_ids is None else None,
            # start_date="",
            start_date="2013-01-01T00:00:00",
            end_date="2017-01-01T00:00:00",
            platform=query_platform,
            collection_category=None,
            collections=["landsat-c2l2-sr"],
            extra_query=extra_query if force_path_row else None,
            cloud_cover=5,
        )
    elif platform == "LANDSAT_7":
        query = get_search_query(
            bbox_list[aoi_index] if forced_ids is None else None,
            # start_date="",
            start_date="2003-01-01T00:00:00",
            end_date="2005-12-31T00:00:00",
            platform=query_platform,
            collection_category=None,
            collections=["landsat-c2l2-sr"],
            extra_query=extra_query if force_path_row else None,
            # cloud_cover=50,
        )
    else:  # LANDSAT 4, 5
        query = get_search_query(
            bbox_list[aoi_index] if forced_ids is None else None,
            start_date="2003-12-31T00:00:00",
            # start_date="1985-01-01T00:00:00",
            end_date="2012-12-31T00:00:00",
            platform=query_platform,
            collection_category=None,
            collections=l_4_5_collections,
            cloud_cover=5,
            extra_query=extra_query if force_path_row else None,
        )

    # if force_path_row:
    #     del query["bbox"]

    print("Search query:", query)

    server_url = "https://landsatlook.usgs.gov/stac-server/search"
    features = query_stac_server(query, server_url, id_filter=id_filter)
    print(len(features), "features found")

    if len(features) < 12 and not force_path_row:
        print(f"Not enough features found: {len(features)}, skipping AOI {aoi_index}")
    else:
        os.makedirs(f"data/{inputs_dir}/features/{platform}", exist_ok=True)
        with open(
            f"data/{inputs_dir}/features/{platform}/" + str(aoi_index) + ".pkl", "wb"
        ) as f:
            pickle.dump(features, f)

In [None]:
# if not query_and_process:
#     with open(f"data/inputs/features/{platform}/" + str(aoi_index) + ".pkl", "rb") as f:
#         features = pickle.load(f)

if query_and_process:
    scene_dict, scene_list = find_scenes_dict(
        features,
        one_per_month=one_per_month,
        # start_end_years=[2009, 2010],
        acceptance_list=bands + ["thumbnail"],
        remove_duplicate_times=remove_duplicate_times,
        duplicate_idx=1,
    )
    path_rows = list(scene_dict.keys())
    if len(path_rows) == 0:
        raise ValueError("No scenes found, cannot continue")
    print(path_rows)
    dates = [list(scene_dict[pr].keys()) for pr in path_rows]
    date_len = [len(d) for d in dates]

    path_row = path_rows[np.argmax(date_len)]
    diffs = [abs(int(pr) - int(path_row)) for pr in path_rows]
    up = path_rows[np.argmax(diffs)]
    pr_list = [path_row, up]

    print(path_row)
    print(pr_list)

In [None]:
# full_scene_dict, full_scene_list = find_scenes_dict(
#     features,
#     one_per_month=False,
#     # start_end_years=[2009, 2010],
#     acceptance_list=bands + ["thumbnail"],
#     remove_duplicate_times=False,
#     duplicate_idx=1,
# )
# full_scene_list = [f for f in full_scene_list if "089080" in f["scene_name"]]
# print(len(full_scene_list), "scenes in full list")
# pr_date_list_processed = download_and_process_series(
#     full_scene_list,
#     bands,
#     ["_B3", "_B2", "_B1"],
#     "templ5/l5_input",
#     "templ5/l5_process",
#     "templ5/l5_ds",
#     aws_session,
#     keep_original_band_scenes,
#     stretch_contrast=True,
#     gray_scale=True,
#     preserve_depth=True,
#     min_max_scaling=False,
# )
# files = sorted(glob.glob("templ5/l5_ds/*"))
# make_difference_gif(
#     files,
#     "templ5/out_3.gif",
#     mosaic_scenes=True,
#     fps=10,
# )
# files = sorted(glob.glob("templ5/l5_ds/*"))
# methods = ["Co_Register", "Arosics", "Karios"]
# for method in methods:
#     output_dir = f"templ5/{method}_L2T1"
#     coreg(
#         files[0],
#         files[1:],
#         output_dir,
#         method,
#         fps=10,
#     )

In [None]:
if query_and_process:
    if len(features) > 25:
        random.seed(42)
        features = random.sample(features, 25)
        print("Randomly sampled 25 features for download")

    _, full_scene_list = find_scenes_dict(
        features,
        one_per_month=False,
        # start_end_years=[2009, 2010],
        acceptance_list=bands + ["thumbnail"],
        remove_duplicate_times=False,
        duplicate_idx=1,
    )
    i = 0
    shutil.rmtree("temp_data", ignore_errors=True)
    os.makedirs("temp_data", exist_ok=True)
    s3_list = [s["thumbnail_alternate"] for s in full_scene_list]
    outputs = []
    for url in s3_list:
        outputs.append(f"temp_data/{os.path.basename(url)}")
    bucket = "usgs-landsat"
    download_files(bucket, s3_list, outputs, -1, is_async_download=False)

In [None]:
pr = None
if force_path_row:
    assert forced_pr is not None, "Forced path row must be provided"
    if query_and_process:
        if len(forced_pr.split("_")) > 1:
            sn = combine_scene_dicts([scene_dict[pr] for pr in forced_pr.split("_")])
            data_dict = sn.copy()
        else:
            data_dict = scene_dict[forced_pr].copy()
    pr = forced_pr
else:
    if combined_path_rows:
        if query_and_process:
            sn = combine_scene_dicts([scene_dict[pr] for pr in pr_list])
            data_dict = sn.copy()
        pr = "_".join(pr_list)
    else:
        if query_and_process:
            data_dict = scene_dict[path_row].copy()
        pr = path_row
print(pr)

In [None]:
if query_and_process:
    if alternate_pairs:
        closest_pair = get_pair_dict_alternate(
            scene_dict[pr_list[0]],
            scene_dict[pr_list[1]],
            "closest",
            reference_month_1=reference_month_1,
            reference_month_2=reference_month_2,
        )
        farthest_pair = get_pair_dict_alternate(
            scene_dict[pr_list[0]],
            scene_dict[pr_list[1]],
            "farthest",
            reference_month_1=reference_month_1,
            reference_month_2=reference_month_2,
        )
    else:
        closest_pair = get_pair_dict(
            data_dict, "closest", reference_month=reference_month
        )
        farthest_pair = get_pair_dict(
            data_dict, "farthest", reference_month=reference_month
        )

    print("Closest pair:")
    print(closest_pair[0])
    print(closest_pair[1])
    print("Farthest pair:")
    print(farthest_pair[0])
    print(farthest_pair[1])

    s3_list = [
        closest_pair[0]["thumbnail_alternate"],
        closest_pair[1]["thumbnail_alternate"],
        farthest_pair[1]["thumbnail_alternate"],
    ]
    outputs = []
    for url in s3_list:
        outputs.append(
            f"data/{inputs_dir}/thumbnails/{dir_suffix}{platform}_{path_row}/{os.path.basename(url)}"
        )

    bucket = "usgs-landsat"
    download_files(bucket, s3_list, outputs, -1, is_async_download=False)

In [None]:
if query_and_process:
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    im0 = plt.imread(outputs[0])
    im1 = plt.imread(outputs[1])
    im2 = plt.imread(outputs[2])
    axes[0].imshow(im0)
    axes[1].imshow(im1)
    axes[2].imshow(im2)
    axes[0].set_title(os.path.basename(outputs[0]).replace("_thumb_small.jpeg", ""))
    axes[1].set_title(os.path.basename(outputs[1]).replace("_thumb_small.jpeg", ""))
    axes[2].set_title(os.path.basename(outputs[2]).replace("_thumb_small.jpeg", ""))
    plt.tight_layout()

In [None]:
output_dir = f"data/{inputs_dir}/{dir_suffix}{platform}_{pr}"
if query_and_process:
    download_and_process_pairs(
        (
            [scene_dict[pr_list[0]], scene_dict[pr_list[1]]]
            if alternate_pairs
            else data_dict
        ),
        bands,
        output_dir,
        aws_session,
        keep_original_band_scenes,
        reference_month=(
            [reference_month_1, reference_month_2]
            if alternate_pairs
            else reference_month
        ),
        gray_scale=True,
        averaging=True,
        subdir=subdir,
        stretch_contrast=enhance_image,
        force_reprocess=force_reprocess,
        filename_suffix=filename_suffix,
        preserve_depth=True if platform in ["LANDSAT_4", "LANDSAT_5"] else False,
    )

In [None]:
scene_df = pd.read_csv(f"data/{inputs_dir}/{dir_suffix}{platform}_{pr}/pairs.csv")
ref_image = scene_df["Reference"][0]
tgt_images = [
    scene_df["Closest_target"][0],
    scene_df["Farthest_target"][0],
]
print("Reference image:", ref_image)
print("Closest target image:", tgt_images[0])
print("Farthest target image:", tgt_images[1])

ref_time = datetime.strptime(os.path.basename(ref_image).split("_")[3], "%Y%m%d")
tgt_times = [
    datetime.strptime(os.path.basename(tgt).split("_")[3], "%Y%m%d")
    for tgt in tgt_images
]
print("Time differences:", [(tgt_time - ref_time).days for tgt_time in tgt_times])

In [None]:
if platform == "LANDSAT_7" and fill_nodata:
    output_dir = f"data/{inputs_dir}/{dir_suffix}{platform}_{pr}_infilled"
    process_existing_outputs(
        [ref_image] + tgt_images,
        output_dir,
        subdir=subdir,
        force_reprocess=force_reprocess,
        min_max_scaling=False,
        fill_nodata=True,
    )

    scene_df = pd.read_csv(
        f"data/{inputs_dir}/{dir_suffix}{platform}_{pr}_infilled/pairs.csv"
    )
    ref_image_infill = scene_df["Reference"][0]
    tgt_images_infill = [
        scene_df["Closest_target"][0],
        scene_df["Farthest_target"][0],
    ]
else:
    ref_image_infill = ref_image
    tgt_images_infill = tgt_images

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
im0 = rasterio.open(str(ref_image).replace(subdir, f"{subdir}_ds"))
im1 = rasterio.open(str(tgt_images[0]).replace(subdir, f"{subdir}_ds"))
im2 = rasterio.open(str(tgt_images[1]).replace(subdir, f"{subdir}_ds"))
show(im0, ax=axes[0], cmap="gray", title=os.path.basename(ref_image))
show(im1, ax=axes[1], cmap="gray", title=os.path.basename(tgt_images[0]))
show(im2, ax=axes[2], cmap="gray", title=os.path.basename(tgt_images[1]))
plt.tight_layout()

In [None]:
output_dir = f"data/{inputs_dir}/{dir_suffix}{platform}_{pr}_edge_detection"
if platform == "LANDSAT_7" and fill_nodata:
    output_dir = (
        f"data/{inputs_dir}/{dir_suffix}{platform}_{pr}_infilled_edge_detection"
    )
process_existing_outputs(
    [ref_image_infill] + tgt_images_infill,
    output_dir,
    edge_detection=True,
    edge_detection_mode="canny",
    subdir=subdir,
    force_reprocess=force_reprocess,
    fill_nodata=fill_nodata,
)

scene_df = pd.read_csv(f"{output_dir}/pairs.csv")
ref_image_edge = scene_df["Reference"][0]
tgt_images_edge = [
    scene_df["Closest_target"][0],
    scene_df["Farthest_target"][0],
]

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
im0 = rasterio.open(str(ref_image_edge).replace(subdir, f"{subdir}_ds"))
im1 = rasterio.open(str(tgt_images_edge[0]).replace(subdir, f"{subdir}_ds"))
im2 = rasterio.open(str(tgt_images_edge[1]).replace(subdir, f"{subdir}_ds"))
show(im0, ax=axes[0], cmap="gray", title=os.path.basename(ref_image_edge))
show(im1, ax=axes[1], cmap="gray", title=os.path.basename(tgt_images_edge[0]))
show(im2, ax=axes[2], cmap="gray", title=os.path.basename(tgt_images_edge[1]))
plt.tight_layout()

#### Co_Register

In [None]:
output_dir = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}/Co_Register"

shifts, target_ids = coreg(
    ref_image_infill,
    tgt_images_infill,
    output_dir,
    # phase_corr_filter=False,
    # phase_corr_valid_num_points=1,
    # of_dist_thresh=5,
    # band_number=2,
    # no_ransac=True,
    method="Co_Register",
)
print("\nCo-register shifts:")
for i, shift in enumerate(shifts):
    print(
        f"Target {target_ids[i]}: {tuple([np.round(el.tolist(), 3).tolist() for el in shift])} pixels"
    )

#### Karios

In [None]:
output_dir = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}/Karios"
shifts, target_ids = coreg(
    ref_image_infill,
    tgt_images_infill,
    output_dir,
    method="Karios",
    # scan_big_shifts=True,
)
print("\nKarios shifts:")
for i, shift in enumerate(shifts):
    print(
        f"Target {target_ids[i]}: {tuple([np.round(el, 3).tolist() for el in shifts[shift]])} pixels"
    )

#### AROSICS

In [None]:
output_dir = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}/AROSICS"
shifts, target_ids = coreg(
    ref_image_infill,
    tgt_images_infill,
    output_dir,
    method="AROSICS",
    existing_ref_image=(
        ref_image_infill if (platform == "LANDSAT_7" and not fill_nodata) else None
    ),
    existing_tgt_images=(
        tgt_images_infill if (platform == "LANDSAT_7" and not fill_nodata) else None
    ),
)
print("\nAROSICS shifts:")
for i, shift in enumerate(shifts):
    print(
        f"Target {target_ids[i]}: {tuple([np.round(el.tolist(), 3).tolist() for el in shift])} pixels"
    )

#### AROSICS EDGE

In [None]:
output_dir = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}/AROSICS_edge"
shifts, target_ids = coreg(
    ref_image_edge,
    tgt_images_edge,
    output_dir,
    method="AROSICS",
    existing_ref_image=ref_image_infill,
    existing_tgt_images=tgt_images_infill,
)
print("\nAROSICS shifts:")
for i, shift in enumerate(shifts):
    print(
        f"Target {target_ids[i]}: {tuple([np.round(el.tolist(), 3).tolist() for el in shift])} pixels"
    )

In [None]:
root_output = f"data/{outputs_folder}/{dir_suffix}{platform}_{pr}"

combine_comparison_results(root_output)

### Only Landsat 7

In [None]:
if fill_nodata and (platform == "LANDSAT_7"):
    print("\nProcessing filled nodata outputs for comparison...")
    results_csv = (
        f"data/{outputs_folder}/{dir_suffix}LANDSAT_7_{pr}/co_registration_results.csv"
    )
    results_df = pd.read_csv(results_csv)
    results_df = results_df.replace(to_replace="Failed", value=("50, 50"))
    coreg_shifts = (
        results_df["Co-Register Shifts"]
        .apply(
            lambda x: (
                float(x.split(",")[0].replace("(", "")),
                float(x.split(",")[1].replace(")", "")),
            )
        )
        .to_list()
    )
    karios_shifts = (
        results_df["Karios Shifts"]
        .apply(
            lambda x: (
                float(x.split(",")[0].replace("(", "")),
                float(x.split(",")[1].replace(")", "")),
            )
        )
        .to_list()
    )
    arosics_shifts = (
        results_df["AROSICS Shifts"]
        .apply(
            lambda x: (
                float(x.split(",")[0].replace("(", "")),
                float(x.split(",")[1].replace(")", "")),
            )
        )
        .to_list()
    )
    arosics_edge_shifts = (
        results_df["AROSICS Edge Shifts"]
        .apply(
            lambda x: (
                float(x.split(",")[0].replace("(", "")),
                float(x.split(",")[1].replace(")", "")),
            )
        )
        .to_list()
    )
    print("Co-Register shifts:", coreg_shifts)
    print("Karios shifts:", karios_shifts)
    print("AROSICS shifts:", arosics_shifts)
    print("AROSICS Edge shifts:", arosics_edge_shifts)
    tool_names = ["Co_Register", "Karios", "AROSICS", "AROSICS_edge"]
    tool_shifts = [
        coreg_shifts,
        karios_shifts,
        arosics_shifts,
        arosics_edge_shifts,
    ]

    for tool_name, tool_shift in zip(tool_names, tool_shifts):
        output_dir = f"data/{outputs_folder}_L7/{dir_suffix}{platform}_{pr}/{tool_name}"
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f"{output_dir}/Aligned", exist_ok=True)
        processed_output_images = []
        processed_tgt_images = []
        for i, tgt_image in enumerate(tgt_images):
            output_path = os.path.join(
                f"{output_dir}/Aligned", os.path.basename(tgt_image)
            )
            warp_affine_dataset(
                tgt_image,
                output_path,
                translation_x=tool_shift[i][0],
                translation_y=tool_shift[i][1],
            )
            processed_output_images.append(output_path)
            processed_tgt_images.append(tgt_image)
        generate_results_from_raw_inputs(
            ref_image,
            processed_output_images,
            processed_tgt_images,
            output_dir=output_dir,
            shifts=np.array(tool_shift),
            run_time=0.0,
            target_ids=list(range(len(processed_tgt_images))),
        )
    root_output = f"data/{outputs_folder}_L7/{dir_suffix}{platform}_{pr}"

    combine_comparison_results(root_output)