We now have trained models capable of using SAR, and SAR+Optical data to make segmentation maps of open surface water. Let's see how the models perform on a previously unknown scene

In [None]:
from pathlib import Path
from collections import defaultdict
import rasterio
import pandas as pd
import numpy as np
from rasterio.warp import transform_bounds

from rasterio.merge import merge

from tools import retrieve_hansen_mosaic, return_windowed_merge, denoise, return_slice_list, get_cropped_profile, return_nodata_mask, retrieve_hand_data

import torch
from model import sarDataLoader, sarInferenceModel

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Force kernel to use this GPU

In [None]:
scene_path = Path('../data/inference_scenes/AP_13874_FBD_F0580_RT1')

For a newly obtained scene, we need to still download corresponding Global Forest Watch and HAND data. We will then need to split these data into 512x512 pixel chips

In [None]:
output_path = scene_path / 'chips'
output_path.mkdir(parents=True, exist_ok=True)

# setup folders
chip_types = ['hh', 'hv', 'red', 'nir', 'swir1', 'swir2', 'dem', 'hand']
chip_paths = []
for c in chip_types:
    (output_path/c).mkdir(exist_ok=True)
    chip_paths.append(output_path/c)

chip_path_dict = dict(zip(chip_types, chip_paths))

In [None]:
remake_chips = True

chips_list = defaultdict(list)

scene_name = scene_path.name

hh_file = list(scene_path.glob('*HH*'))[0]
hv_file = list(scene_path.glob('*HV*'))[0]
dem_file = list(scene_path.glob('*dem*'))[0]

with rasterio.open(hh_file) as ds:
    hh_img = ds.read(1)
    sar_bounds = ds.bounds 
    sar_profile = ds.profile 
    sar_crs = ds.crs

with rasterio.open(hv_file) as ds:
    hv_img = ds.read(1)

with rasterio.open(dem_file) as ds:
    dem_img = ds.read(1)
    dem_profile = ds.profile

# Let's retrieve the Hansen tiles overlapping the SAR scene
sar_bounds_4326 = transform_bounds(sar_crs.to_epsg(), 4326, *sar_bounds)
hansen_files = retrieve_hansen_mosaic(sar_bounds_4326, data_product = 'first', download_path=Path('../data/hansen_mosaics/'))

hansen_img, hansen_profile = return_windowed_merge(hansen_files, sar_bounds_4326, sar_profile)

mask = return_nodata_mask([hh_img, hv_img], nodata=0)
mask += return_nodata_mask([hansen_img[0]], hansen_profile['nodata'])

# Obtain HAND data
hand_files = retrieve_hand_data(sar_bounds_4326, download_path=Path('../data/hand_data/'))
hand_img, hand_profile = return_windowed_merge(hand_files, sar_bounds_4326, sar_profile)
hand_img = np.squeeze(hand_img)

# Mask out no data regions
mask = np.where(mask>0, 0, 1).astype('uint8')
hh_img *= mask
hv_img *= mask
hansen_img *= mask
dem_img *= mask
hand_img *= mask

chip_prefix = f"AP_{scene_name[3:8]}{scene_name[14:18]}"

image_dict = {
    'hh': (hh_img, sar_profile['nodata'], sar_profile['dtype']),
    'hv': (hv_img, sar_profile['nodata'], sar_profile['dtype']),
    'red': (hansen_img[0, ...], 0, 'int16'),
    'nir': (hansen_img[1, ...], 0, 'int16'),
    'swir1': (hansen_img[2, ...], 0, 'int16'),
    'swir2': (hansen_img[3, ...], 0, 'int16'),
    'dem': (dem_img, dem_profile['nodata'], dem_img.dtype),
    'hand': (hand_img, hand_profile['nodata'], hand_img.dtype),
}

# We specify a stride of 128 so that bad inferences near chip edges can be minimized
slice_list = return_slice_list(hh_img.shape, (512, 512), x_stride=128, y_stride=128)

count = 0

if remake_chips:
    for (y_slice, x_slice) in slice_list:
        
        current_filename = f"{chip_prefix}_{str(count).zfill(5)}.tif"
        chip_profile = get_cropped_profile(sar_profile, x_slice, y_slice)

        for _chip_type, _chip_output_path in chip_path_dict.items():
            chip_profile['nodata'] = image_dict[_chip_type][1]
            chip_profile['dtype'] = image_dict[_chip_type][2]
            temp_chip = image_dict[_chip_type][0][y_slice, x_slice]
            with rasterio.open(_chip_output_path / current_filename, 'w', **chip_profile) as ds:
                ds.write(temp_chip.reshape(1, *temp_chip.shape))
            
            chips_list[_chip_type].append(_chip_output_path / current_filename)        
        
        count += 1
    df = pd.DataFrame(chips_list)
    df.to_csv(output_path/'chips.csv')
    print(f"Number of chips: {len(df)}")

else:
    df = pd.read_csv(output_path/'chips.csv')

## Generate inferences

### SAR only model

In [None]:
model_params = {
    "sarData":True,
    "denoisingWeight":0.35,
    "opticalData":False,    
    "output_classes":2,
    "backbone" : "resnet50",
    "gpu": True,
    "ngpus":1,
    "experiment_name" : "sar_only_model"
}

dataloader_params = {
    "return_sar":True,
    "return_optical":False,
    "denoising_weight":.35,
    "return_optical":False,
    "return_dem":True,
    "return_hand":True
}

