In [1]:
import os
import json
import numpy as np
from datetime import datetime
from matplotlib import pyplot as plt

from osgeo import gdal

In [2]:
DATA_ROOT_DIR = "/Volumes/X/Data/fusion-s1-s2/"
S2_ROOT_PATH = f"{DATA_ROOT_DIR}s2/sre-10m/"
ORBIT = "044"
S1_ROOT_PATH = f"{DATA_ROOT_DIR}s1db/32VNH/threeband/{ORBIT}/"

In [16]:
with open("data/candidates_cloudy.json", "r") as f:
    candidates = json.load(f)

In [4]:
def closest_date(target_date, date_array):
    target = datetime.strptime(target_date, '%Y%m%d')
    date_array = [d for d in date_array if "Store" not in d]
    date_array = [datetime.strptime(date, '%Y%m%d') for date in date_array]
    closest_date = min(date_array, key=lambda x: abs(target - x))
    return closest_date.strftime('%Y%m%d')

In [5]:
DATASET_TRAIN, DATASET_TEST, DATASET_VALIDATION, DATASET = {}, {}, {}, {}

In [8]:
s1_dates = [d.split("_")[-1] for d in os.listdir(f"data/cropped/s1/")]
idx = 0

for k, v in candidates.items():
    cloudy = v["cloudy"]
    cloudy_name = "_".join(v["cloudy"].split("_")[:3])

    cloud_free = v["cloud_free"]
    cloud_free_name = "_".join(v["cloud_free"].split("_")[:3])
    date_cloudy = cloudy.split("_")[2]

    s1_date = closest_date(date_cloudy, s1_dates)

    TEMP_DATASET = {
        "s2_cloudy_B02": f"data/cropped/s2/{cloudy_name}/{cloudy_name}_B02/{'_'.join(cloudy.split('_')[:3])}_B02_{'_'.join(cloudy.split('_')[3:])}",
        "s2_cloudy_B03": f"data/cropped/s2/{cloudy_name}/{cloudy_name}_B03/{'_'.join(cloudy.split('_')[:3])}_B03_{'_'.join(cloudy.split('_')[3:])}",
        "s2_cloudy_B04": f"data/cropped/s2/{cloudy_name}/{cloudy_name}_B04/{'_'.join(cloudy.split('_')[:3])}_B04_{'_'.join(cloudy.split('_')[3:])}",
        # "s2_cloudy_B05": f"data/cropped/s2/{cloudy_name}/{cloudy_name}_B05/{'_'.join(cloudy.split('_')[:3])}_B05_{'_'.join(cloudy.split('_')[3:])}",
        # "s2_cloudy_B06": f"data/cropped/s2/{cloudy_name}/{cloudy_name}_B06/{'_'.join(cloudy.split('_')[:3])}_B06_{'_'.join(cloudy.split('_')[3:])}",
        # "s2_cloudy_B07": f"data/cropped/s2/{cloudy_name}/{cloudy_name}_B07/{'_'.join(cloudy.split('_')[:3])}_B07_{'_'.join(cloudy.split('_')[3:])}",
        "s2_cloudy_B08": f"data/cropped/s2/{cloudy_name}/{cloudy_name}_B08/{'_'.join(cloudy.split('_')[:3])}_B08_{'_'.join(cloudy.split('_')[3:])}",
        # "s2_cloudy_B8A": f"data/cropped/s2/{cloudy_name}/{cloudy_name}_B8A/{'_'.join(cloudy.split('_')[:3])}_B8A_{'_'.join(cloudy.split('_')[3:])}",
        # "s2_cloudy_B11": f"data/cropped/s2/{cloudy_name}/{cloudy_name}_B11/{'_'.join(cloudy.split('_')[:3])}_B11_{'_'.join(cloudy.split('_')[3:])}",
        # "s2_cloudy_B12": f"data/cropped/s2/{cloudy_name}/{cloudy_name}_B12/{'_'.join(cloudy.split('_')[:3])}_B12_{'_'.join(cloudy.split('_')[3:])}",
        "s2_cloud_free_B02": f"data/cropped/s2/{cloud_free_name}/{cloud_free_name}_B02/{'_'.join(cloud_free.split('_')[:3])}_B02_{'_'.join(cloud_free.split('_')[3:])}",
        "s2_cloud_free_B03": f"data/cropped/s2/{cloud_free_name}/{cloud_free_name}_B03/{'_'.join(cloud_free.split('_')[:3])}_B03_{'_'.join(cloud_free.split('_')[3:])}",
        "s2_cloud_free_B04": f"data/cropped/s2/{cloud_free_name}/{cloud_free_name}_B04/{'_'.join(cloud_free.split('_')[:3])}_B04_{'_'.join(cloud_free.split('_')[3:])}",
        # "s2_cloud_free_B05": f"data/cropped/s2/{cloud_free_name}/{cloud_free_name}_B05/{'_'.join(cloud_free.split('_')[:3])}_B05_{'_'.join(cloud_free.split('_')[3:])}",
        # "s2_cloud_free_B06": f"data/cropped/s2/{cloud_free_name}/{cloud_free_name}_B06/{'_'.join(cloud_free.split('_')[:3])}_B06_{'_'.join(cloud_free.split('_')[3:])}",
        # "s2_cloud_free_B07": f"data/cropped/s2/{cloud_free_name}/{cloud_free_name}_B07/{'_'.join(cloud_free.split('_')[:3])}_B07_{'_'.join(cloud_free.split('_')[3:])}",
        "s2_cloud_free_B08": f"data/cropped/s2/{cloud_free_name}/{cloud_free_name}_B08/{'_'.join(cloud_free.split('_')[:3])}_B08_{'_'.join(cloud_free.split('_')[3:])}",
        # "s2_cloud_free_B8A": f"data/cropped/s2/{cloud_free_name}/{cloud_free_name}_B8A/{'_'.join(cloud_free.split('_')[:3])}_B8A_{'_'.join(cloud_free.split('_')[3:])}",
        # "s2_cloud_free_B11": f"data/cropped/s2/{cloud_free_name}/{cloud_free_name}_B11/{'_'.join(cloud_free.split('_')[:3])}_B11_{'_'.join(cloud_free.split('_')[3:])}",
        # "s2_cloud_free_B12": f"data/cropped/s2/{cloud_free_name}/{cloud_free_name}_B12/{'_'.join(cloud_free.split('_')[:3])}_B12_{'_'.join(cloud_free.split('_')[3:])}",
        "s1_hv": f"data/cropped/s1/S1_32VNH_{s1_date}/S1_32VNH_{s1_date}_HV/S1_32VNH_{s1_date}_HV_{'_'.join(cloudy.split('_')[3:])}",
        "s1_vv": f"data/cropped/s1/S1_32VNH_{s1_date}/S1_32VNH_{s1_date}_VV/S1_32VNH_{s1_date}_VV_{'_'.join(cloudy.split('_')[3:])}",
    }

    if all([os.path.isfile(tv) for tv in TEMP_DATASET.values()]):
        DATASET[idx] = TEMP_DATASET
        idx += 1

