In [1]:
import random
import openslide
import os
import torch
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from urllib.request import urlopen
from PIL import Image
import timm
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from glob import glob
from sklearn.model_selection import train_test_split
import torch.nn as nn
import random
import torchmetrics
from torch.nn.modules.batchnorm import _BatchNorm
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pyvips
import json
import cv2
device='cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
class_list={
    'background':0,
    'Tumor':1,
    'Stroma':2,
    'Immune cells':3,
    'Necrosis':4,
    'alveoli':5,
    'Other':6
}

In [3]:
class FeatureExtractor(nn.Module):
    """Feature extoractor block"""
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        cnn1= timm.create_model('tf_efficientnetv2_s', pretrained=True)
        self.feature_ex = nn.Sequential(*list(cnn1.children())[:-1])

    def forward(self, inputs):
        features = self.feature_ex(inputs)
        
        return features
class custom_model(nn.Module):
    def __init__(self, num_classes, image_feature_dim,feature_extractor_scale1: FeatureExtractor):
        super(custom_model, self).__init__()
        self.num_classes = num_classes
        self.image_feature_dim = image_feature_dim

        # Remove the classification head of the CNN model
        self.feature_extractor = feature_extractor_scale1
        # Classification layer
        self.classification_layer = nn.Linear(image_feature_dim, num_classes)
        
    def forward(self, inputs):
        batch_size, channels, height, width = inputs.size()
        
        # Feature extraction using the pre-trained CNN
        features = self.feature_extractor(inputs)  # Shape: (batch_size, 2048, 1, 1)
        
        # Classification layer
        logits = self.classification_layer(features)  # Shape: (batch_size, num_classes)
        
        return logits
    

Feature_Extractor=FeatureExtractor()
model = custom_model(len(class_list),1280,Feature_Extractor)
model = model.to(device)
model.load_state_dict(torch.load('../../model/IHC_Tissue_Region_classification/check.pt'))

<All keys matched successfully>

In [None]:
slide_path=glob('../../data/WSI_test/BR/*KI*.tiff')
image_size=256
image_resize=128
label_img_resize=8 #패치 하나당 리사이즈 크기
count=0
pixel_size=0.25 #mpp
pixel_extend=pixel_size*pixel_size*image_size*image_size

# 색상 매핑 사전 정의
# color_map = {
#     0: [255, 255, 255],  # background - white
#     1: [255, 0, 0],      # Tumor - red
#     2: [0, 255, 0],      # Stroma - green
#     3: [0, 0, 255],      # Immune cells - blue
#     4: [255, 255, 0],    # Necrosis - yellow
#     5: [0, 255, 255],    # alveoli - cyan
#     6: [128, 128, 128]   # Other - gray
# }
color_map = {
    0: [255, 255, 255],  # background - white
    1: [255, 0, 0],      # Tumor - red
    2: [0, 255, 0],      # Stroma - green
    3: [0, 0, 255],      # Immune cells - blue
    4: [255, 255, 255],    # Necrosis - yellow
    5: [0, 255, 255],    # alveoli -alveoli
    6: [255, 255, 255]   # Other - gray
}
label_size = {
    'total': 0, #total area
    'Tumor': 0,      # Tumor
    'Stroma': 0,      # Stroma 
    'Immune cells': 0,  # Immune cells
    'alveoli': 0,    # alveoli
}

