In [None]:
import segmentation_models_pytorch as smp
import openslide
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import ToTensor
import torch.nn.functional as F

import pyvips
import matplotlib.pyplot as plt
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm
import os
from PIL import Image,ImageOps
import json
device=torch.device("cuda:0") if torch.cuda.is_available() else 'cpu'

In [None]:

class_list = {
    0:["Background"],
    1:["NT_stroma"],
    2:["NT_epithelial"], 
    3:["NT_immune"],
    4:["Tumor"],
    5:["TP_invasive"],
    6:["TP_in_situ"],

}

model_path='../../model/areaSeg/breast/ST_callback.pt'
model=smp.DeepLabV3Plus(
        encoder_name="efficientnet-b7",
        encoder_weights="imagenet",
        in_channels=3,
        classes=len(class_list),
    ).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

def Image2RGBA(src_img, slide):
    dst_img=ImageOps.expand(src_img, border=(0, 0, (slide.width-src_img.size[0]*32)//32, (slide.height-src_img.size[1]*32)//32), fill=(255, 255, 255))
    
    # 투명도 설정을 위해 RGBA로 변환
    dst_img = dst_img.convert('RGBA')
    data = np.array(dst_img)
    
    # 흰색 픽셀 찾기 (RGB 값이 모두 255인 경우)
    white_mask = (data[:,:,0] == 255) & (data[:,:,1] == 255) & (data[:,:,2] == 255)
    
    # 흰색은 완전 투명 (알파값 0), 나머지는 50% 투명 (알파값 128)
    data[white_mask, 3] = 0     # 완전 투명
    data[~white_mask, 3] = 75  # 50% 투명
    
    # numpy 배열을 다시 PIL 이미지로 변환
    dst_img = Image.fromarray(data, 'RGBA')

    return dst_img

In [None]:
color_map = {
    0: (255, 255, 255),          # Background - White
    1: (0, 255, 0),        # NT_stroma - Green
    2: (0, 0, 255),        # NT_epithelial - Blue
    3: (255, 255, 0),      # NT_immune - Yellow
    4: (255, 0, 0),        # Tumor - Red
    5: (255, 165, 0),      # TP_invasive - Orange
    6: (128, 0, 128),      # TP_in_situ - Purple
}
overlap=128
save_path='../../results/BR_HnE/'
# slide_path=glob('../../data/BR_HnE/*.ndpi')
slide_path=glob('../../data/BR_HnE/*.ndpi')
mpp=0.226*2
def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)
create_directory(save_path)
color_map_name = {
    "NT_stroma": [0, 255, 0],     # NT_stroma - Green
    "NT_epithelial": [0, 0, 255],   # NT_epithelial - Blue
    "NT_immune": [255, 255, 0],   # NT_immune - Yellow
    "Tumor": [255, 0, 0],   # Tumor - Red
    "TP_invasive": [255, 165, 0],   # TP_invasive - Orange
    "TP_in_situ": [128, 0, 128],   # TP_in_situ - Purple
}
with open(f'{save_path}class_color.json', 'w') as f:
    json.dump(color_map_name, f, indent=4)
for i in range(len(slide_path)):
    mpp_size={
        "NT_stroma":0.0,
        "NT_epithelial":0.0,
        "NT_immune": 0.0,  
        "Tumor": 0.0,   
        "TP_invasive":0.0,
        "TP_in_situ": 0.0, 
        "total":0.0
    }
    slide_image=openslide.OpenSlide(slide_path[i])
    image_size=2024
    predict_size=1024
    thumbnail=slide_image.get_thumbnail((slide_image.dimensions[0]//128, slide_image.dimensions[1]//128))
    slide = pyvips.Image.new_from_file(slide_path[i])

    thumb_mask=cv2.threshold(255-np.array(thumbnail)[:,:,1],30,255,cv2.THRESH_OTSU)[1]
    thumb_mask=cv2.morphologyEx(thumb_mask,cv2.MORPH_CLOSE,np.ones((15,15),np.uint8))
    thumb_mask=cv2.morphologyEx(thumb_mask,cv2.MORPH_OPEN,np.ones((5,5),np.uint8))
    predict_mask = np.ones(((slide.height//2), (slide.width//2),3), dtype=np.uint8)*255
    for row in tqdm(range(0,slide.height//(image_size-overlap)-1)):
        for col in range(0,slide.width//(image_size-overlap)-1):
            if thumb_mask[row*(image_size-overlap)//128:(row+1)*(image_size-overlap)//128,col*(image_size-overlap)//128:(col+1)*(image_size-overlap)//128].sum()==0:
                predict_mask[row*(image_size-overlap)//2+overlap//4:row*(image_size-overlap)//2+predict_size,col*(image_size-overlap)//2+overlap//4:col*(image_size-overlap)//2+predict_size]=[255,255,255]
                continue
            patch = slide.crop(col*(image_size-overlap), row*(image_size-overlap), image_size, image_size)
            patch = np.ndarray(buffer=patch.write_to_memory(),
                                dtype=np.uint8,
                                shape=[patch.height, patch.width, patch.bands])
            patch=cv2.resize(patch[:,:,:3],(predict_size,predict_size),interpolation=cv2.INTER_NEAREST)
            torch_patch=ToTensor()(patch).unsqueeze(0).to(device)
            with torch.no_grad():
                output=model(torch_patch)
            pr_mask=np.argmax(F.softmax(output,dim=1).cpu().numpy(),axis=1)
            mpp_size["total"]+=(pr_mask[0]>0).sum()*mpp*mpp
            for cls in range(1,len(class_list)):
                mpp_size[class_list[cls][0]]+=(pr_mask[0]==cls).sum()*mpp*mpp
                
            pr_mask_rgb=np.zeros((predict_size,predict_size,3),dtype=np.uint8)
            for k,v in color_map.items():
                pr_mask_rgb[pr_mask[0]==k]=v
            predict_mask[row*(image_size-overlap)//2+overlap//4:row*(image_size-overlap)//2+predict_size,col*(image_size-overlap)//2+overlap//4:col*(image_size-overlap)//2+predict_size]=pr_mask_rgb[overlap//4:,overlap//4:]

    predict_mask=cv2.resize(predict_mask,(slide.width//32,slide.height//32),interpolation=cv2.INTER_NEAREST)
    Image2RGBA(Image.fromarray(predict_mask), slide).save(f'{save_path}{os.path.basename(slide_path[i]).split(".")[0]}.png')
    with open(f'{save_path}{os.path.basename(slide_path[i]).split(".")[0]}.json', 'w') as f:
        json.dump(mpp_size, f, indent=4)
    