In [None]:
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from glob import glob
from tqdm import tqdm
import yaml
from nets import nn
import torchvision
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:",device)

tumor cell binary detection model

In [None]:
with open('utils/args.yaml', errors='ignore') as f:
    params = yaml.safe_load(f)
def collate_fn1(batch):
    samples, cls, box, indices = zip(*batch)

    cls = torch.cat(cls, dim=0)
    box = torch.cat(box, dim=0)

    new_indices = list(indices)
    for i in range(len(indices)):
        new_indices[i] += i
    indices = torch.cat(new_indices, dim=0)

    targets = {'cls': cls,
                'box': box,
                'idx': indices}
    return torch.stack(samples, dim=0), targets
def wh2xy(x):
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y
   
def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65, class_thresholds=None):
    """
    빠른 클래스별 NMS - 성능 최적화 버전
    """
    max_wh = 7680
    max_det = 300
    max_nms = 30000

    bs = outputs.shape[0]
    nc = outputs.shape[1] - 4
    
    # 빠른 필터링을 위해 가장 낮은 threshold 사용
    min_conf = confidence_threshold
    if class_thresholds:
        min_conf = min(min(class_thresholds.values()), confidence_threshold)
    
    # 전체 confidence가 낮은 것들 먼저 제거
    xc = outputs[:, 4:4 + nc].amax(1) > min_conf
    
    output = [torch.zeros((0, 6), device=outputs.device)] * bs
    
    for xi, x in enumerate(outputs):  # image index, image inference
        x = x.transpose(0, -1)[xc[xi]]
        
        if not x.shape[0]:
            continue

        # 박스와 클래스 분리
        box, cls = x.split((4, nc), 1)
        box = wh2xy(box)
        
        # 각 검출의 최고 클래스와 confidence 찾기
        conf, j = cls.max(1, keepdim=True)
        x = torch.cat((box, conf, j.float()), 1)
        
        # 클래스별 threshold 적용 (간단한 방식)
        if class_thresholds:
            keep = torch.zeros(x.shape[0], dtype=torch.bool, device=x.device)
            for i, detection in enumerate(x):
                class_id = int(detection[5].item())
                threshold = class_thresholds.get(class_id, confidence_threshold)
                if detection[4].item() >= threshold:
                    keep[i] = True
            x = x[keep]
        else:
            x = x[x[:, 4] > confidence_threshold]
        
        if not x.shape[0]:
            continue
            
        # confidence로 정렬하고 상위 max_nms개만 유지
        x = x[x[:, 4].argsort(descending=True)[:max_nms]]
        
        # 빠른 NMS - PyTorch 내장 함수 사용
        c = x[:, 5:6] * max_wh  # 클래스별 offset
        boxes = x[:, :4] + c
        scores = x[:, 4]
        
        # NMS 적용
        keep = torchvision.ops.nms(boxes, scores, iou_threshold)
        if keep.shape[0] > max_det:
            keep = keep[:max_det]
        
        output[xi] = x[keep]
    
    return output
    
def pred_patch(torch_patch, model, start_x, start_y, magnification):
    model.eval()
    
    # HnE 세포 분류를 위한 클래스별 개별 confidence threshold 설정
    class_thresholds = {
        0: 0.3,  # Neutrophil
        1: 0.3,  # Epithelial

    }
    
    # 각 클래스별 세포 리스트 (임시로 기존 변수명 유지)
    cells_list = []  # 모든 검출된 세포를 여기에 저장

    
    with torch.no_grad():
        with torch.amp.autocast('cuda'):
            pred = model(torch_patch)
        
        # 빠른 NMS 적용
        results = non_max_suppression(pred, confidence_threshold=0.3, 
                                    iou_threshold=0.3, class_thresholds=class_thresholds)
        
        if len(results[0]) > 0:
            # 벡터화된 처리로 속도 향상
            detections = results[0]
            xyxy = detections[:, :4]
            confs = detections[:, 4]
            cls_ids = detections[:, 5]
            
            # 중심점 계산 (벡터화)
            centers_x = (xyxy[:, 0] + xyxy[:, 2]) / 2
            centers_y = (xyxy[:, 1] + xyxy[:, 3]) / 2
            
            # 실제 좌표 계산
            actual_x = start_x + centers_x * magnification
            actual_y = start_y + centers_y * magnification
            
            # 모든 클래스의 세포를 cells_list 리스트에 저장 (벡터화)
            for i in range(len(detections)):
                cls_id = int(cls_ids[i].item())
                cell_data = {
                    'x': actual_x[i].item(), 
                    'y': actual_y[i].item(), 
                    'cls_id': cls_id,
                    'confidence': confs[i].item()
                }
                
                # 모든 세포를 cells_list에 저장 (XML 생성 함수에서 cls_id로 구분)
                cells_list.append(cell_data)
    
    return cells_list