In [9]:
with open("data/dataset_cloudy.json", "w") as f:
    json.dump(DATASET, f)

In [24]:
TO_REMOVE_SHADOW_ON_GT = [0,5,7,8,18,84,141,146,148,149,158,159,161,177,250,493,503,564,577,579,
                          584,685,686,687,737,743,745,746,748,913,941,950,966,973,1195,1200,1201,1215,1274,1285,1294,1305,1311,
                          1489,1490,1493,1502,1534,1541,1544]
TO_REMOVE_WRONG_GT = [24,53,54,55,57,58,60,64,65,68,71,79,171,172,173,174,175,176,260,
                      262,270,271,272,273,397,643,677,690,712,1078,1248,1249,1250,1251,1252,1253,1254,1255,1256,1257,1258,1259,
                      1260,1320,1474,1475]
TO_REMOVE_WEIRD_THINGS_ON_GT = [1,185]
TO_REMOVE_CLOUDS_ON_GT = [4,11,156,256,259,268,269,274,275,276,301,319,326,334,478,491,539,549,622,753,759,883,902,
                          911,942,1004,1159,1161,1219,1220,1223,1224,1414,1420,1482,1585]
TO_REMOVE = TO_REMOVE_SHADOW_ON_GT + TO_REMOVE_WRONG_GT + TO_REMOVE_WEIRD_THINGS_ON_GT + TO_REMOVE_CLOUDS_ON_GT

