In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
%matplotlib inline
import numpy as np
import pandas as pd
import os, shutil, glob, sys, math, cv2, re

import segmentation_models_pytorch as smp
import albumentations as albu
from torchsummary import summary

from tqdm import tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torch import Tensor
from torch.jit.annotations import List
from torchvision import models

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

from sklearn.model_selection import train_test_split

# model

In [None]:
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(in_features=512, out_features=2, bias=True)
model.load_state_dict(torch.load("2020_09_22_18_52_46_Resnet18-tumor-or-nonTumor_Dataset_Zenodo.h5"))
model.cuda()
model.eval()

# dataset

In [None]:
def get_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.Resize(height = 224, width = 224, always_apply=True),
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

# https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py, to_tensor     
def to0_1(x, **kwargs):
    return x/255

def get_preprocessing():

    _transform = [
        albu.Lambda(image=to_tensor, mask=to_tensor),
        albu.Lambda(image=to0_1, mask=to0_1),
    ]
    return albu.Compose(_transform)

In [None]:
class Dataset(BaseDataset):
    
    def __init__(self, image_array, augmentation=None, preprocessing=None):
        self.image_array = image_array
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        fp = self.image_array[i]
        
        image = cv2.imread(fp)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         imgNorm= normalizeStaining.normalizeStaining(img = im_rgb)
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image)
            image = sample['image']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image)
            image = sample['image']
        
        return fp, image
    
    def __len__(self):
        return len(self.image_array)

In [None]:
patches_folder = '/nfs/Shared/data/tcga/patches'
all_folders = sorted(os.listdir(patches_folder))

In [None]:
for number, folder_name in enumerate(all_folders):
    
    all_images = sorted(glob.glob(os.path.join(patches_folder, folder_name, "*.jpg")))
    
    test_dataset = Dataset(
        all_images,
        augmentation = get_augmentation(),
        preprocessing = get_preprocessing()
    )
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=32)
    
    is_tumor_image_name = np.array([])
    
    for data in test_dataloader:
        image_names, images = data
        image_names = np.array([name for name in image_names])
        
        images = images.cuda()
        
        with torch.no_grad():
            pred = model(images)
        pred = torch.softmax(pred, axis = 1)
        pred = torch.argmax(pred, axis = 1)
        pred = pred.detach().cpu().numpy()
        
        if len(pred.shape) != 1:
            pred = pred.squeeze()
            
        tp = np.where(pred == 1)[0]
        is_tumor_image_name = np.concatenate((is_tumor_image_name, image_names[tp]))
    print("{}, all_images: {}, tumor_images: {}".format(number, len(all_images), len(is_tumor_image_name)))
    np.save("/nfs/Shared/data/tcga/tumor/{}.npy".format(folder_name), is_tumor_image_name)


In [None]:
npy = glob.glob("/nfs/Shared/data/tcga/tumor/*.npy")
c = 0
for i in npy:
    j = np.load(i)
    c += len(j)