In [None]:
import segmentation_models_pytorch as smp
import openslide
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import pytorch_model_summary
import pyvips
import matplotlib.pyplot as plt
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm
import os
device=torch.device("cuda:0") if torch.cuda.is_available() else 'cpu'

In [None]:
class_list = {
    0:['stroma'],
    1:['immune'],
    2:['Normal'],
    3:['Tumor'],
}
model_path='../../model/NIPA/best_seg_ST_class.pt'
model=smp.Unet(
        encoder_name="efficientnet-b7",
        encoder_weights="imagenet",
        in_channels=3,
        classes=len(class_list),
    ).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
pytorch_model_summary.summary(model,torch.zeros(1, 3, 1024, 1024).to(device),print_summary=True)

In [None]:

color_map = {
    0: [255, 255, 255],  # background - white
    1: [255, 0, 0],      # stroma - red
    2: [0, 255, 0],      # immune - green
    3: [0, 0, 255],      # Normal - blue
    4: [255, 255, 0],    # Tumor - yellow
    
}
slide_path=glob('../../data/BR_HnE/*.ndpi')
i=0
slide_image=openslide.OpenSlide(slide_path[i])
image_size=2024
thumbnail=slide_image.get_thumbnail((slide_image.dimensions[0]//64, slide_image.dimensions[1]//64))
slide = pyvips.Image.new_from_file(slide_path[i])

thumb_mask=cv2.threshold(255-np.array(thumbnail)[:,:,1],30,255,cv2.THRESH_BINARY)[1]
thumb_mask=cv2.morphologyEx(thumb_mask,cv2.MORPH_CLOSE,np.ones((15,15),np.uint8))
thumb_mask=cv2.morphologyEx(thumb_mask,cv2.MORPH_OPEN,np.ones((5,5),np.uint8))
total_patches = (slide.width//image_size-1) * (slide.height//image_size-1)

patch=slide.crop(300*64,500*64, image_size, image_size)
patch=np.ndarray(buffer=patch.write_to_memory(),
                            dtype=np.uint8,
                            shape=[patch.height, patch.width, patch.bands])
patch=cv2.resize(patch[:,:,:3],(1024,1024))

In [None]:
torch_patch=ToTensor()(patch).unsqueeze(0).to(device)
with torch.no_grad():
    output=model(torch_patch)
    pr_mask=F.softmax(output,dim=1)
    pr_mask = torch.where(pr_mask>0.5,1,0).cpu().numpy()

    
pr_mask=np.argmax(pr_mask,axis=1)+1
pr_mask_rgb=np.zeros((1024,1024,3),dtype=np.uint8)
for k,v in color_map.items():
    pr_mask_rgb[pr_mask[0]==k]=v
plt.figure(figsize=(12,8))
plt.subplot(1,3,1)
plt.imshow(patch)
plt.title('H&E patch')
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(pr_mask_rgb)
plt.title('Predicted mask')
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(patch/255*0.5+pr_mask_rgb/255*0.5,cmap='jet',vmin=0,vmax=4)
plt.title('Predicted classes')
plt.axis('off')
plt.show()

In [None]:
patch*0.5+pr_mask_rgb*0.5