In [1]:
import geopandas as gpd
import rasterio
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import yaml
from datetime import datetime, timedelta
from distmetrics import (compute_transformer_zscore, 
                         load_trained_transformer_model, 
                         compute_mahalonobis_dist_2d, 
                          compute_mahalonobis_dist_1d, compute_log_ratio_decrease_metric)
from tqdm import tqdm
import json

# Parameters

In [20]:
EVENT_NAME = 'chile_fire_2024'
TRACK_IDX = 0

MAX_PRE_EVENT_GROUPS = 2
MAX_POST_EVENT_GROUPS = 2

LOOKBACK_DELTA_DAYS = 0

MAX_PRE_IMGS = 8
MIN_PRE_IMGS = 2

DISTMETRIC_NAME = 'transformer' # options 'transformer', 'mahalanobis_2d', 'mahalanobis_vh', 'mahalanobis_1d_max', 'log_ratio_vh'

In [21]:
assert DISTMETRIC_NAME in ['transformer', 'mahalanobis_2d', 'mahalanobis_vh', 'mahalanobis_1d_max', 'log_ratio_vh']

# Directories

In [22]:
DIST_EVENT_REPO = Path('../dist-s1-events')

DATA_DIR = DIST_EVENT_REPO / 'out'
EVENT_DIR = DATA_DIR / EVENT_NAME
EVENT_DIR = DATA_DIR / EVENT_NAME
EVENT_YAML_DIR = DIST_EVENT_REPO / 'events'
rtc_dir = EVENT_DIR / 'rtc_ts_merged'
tracks = sorted([int(d.stem.replace('track', '')) for d in rtc_dir.glob('*/')])
ts_dir = rtc_dir / f'track{tracks[TRACK_IDX]}'
WATER_MASK_DIR = EVENT_DIR / 'water_mask'

In [23]:
dirs = [DIST_EVENT_REPO,
DATA_DIR,
EVENT_DIR,
EVENT_DIR,
EVENT_YAML_DIR,
WATER_MASK_DIR,
rtc_dir,
ts_dir]

[p.exists() for p in dirs]

[True, True, True, True, True, True, True, True]

In [24]:
yaml_file = EVENT_YAML_DIR / f'{EVENT_NAME}.yml'
with open(yaml_file) as f:
    event_dict = yaml.safe_load(f)["event"]
event_dict

{'event_name': 'chile_fire_2024',
 'bounds': [-71.53071089, -33.20143816, -71.2964628, -32.98270579],
 'event_date': '2024-02-04',
 'pre_event_window_days': 180,
 'post_event_window_days': 100,
 'rtc_track_numbers': [18, 156],
 'mgrs_tiles': ['19HBD'],
 'source_id': 'EMSR715 AOI1',
 'dist_hls_confirmed_change_min_days': 30,
 'links': ['https://earthobservatory.nasa.gov/images/152411/fires-rage-in-central-chile',
  'https://en.wikipedia.org/wiki/2024_Chile_wildfires']}

# RTC

In [25]:
EVENT_DATE = datetime.strptime(event_dict['event_date'], '%Y-%m-%d')
TRACKS = event_dict['rtc_track_numbers']
EVENT_DATE, TRACKS

In [27]:
rtc_acq_dts = sorted([datetime.strptime(p.stem.split('_')[1], '%Y-%m-%d') for p in ts_dir.glob('*VV.tif')])
rtc_acq_dts

[datetime.datetime(2023, 11, 1, 0, 0),
 datetime.datetime(2023, 11, 13, 0, 0),
 datetime.datetime(2023, 11, 25, 0, 0),
 datetime.datetime(2023, 12, 7, 0, 0),
 datetime.datetime(2023, 12, 19, 0, 0),
 datetime.datetime(2023, 12, 31, 0, 0),
 datetime.datetime(2024, 1, 12, 0, 0),
 datetime.datetime(2024, 1, 24, 0, 0),
 datetime.datetime(2024, 2, 5, 0, 0),
 datetime.datetime(2024, 2, 17, 0, 0),
 datetime.datetime(2024, 2, 29, 0, 0),
 datetime.datetime(2024, 3, 12, 0, 0),
 datetime.datetime(2024, 3, 24, 0, 0)]