model = sarInferenceModel(model_params)
model.load_state_dict(torch.load(f"model-weights/sar_only_model.pt"), strict=False)
model.model.cuda() # Load model into the GPU so that it can do batch processing of chips efficiently
model.model.eval() # Put model in evaluation mode

In [None]:
dataset = sarDataLoader(x_paths=df, y_paths=None, **dataloader_params)

inference_path = scene_path / 'inferences'
inference_path.mkdir(exist_ok=True)

exp_name = 'sar_only_inferences'

output_path = inference_path / exp_name
output_path.mkdir(exist_ok=True)

current_inferences = []
count = 0

def return_batch_indices(dataset_len, batch_size):
    batch_indices = []
    for i in range(0, dataset_len, batch_size):
        indices = np.array(list(range(i, i+batch_size)))
        indices = indices[indices<dataset_len]
        batch_indices.append(indices)
    
    return batch_indices

batch_idxs = return_batch_indices(len(df), batch_size=12)

for batch_idx in batch_idxs:
    batch = [dataset.__getitem__(i, inference=True) for i in batch_idx]
    
    img_batch = np.stack([_b[0] for _b in batch], axis = 0)
    profile_batch = [_b[2] for _b in batch]
    
    img_batch = torch.Tensor(img_batch).cuda(non_blocking=True)
    inferences = model.forward(img_batch).detach().cpu().numpy()

    for n in np.arange(inferences.shape[0]):
        inference = np.argmax(np.squeeze(inferences[n, ...]), axis=0)
        inference_filename = output_path/f'inference_{str(count).zfill(5)}.tiff'

        chip_profile = profile_batch[n]
        chip_profile['nodata'] = -1
        chip_profile['nodata'] = 'int8'

        with rasterio.open(inference_filename, 'w', **chip_profile) as ds:
            ds.write(inference.reshape(1, *inference.shape).astype('int8'))                                                             
        count += 1
        current_inferences.append(inference_filename)

merged_inference, out_trans = merge(current_inferences, method='last')

with rasterio.open(list(scene_path.glob("*HH*"))[0]) as ds:
    sar_profile = ds.profile
    nodata_mask = ds.read() == ds.profile['nodata']

merged_inference[nodata_mask] = -1
inference_profile = sar_profile
inference_profile['nodata'] = -1
inference_profile['dtype'] = 'int8'
with rasterio.open(output_path/f'{exp_name}_merged_inferences.tif', 'w', **inference_profile) as ds:
    ds.write(merged_inference.astype('int8'))

### SAR and optical inferences

In [None]:
model_params = {
    "sarData":True,
    "denoisingWeight":0.35,
    "opticalData":True,    
    "output_classes":2,
    "backbone" : "resnet50",
    "gpu": True,
    "ngpus":1,
    "experiment_name" : "sar_and_optical_model"
}

dataloader_params = {
    "return_sar":True,
    "return_optical":True,
    "denoising_weight":.35,
    "return_dem":True,
    "return_hand":True
}

model = sarInferenceModel(model_params)
model.load_state_dict(torch.load(f"model-weights/sar_and_optical_model.pt"), strict=False)
model.model.cuda() # Load model into the GPU so that it can do batch processing of chips efficiently
model.model.eval() # Put model in evaluation mode

In [None]:
dataset = sarDataLoader(x_paths=df, y_paths=None, **dataloader_params)

inference_path = scene_path / 'inferences'
inference_path.mkdir(exist_ok=True)

exp_name = 'sar_and_optical_inferences'

output_path = inference_path / exp_name
output_path.mkdir(exist_ok=True)

current_inferences = []
count = 0

def return_batch_indices(dataset_len, batch_size):
    batch_indices = []
    for i in range(0, dataset_len, batch_size):
        indices = np.array(list(range(i, i+batch_size)))
        indices = indices[indices<dataset_len]
        batch_indices.append(indices)
    
    return batch_indices

batch_idxs = return_batch_indices(len(df), batch_size=12)

for batch_idx in batch_idxs:
    batch = [dataset.__getitem__(i, inference=True) for i in batch_idx]
    
    img_batch = np.stack([_b[0] for _b in batch], axis = 0)
    profile_batch = [_b[2] for _b in batch]
    
    img_batch = torch.Tensor(img_batch).cuda(non_blocking=True)
    inferences = model.forward(img_batch).detach().cpu().numpy()

    for n in np.arange(inferences.shape[0]):
        inference = np.argmax(np.squeeze(inferences[n, ...]), axis=0)
        inference_filename = output_path/f'inference_{str(count).zfill(5)}.tiff'

        chip_profile = profile_batch[n]
        chip_profile['nodata'] = -1
        chip_profile['nodata'] = 'int8'

        with rasterio.open(inference_filename, 'w', **chip_profile) as ds:
            ds.write(inference.reshape(1, *inference.shape).astype('int8'))
        count += 1
        current_inferences.append(inference_filename)

merged_inference, out_trans = merge(current_inferences, method='last')

with rasterio.open(list(scene_path.glob("*HH*"))[0]) as ds:
    sar_profile = ds.profile
    nodata_mask = ds.read() == ds.profile['nodata']

merged_inference[nodata_mask] = -1
inference_profile = sar_profile
inference_profile['nodata'] = -1
inference_profile['dtype'] = 'int8'
with rasterio.open(output_path/f'{exp_name}_merged_inferences.tif', 'w', **inference_profile) as ds:
    ds.write(merged_inference.astype('int8'))