Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Processing imagery with dask/xarray: example application for identifying outlier imagery #154

Closed
dbuscombe-usgs opened this issue Jun 13, 2023 · 8 comments
Assignees
Labels
Research Investigate if something is possible, experiment V2 for version 2 of coastseg

Comments

@dbuscombe-usgs
Copy link
Member

Problem: automatically identify bad satellite imagery

Potential solution?

  1. read ms geotiff file collection from a particular sensor e.g. L8 into xarray
  2. filter out images that are not modal size (array shape)
  3. make a mean RGB image
  4. use metrics to compare all images to the reference image
  5. make a movie of images and metrics to see if the approach will work
  6. use xarray and dask to keep computations efficient and in memory

Imports and Dask cluster:

import rioxarray
import xarray as xr 
from dask.distributed import Client
from glob import glob 
import os
from statistics import mode
from tqdm import tqdm

dtype = 'float32'
chunksize = ("auto", "auto")

# start client
n_workers = 8
threads_per_worker = 2
memory_limit='10GB'

client = Client(n_workers=n_workers, threads_per_worker=threads_per_worker, memory_limit=memory_limit)

Read in a folder of files from a particular sensor. We expect each image to be identical size. If it is not, it is discarded

inpath = '/media/marda/TWOTB1/USGS/Doodleverse/github/CoastSeg/data/ID_spl2_datetime06-13-23__08_06_34/L8/ms/'

files = sorted(glob(inpath+os.sep+'*.tif'))
len(files)

## get rid of any file that is not modal shape
## first, find shape of all files
shapes = []
for f in files:
    im = rioxarray.open_rasterio(f)
    shapes.append(im.shape)

## get rid of any file that is not modal shape
valid_files = [f for counter,f in enumerate(files) if shapes[counter]==mode(shapes)]

Make xarray

## get timestamps from files
times = [f.split(os.sep)[-1].split('_')[0] for f in valid_files]

# Create variable used for time axis
time_var = xr.Variable('time',times)

# Load in and concatenate all individual GeoTIFFs
geotiffs_da = xr.concat([rioxarray.open_rasterio(i, chunks=chunksize, dtype=dtype) for i in valid_files],
                        dim=time_var)
# Covert our xarray.DataArray into a xarray.Dataset
geotiffs_ds = geotiffs_da.to_dataset('band')

## rename the bands
geotiffs_ds = geotiffs_ds.rename({1: 'red'})
geotiffs_ds = geotiffs_ds.rename({2: 'green'})
geotiffs_ds = geotiffs_ds.rename({3: 'blue'})
## drop the 4th and 5th bands
geotiffs_ds = geotiffs_ds.drop_vars(4)
geotiffs_ds = geotiffs_ds.drop_vars(5)

Make a refernece time-averaged image, ignoring NaNs

## make reference images
ref = geotiffs_ds.mean("time", skipna=True)
## reproject if necessary
# ref = ref.rio.reproject("epsg:4326")
ref.rio.to_raster(raster_path=f"mean_image.tif")#, dtype=dtype)
# del ref

Cycle through each image and compute RMSE and PSNR metrics. Good images should have low RMSE and high PSNR

## compute per-channel max
MAX = ref.max()
## collate metrics per image
PSNR=[]; RMSE=[]; 
for time in tqdm(times):
    im = geotiffs_ds.sel(time=time)

    ## mse
    mse_value = np.mean((ref.astype(np.float64)-im.astype(np.float64))**2)
    # mean_mse_value = np.mean(np.mean(mse_value).to_array()).to_numpy()
    ## rmse
    rmse_value = np.sqrt(mse_value)
    mean_rmse_value = np.mean(np.mean(rmse_value).to_array()).to_numpy()
    RMSE.append(mean_psnr_value)

    ##psnr 
    psnr = 10 * np.log10(MAX**2 /mse_value)
    mean_psnr_value = np.mean(np.mean(psnr).to_array()).to_numpy()
    PSNR.append(mean_psnr_value)

Make an animated gif of all the imagery and values

### make an animation to show values on imagery
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.ticker as ticker