At least, we want:

1. post-event image + {pre-event} images AND pre-event image + {pre-event images} mirroring operational inputs
2. twice - the first and second acquistion after event = post AND two sets of preimages before acquisition

In [28]:
pre_dts_all =  [dt for dt in rtc_acq_dts if dt < EVENT_DATE - timedelta(days=LOOKBACK_DELTA_DAYS)]
pre_dts_all, len(pre_dts_all)

([datetime.datetime(2023, 11, 1, 0, 0),
  datetime.datetime(2023, 11, 13, 0, 0),
  datetime.datetime(2023, 11, 25, 0, 0),
  datetime.datetime(2023, 12, 7, 0, 0),
  datetime.datetime(2023, 12, 19, 0, 0),
  datetime.datetime(2023, 12, 31, 0, 0),
  datetime.datetime(2024, 1, 12, 0, 0),
  datetime.datetime(2024, 1, 24, 0, 0)],
 8)

In [29]:
post_dts_all =  [dt for dt in rtc_acq_dts if (dt >= EVENT_DATE)]
post_dts_all

[datetime.datetime(2024, 2, 5, 0, 0),
 datetime.datetime(2024, 2, 17, 0, 0),
 datetime.datetime(2024, 2, 29, 0, 0),
 datetime.datetime(2024, 3, 12, 0, 0),
 datetime.datetime(2024, 3, 24, 0, 0)]

# Generate Groups

In [30]:
pre_event_group_dts = [{'pre': pre_dts_all[-MAX_PRE_IMGS - i - 1: - i],
                     'post': [pre_dts_all[-i]]        
                    } for i in range(1, MAX_PRE_EVENT_GROUPS + 1)
                    if len(pre_dts_all[-MAX_PRE_IMGS - i - 1: - i]) >= MIN_PRE_IMGS]
pre_event_group_dts

[{'pre': [datetime.datetime(2023, 11, 1, 0, 0),
   datetime.datetime(2023, 11, 13, 0, 0),
   datetime.datetime(2023, 11, 25, 0, 0),
   datetime.datetime(2023, 12, 7, 0, 0),
   datetime.datetime(2023, 12, 19, 0, 0),
   datetime.datetime(2023, 12, 31, 0, 0),
   datetime.datetime(2024, 1, 12, 0, 0)],
  'post': [datetime.datetime(2024, 1, 24, 0, 0)]},
 {'pre': [datetime.datetime(2023, 11, 1, 0, 0),
   datetime.datetime(2023, 11, 13, 0, 0),
   datetime.datetime(2023, 11, 25, 0, 0),
   datetime.datetime(2023, 12, 7, 0, 0),
   datetime.datetime(2023, 12, 19, 0, 0),
   datetime.datetime(2023, 12, 31, 0, 0)],
  'post': [datetime.datetime(2024, 1, 12, 0, 0)]}]

In [31]:
post_event_group_dts = [{'pre': pre_dts_all[-MAX_PRE_IMGS -i:],
                     'post': [post_dts_all[i]]        
                    } for i in range(MAX_POST_EVENT_GROUPS)
                    if len(pre_dts_all[-MAX_PRE_IMGS -i:]) >= MIN_PRE_IMGS]
post_event_group_dts

[{'pre': [datetime.datetime(2023, 11, 1, 0, 0),
   datetime.datetime(2023, 11, 13, 0, 0),
   datetime.datetime(2023, 11, 25, 0, 0),
   datetime.datetime(2023, 12, 7, 0, 0),
   datetime.datetime(2023, 12, 19, 0, 0),
   datetime.datetime(2023, 12, 31, 0, 0),
   datetime.datetime(2024, 1, 12, 0, 0),
   datetime.datetime(2024, 1, 24, 0, 0)],
  'post': [datetime.datetime(2024, 2, 5, 0, 0)]},
 {'pre': [datetime.datetime(2023, 11, 1, 0, 0),
   datetime.datetime(2023, 11, 13, 0, 0),
   datetime.datetime(2023, 11, 25, 0, 0),
   datetime.datetime(2023, 12, 7, 0, 0),
   datetime.datetime(2023, 12, 19, 0, 0),
   datetime.datetime(2023, 12, 31, 0, 0),
   datetime.datetime(2024, 1, 12, 0, 0),
   datetime.datetime(2024, 1, 24, 0, 0)],
  'post': [datetime.datetime(2024, 2, 17, 0, 0)]}]

