In [None]:
import os
import numpy as np
from glob import glob
from datetime import datetime

In [None]:
input_folder = "data/all_data"
output_folder = os.path.join(input_folder, "aggregated-temporal")
os.makedirs(output_folder, exist_ok=True)

In [None]:
# load files
def extract_date(filename: str) -> datetime:
    basename = os.path.basename(filename)
    date_str = basename.split("_")[0]
    return datetime.strptime(date_str, "%Y-%m-%d")

# sort files
file_paths = sorted(glob(os.path.join(input_folder, "*.npy")), key=extract_date)

# filter for single year data
file_paths = [fp for fp in file_paths if extract_date(fp).year == 2024]

file_paths = sorted(glob(os.path.join(input_folder, "*_mask.npy")), key=extract_date)

# set priotity colors
priority_values = [
    np.array([255, 0, 0]),  # max 
    np.array([0, 0, 255]),
    np.array([0, 255, 0])   # min 
]

# function to check if pixel is valid
def is_valid(pixel: np.ndarray) -> bool:
    return not (np.all(pixel == [0, 0, 0]) or np.all(pixel == [255, 255, 255]))

def match_priority(value: np.ndarray) -> bool:
    for p in priority_values:
        if np.all(value == p):
            return True
    return False

def fill_pixels(base: np.ndarray, candidate: np.ndarray) -> np.ndarray:
    filled = base.copy()
    mask_invalid = (np.all(base == [0, 0, 0], axis=-1) | np.all(base == [255, 255, 255], axis=-1))
    
    for p in priority_values:
        mask_candidate = np.all(candidate == p, axis=-1)
        mask = mask_invalid & mask_candidate
        filled[mask] = p
        mask_invalid[mask] = False  # update remaining invalid pixels
    return filled

COLOR_TO_CLASS = {
    (255, 0, 0): "built_up",
    (0, 0, 255): "water",
    (0, 255, 0): "green",
    (255, 255, 255): "cloud",
    (0, 0, 0): "unclassified",
}

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items()}

# higher number = higher priority
CLASS_PRIORITY = {
    "built_up": 4,
    "water": 3,
    "green": 2,
    "cloud": 1,
    "unclassified": 0,
}



def temporal_correction(mask_stack: list, window: int =2) -> np.ndarray:
    """
    Applies temporal consistency correction using class priority.
    mask_stack: [T, H, W, 3]
    window: number of time steps to look backward and forward (e.g., 2 = +-2)
    """
    corrected = mask_stack.copy()
    T, H, W, _ = mask_stack.shape

    for h in range(H):
        for w in range(W):
            # Get class series for this pixel over time
            class_series = []
            for t in range(T):
                color = tuple(mask_stack[t, h, w])
                cls = COLOR_TO_CLASS.get(color, "unclassified")
                class_series.append(cls)

            # Correct class_series temporally
            new_class_series = class_series.copy()
            for t in range(T):
                current_class = class_series[t]
                current_priority = CLASS_PRIORITY[current_class]

                # Gather context window
                t_start = max(0, t - window)
                t_end = min(T, t + window + 1)
                surrounding = class_series[t_start:t_end]

                # Count frequencies and priorities
                counts = {}
                for cls in surrounding:
                    counts[cls] = counts.get(cls, 0) + 1

                # Pick the most frequent class *with highest priority*
                best_class = current_class
                max_priority = current_priority
                max_votes = 0

                for cls, count in counts.items():
                    pri = CLASS_PRIORITY[cls]
                    if count > max_votes or (count == max_votes and pri > max_priority):
                        best_class = cls
                        max_votes = count
                        max_priority = pri

                new_class_series[t] = best_class

            # Apply new corrected class series back
            for t in range(T):
                corrected[t, h, w] = CLASS_TO_COLOR[new_class_series[t]]

    return corrected


In [None]:
# load masks
all_masks = []
dates = []

for path in file_paths:
    mask = np.load(path)
    all_masks.append(mask)
    dates.append(extract_date(path))

mask_stack = np.stack(all_masks)  # shape: (time, height, width, 3)

# temporal correction / aggregation
corrected_stack = temporal_correction(mask_stack)

# === AGGREGATE AS BEFORE ===
i = 0
n = len(corrected_stack)

while i < n:
    base = corrected_stack[i]
    base_date = dates[i]
    last_used = base_date
    print(">", base)

    j = i + 1
    while j < n:
        candidate = corrected_stack[j]
        print(">>", candidate)
        base = fill_pixels(base, candidate)
        valid_pixels = np.array([is_valid(px) for px in base.reshape(-1, 3)])
        if valid_pixels.sum() / len(valid_pixels) >= 0.95:
            last_used = dates[j]
            break
        last_used = dates[j]
        j += 1

    out_name = f"{last_used.strftime('%Y-%m-%d')}_aggregated_mask.npy"
    out_path = os.path.join(output_folder, out_name)
    np.save(out_path, base)
    print(f"Saved aggregated file: {out_path}")

    i = j + 1


> [[[  0 255   0]
  [  0 255   0]
  [  0 255   0]
  ...
  [  0   0   0]
  [  0   0   0]
  [  0   0   0]]

 [[  0 255   0]
  [  0 255   0]
  [  0 255   0]
  ...
  [  0   0   0]
  [  0   0   0]
  [  0   0   0]]

 [[  0 255   0]
  [  0 255   0]
  [255   0   0]
  ...
  [  0   0   0]
  [  0   0   0]
  [  0   0   0]]

 ...

 [[255   0   0]
  [  0 255   0]
  [  0 255   0]
  ...
  [  0 255   0]
  [  0 255   0]
  [  0 255   0]]

 [[  0   0   0]
  [  0 255   0]
  [  0 255   0]
  ...
  [  0 255   0]
  [  0 255   0]
  [  0 255   0]]

 [[  0 255   0]
  [  0 255   0]
  [  0   0   0]
  ...
  [  0 255   0]
  [  0 255   0]
  [  0 255   0]]]
>> [[[  0 255   0]
  [  0 255   0]
  [  0 255   0]
  ...
  [  0   0   0]
  [  0   0   0]
  [  0   0   0]]

 [[  0 255   0]
  [  0 255   0]
  [  0 255   0]
  ...
  [  0   0   0]
  [  0   0   0]
  [  0   0   0]]

 [[  0 255   0]
  [255   0   0]
  [255   0   0]
  ...
  [  0   0   0]
  [  0   0   0]
  [  0   0   0]]

 ...

 [[255   0   0]
  [  0 255   0]
  [  0 255   0]

KeyboardInterrupt: 