# make a function to call
def make_ani(ds, times, values1, values2):
    fig, ax = plt.subplots(1,1)
    ims = []
    for time,value1,value2 in zip(times, values1, values2):
        frame = ds.sel(time=time).to_array()
        im1 = ax.imshow(frame.T, animated=True)
        ax.xaxis.set_major_locator(ticker.NullLocator())
        ax.yaxis.set_major_locator(ticker.NullLocator())
        t1=ax.text(20,20,"PSNR="+str(value1), color='w', animated=True)
        t2=ax.text(20,30,"RMSE="+str(value2), color='w', animated=True)
        ims.append([im1,t1,t2])

    ani = animation.ArtistAnimation(fig, ims, interval=100, blit=True,
                                    repeat_delay=0)
    return ani


ani = make_ani(geotiffs_ds,times, np.array(PSNR), np.array(RMSE))

ani.save("filter_psnr.gif", writer='imagemagick', fps=1)
del ani

Example (truncated) inputs:
ms.zip

Example output:
filter_psnr

@dbuscombe-usgs dbuscombe-usgs added V2 for version 2 of coastseg Research Investigate if something is possible, experiment labels Jun 13, 2023
@dbuscombe-usgs dbuscombe-usgs self-assigned this Jun 13, 2023
@2320sharon
Copy link
Collaborator

Do you think it would be useful to use the ratio of rmse to psnr to filter out images? For example if the ratio of some images is much larger than the average that image gets tossed?

@dbuscombe-usgs
Copy link
Member Author

Possibly. We could continue to explore, for V2...

For example, we should probably work out the details with a set of known good and bad images. I think it probably works in conjunction with the black pixel filter maybe?

Here's a tiny set I just made

bad.zip
good.zip

@2320sharon
Copy link
Collaborator

I think that would be a good idea. Also I liked your idea of "It might also be possible to adapt this idea for outputs ... you can imagine loading all the "predseg" images into an array, making the average, and filtering out bad segmentations using a similar approach ... (for v2 maybe)"

@dbuscombe-usgs
Copy link
Member Author

I spent some time working on this idea of taking all the model outputs and comparing them to the average output, then filtering based on the similarity (or dissimilarity) between the time-varying label and the time-average label.

Different sensors have different sized images, so they have to be processed separately. So my first move to make separate lists of npz files from the 'out' directory

### get lists of files per sensor
inpath = 'RGB/8190958/'
L5_files = sorted(glob(inpath+os.sep+'*L5*.npz'))
L7_files = sorted(glob(inpath+os.sep+'*L7*.npz'))
L8_files = sorted(glob(inpath+os.sep+'*L8*.npz'))
S2_files = sorted(glob(inpath+os.sep+'*S2*.npz'))

Then I filter out any images than deviate from the modal shape, like in the earlier example above

def return_valid_files(files):
    shapes = get_image_shapes(files)
    ## get rid of any file that is not modal shape
    valid_files = [f for counter,f in enumerate(files) if shapes[counter]==mode(shapes)]
    return valid_files

valid_L5_files = return_valid_files(L5_files)
(etc)

Next we get time vectors and make xarrays out of the labels, as well as a time-average of the label timestack

def get_time_vectors(files):
    times = [f.split(os.sep)[-1].split('_')[0] for f in files]
    time_var = xr.Variable('time',times)
    return times, time_var 

S2_times, S2_time_var  = get_time_vectors(valid_S2_files)
(etc)

def return_timeav(valid_files,time_var):
    da = xr.concat([load_xarray_data(i) for i in valid_files],dim=time_var)
    timeav = da.mean("time", skipna=True)
    return timeav, da

S2_timeav, S2_da = return_timeav(valid_S2_files,S2_time_var)
(etc)

The process of determining 'good' and 'bad' relies on distance to the average label. The metric I chose is RMSE (others could be explored). One I have that measured for all images, I use K-means clustering to give me the 2 classes, i.e. good and bad. Bad has a large RMSE. I know which label is which because 'bad' files with have an average RMSE that is larger than 'good' files

def measure_rmse(da, times, timeav):
    rmse=[]
    for t in times:
        dat = da.sel(time=t)
        rmse.append(float(np.sqrt(np.mean((dat - timeav)**2)).to_numpy()))
    input_rmse = np.array(rmse).reshape(-1, 1)
    return rmse, input_rmse