In [32]:
post_event_groups_ind = [{key: [ind for (ind, dt) in enumerate(rtc_acq_dts) if dt in d[key]] 
                          for key in d.keys()} 
                         for d in post_event_group_dts]
pre_event_groups_ind = [{key: [ind for (ind, dt) in enumerate(rtc_acq_dts) if dt in d[key]] 
                          for key in d.keys()} 
                         for d in pre_event_group_dts]
pre_event_groups_ind

[{'pre': [0, 1, 2, 3, 4, 5, 6], 'post': [7]},
 {'pre': [0, 1, 2, 3, 4, 5], 'post': [6]}]

In [33]:
vv_paths = sorted(list(ts_dir.glob('*VV.tif')))
vh_paths = sorted(list(ts_dir.glob('*VH.tif')))

# Metric setup

In [38]:
def distmetric(pre_vv=None,
               pre_vh=None,
               post_vv=None,
               post_vh=None):
    """Pass in all data; use only what is necesary"""
    if DISTMETRIC_NAME == 'transformer':
        if any([len(pre) < 2 for pre in pre_vv + pre_vh]):
            raise ValueError('pre images must be at least 2')
        model = load_trained_transformer_model()
        dist_ob = compute_transformer_zscore(model,
                                          pre_vv,
                                          pre_vh,
                                          post_vv,
                                          post_vh,
                                          stride=2)
        distance = dist_ob.dist
    elif DISTMETRIC_NAME == 'mahalanobis_2d':
        dist_ob = compute_mahalonobis_dist_2d(pre_vv,
                                          pre_vh,
                                          post_vv,
                                          post_vh,
                                          eig_lb=.001 * np.sqrt(2),
                                          window_size=3,
                                          logit_transformed=True)
        distance = dist_ob.dist
    elif DISTMETRIC_NAME == 'mahalanobis_vh':
        dist_ob = compute_mahalonobis_dist_1d(
                                          pre_vh,
                                          post_vh,
                                          sigma_lb=.005,
                                          window_size=3,
                                          logit_transformed=True)
        distance = dist_ob.dist
    elif DISTMETRIC_NAME == 'log_ratio_vh':
        dist_ob = compute_log_ratio_decrease_metric(pre_vh,
                                          post_vh,
                                          window_size=1,
                                          qual_stat_for_pre_imgs='median')
        distance = dist_ob.dist

    elif DISTMETRIC_NAME == 'mahalanobis_1d_max':
        dist_ob_vh = compute_mahalonobis_dist_1d(
                                          pre_vh,
                                          post_vh,
                                          sigma_lb=.005,
                                          window_size=3,
                                          logit_transformed=True)
        distance_vh = dist_ob_vh.dist

        dist_ob_vv = compute_mahalonobis_dist_1d(
                                          pre_vv,
                                          post_vv,
                                          sigma_lb=.01,
                                          window_size=3,
                                          logit_transformed=True)
        distance_vv = dist_ob_vv.dist
        distance = np.maximum(distance_vv, distance_vh)
    
    else:
        raise NotImplementedError
    return distance

# Serialize

In [39]:
site_metric_dir = Path('out_metrics') / EVENT_NAME / f'track{TRACKS[TRACK_IDX]}' / DISTMETRIC_NAME
site_metric_dir.mkdir(exist_ok=True, parents=True)

In [42]:
with rasterio.open(WATER_MASK_DIR / 'water_mask.tif') as ds:
    X_water = ds.read(1).astype(bool)
def open_one_arr(path):
    with rasterio.open(path) as ds:
        X = ds.read(1)
        X[X_water] = np.nan
    # For debugging
    # X = X[2000:2100, 2500:2600]
    return X

def datetime_serializer(obj):
    if isinstance(obj, datetime):
        return obj.isoformat()  # Convert to ISO 8601 format
    raise TypeError("Type not serializable")

