In [None]:
from utils import *
import getpass

# from landsatxplore.earthexplorer import EarthExplorer
import tarfile
import boto3
from arosics import COREG, COREG_LOCAL
from subprocess import run
import shlex

keep_original_band_scenes = False
one_per_month = True
dir_suffix = ""
if (dir_suffix != "") and (not dir_suffix.endswith("/")):
    dir_suffix = dir_suffix + "/"

aws_session = rasterio.session.AWSSession(boto3.Session(), requester_pays=True)
# aws_session = rasterio.session.AWSSession(boto3.Session())

query_and_process = False

In [None]:
wa_bbox = resize_bbox(
    read_kml_polygon("data/inputs_old/LANDSAT_8_127111/WA.kml")[1], 0.1
)
aoi_polys = kml_to_poly("data/inputs_old/aois.kml").geoms

bbox_list = [
    [67.45, -72.55, 67.55, -72.45],  # Amery bed rock
    [69.2, -68.1, 69.4, -67.9],  # Amery top
    [wa_bbox.left, wa_bbox.bottom, wa_bbox.right, wa_bbox.top],  # WA sand dunes
    *[list(p.bounds) for p in aoi_polys],  # AOI polygons
]

In [None]:
platform = "LANDSAT_8"
r_channel = "red"
g_channel = "green"
b_channel = "blue"

aoi_index = 0

if query_and_process:
    query = get_search_query(
        bbox_list[aoi_index],
        # start_date="",
        start_date="2014-01-01T00:00:00",
        end_date="2016-01-01T00:00:00",
        platform=platform,
    )

    server_url = "https://landsatlook.usgs.gov/stac-server/search"
    features = query_stac_server(query, server_url)

In [None]:
if query_and_process:
    scene_dict, scene_list = find_scenes_dict(
        features,
        one_per_month=True,
        # start_end_years=[2009, 2010],
        acceptance_list=["red", "green", "blue"],
    )
    path_rows = list(scene_dict.keys())[1:2]
    print(path_rows)
    dates = [list(scene_dict[pr].keys()) for pr in path_rows]
    print(dates)

In [None]:
# sn = combine_scene_dicts([scene_dict['126111'], scene_dict['127111']])
# sn_closest_pair = get_pair_dict(sn, "closest")
# sn_farthest_pair = get_pair_dict(sn, "farthest")

In [None]:
path_row = "127111"
# closest_pair = get_pair_dict(scene_dict[path_row], "closest")
# farthest_pair = get_pair_dict(scene_dict[path_row], "farthest")

