In [None]:
import os
import rasterio
import cv2 as cv
import sys
import pandas as pd
from subprocess import run
import shlex

sys.path.insert(1, "../")
from utils import *

In [None]:
product_id = "LANDSAT_8_108074"
scenes = glob.glob(f"../data/inputs/{product_id}/true_color/**")
ref_image = scenes[0]
tgt_images = scenes[1:]
process_ids = {}
for i, tgt in enumerate(tgt_images):
    process_ids[os.path.basename(tgt)] = i

In [None]:
output_dir = f"../data/outputs/KARIOS/{product_id}"
os.makedirs(output_dir, exist_ok=True)
log_file = f"../data/outputs/KARIOS/{product_id}/karios.log"
for tgt_image in tgt_images:
    cmd = f"python /home/ubuntu/Coreg/karios/karios/karios.py {tgt_image} {ref_image} --out {output_dir} --log-file-path {log_file}"
    print(f"Running {cmd}")
    run(shlex.split(cmd))

In [None]:
scene_names = []
shifts = []
with open(log_file, "r") as f:
    for line in f:
        if "_T1_SR_TC.TIF" in line:
            scene_names.append(line.strip().split(" ")[-1])
        if "DX/DY(KLT) MEAN" in line:
            splits = line.strip().split(" ")
            shifts.append([float(splits[-3]), float(splits[-1])])

shifts_dict = {}
for f, sh in zip(scene_names, shifts):
    shifts_dict[f] = sh

In [None]:
shifts_dict

In [None]:
output_dir = f"../data/outputs/KARIOS/{product_id}/Aligned"
os.makedirs(output_dir, exist_ok=True)
for key in list(shifts_dict.keys()):
    output_path = os.path.join(output_dir, os.path.basename(key))
    shift_x, shift_y = shifts_dict[key]
    tgt_aligned = warp_affine_dataset(
        key, output_path, translation_x=shift_x, translation_y=shift_y
    )

In [None]:
processed_output_images = glob.glob(f"../data/outputs/KARIOS/{product_id}/Aligned/**")

aligned_process_ids = [
    process_ids[os.path.basename(t)] for t in processed_output_images
]
perm = np.argsort(aligned_process_ids)
processed_output_images = [processed_output_images[idx] for idx in perm]

tgt_aligned_list = []
ref_imgs = []
for i, tgt in enumerate(processed_output_images):
    _, (_, _, ref_overlap, tgt_overlap) = find_overlap(ref_image, tgt, True)
    ref_imgs.append(cv.cvtColor(ref_overlap, cv.COLOR_BGR2GRAY).astype("uint8"))
    tgt_aligned_list.append(cv.cvtColor(tgt_overlap, cv.COLOR_BGR2GRAY).astype("uint8"))

datasets_paths = [ref_image] + processed_output_images
ssims_aligned = [
    np.round(ssim(ref_imgs[id], tgt_aligned_list[id], win_size=3), 3)
    for id in range(len(tgt_aligned_list))
]
mse_aligned = [
    np.round(mse(ref_imgs[id], tgt_aligned_list[id]), 3)
    for id in range(len(tgt_aligned_list))
]
target_titles = [f"target_{str(i)}" for i in sorted(aligned_process_ids)]
datasets_titles = ["Reference"] + [
    f"{target_title}, ssim:{ssim_score}, mse:{mse_score}"
    for target_title, ssim_score, mse_score in zip(
        target_titles, ssims_aligned, mse_aligned
    )
]
make_difference_gif(
    datasets_paths,
    f"../data/outputs/KARIOS/{product_id}/karios.gif",
    datasets_titles,
    mosaic_scenes=True,
)

tgt_images = [tgt_images[idx] for idx in sorted(aligned_process_ids)]
tgt_raw_list = []
ref_imgs = []
for i, tgt in enumerate(tgt_images):
    _, (_, _, ref_overlap, tgt_overlap) = find_overlap(ref_image, tgt, True)
    ref_imgs.append(cv.cvtColor(ref_overlap, cv.COLOR_BGR2GRAY).astype("uint8"))
    tgt_raw_list.append(cv.cvtColor(tgt_overlap, cv.COLOR_BGR2GRAY).astype("uint8"))

datasets_paths = [ref_image] + tgt_images
ssims_aligned_raw = [
    np.round(ssim(ref_imgs[id], tgt_raw_list[id], win_size=3), 3)
    for id in range(len(tgt_raw_list))
]
mse_aligned_raw = [
    np.round(mse(ref_imgs[id], tgt_raw_list[id]), 3) for id in range(len(tgt_raw_list))
]
datasets_titles = ["Reference"] + [
    f"{target_title}, ssim:{ssim_score}, mse:{mse_score}"
    for target_title, ssim_score, mse_score in zip(
        target_titles, ssims_aligned_raw, mse_aligned_raw
    )
]
make_difference_gif(
    datasets_paths,
    f"../data/outputs/KARIOS/{product_id}/karios_raw.gif",
    datasets_titles,
    mosaic_scenes=True,
)

out_ssim = f"../data/outputs/KARIOS/{product_id}/karios.csv"
out_ssim_df = pd.DataFrame(
    zip(target_titles, ssims_aligned_raw, ssims_aligned, mse_aligned_raw, mse_aligned),
    columns=["Title", "SSIM Raw", "SSIM Aligned", "MSE Raw", "MSE Aligned"],
    index=None,
)
out_ssim_df.to_csv(out_ssim, encoding="utf-8")