def get_kmeans_clusters(input_rmse, rmse):
    kmeans = KMeans(n_clusters=2, random_state=0, n_init="auto").fit(input_rmse)
    labels = kmeans.labels_
    scores = [np.mean(np.array(rmse)[labels==0]), np.mean(np.array(rmse)[labels==1])]
    return labels, scores

def get_good_bad_files(files,labels,scores, str_label):
    files_bad = np.array(files)[labels==np.argmax(scores)]
    files_good = np.array(files)[labels==np.argmin(scores)]
    print(f"{len(files_good)} good {str_label} labels")
    print(f"{len(files_bad)} bad {str_label} labels")
    return files_bad, files_good

rmse, input_rmse = measure_rmse(S2_da, S2_times, S2_timeav)
labels, scores = get_kmeans_clusters(input_rmse, rmse)
S2_files_bad, S2_files_good = get_good_bad_files(valid_S2_files,labels,scores,str_label='S2')

(etc)

Finally, I copy the good and bad files to new directories. I move both the 'predseg' and npz file

def copy_files(files, outdir):
    for f in files:
        shutil.copyfile(f,outdir+os.sep+f.split(os.sep)[-1])
        shutil.copyfile(f.replace('res.npz','predseg.png'),outdir+os.sep+f.split(os.sep)[-1].replace('res.npz','predseg.png'))

os.mkdir('L5_bad')
copy_files(L5_files_bad, 'L5_bad')
(etc)

os.mkdir('S2_good')
copy_files(S2_files_good, 'S2_good')
(etc)

I think it works pretty well! Here are the "good" L5, L7, L8, and S2 images at a site

L5_good
L7_good
L8_good
S2_good

compare with the "bad" images .... I hope you agree, this seems like an effective approach
L5_bad
L7_bad
L8_bad
S2_bad

@dbuscombe-usgs
Copy link
Member Author

here is my script
filter_good_labels.zip

I can now work on a combined label filter and shoreline detection (from #168) workflow

@dbuscombe-usgs
Copy link
Member Author

new_shoreline_detect_workflow.zip

This contains two scripts, one that filters out bad images automatically (seems to work quite well), and the other is my latest attempt at a shoreline detection algorithm. This approach is working quite well, so we can discuss what steps may be required for its implementation. In the end, the part based on the distance transform was causing more problems than it was solving, and I ended up ditching the boundary tracing algorithm too because of bad results when the coastline is not linear

I'll be out for most of the next week, but wanted to make sure I updated here first. Some outputs pasted below

ex5_L8_shorelines
ex4_L7_shorelines
ex3_L8_shorelines
ex3_L7_shorelines
ex2_S2_shorelines
ex2_L7_shorelines
ex2_L8_shorelines
ex1_L8_shorelines

@2320sharon
Copy link
Collaborator

Thanks for providing your code and testing these algorithms on a variety of sites.
I'll analyze the code then work on a coastseg prototype that utilizes the new good/bad filtering and the new shoreline detection method. I'll create new branch called issue/154 where I'll work on the coastseg prototype.

@2320sharon
Copy link
Collaborator

This issue is under development. In coastseg I've already implemented the functionality to sort the model outputs into "good" and "bad" directories using the filter_model_outputs function. Once the model outputs have been sorted the get_filtered_files_dict is called to get the names of all the files in the "good" directory. Finally the edit_metadata function is called to remove the files that were not good from the metadata dictionary. This metadata dictionary is then used as the input to the extract shorelines function which will read the filenames and other necessary metadata from the metadata dictionary to create the extracted shorelines.

This logic has been incorporated into the main branch.

    for satname in satellites:
        # get all the model_outputs that have the satellite in the filename
        files = glob(f"{session_path}{os.sep}*{satname}*.npz")
        if len(files) != 0:
            filter_model_outputs(satname, files, good_folder, bad_folder)

    # for each satellite get the list of files that were sorted as 'good'
    filtered_files = get_filtered_files_dict(good_folder, "npz", sitename)
    # keep only the metadata for the files that were sorted as 'good'
    metadata = edit_metadata(metadata, filtered_files)

2320sharon added a commit that referenced this issue Oct 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Research Investigate if something is possible, experiment V2 for version 2 of coastseg
Projects
None yet
Development

No branches or pull requests

2 participants