In [31]:
with open("data/to_remove.txt", "a") as f:
    for i in TO_REMOVE:
        f.write(str(DATASET[str(i)]["s2_cloud_free_B02"]) + "\n")

In [32]:
for i in TO_REMOVE:
    del DATASET[str(i)]

In [10]:
with open("data/dataset_cloudy.json", "w") as f:
    json.dump(DATASET, f)

In [28]:
# Separate dataset
with open("data/dataset_cloudy.json", "r") as f:
    DATASET = json.load(f)
    
dataset_keys = list(DATASET.keys())
np.random.shuffle(dataset_keys)
split_index_80 = int(0.8 * len(dataset_keys))
split_index_80 = int(0.9 * len(dataset_keys))

dataset_keys_train = dataset_keys[:split_index_80]
dataset_keys_test = dataset_keys[split_index_80:split_index_80]
dataset_keys_validation = dataset_keys[split_index_80:]

In [29]:
i = 0
for k in dataset_keys_train:
    DATASET_TRAIN[i] = DATASET[k]
    i += 1

with open("data/dataset_cloudy_train.json", "w") as f:
    json.dump(DATASET_TRAIN, f)

l = 0
for k in dataset_keys_test:
    DATASET_TEST[l] = DATASET[k]
    l += 1
with open("data/dataset_cloudy_test.json", "w") as f:
    json.dump(DATASET_TEST, f)

d = 0
for k in dataset_keys_validation:
    DATASET_TEST[d] = DATASET[k]
    d += 1
with open("data/dataset_cloudy_validation.json", "w") as f:
    json.dump(DATASET_TEST, f)


In [30]:
with open("data/dataset_cloudy_test.json", "r") as f:
    DATASET = json.load(f)

In [None]:
for k, v in DATASET.items():
    print(k)
    cB2 = gdal.Open(f"{v['s2_cloudy_B02']}").ReadAsArray()
    cB3 = gdal.Open(f"{v['s2_cloudy_B03']}").ReadAsArray()
    cB4 = gdal.Open(f"{v['s2_cloudy_B04']}").ReadAsArray()

    c = np.stack((cB4, cB3, cB2), axis=-1)/2000

    gB2 = gdal.Open(f"{v['s2_cloud_free_B02']}").ReadAsArray()
    gB3 = gdal.Open(f"{v['s2_cloud_free_B03']}").ReadAsArray()
    gB4 = gdal.Open(f"{v['s2_cloud_free_B04']}").ReadAsArray()

    g = np.stack((gB4, gB3, gB2), axis=-1)/2000

    plt.figure()

    # Create the first subplot, add title, and display the first image
    plt.subplot(1, 2, 1)
    plt.title("input")
    plt.imshow(c)
    plt.axis('off')

    # Create the second subplot, add title, and display the second image
    plt.subplot(1, 2, 2)
    plt.title("truth")
    plt.imshow(g)
    plt.axis('off')

    # Save the figure to a file
    print(os.getcwd())

    plt.savefig(f"filtering/{k}_{v['s2_cloudy_B02'].split('.')[0].split('/')[-1]}.png")

    # Close the figure
    plt.close()
    