In [2]:
import os
import pandas as pd
import numpy as np
import json
from osgeo import gdal, osr
import warnings
from datetime import datetime
from multiprocessing import Pool
from joblib import Parallel, delayed
from functools import partial
from datetime import datetime
import matplotlib.pyplot as plt
import psutil
from PIL import Image
warnings.filterwarnings("ignore")

def read_tif(tif_file):
    dataset = gdal.Open(tif_file)
    cols = dataset.RasterXSize
    rows = dataset.RasterYSize
    im_proj = (dataset.GetProjection())
    im_Geotrans = (dataset.GetGeoTransform())
    im_data = np.moveaxis(dataset.ReadAsArray(0, 0, cols, rows), 0, -1)
    return im_data, im_Geotrans, im_proj,rows, cols

def array_to_geotiff(array, output_path, geo_transform, projection, band_names=None):
    rows, cols, num_bands = array.shape
    driver = gdal.GetDriverByName('GTiff')
    dataset = driver.Create(output_path, cols, rows, num_bands, gdal.GDT_Float32)
    
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    
    for band_num in range(num_bands):
        band = dataset.GetRasterBand(band_num + 1)
        band.WriteArray(array[:, :, band_num])
        band.FlushCache()
        
        if band_names:
            band.SetDescription(band_names[band_num])
    
    dataset = None
    band = None
    return
def set_nan_around_nan(arr, radius=3):
    nan_indices = np.isnan(arr)
    nan_x, nan_y = np.where(nan_indices)
    for x, y in zip(nan_x, nan_y):
        x_start = max(0, x - radius)
        x_end = min(arr.shape[0], x + radius + 1)
        y_start = max(0, y - radius)
        y_end = min(arr.shape[1], y + radius + 1)
        arr[x_start:x_end, y_start:y_end][np.logical_not(nan_indices[x_start:x_end, y_start:y_end])] = np.nan
    return arr

In [3]:
prisma_path = "/Users/fji/Desktop/data/9_PRISMA_imagery_smoothed_tif/"
mask_path = "/Users/fji/Desktop/data/11_PRISMA_could_mask/"

folders = ['D01_BART','D01_HARV','D02_SCBI','D03_OSBS','D07_MLBS','D07_ORNL','D08_TALL',
           'D10_CPER','D13_MOAB','D14_JORN','D14_SRER','D16_WREF','D19_BONA','D19_HEAL']
for folder in folders:
    os.makedirs(f"{mask_path}{folder}", exist_ok=True)

for i, folder in enumerate(folders):
    print(folder)
    prisma_tif_folder = f"{prisma_path}{folder}/"
    mask_tif_path = f"{mask_path}{folder}/"
    
    file_name = os.listdir(prisma_tif_folder)
    file_name = [x for x in file_name if "_FULL.tif" in x and "._" not in x and ".aux.xml" not in x]
    
    for file in file_name:
        basename = file.split(".")[0]
        prisma_tif = f"{prisma_tif_folder}{file}"
        im_data, im_Geotrans, im_proj,rows, cols = read_tif(prisma_tif)
        
        condition1 = im_data[:,:,0:20].mean(axis = 2)<0.15
        condition2 = im_data[:,:,55:75].mean(axis = 2)>0.2
        mask1 = np.where(condition1, 1, np.nan)
        mask2 = np.where(condition2, 1, np.nan)
        mask = mask1*mask2
        mask = set_nan_around_nan(mask)
        mask = np.expand_dims(mask, axis=-1)
        
        output_path = f"{mask_tif_path}{basename}_mask.tif"
        array_to_geotiff(mask, output_path, im_Geotrans, im_proj, band_names=None)

D01_BART
D01_HARV
D02_SCBI
D03_OSBS
D07_MLBS
D07_ORNL
D08_TALL
D10_CPER
D13_MOAB
D14_JORN
D14_SRER
D16_WREF
D19_BONA
D19_HEAL
