v2
7/11/2018

In [1]:
import argparse
import os
import glob
import numpy as np
import cv2
import torch
import sklearn.feature_extraction.image
import sys

sys.path.insert(1,'/mnt/data/home/pjl54/WSI_handling')
import wsi

In [2]:
from unet import UNet

In [3]:
#-----helper function to split data into batches
def divide_batch(l, n): 
    for i in range(0, len(l), n):  
        yield l[i:i + n] 

In [4]:
# ----- parse command line arguments
parser = argparse.ArgumentParser(description='Make output for entire image using Unet')
parser.add_argument('input_pattern',
                    help="input filename pattern. try: *.png, or tsv file containing list of files to analyze",
                    nargs="*")

_StoreAction(option_strings=[], dest='input_pattern', nargs='*', const=None, default=None, type=None, choices=None, help='input filename pattern. try: *.png, or tsv file containing list of files to analyze', metavar=None)

In [5]:
parser.add_argument('-r', '--resolution', help="image resolution in microns per pixel", default=1, type=float)
parser.add_argument('-c', '--color', help="annotation color to use, default None", default='green', type=str)
parser.add_argument('-a', '--annotation', help="annotation index to use, default largest", default='wsi', type=str)

parser.add_argument('-p', '--patchsize', help="patchsize, default 256", default=256, type=int)
parser.add_argument('-s', '--batchsize', help="batchsize for controlling GPU memory usage, default 10", default=10, type=int)
parser.add_argument('-o', '--outdir', help="outputdir, default ./output/", default="./output/", type=str)
parser.add_argument('-m', '--model', help="model", default="best_model.pth", type=str)
parser.add_argument('-i', '--gpuid', help="id of gpu to use", default=0, type=int)
parser.add_argument('-f', '--force', help="force regeneration of output even if it exists", default=False,
                    action="store_true")
parser.add_argument('-b', '--basepath',
                    help="base path to add to file names, helps when producing data using tsv file as input",
                    default="", type=str)

_StoreAction(option_strings=['-b', '--basepath'], dest='basepath', nargs=None, const=None, default='', type=<class 'str'>, choices=None, help='base path to add to file names, helps when producing data using tsv file as input', metavar=None)

In [27]:
args = parser.parse_args(['-s20','-o/mnt/data/home/pjl54/bladder/test','-r1','-a','largest','-m','/mnt/data/home/pjl54/models/models/bladderTE_1mpp_256p.pth','-i2','/mnt/ccipd_data/UH_Bladder_Cancer_Project/Blad170830/Blad_1.tif'])

In [7]:
if not (args.input_pattern):
    parser.error('No images selected with input pattern')

In [8]:
OUTPUT_DIR = args.outdir


In [28]:
batch_size = args.batchsize
patch_size = args.patchsize
base_stride_size = patch_size//2

In [10]:
# ----- load network
device = torch.device(args.gpuid if torch.cuda.is_available() else 'cpu')

In [11]:
checkpoint = torch.load(args.model, map_location=lambda storage, loc: storage) #load checkpoint to CPU and then put to device https://discuss.pytorch.org/t/saving-and-loading-torch-models-on-2-machines-with-different-number-of-gpu-devices/6666
model = UNet(n_classes=checkpoint["n_classes"], in_channels=checkpoint["in_channels"],
             padding=checkpoint["padding"], depth=checkpoint["depth"], wf=checkpoint["wf"],
             up_mode=checkpoint["up_mode"], batch_norm=checkpoint["batch_norm"]).to(device)
model.load_state_dict(checkpoint["model_dict"])
model.eval()

