<a href="https://colab.research.google.com/github/aashrithresearch/pytorch_pathology/blob/main/lymphoma_output.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import argparse
import os
import glob
import numpy as np
import cv2
import torch
import numbers
from numpy.lib.stride_tricks import as_strided
from torchvision.models import DenseNet

In [3]:
def extract_patches(arr, patch_shape=8, extraction_step=1):
    arr_ndim = arr.ndim

    if isinstance(patch_shape, numbers.Number):
        patch_shape = tuple([patch_shape] * arr_ndim)
    if isinstance(extraction_step, numbers.Number):
        extraction_step = tuple([extraction_step] * arr_ndim)

    patch_strides = arr.strides

    slices = tuple(slice(None, None, st) for st in extraction_step)
    indexing_strides = arr[slices].strides

    patch_indices_shape = (
        (np.array(arr.shape) - np.array(patch_shape)) // np.array(extraction_step)
    ) + 1

    shape = tuple(list(patch_indices_shape) + list(patch_shape))
    strides = tuple(list(indexing_strides) + list(patch_strides))

    patches = as_strided(arr, shape=shape, strides=strides)
    return patches

In [4]:
def divide_batch(l, n):
    for i in range(0, l.shape[0], n):
        yield l[i:i + n,::]

In [7]:
! mkdir /content/drive/MyDrive/lymphoma.tar/data/

In [15]:
! mkdir /content/drive/MyDrive/lymphoma.tar/output/

In [18]:
parser = argparse.ArgumentParser(description='Make output for entire image using Unet')
parser.add_argument('input_pattern',
                    help="input filename pattern. try: *.png, or tsv file containing list of files to analyze",
                    nargs="*")

parser.add_argument('-p', '--patchsize', help="patchsize, default 256", default=256, type=int)
parser.add_argument('-s', '--batchsize', help="batchsize for controlling GPU memory usage, default 10", default=10, type=int)
parser.add_argument('-o', '--outdir', help="outputdir, default ./output/", default="./output/", type=str)
parser.add_argument('-r', '--resize', help="resize factor 1=1x, 2=2x, .5 = .5x", default=1, type=float)
parser.add_argument('-m', '--model', help="model", default="/content/drive/MyDrive/lymphoma.tar/lymphoma_densenet_best_model.pth", type=str)
parser.add_argument('-i', '--gpuid', help="id of gpu to use", default=0, type=int)
parser.add_argument('-f', '--force', help="force regeneration of output even if it exists", default=False,
                    action="store_true")
parser.add_argument('-b', '--basepath',
                    help="base path to add to file names, helps when producing data using tsv file as input",
                    default="", type=str)

args = parser.parse_args(["*tif", "-p", "256", "-m", "/content/drive/MyDrive/lymphoma.tar/lymphoma_densenet_best_model.pth"])

if not (args.input_pattern):
    parser.error('No images selected with input pattern')

OUTPUT_DIR = args.outdir
resize = args.resize

batch_size = args.batchsize
patch_size = args.patchsize
stride_size = patch_size//2

In [7]:
device = torch.device(args.gpuid if args.gpuid!=-2 and torch.cuda.is_available() else 'cpu')

checkpoint = torch.load(args.model, map_location=lambda storage, loc: storage, weights_only=False) #load checkpoint to CPU and then put to device https://discuss.pytorch.org/t/saving-and-loading-torch-models-on-2-machines-with-different-number-of-gpu-devices/6666

model = DenseNet(growth_rate=checkpoint["growth_rate"], block_config=checkpoint["block_config"],
                 num_init_features=checkpoint["num_init_features"], bn_size=checkpoint["bn_size"],
                 drop_rate=checkpoint["drop_rate"], num_classes=checkpoint["num_classes"]).to(device)

model.load_state_dict(checkpoint["model_dict"])
model.eval()

print(f"total params: \t{sum([np.prod(p.size()) for p in model.parameters()])}")

total params: 	415683


In [8]:
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

files = []
basepath = args.basepath
basepath = basepath + os.sep if len(
    basepath) > 0 else ""

if len(args.input_pattern) > 1:
    files = args.input_pattern
elif args.input_pattern[0].endswith("tsv"):
    with open(args.input_pattern[0], 'r') as f:
        for line in f:
            if line[0] == "#":
                continue
            files.append(basepath + line.strip().split("\t")[0])
else:
    files = glob.glob(args.basepath + args.input_pattern[0])

In [9]:
os.path.exists(args.model)

True

In [11]:
for fname in files:

    fname = fname.strip()
    print(fname)
    newfname_class = "%s/%s_class.png" % (OUTPUT_DIR, os.path.basename(fname)[0:-4])

    print(f"working on file: \t {fname}")
    io = cv2.cvtColor(cv2.imread(fname),cv2.COLOR_BGR2RGB)
    io = cv2.resize(io, (0, 0), fx=args.resize, fy=args.resize)

    io_shape_orig = np.array(io.shape)

    io = np.pad(io, [(stride_size//2, stride_size//2), (stride_size//2, stride_size//2), (0, 0)], mode="reflect")

    io_shape_wpad = np.array(io.shape)

    npad0 = int(np.ceil(io_shape_wpad[0] / patch_size) * patch_size - io_shape_wpad[0])
    npad1 = int(np.ceil(io_shape_wpad[1] / patch_size) * patch_size - io_shape_wpad[1])

    io = np.pad(io, [(0, npad0), (0, npad1), (0, 0)], mode="constant")

    arr_out = extract_patches(io,(patch_size,patch_size,3),stride_size)
    arr_out_shape = arr_out.shape
    arr_out = arr_out.reshape(-1,patch_size,patch_size,3)

    output = np.zeros((0,checkpoint["num_classes"]))
    for batch_arr in divide_batch(arr_out,batch_size):

        arr_out_gpu = torch.from_numpy(batch_arr.transpose(0, 3, 1, 2) / 255).type('torch.FloatTensor').to(device)

        output_batch = model(arr_out_gpu)

        output_batch = output_batch.detach().cpu().numpy()
        output = np.append(output,output_batch,axis=0)


    tileclass = np.argmax(output, axis=1)
    predc,predccounts=np.unique(tileclass, return_counts=True)
    for c,cc in zip(predc,predccounts):
        print(f"class/count: \t{c}\t{cc}")

    print(f"predicted class:\t{predc[np.argmax(predccounts)]}")