In [None]:
if query_and_process:
    path_row_idx = path_rows.index(path_row)
    date = dates[path_row_idx][0]

    r_url = scene_dict[path_row][date][0][r_channel]
    g_url = scene_dict[path_row][date][0][g_channel]
    b_url = scene_dict[path_row][date][0][b_channel]

    r_aws = scene_dict[path_row][date][0][r_channel + "_alternate"]
    g_aws = scene_dict[path_row][date][0][g_channel + "_alternate"]
    b_aws = scene_dict[path_row][date][0][b_channel + "_alternate"]

    r_band_suffix = os.path.splitext(os.path.basename(r_url))[0].split("_")[-1]
    g_band_suffix = os.path.splitext(os.path.basename(g_url))[0].split("_")[-1]
    b_band_suffix = os.path.splitext(os.path.basename(b_url))[0].split("_")[-1]

    for pr in path_rows:
        counter = 1
        true_color_dir = f"data/inputs/{dir_suffix}{platform}_{pr}/true_color"
        os.makedirs(true_color_dir, exist_ok=True)

        true_color_ds_dir = f"data/inputs/{dir_suffix}{platform}_{pr}/true_color_ds"
        os.makedirs(true_color_ds_dir, exist_ok=True)

        pr_dict = scene_dict[pr]
        closest_pair = get_pair_dict(pr_dict, "closest")
        farthest_pair = get_pair_dict(pr_dict, "farthest")

        pr_date_list = closest_pair + [farthest_pair[1]]
        for el in pr_date_list:
            print(
                f"Now downloading and processing pairs for {el['scene_name']} and path_row: {pr}, scene {counter} from total of 3.",
                end="\r",
            )
            counter += 1
            r_url = el[r_channel + "_alternate"]
            g_url = el[g_channel + "_alternate"]
            b_url = el[b_channel + "_alternate"]
            output_dir = (
                f"data/inputs/{dir_suffix}{platform}_{pr}/Originals/{el['scene_name']}"
            )
            os.makedirs(output_dir, exist_ok=True)
            r_output = os.path.join(output_dir, os.path.basename(r_url))
            g_output = os.path.join(output_dir, os.path.basename(g_url))
            b_output = os.path.join(output_dir, os.path.basename(b_url))
            r_img, r_meta = stream_scene_from_aws(r_url, aws_session)
            g_img, g_meta = stream_scene_from_aws(g_url, aws_session)
            b_img, b_meta = stream_scene_from_aws(b_url, aws_session)

            imgs = [r_img, g_img, b_img]
            outputs = [r_output, g_output, b_output]
            metas = [r_meta, g_meta, b_meta]
            for i, img in enumerate(imgs):
                with rasterio.open(outputs[i], "w", **metas[i]["profile"]) as ds:
                    ds.write(img[0, :, :], 1)

            files = glob.glob(f"{output_dir}/**")
            r_band = list(filter(lambda f: f.endswith(f"_{r_band_suffix}.TIF"), files))[
                0
            ]
            g_band = list(filter(lambda f: f.endswith(f"_{g_band_suffix}.TIF"), files))[
                0
            ]
            b_band = list(filter(lambda f: f.endswith(f"_{b_band_suffix}.TIF"), files))[
                0
            ]
            true_bands = [r_band, g_band, b_band]
            tc_file = (
                f"{os.path.join(true_color_dir, os.path.basename(output_dir))}_TC.TIF"
            )
            tc_file_ds = os.path.join(true_color_ds_dir, os.path.basename(tc_file))
            make_true_color_scene(true_bands, tc_file, gray_scale=True, averaging=True)
            downsample_dataset(tc_file, 0.2, tc_file_ds)

            el["local_path"] = tc_file
            el["local_path_ds"] = tc_file_ds

            if not keep_original_band_scenes:
                shutil.rmtree(
                    f"data/inputs/{dir_suffix}{platform}_{pr}/Originals",
                    ignore_errors=True,
                )

        cols = ["Reference", "Closest_target", "Farthest_target"]
        df = pd.DataFrame(
            {
                cols[i]: [
                    el["local_path"],
                    el["local_path_ds"],
                ]
                for i, el in enumerate(pr_date_list)
            },
            columns=cols,
        )
        df.to_csv(
            f"data/inputs/{dir_suffix}{platform}_{pr}/pairs.csv",
            index=False,
        )

In [None]:
scene_df = pd.read_csv(f"data/inputs/{dir_suffix}{platform}_{path_row}/pairs.csv")
ref_image = scene_df["Reference"][0]
tgt_images = [
    scene_df["Closest_target"][0],
    scene_df["Farthest_target"][0],
]

#### Co_Register

In [None]:
output_path = f"data/outputs/{dir_suffix}{platform}_{path_row}/Co_Register"
_, shifts = co_register(
    ref_image,
    tgt_images,
    output_path=output_path,
    return_shifted_images=True,
    use_overlap=True,
    phase_corr_filter=True,
    # band_number=2,
)

#### AROSICS

In [None]:
tgt_images_copy = tgt_images.copy()
output_dir = f"data/outputs/{dir_suffix}{platform}_{path_row}/AROSICS/Aligned"
local_outputs = [
    os.path.join(
        output_dir,
        os.path.basename(tgt),
    )
    for tgt in tgt_images_copy
]
os.makedirs(output_dir, exist_ok=True)

processed_output_images = []
print(f"Reference image: {ref_image}")
for i, tgt_image in enumerate(tgt_images_copy):
    print(f"Coregistering {tgt_image}")
    coreg_local = COREG_LOCAL(
        im_ref=ref_image,
        im_tgt=tgt_image,
        grid_res=250,
        # max_points=200,
        path_out=local_outputs[i],
        fmt_out="GTIFF",
        # v=True,
        nodata=(0.0, 0.0),
        # r_b4match=2,
        # s_b4match=2,
        align_grids=True,
        # max_iter=10,
        # max_shift=10,
        # CPUs=8,
        ignore_errors=True,
        min_reliability=30,
    )
    res = coreg_local.correct_shifts()
    if not coreg_local.success:
        print(
            f"Coregistration not successfull for {tgt_image}. Removing the corresponding output: {local_outputs[i]}"
        )
        if os.path.isfile(local_outputs[i]):
            os.remove(local_outputs[i])
    else:
        processed_output_images.append(local_outputs[i])


