Skip to content

Commit

Permalink
Merge pull request #8 from avanetten/master
Browse files Browse the repository at this point in the history
Updates to readme and inference scripts
  • Loading branch information
avanetten committed Oct 17, 2019
2 parents ac24a85 + bfafa5e commit a3751b5
Show file tree
Hide file tree
Showing 10 changed files with 322 additions and 705 deletions.
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -12,9 +12,9 @@ For further details see:
1. [Large Road Networks](https://medium.com/the-downlinq/extracting-road-networks-at-scale-with-spacenet-b63d995be52d)
2. [Road Speeds](https://medium.com/the-downlinq/inferring-route-travel-times-with-spacenet-7f55e1afdd6d)
3. [OSM+Google Imagery](https://medium.com/the-downlinq/computer-vision-with-openstreetmap-and-spacenet-a-comparison-cc70353d0ace)
4. [Data Prep](https://medium.com/the-downlinq/the-spacenet-5-baseline-part-1-imagery-and-label-preparation-598af46d485e)
5. [Training A Road Speed Segmentation Model](https://medium.com/the-downlinq/the-spacenet-5-baseline-part-2-training-a-road-speed-segmentation-model-2bc93de564d7)

4. [SpaceNet 5 Baseline Part 1 - Data Prep](https://medium.com/the-downlinq/the-spacenet-5-baseline-part-1-imagery-and-label-preparation-598af46d485e)
5. [SpaceNet 5 Baseline Part 2 - Segmentation](https://medium.com/the-downlinq/the-spacenet-5-baseline-part-2-training-a-road-speed-segmentation-model-2bc93de564d7)
6. [SpaceNet 5 Baseline Part 3 - Road Graph + Speed](https://medium.com/the-downlinq/the-spacenet-5-baseline-part-3-extracting-road-speed-vectors-from-satellite-imagery-5d07cd5e1d21)

____
### Install ###
Expand Down
81 changes: 1 addition & 80 deletions cresi/03a_merge_preds.py
Expand Up @@ -13,8 +13,6 @@
# skimage gives really annoying warnings
import warnings
warnings.filterwarnings("ignore")
#with warnings.catch_warnings():
# warnings.simplefilter("ignore")


############
Expand All @@ -24,36 +22,6 @@
tqdm.monitor_interval = 0
############


'''
Surprisingly, merge_tiffs works fine with multichannel, e.g.:
# explore taking mean of multiple images (relates to merge_preds.py)
import numpy as np
# set first channel of a to 5, rest to 0
a = np.zeros((7,20,20))
a[0,:,:] = 5
# set b as ones
b = np.ones((7,20,20))
# set c as 3, except set 7th channel as -10
c = 3 * np.ones((7,20,20))
c[6,:,:] = -10
probs = []
for prob_arr in [a,b,c]:
print ("prob_arr.shape:", prob_arr.shape)
probs.append(prob_arr)
prob_arr = np.mean(probs, axis=0)
print ("prob_arr.shape:", prob_arr.shape)
# first channel should be mean of (5, 1, 3) = 3
print ("prob_arr[0,:,:]:", prob_arr[0,:,:])
# third channel should be mean of (0, 1, 3)
print ("prob_arr[3,:,:]:", prob_arr[3,:,:])
# 7th channel should be mean of (0, 1, -10) = -3
print ("prob_arr[6,:,:]:", prob_arr[6,:,:])
#print ("prob_arr:", prob_arr)
'''

def merge_tiffs(root, out_dir, out_dir_gdal=None, num_classes=1,
verbose=False):

Expand Down Expand Up @@ -119,8 +87,6 @@ def merge_tiffs(root, out_dir, out_dir_gdal=None, num_classes=1,
CreateMultiBandGeoTiff(outpath_gdal, mask_gdal)




def merge_tiffs_defferent_folders(roots, res):
'''Need to update to handle multiple bands!'''
os.makedirs(os.path.join(res), exist_ok=True)
Expand All @@ -137,52 +103,9 @@ def merge_tiffs_defferent_folders(roots, res):
res_path_geo = os.path.join(res, prob_file)
cv2.imwrite(res_path_geo, prob_arr)

#def all_dice(pred_path, gt_path):
# all_d= []
# for im in os.listdir(pred_path):
# img_ds = gdal.Open(os.path.join(pred_path, im), gdal.GA_ReadOnly)
# img = img_ds.GetRasterBand(1).ReadAsArray()
# gt_ds = gdal.Open(os.path.join(gt_path, im.replace('RGB', "GTI")), gdal.GA_ReadOnly)
# gt = gt_ds.GetRasterBand(1).ReadAsArray()
# dsm_ds = gdal.Open(os.path.join(gt_path, im.replace('RGB', 'DSM')), gdal.GA_ReadOnly)
# band_dsm = dsm_ds.GetRasterBand(1)
# nodata = band_dsm.GetNoDataValue()
# dsm = band_dsm.ReadAsArray()
# img[dsm==nodata] = 0
# gt[dsm==nodata] = 0
#
# d = 1 - dice(img.flatten() > .4, gt.flatten() >= 1)
# print(im, d)
# all_d.append(d)
# print(np.mean(all_d))



def execute():
# # if using argparse
# parser = argparse.ArgumentParser()
# parser.add_argument('--folds_save_dir', type=str, default='/raid/local/src/apls/albu_inference_mod/results',
# help="path to predicted folds")
# parser.add_argument('--out_dir', type=str, default='/raid/local/src/apls/albu_inference_mod/results',
# help="path to merged predictions")
# args = parser.parse_args()
# #out_dir = os.path.join(os.path.dirname(root), 'merged')
# os.makedirs(args.out_dir, exist_ok=True) #os.path.join(root, 'merged'), exist_ok=True)
#
# t0 = time.time()
# merge_tiffs(args.folds_save_dir, args.out_dir)
# t1 = time.time()
# print ("Time to merge", len(os.listdir(args.folds_save_dir)), "files:", t1-t0, "seconds")
#
# # compress original folds
# output_filename = args.folds_save_dir
# print ("output_filename:", output_filename)
# shutil.make_archive(output_filename, 'gztar', args.folds_save_dir) #'zip', res_dir)
# # remove folds
# #shutil.rmtree(args.folds_save_dir, ignore_errors=True)
#



# if using config instead of argparse
parser = argparse.ArgumentParser()
parser.add_argument('config_path')
Expand All @@ -203,7 +126,6 @@ def execute():
merge_dir_gdal = merge_dir + '_gdal'
#merge_dir_gdal = None


verbose = False
#res_dir = config.folds_save_dir
#res_dir = os.path.join(config.results_dir, config.folder + config.out_suff + '/folds')
Expand Down Expand Up @@ -235,7 +157,6 @@ def execute():
# remove folds
shutil.rmtree(folds_dir, ignore_errors=True)


print ("Compress original gdal folds...")
output_filename = folds_dir + '_gdal'
if os.path.exists(output_filename):
Expand Down
72 changes: 1 addition & 71 deletions cresi/03b_stitch.py
Expand Up @@ -16,13 +16,9 @@
import cv2
import time
import logging
from json.config import Config
from jsons.config import Config
from utils import make_logger

#import sys
#path_basiss = os.path.dirname(os.path.realpath(__file__))
#sys.path.append(path_basiss)
#import basiss

###############################################################################
def post_process_image(df_pos_, data_dir, num_classes=1, im_prefix='',
Expand Down Expand Up @@ -89,17 +85,6 @@ def post_process_image(df_pos_, data_dir, num_classes=1, im_prefix='',
# print("reorder mask_slice_refine.shape", mask_slice_refine.shape)
mask_slice_refine = np.moveaxis(mask_slice_refine, 0, -1)

#print ("mask_slice_refine.shape:", mask_slice_refine.shape)
#print ("Time to read image:", time.time() - t01, "seconds")

# # we want skimage to read in (channels, h, w) for multi-channel
# # assume less than 20 channels
# #print ("mask_channels.shape:", mask_channels.shape)
# if prob_arr_tmp.shape[0] > 20:
# #print ("mask_channels.shape:", mask_channels.shape)
# prob_arr = np.moveaxis(prob_arr_tmp, 0, -1)
# #print ("mask.shape:", mask.shape)

# rescale make slice?
if rescale_factor != 1:
mask_slice_refine = (mask_slice_refine / rescale_factor).astype(np.uint8)
Expand Down Expand Up @@ -129,9 +114,6 @@ def post_process_image(df_pos_, data_dir, num_classes=1, im_prefix='',
overlay_count[np.where(overlay_count == 0)] = 1
if rescale_factor != 1:
mask_raw = mask_raw.astype(np.uint8)

#print ("np.max(overlay_count):", np.max(overlay_count))
#print ("np.min(overlay_count):", np.min(overlay_count))

# throws a memory error if using np.divide...
if (w < 60000) and (h < 60000):
Expand All @@ -144,16 +126,6 @@ def post_process_image(df_pos_, data_dir, num_classes=1, im_prefix='',
for j in range(h):
#print ("j:", j)
mask_norm[j] = (mask_raw[j] / overlay_count[j]).astype(np.uint8)

# # throws a memory error if using np.divide...
# if (w < 60000) and (h < 60000):
# mask_norm = np.divide(mask_raw, overlay_count).astype(np.uint8)
# else:
# for j in range(h):
# #print ("j:", j)
# mask_norm[j] = (mask_raw[j] / overlay_count[j]).astype(np.uint8)
# #for k in range(w):
# # mask_norm[j,k] = (mask_raw[j,k] / overlay_count[j,k]).astype(np.uint8)

# rescale mask_norm
if rescale_factor != 1:
Expand Down Expand Up @@ -232,9 +204,6 @@ def post_process_image_3band(df_pos_, data_dir, n_bands=3, im_prefix='',
overlay_count[np.where(overlay_count == 0)] = 1
if rescale_factor != 1:
im_raw = im_raw.astype(np.uint8)

#print ("np.max(overlay_count):", np.max(overlay_count))
#print ("np.min(overlay_count):", np.min(overlay_count))

# throws a memory error if using np.divide...
if h < 60000:
Expand All @@ -249,9 +218,6 @@ def post_process_image_3band(df_pos_, data_dir, n_bands=3, im_prefix='',
if rescale_factor != 1:
im_norm = (im_norm * rescale_factor).astype(np.uint8)

#print ("im_norm.shape:", im_norm.shape)
#print ("im_norm.dtype:", im_norm.dtype)

return name, im_norm, im_raw, overlay_count

###############################################################################
Expand Down Expand Up @@ -298,17 +264,6 @@ def main():

# assume tile csv is in data dir, not root dir
path_tile_df_csv = os.path.join(config.path_data_root, os.path.dirname(config.test_sliced_dir), config.tile_df_csv)
# try tile_df_csv in results path
#path_tile_df_csv = os.path.join(config.path_results_root, config.test_results_dir, config.tile_df_csv)

#out_dir_mask_norm = config.stitched_dir_norm #os.path.join(config.stitched_dir ,'mask_norm')
#out_dir_mask_raw = config.stitched_dir_raw #os.path.join(config.stitched_dir, 'mask_raw')
#out_dir_count = config.stitched_dir_count #os.path.join(config.stitched_dir, 'mask_count')
#res_root_dir = os.path.dirname(config.merged_dir)
##out_dir_root = os.path.join(res_root_dir, 'stitched')
#out_dir_mask_norm = os.path.join(res_root_dir, 'stitched/mask_norm')
#out_dir_mask_raw = os.path.join(res_root_dir, 'stitched/mask_raw')
#out_dir_count = os.path.join(res_root_dir, 'stitched/mask_count')

# make dirs
os.makedirs(out_dir_mask_norm, exist_ok=True)
Expand All @@ -318,30 +273,6 @@ def main():
res_root_dir = os.path.join(config.path_results_root, config.test_results_dir)
log_file = os.path.join(res_root_dir, 'stitch.log')
console, logger1 = make_logger.make_logger(log_file, logger_name='log')
# ###############################################################################
# # https://docs.python.org/3/howto/logging-cookbook.html#logging-to-multiple-destinations
# # set up logging to file - see previous section for more details
# res_root_dir = os.path.join(config.path_results_root, config.test_results_dir)
# log_file = os.path.join(res_root_dir, 'stitch.log')
# logging.basicConfig(level=logging.DEBUG,
# format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
# datefmt='%m-%d %H:%M',
# filename=log_file,
# filemode='w')
# # define a Handler which writes INFO messages or higher to the sys.stderr
# console = logging.StreamHandler()
# console.setLevel(logging.INFO)
# # set a format which is simpler for console use
# formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
# #formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
# # tell the handler to use this format
# console.setFormatter(formatter)
# # add the handler to the root logger
# logging.getLogger('').addHandler(console)
# logger1 = logging.getLogger('log')
# logger1.info("log file: {x}".format(x=log_file))
# ###############################################################################


# read in df_pos
#df_file = os.path.join(out_dir_root, 'tile_df.csv')
Expand Down Expand Up @@ -427,6 +358,5 @@ def main():
return



if __name__ == "__main__":
main()

0 comments on commit a3751b5

Please sign in to comment.