In [5]:
#%pip install sympy
import xarray as xr
#import dask.array as da
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import time
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import os

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
print(f"Using device: {DEVICE}")

sam = sam_model_registry[MODEL_TYPE](checkpoint='../model/sam_vit_h_4b8939.pth')
sam.to(device=DEVICE)

print('Sucessfully imported SAM')



Using device: cpu
Sucessfully imported SAM


In [2]:
## check which files are already processed and which still need processing: 

input_path = 'input_sam/'
output_path = 'output_sam/'

processed_files = os.listdir(output_path)
processed_files.sort()
processed_files

to_do = os.listdir(input_path)
to_do.sort()

to_do = [os.path.join('~/input_sam', f) for f in to_do]

NameError: name 'os' is not defined

In [3]:

def timing_wrapper(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(f"Executed {func.__name__} in {(end - start):.2f} seconds")
        return result
    return wrapper

@timing_wrapper
def predict_sam(image, generator):
    return generator.generate(image)


@timing_wrapper
def sam_result_to_ds(ds, sam_result):
    
    ds_ny, ds_nx = ds.dims['y'], ds.dims['x']
    
    ny, nx = sam_result[0]['segmentation'].shape
    
    nseg = len(sam_result)
    
    da_segmentation = da.zeros((ny, nx)).astype('int')
    da_area = da.zeros(nseg)
    da_bbox = da.zeros((nseg, 4))
    da_iou = da.zeros(nseg)
    da_stabscore = da.zeros(nseg)
    da_coords =  da.zeros(nseg).astype('int')
    da_pcoords = da.zeros((nseg, 2))   
    
    output_ds = xr.Dataset(
        data_vars = {
            'segmentation' : (('y', 'x'), da_segmentation),
            'area' : (('seg'), da_area),
            'bbox' : (('seg', 'bbox_coords'), da_bbox),
            'predicted_iou' : (('seg'), da_iou),
            'stability_score' : (('seg'), da_stabscore),
            'point_coords' : (('seg', 'p_coords'), da_pcoords)
        },
        coords = {
            'seg' : np.arange(nseg),
            'x' : ds.x.values,
            'y' : ds.y.values,
            'bbox_coords' : [0, 1, 2, 3],
            'p_coords' : [0, 1]
        },
    )

    for i, seg in enumerate(sam_result):
        for key, value in seg.items():
            
            if key in ['segmentation', 'area', 'predicted_iou', 'stability_score', 'point_coords', 'bbox']:
                if key == 'segmentation':
                    output_ds[key] += value * (i+1)
                else:
                    output_ds[key][i] = np.array(value).squeeze()

    return output_ds 


@timing_wrapper
def main_pushbroom(infile, outfile, variable='bt_1', xindex=None, sam=sam, mask_gen_params=None):
    
    if xindex.any():
        ds = xr.open_dataset(infile).sel(x=xindex)
    else: 
        ds = xr.open_dataset(infile)
        
    #outfile = infile.replace('input', 'output')
    
    if mask_gen_params is None:
        mask_gen_params = {
            'model': sam,
            'points_per_side': 4,
            'pred_iou_thresh': 0.86,
            'stability_score_thresh': 0.92,
            'crop_n_layers': 1,
            'crop_n_points_downscale_factor': 2,
            'min_mask_region_area': 100  # Requires open-cv to run post-processing
        }
        
    mask_generator = SamAutomaticMaskGenerator(**mask_gen_params)
        
    if variable == 'bt_1':
    
    ### scale the data to [0, 1] interval, scale by 256 for image representation and convert to uint 
    
        bt_1_scaled = (ds.bt_1 - ds.bt_1.min()) /  (ds.bt_1.max() - ds.bt_1.min())


        RGB = xr.concat([bt_1_scaled, bt_1_scaled, bt_1_scaled], dim='variable') * 255
        RGB = RGB.astype('uint8').transpose('y', 'x', 'variable').values

    elif variable == 'proba':

        RGB = (ds[['cl_0', 'cl_1', 'cl_2']].to_array().transpose('y', 'x', 'variable') * 255).astype('uint8').values

    sam_result = predict_sam(RGB, mask_generator)
    ds_result = sam_result_to_ds(ds, sam_result)
    
    ds_result['pred_proba'] = (('surface_class', 'y', 'x'),  (ds.drop_vars(['bt_1', 'lat', 'lon', 'skin_t']).to_array().values * 100).astype('uint8'))
    #outfile = outfile.replace('concat', f'sam_predict_{xindex[0]}_{xindex[-1]}')
    
    attrs = {}
    
    for key, value in mask_gen_params.items():
        if key != 'model':
            attrs[f'sam_param:{key}'] = value
            
    ds_result.attrs = attrs
    
    print(outfile)
    ds_result.to_netcdf(outfile, mode='w')
    
    return ds_result


In [4]:
def list_expected_output_files(infile):
    
    ds = xr.open_dataset(infile)


    dirname = infile.split('/')[-1].split('_concat.nc')[0]
    dirname = os.path.join('/home/sc.uni-leipzig.de/jn906hluu/output_sam/', dirname)
    if not os.path.isdir(dirname):
        os.mkdir(dirname)

    else:
        print(f'{dirname} already exists \nNumber of files: {len(os.listdir(dirname))} ')

    print(f'dataset x-size: {ds.x.size}')

    splits =np.ceil(ds.x.size / 5000)
    x_index = np.array_split(ds['x'].values, splits)

    date = infile.split('/')[-1].split('T')[0]

    expected_outputs = [f'{date}_sam_predict_{x[0]}_{x[-1]}.nc' for x in x_index]
    expected_outputs = [os.path.join(dirname, f) for f in expected_outputs]
    expected_outputs

    ds.close()
    
    return expected_outputs, x_index



In [6]:
for infile in to_do:
    
    expected_outputs, x_index = list_expected_output_files(infile)
    
    for output, index in zip(expected_outputs, x_index):
        
        if not os.path.isfile(output):
            print(f'Processing {output}')
            
            ds_result = main_pushbroom(infile, output, xindex=index,
                 mask_gen_params = {
                    'model': sam,
                    'points_per_side': 256,
                    'pred_iou_thresh': 0.86,
                    'stability_score_thresh': 0.92,
                    'crop_n_layers': 3,
                    'crop_n_points_downscale_factor': 2,
                    'min_mask_region_area': 100  # Requires open-cv to run post-processing
                    }     
                )

            
        else: 
            print('Skipping')

/home/sc.uni-leipzig.de/jn906hluu/output_sam/2022-03-20T10:35:00_2022-03-20T10:50:00 already exists 
Number of files: 11 
dataset x-size: 20302
Skipping
Skipping
Skipping
Skipping
Skipping
/home/sc.uni-leipzig.de/jn906hluu/output_sam/2022-03-20T11:20:00_2022-03-20T11:26:00 already exists 
Number of files: 4 
dataset x-size: 7561
Skipping
Skipping
/home/sc.uni-leipzig.de/jn906hluu/output_sam/2022-03-21T11:39:00_2022-03-21T11:44:00 already exists 
Number of files: 3 
dataset x-size: 6199
Skipping
Skipping
/home/sc.uni-leipzig.de/jn906hluu/output_sam/2022-03-28T10:28:00_2022-03-28T11:05:00 already exists 
Number of files: 12 
dataset x-size: 57413
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
/home/sc.uni-leipzig.de/jn906hluu/output_sam/2022-03-28T13:13:30_2022-03-28T13:21:30 already exists 
Number of files: 3 
dataset x-size: 12929
Skipping
Skipping
Skipping
/home/sc.uni-leipzig.de/jn906hluu/output_sam/2022-03-28T14:12:30_2022