In [None]:
import segmentation_models_pytorch as smp
import openslide
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import ToTensor
import torch.nn.functional as F

import pyvips
import matplotlib.pyplot as plt
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm
import os
from PIL import Image
device=torch.device("cuda:0") if torch.cuda.is_available() else 'cpu'

In [None]:
class_list = {
    0:['stroma'],
    1:['immune'],
    2:['Normal'], 
    3:['Tumor'],
}
model_path='../../model/NIPA/best_seg_ST_class.pt'
model=smp.Unet(
        encoder_name="efficientnet-b7",
        encoder_weights="imagenet",
        in_channels=3,
        classes=len(class_list),
    ).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

In [None]:

color_map = {
    0: [255, 255, 255],      # stroma - white
    1: [0, 255, 0],      # immune - green
    2: [0, 0, 255],      # Normal - blue
    3: [255, 0, 0],    # Tumor - red
    
}
overlap=64
slide_path=glob('../../data/BR_HnE/*.ndpi')
for i in range(len(slide_path)):
    slide_image=openslide.OpenSlide(slide_path[i])
    image_size=2024
    predict_size=1024
    thumbnail=slide_image.get_thumbnail((slide_image.dimensions[0]//128, slide_image.dimensions[1]//128))
    slide = pyvips.Image.new_from_file(slide_path[i])

    thumb_mask=cv2.threshold(255-np.array(thumbnail)[:,:,1],30,255,cv2.THRESH_OTSU)[1]
    thumb_mask=cv2.morphologyEx(thumb_mask,cv2.MORPH_CLOSE,np.ones((15,15),np.uint8))
    thumb_mask=cv2.morphologyEx(thumb_mask,cv2.MORPH_OPEN,np.ones((5,5),np.uint8))
    total_patches = (slide.width//image_size-1) * (slide.height//image_size-1)
    predict_mask = np.ones((1024*(slide.height//image_size), 1024*(slide.width//image_size),3), dtype=np.uint8)*255
    for row in tqdm(range(0,slide.height//((image_size-overlap)-1))):
        for col in range(0,slide.width//((image_size-overlap)-1)):
            if thumb_mask[row*(image_size-overlap)//128:(row+1)*(image_size-overlap)//128,col*(image_size-overlap)//128:(col+1)*(image_size-overlap)//128].sum()==0:
                predict_mask[row*(predict_size-overlap//2)+overlap//4:row*(predict_size-overlap//2)+overlap//4+predict_size-overlap//4,col*(predict_size-overlap//2)+overlap//4:col*(predict_size-overlap//2)+overlap//4+predict_size-overlap//4]=[255,255,255]
                continue
            patch = slide.crop(col*(image_size-overlap), row*(image_size-overlap), image_size, image_size)
            patch = np.ndarray(buffer=patch.write_to_memory(),
                                dtype=np.uint8,
                                shape=[patch.height, patch.width, patch.bands])
            patch=cv2.resize(patch[:,:,:3],(predict_size,predict_size),interpolation=cv2.INTER_NEAREST)
            torch_patch=ToTensor()(patch).unsqueeze(0).to(device)
            with torch.no_grad():
                output=model(torch_patch)
                pr_mask=F.softmax(output,dim=1)
                pr_mask = torch.where(pr_mask>0.3,1,0).cpu().numpy()
            pr_mask=np.argmax(pr_mask,axis=1)
            pr_mask_rgb=np.zeros((predict_size,predict_size,3),dtype=np.uint8)
            for k,v in color_map.items():
                pr_mask_rgb[pr_mask[0]==k]=v
            predict_mask[row*(predict_size-overlap//2)+overlap//4:row*(predict_size-overlap//2)+overlap//4+predict_size-overlap//4,col*(predict_size-overlap//2)+overlap//4:col*(predict_size-overlap//2)+overlap//4+predict_size-overlap//4]=pr_mask_rgb[overlap//4:,overlap//4:]

    predict_mask=cv2.resize(predict_mask,(slide.width//32,slide.height//32),interpolation=cv2.INTER_NEAREST)
    Image.fromarray(predict_mask).save(f'../../results/{os.path.basename(slide_path[i])[:-5]}.png')

In [None]:
((row+1)*(predict_size-overlap//2))-(row*(predict_size-overlap//2)+overlap//4)