In [1]:
import os
import pickle
import shutil

from skimage.transform import resize
import rasterio
import numpy as np

In [2]:
working_dir = r'\\export.hpc.ut.ee\gis\Ditches\working\deep_learning\data\prediction'
os.chdir(working_dir)
os.getcwd()

'\\\\export.hpc.ut.ee\\gis\\Ditches\\working\\deep_learning\\data\\prediction'

In [3]:
%%time

experiment = 'prediction_finetuned_aug_estonia_relu_ks3x3_lr0.0001_250ep_bs4_from_train_aug_sweden_test_unaug_sweden'
experiment_dir = fr'\\export.hpc.ut.ee\gis\Ditches\working\pytorch_unet\experiments\{experiment}'
with open(f'{experiment_dir}/result/prediction_TestSet_20240617_185635.pkl', 'rb') as file:
    predictions = pickle.load(file)

CPU times: total: 234 ms
Wall time: 20.2 s


# Write predicted probabilities to GeoTIFF

In [4]:
# Generate output directory
out_dir = f'{working_dir}/predicted_probs'
if os.path.exists(out_dir):
    shutil.rmtree(out_dir)
os.makedirs(out_dir)

In [5]:
%%time

out_img_size = 500
for prediction in predictions:
    for file, predicted_probs in zip(prediction['file'], prediction['data']):
        map_sheet = file.split('_')[0]
        
        # Resize image
        predicted_probs = resize(predicted_probs, (1, out_img_size, out_img_size), mode="constant", preserve_range=True)
        
        # Read HPMF
        fp_hpmf = f'{working_dir}/hpmf/{file}'
        with rasterio.open(fp_hpmf) as src:
            
            # Output profile
            out_profile = src.profile.copy()
            
            # Write output to GeoTIFF
            out_dir = f'{working_dir}/predicted_probs/{map_sheet}'
            if not os.path.exists(out_dir):
                os.makedirs(out_dir)
            fp_predicted_probs = f'{out_dir}/{file}'
            with rasterio.open(fp_predicted_probs, 'w', **out_profile) as dst:
                dst.write(predicted_probs)

CPU times: total: 39.7 s
Wall time: 2min 37s


# Write predicted labels to GeoTIFF

In [6]:
# Extract labels based on probability threshold
def extract_labels(fp_predicted_probs, fp_predicted_labels, prob_threshold=0.5):
    with rasterio.open(fp_predicted_probs) as src:
        
        # Read predicted probabilities
        predicted_probs = src.read(1)
        
        # Extract predicted labels
        predicted_labels = np.where(predicted_probs > prob_threshold, 1, 0)
        
        # Output profile
        out_profile = src.profile.copy()
        out_profile['dtype'] = 'int32'
        out_profile['nodata'] = -9999
        
        # Write output to GeoTIFF
        with rasterio.open(fp_predicted_labels, 'w', **out_profile) as dst:
            dst.write(predicted_labels, 1)
            
    return

In [7]:
prob_thresholds = [0.5, 0.1]
for prob_threshold in prob_thresholds:
    out_dir = f'{working_dir}/predicted_labels_{prob_threshold}'
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)

In [8]:
%%time

prob_thresholds = [0.5, 0.1]
for prob_threshold in prob_thresholds:
    for map_sheet in os.listdir(f'{working_dir}/predicted_probs'):
        for file in os.listdir(f'{working_dir}/predicted_probs/{map_sheet}'):
            
            # Extract predicted labels
            fp_predicted_probs = f'{working_dir}/predicted_probs/{map_sheet}/{file}'
            out_dir = f'{working_dir}/predicted_labels_{prob_threshold}/{map_sheet}'
            if not os.path.exists(out_dir):
                os.makedirs(out_dir)
            fp_predicted_labels = f'{out_dir}/{file}'
            extract_labels(fp_predicted_probs, fp_predicted_labels, prob_threshold)

CPU times: total: 46.6 s
Wall time: 4min 8s