def distmetric_serialize():
    indices = post_event_groups_ind + pre_event_groups_ind
    dts = post_event_group_dts + pre_event_group_dts
    event_tokens = ['post'] * len(post_event_groups_ind) + ['pre'] * len(pre_event_groups_ind)
    paths = [{'pre_vv': sorted([vv_paths[i] for i in d['pre']]),
              'pre_vh': sorted([vh_paths[i] for i in d['pre']]),
              'post_vv': [vv_paths[i] for i in d['post']],
              'post_vh': [vh_paths[i] for i in d['post']]} 
             for d in indices]
    track_token = f'track{TRACKS[TRACK_IDX]}'

    def distmetric_wrapper(zipped_inputs):
        path_dict, date_dict, event_token = zipped_inputs
        array_inputs = {key: list(map(open_one_arr, paths)) for (key, paths) in path_dict.items()}
        distance = distmetric(array_inputs['pre_vv'],
                              array_inputs['pre_vh'],
                              array_inputs['post_vv'][0],
                              array_inputs['post_vh'][0])
        date_token = date_dict['post'][0].strftime('%Y-%m-%d')
        out_path = site_metric_dir / f'{event_token}_{date_token}_{DISTMETRIC_NAME}_{track_token}.tif'
        with rasterio.open(paths[0]['post_vv'][0]) as ds:
            p = ds.profile
        with rasterio.open(out_path, 'w', **p) as ds:
            ds.write(distance, 1)
        return out_path
    inputs = list(zip(paths, dts, event_tokens))
    out_paths = list(map(distmetric_wrapper, tqdm(inputs)))

    # Global Params
    params = {'lookback_delta_days': LOOKBACK_DELTA_DAYS,
                  'max_pre_event_groups': MAX_PRE_EVENT_GROUPS,
                  'max_pre_event_groups': MAX_POST_EVENT_GROUPS,
                  'max_pre_imgs': MAX_PRE_IMGS,
                  'min_pre_imgs': MIN_PRE_IMGS,
                  'distmetric_name': DISTMETRIC_NAME}
    json.dump(params, open(site_metric_dir / 'metric_params.json', 'w'), indent=4)

    # date inputs
    data = {f'{t}_{dt_data["post"][0].strftime("%Y-%m-%d")}' :dt_data 
            for t, dt_data in zip(event_tokens, dts)}
    with open(site_metric_dir / f'dates_{track_token}.json', 'w') as json_file:
        json.dump(data, json_file, default=datetime_serializer, indent=4)
    return out_paths

In [43]:
distmetric_serialize()

  0%|                                                    | 0/4 [00:00<?, ?it/s]
Rows Traversed:   0%|                                   | 0/43 [00:00<?, ?it/s][A
Rows Traversed:   2%|▋                          | 1/43 [00:00<00:05,  7.62it/s][A
Rows Traversed:   5%|█▎                         | 2/43 [00:00<00:05,  7.27it/s][A
Rows Traversed:   7%|█▉                         | 3/43 [00:00<00:05,  7.05it/s][A
Rows Traversed:   9%|██▌                        | 4/43 [00:00<00:05,  6.96it/s][A
Rows Traversed:  12%|███▏                       | 5/43 [00:00<00:05,  6.89it/s][A
Rows Traversed:  14%|███▊                       | 6/43 [00:00<00:05,  6.80it/s][A
Rows Traversed:  16%|████▍                      | 7/43 [00:01<00:05,  6.76it/s][A
Rows Traversed:  19%|█████                      | 8/43 [00:01<00:05,  6.54it/s][A
Rows Traversed:  21%|█████▋                     | 9/43 [00:01<00:05,  6.44it/s][A
Rows Traversed:  23%|██████                    | 10/43 [00:01<00:05,  6.22it/s][A
Rows Tr

[PosixPath('out_metrics/chile_fire_2024/track18/transformer/post_2024-02-05_transformer_track18.tif'),
 PosixPath('out_metrics/chile_fire_2024/track18/transformer/post_2024-02-17_transformer_track18.tif'),
 PosixPath('out_metrics/chile_fire_2024/track18/transformer/pre_2024-01-24_transformer_track18.tif'),
 PosixPath('out_metrics/chile_fire_2024/track18/transformer/pre_2024-01-12_transformer_track18.tif')]