for i in tqdm(range(len(slide_path))):

    temp_patch_size=2048
    temp_patch_count=temp_patch_size//image_size
    file_name=os.path.basename(slide_path[i]).split('.')[0]
    slide = pyvips.Image.new_from_file(slide_path[i])
    trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
    temp_img_list=[]
    for row in range(0,slide.height//temp_patch_size):
        temp_col_list=[]
        for col in range(0,slide.width//temp_patch_size):
            patch = slide.crop(col*temp_patch_size, row*temp_patch_size, temp_patch_size, temp_patch_size)
            patch = np.ndarray(buffer=patch.write_to_memory(),
                                dtype=np.uint8,
                                shape=[patch.height, patch.width, patch.bands])
            temp_col_list.append(patch)
        temp_img_list.append(temp_col_list)


    # 병렬 처리를 위한 함수 정의
    def process_large_patch(patch_data):
        row, col, img = patch_data
        
        # 이미지를 텐서로 변환
        img_tensor = trans(img[:,:,:3]).to(device)  # (C, H, W)
        
        # 작은 패치들로 분할하여 배치 생성
        patches = []
        positions = []
        
        for detail_row in range(temp_patch_count):
            for detail_col in range(temp_patch_count):
                patch = img_tensor[:, detail_row*image_size:(detail_row+1)*image_size, 
                                detail_col*image_size:(detail_col+1)*image_size]
                patch = F.interpolate(patch.unsqueeze(0), size=(image_resize, image_resize), mode='bilinear', align_corners=False).squeeze(0)
                patches.append(patch)
                positions.append((detail_row, detail_col))
        
        # 배치로 변환
        patches_batch = torch.stack(patches)  # (N, C, H, W)
        
        # 배치 예측 (1024x1024 패치에서 8x8=64개의 128x128 패치가 나오므로)
        # 한 번에 모든 패치를 처리할 수 있음 (배치 크기 = 64)
        predictions = []
        
        model.eval()
        with torch.no_grad():
            # patches_batch는 정확히 64개 (8x8)의 패치를 포함
            outputs = model(patches_batch)
            _, preds = torch.max(outputs, 1)
            predictions = preds.cpu().numpy()
        
        return row, col, predictions, positions

    # 병렬 처리를 위한 데이터 준비
    from concurrent.futures import ThreadPoolExecutor
    import multiprocessing as mp

    # 레이블 이미지 초기화
    label_img = np.zeros((label_img_resize*temp_patch_count*len(temp_img_list),
                        label_img_resize*temp_patch_count*len(temp_img_list[0]), 3), dtype=np.uint8)

    # 패치 데이터 준비
    patch_data_list = []
    for row in range(len(temp_img_list)):
        for col in range(len(temp_img_list[0])):
            patch_data_list.append((row, col, temp_img_list[row][col]))

    print(f"총 {len(patch_data_list)}개의 큰 패치를 처리합니다...")

    # 순차적으로 처리 (GPU 메모리 관리를 위해)
    for patch_data in tqdm(patch_data_list, desc="Processing patches"):
        row, col, predictions, positions = process_large_patch(patch_data)
        
        # 예측 결과를 레이블 이미지에 적용
        for pred, (detail_row, detail_col) in zip(predictions, positions):
            y_start = row * temp_patch_count * label_img_resize + detail_row * label_img_resize
            y_end = y_start + label_img_resize
            x_start = col * temp_patch_count * label_img_resize + detail_col * label_img_resize
            x_end = x_start + label_img_resize
            if pred!=0:
                label_size['total']+=pixel_extend
            label_img[y_start:y_end, x_start:x_end] = color_map[pred]
            
            if pred==1:
                label_size['Tumor']+=pixel_extend
            elif pred==2:
                label_size['Stroma']+=pixel_extend
            elif pred==3:
                label_size['Immune cells']+=pixel_extend
            # elif pred==4:
            #     label_size['Necrosis']+=pixel_extend
            elif pred==5:
                label_size['alveoli']+=pixel_extend

    print("세그멘테이션 완료. 결과를 저장합니다...")
    def create_dir(path):
        if not os.path.exists(path):
            os.makedirs(path)
    create_dir('../../results/BR/WSI_classification/')
    Image.fromarray(label_img).save(f'../../results/BR/WSI_classification/{file_name}_segmentation.png')
    json.dump(label_size, open(f'../../results/BR/WSI_classification/{file_name}_segmentation_size.json','w'))

  0%|          | 0/5 [00:00<?, ?it/s]

Exception ignored from cffi callback <function _log_handler_callback at 0x7fd4be197e20>:
Traceback (most recent call last):
  File "/home/work/.local/lib/python3.12/site-packages/pyvips/__init__.py", line 149, in _log_handler_callback
    def _log_handler_callback(domain, level, message, user_data):

KeyboardInterrupt: 


In [None]:
temp_patch_count

8

In [None]:
slide_path[0]