In [50]:
import argparse
import os
import glob
import numpy as np
import cv2
import torch
import sklearn.feature_extraction.image

import sys
sys.path.insert(1,'/mnt/data/home/pjl54/WSI_handling')
import wsi

from torchvision.models import DenseNet

#-----helper function to split data into batches
def divide_batch(l, n): 
    for i in range(0, len(l), n):  
        yield l[i:i + n]

# ----- parse command line arguments

input_pattern=['/mnt/ccipd_data/UH_Bladder_Cancer_Project/Blad170830/Blad_2.tif','/mnt/ccipd_data/UH_Bladder_Cancer_Project/Blad170830/Blad_3.tif']
resolution = 0.5
model = '/mnt/data/home/pjl54/bladder/densenet/bladder_densenet_best_model.pth'
color=None
annotation='largest'
patchsize=256
batchsize=10
outdir='./output/'
gpuid=0
force=False
basepath=''
desired_mask_mpp=4

In [None]:
from unet import UNet

te_model = '/mnt/data/home/pjl54/bladder/bladderTE_1mpp_256p.pth'
te_mpp = 1;
te_device = 2
checkpoint = torch.load(te_model, map_location=lambda storage, loc: storage) #load checkpoint to CPU and then put to device https://discuss.pytorch.org/t/saving-and-loading-torch-models-on-2-machines-with-different-number-of-gpu-devices/6666
model = UNet(n_classes=checkpoint["n_classes"], in_channels=checkpoint["in_channels"],
             padding=checkpoint["padding"], depth=checkpoint["depth"], wf=checkpoint["wf"],
             up_mode=checkpoint["up_mode"], batch_norm=checkpoint["batch_norm"]).to(te_device)
model.load_state_dict(checkpoint["model_dict"])
model.eval()

In [51]:

OUTPUT_DIR = outdir

batch_size = batchsize
patch_size = patchsize
base_stride_size = patch_size//2


# ----- load network
device = torch.device(gpuid if gpuid!=-2 and torch.cuda.is_available() else 'cpu')

checkpoint = torch.load(model, map_location=lambda storage, loc: storage) #load checkpoint to CPU and then put to device https://discuss.pytorch.org/t/saving-and-loading-torch-models-on-2-machines-with-different-number-of-gpu-devices/6666

model = DenseNet(growth_rate=checkpoint["growth_rate"], block_config=checkpoint["block_config"],
                 num_init_features=checkpoint["num_init_features"], bn_size=checkpoint["bn_size"],
                 drop_rate=checkpoint["drop_rate"], num_classes=checkpoint["num_classes"]).to(device)

model.load_state_dict(checkpoint["model_dict"])
model.eval()

print(f"total params: \t{sum([np.prod(p.size()) for p in model.parameters()])}")


total params: 	415554


In [52]:
# ----- get file list

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

files = []
basepath = basepath  #
basepath = basepath + os.sep if len(
    basepath) > 0 else ""  # if the user supplied a different basepath, make sure it ends with an os.sep

if len(input_pattern) > 1:  # bash has sent us a list of files
    files = input_pattern
elif input_pattern[0].endswith("tsv"):  # user sent us an input file
    # load first column here and store into files
    with open(input_pattern[0], 'r') as f:
        for line in f:
            if line[0] == "#":
                continue
            files.append(basepath + line.strip().split("\t")[0])
else:  # user sent us a wildcard, need to use glob to find files
    files = glob.glob(basepath + input_pattern[0])

In [54]:
# ------ work on files
for fname in files:

    fname = fname.strip()
    newfname_class = "%s/%s_class.png" % (OUTPUT_DIR, os.path.basename(fname)[0:-4])

    print(f"working on file: \t {fname}")

    # if not force and os.path.exists(newfname_class):
    #     print("Skipping as output file exists")
    #     continue
    #
    # cv2.imwrite(newfname_class, np.zeros(shape=(1, 1)))

    output_fname = outdir + fname + '_grade_results'
    xml_dir = fname[0:fname.rfind(os.path.sep)]
    xml_fname = xml_dir + os.path.sep + os.path.basename(fname)[0:os.path.basename(fname).rfind('.')] + '.xml'

    img = wsi.wsi(fname,xml_fname)
    
    stride_size = int(base_stride_size * (resolution/img["mpp"]))
    