# 모델 및 파라미터 준비

model_path='../../model/nucleus_marker_yolov11/best_model.pt'
model = nn.yolo_v11_m(len(params['names'])).to(device)
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])


In [None]:
image_list=glob('../../data/BCData/images/train/*.png')
i=16
image=Image.open(image_list[i]).convert('RGB')
image=np.array(image)
cell_list = pred_patch((torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0).to(device), model, 0, 0, 1)

In [None]:
def cell_intensity_classification(image, cell_list):
    hsi_image=cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    size=15
    for i in range(len(cell_list)):
        cell_info=cell_list[i]
        cell_image=255-hsi_image[int(cell_info['y']-size):int(cell_info['y']+size), int(cell_info['x']-size):int(cell_info['x']+size),2]
        if cell_info['cls_id']==1:
            intensity_score=np.mean(cell_image)
            if intensity_score>170:
                cell_list[i]['cls_id']=3
            elif intensity_score>130:
                cell_list[i]['cls_id']=2
            else:
                cell_list[i]['cls_id']=1
    return cell_list    
cell_list=cell_intensity_classification(image, cell_list)
class_0_plotted = False
class_1_plotted = False
class_2_plotted = False
class_3_plotted = False
plt.imshow(image)
for cell in cell_list:
    x = cell['x']
    y = cell['y']
    cls_id = cell['cls_id']
    confidence = cell['confidence']
    if cls_id==0:
        if not class_0_plotted:
            plt.scatter(x, y, c='g', s=30, alpha=0.6, label=f'Class 0')
            class_0_plotted = True
        else:
            plt.scatter(x, y, c='g', s=30, alpha=0.6)
    elif cls_id==1:
        if not class_1_plotted:
            plt.scatter(x, y, c='y', s=30, alpha=0.6, label=f'Class 1')
            class_1_plotted = True
        else:
            plt.scatter(x, y, c='y', s=30, alpha=0.6)
    elif cls_id==2:
        if class_2_plotted==False:
            plt.scatter(x, y, c='b', s=30, alpha=0.6, label=f'Class 2')
            class_2_plotted = True
        else:
            plt.scatter(x, y, c='b', s=30, alpha=0.6)

    elif cls_id==3:
        if class_3_plotted==False:
            plt.scatter(x, y, c='r', s=30, alpha=0.6, label='Class 3')
            class_3_plotted = True
        else:   
            plt.scatter(x, y, c='r', s=30, alpha=0.6)


plt.legend()
plt.title('Detected Cells by Class')
plt.show()

In [None]:
image_list=glob('../../data/BCData/images/train/*.png')
i=10
image=Image.open(image_list[i]).convert('RGB')
image=np.array(image)
hsi_image=cv2.cvtColor(image,cv2.COLOR_RGB2HSV)
plt.imshow(image)
plt.show()
plt.figure(figsize=(15,10))
plt.subplot(2,3,1)
plt.imshow(image[:,:,0])
plt.title('R channel')
plt.subplot(2,3,2)
plt.imshow(image[:,:,1])
plt.title('G channel')
plt.subplot(2,3,3)
plt.imshow(image[:,:,2])
plt.title('B channel')
plt.subplot(2,3,4)
plt.imshow(hsi_image[:,:,0])
plt.title('Hue channel')
plt.subplot(2,3,5)
plt.imshow(hsi_image[:,:,1])
plt.title('Saturation channel')
plt.subplot(2,3,6)
plt.imshow(hsi_image[:,:,2])
plt.title('Intensity channel')