UNet(
  (down_path): ModuleList(
    (0): UNetConvBlock(
      (block): Sequential(
        (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): ReLU()
        (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): UNetConvBlock(
      (block): Sequential(
        (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): ReLU()
        (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): UNetConvBlock(
      (block): Sequential(
        (0): Conv2d(16, 3

In [12]:
print(f"total params: \t{sum([np.prod(p.size()) for p in model.parameters()])}")

total params: 	121538


----- get file list

In [13]:
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [14]:
files = []
basepath = args.basepath  #
basepath = basepath + os.sep if len(
    basepath) > 0 else ""  # if the user supplied a different basepath, make sure it ends with an os.sep

In [15]:
if len(args.input_pattern) > 1:  # bash has sent us a list of files
    files = args.input_pattern
elif args.input_pattern[0].endswith("tsv"):  # user sent us an input file
    # load first column here and store into files
    with open(args.input_pattern[0], 'r') as f:
        for line in f:
            if line[0] == "#":
                continue
            files.append(basepath + line.strip().split("\t")[0])
else:  # user sent us a wildcard, need to use glob to find files
    files = glob.glob(args.basepath + args.input_pattern[0])

In [158]:
# ------ work on files
for fname in files:
    
    fname = fname.strip()
    newfname_class = "%s/%s_class.png" % (OUTPUT_DIR, os.path.basename(fname)[0:fname.rfind('.')])

    print(f"working on file: \t {fname}")
    print(f"saving to : \t {newfname_class}")

#     if not args.force and os.path.exists(newfname_class):
#         print("Skipping as output file exists")
#         continue
        
    cv2.imwrite(newfname_class, np.zeros(shape=(1, 1)))                                            
    
    xml_dir = fname[0:fname.rfind(os.path.sep)]
    xml_fname = xml_dir + os.path.sep + os.path.basename(fname)[0:os.path.basename(fname).rfind('.')] + '.xml'

    img = wsi.wsi(fname,xml_fname)
    
    stride_size = int(base_stride_size * (args.resolution/img["mpp"]))
    
    if(args.annotation.lower() == 'wsi'):
        img_dims = [0,0,w["img_dims"][0],w["img_dims"][1]]
    else:
        img_dims = img.get_dimensions_of_annotation(args.color,args.annotation)
    
    if img_dims:
    
        x_start = int(img_dims[0])
        y_start = int(img_dims[1])
        w_orig = int(img_dims[2]) - x_start
        h_orig = int(img_dims[3]) - y_start


        w = int(w_orig + (patch_size - (w_orig % patch_size)))
        h = int(h_orig + (patch_size - (h_orig % patch_size)))

        x_points = range(x_start-stride_size//2,x_start+w+stride_size//2,stride_size)
        y_points = range(y_start-stride_size//2,y_start+h+stride_size//2,stride_size)

        grid_points = [(x,y) for x in x_points for y in y_points]
        points_split = divide_batch(grid_points,batch_size)

        #in case we have a large network, lets cut the list of tiles into batches
        output = np.zeros((len(grid_points),base_stride_size,base_stride_size),np.uint8)
        for i,batch_points in enumerate(points_split):

            batch_arr = np.array([img.get_tile(args.resolution,coords,(patch_size,patch_size)) for coords in batch_points])

            arr_out_gpu = torch.from_numpy(batch_arr.transpose(0, 3, 1, 2) / 255).type('torch.FloatTensor').to(device)

            # ---- get results
            output_batch = model(arr_out_gpu)

            # --- pull from GPU and append to rest of output 
            output_batch = output_batch.detach().cpu().numpy()
            output_batch = output_batch.argmax(axis=1)
            
            #remove the padding from each tile, we only keep the center
            output_batch = output_batch[:,base_stride_size//2:-base_stride_size//2,base_stride_size//2:-base_stride_size//2]

            output[((i+1)*batch_size - batch_size):((i+1)*batch_size),:,:] = output_batch


        #turn from a single list into a matrix of tiles
        output = output.reshape(len(x_points),len(y_points),base_stride_size,base_stride_size)
        output = np.concatenate(np.concatenate(output.transpose(1,0,2,3),1),1)

        #incase there was extra padding to get a multiple of patch size, remove that as well
        _,mask = img.get_annotated_region(args.resolution,args.color,args.annotation,return_img=False)        

        output = output[0:mask.shape[0], 0:mask.shape[1]] #remove paddind, crop back
        output = np.bitwise_and(output>0,mask>0)*255
        
        # --- save output

        cv2.imwrite(newfname_class, output)

    else:
        print('No annotation of color')

    output = None
    output_batch = None
    mask = None

working on file: 	 /mnt/ccipd_data/UH_Bladder_Cancer_Project/Blad170830/Blad_1.tif
saving to : 	 /mnt/data/home/pjl54/bladder/test/Blad_1.tif_class.png