#     if(annotation.lower() == 'wsi'):
#         img_dims = [0,0,w["img_dims"][0],w["img_dims"][1]]
#     else:
#     img_dims = img.get_dimensions_of_annotation(color,annotation)
    stride_size_converted = img.get_coord_at_mpp(stride_size,input_mpp=img["mpps"][0],output_mpp=desired_mask_mpp)
    [mask_small, resize_factor] = img.mask_out_annotation(desired_mpp=desired_mask_mpp,colors_to_use=color)            
    
    mask_small = mask_small[list(range(0,np.shape(mask_small)[0],stride_size_converted)),:]            
    mask_small = mask_small[:,list(range(0,np.shape(mask_small)[1],stride_size_converted))]            

    [rs,cs]=(mask_small>0).nonzero()
    rs = [r*stride_size_converted for r in rs]
    cs = [c*stride_size_converted for c in cs]

    rs = [img.get_coord_at_mpp(r,img["mpps"][0],desired_mask_mpp) for r in rs]
    cs = [img.get_coord_at_mpp(c,img["mpps"][0],desired_mask_mpp) for c in cs]

    goods = np.ones(np.shape(rs)[0])
    for k in range(0,np.shape(rs)[0]):

        te_tile = wsi_img.get_tile(coords=(cs[k],rs[k]),wh=(int(patch_size*(te_mpp/model_mpp)),int(patch_size*(te_mpp/model_mpp))),desired_mpp=te_mpp)
        if((np.sum(te_tile[:,:,1]>220)/np.size(te_tile[:,:,1]))>0.30):
            goods[k] = False
        else:    
            arr_out_gpu = torch.from_numpy(np.expand_dims(te_tile,axis=0).transpose(0,3,1,2) / 255).type('torch.FloatTensor').to(te_device)
            output_batch = model(arr_out_gpu)
            output = output_batch[0,:,:,:].detach().cpu().numpy()
            te_map = output.argmax(axis=0)==1

            if((np.sum(te_map)/np.size(te_map))<0.50):                
                goods[k] = False


    cs = [c for idx,c in enumerate(cs) if goods[idx]]
    rs = [r for idx,r in enumerate(rs) if goods[idx]]
    
    grid_points = [(x,y) for x in cs for y in rs]

    points_split = divide_batch(grid_points,batch_size)

    #in case we have a large network, lets cut the list of tiles into batches
    output = np.zeros((0,checkpoint["num_classes"]))
    for i,batch_points in enumerate(points_split):

        batch_arr = np.array([img.get_tile(resolution,coords,(patch_size,patch_size)) for coords in batch_points])
#             print(np.shape(arr_out))
#             arr_out = arr_out.reshape(-1,patch_size,patch_size,3)

        arr_out_gpu = torch.from_numpy(batch_arr.transpose(0, 3, 1, 2) / 255).type('torch.FloatTensor').to(device)

        # ---- get results
        output_batch = model(arr_out_gpu)

        # --- pull from GPU and append to rest of output 
        output_batch = output_batch.detach().cpu().numpy()

        output = np.append(output,output_batch,axis=0)


    tileclass = np.argmax(output, axis=1)
    predc,predccounts=np.unique(tileclass, return_counts=True)
    for c,cc in zip(predc,predccounts):
        print(f"class/count: \t{c}\t{cc}")

    print(f"predicted class:\t{predc[np.argmax(predccounts)]}")

working on file: 	 /mnt/ccipd_data/UH_Bladder_Cancer_Project/Blad170830/Blad_2.tif
class/count: 	0	1156
class/count: 	1	59360
predicted class:	1
working on file: 	 /mnt/ccipd_data/UH_Bladder_Cancer_Project/Blad170830/Blad_3.tif


KeyboardInterrupt: 