generate_results_from_raw_inputs(
    ref_image,
    processed_output_images,
    output_dir=f"data/outputs/{dir_suffix}{platform}_{path_row}/AROSICS",
)

#### Karios

In [None]:
tgt_images_copy = tgt_images.copy()
output_dir = f"data/outputs/{dir_suffix}{platform}_{path_row}/Karios"
os.makedirs(output_dir, exist_ok=True)

temp_dir = os.path.join(output_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
ref_profile = rasterio.open(ref_image).profile
tgt_profiles = [rasterio.open(t).profile for t in tgt_images_copy]
for i, tgt_profile in enumerate(tgt_profiles):
    downsample = False
    if tgt_profile["height"] != ref_profile["height"]:
        print(
            f"Target image {tgt_images_copy[i]} has different height than reference image {ref_image}"
        )
        downsample = True
    if tgt_profile["width"] != ref_profile["width"]:
        print(
            f"Target image {tgt_images_copy[i]} has different width than reference image {ref_image}"
        )
        downsample = True
    if downsample:
        downsample_dataset(
            tgt_images_copy[i],
            force_shape=(ref_profile["height"], ref_profile["width"]),
            output_file=f"data/outputs/{dir_suffix}{platform}_{path_row}/Karios/temp/{os.path.basename(tgt_images_copy[i])}",
        )
        tgt_images_copy[i] = (
            f"data/outputs/{dir_suffix}{platform}_{path_row}/Karios/temp/{os.path.basename(tgt_images_copy[i])}"
        )


log_file = f"data/outputs/{dir_suffix}{platform}_{path_row}/Karios/karios.log"
if os.path.isfile(log_file):
    os.remove(log_file)
for i, tgt_image in enumerate(tgt_images_copy):
    try:
        cmd = f"python /home/ubuntu/Coreg/karios/karios/karios.py --out {output_dir} --log-file-path {log_file} {tgt_image} {ref_image}"
        print(f"Running {cmd}")
        run(shlex.split(cmd))
    except Exception as e:
        print(f"Error running karios for {tgt_image}: {e}")
        continue

shutil.rmtree(temp_dir, ignore_errors=True)
tgt_images_copy = tgt_images.copy()
process_ids = {}
for i, tgt in enumerate(tgt_images_copy):
    process_ids[os.path.basename(tgt)] = i
pattern = "_T\\d+_SR_TC.TIF"
scene_names = []
shifts = []
with open(log_file, "r") as f:
    for line in f:
        if bool(re.search(pattern, line)):
            scene_basename = os.path.basename(line.strip().split(" ")[-1])
            for tgt_image in tgt_images_copy:
                if (tgt_image.endswith(scene_basename)) and ("DX/DY(KLT) MEAN" in line):
                    scene_names.append(tgt_image)
                    splits = line.strip().split(" ")
                    shifts.append([float(splits[-3]), float(splits[-1])])
                    break

shifts_dict = {}
for f, sh in zip(scene_names, shifts):
    shifts_dict[f] = sh

print(shifts_dict)

output_dir = f"data/outputs/{dir_suffix}{platform}_{path_row}/Karios/Aligned"
os.makedirs(output_dir, exist_ok=True)
processed_output_images = []
for key in list(shifts_dict.keys()):
    output_path = os.path.join(output_dir, os.path.basename(key))
    shift_x, shift_y = shifts_dict[key]
    tgt_aligned = warp_affine_dataset(
        key, output_path, translation_x=shift_x, translation_y=shift_y
    )
    processed_output_images.append(output_path)


generate_results_from_raw_inputs(
    ref_image,
    processed_output_images,
    output_dir=f"data/outputs/{dir_suffix}{platform}_{path_row}/KARIOS",
)

In [None]:
ref_image

In [None]:
tgt_images

In [None]:
ref_time = datetime.strptime(os.path.basename(ref_image).split("_")[3], "%Y%m%d")
tgt_times = [
    datetime.strptime(os.path.basename(tgt).split("_")[3], "%Y%m%d")
    for tgt in tgt_images
]

In [None]:
[(tgt_time - ref_time).days for tgt_time in tgt_times]

In [None]:
import cv2 as cv
import numpy as np
import rasterio as rio