In [8]:
import matplotlib.pyplot as plt
import numpy as np
import helper
import time
import datetime
import torch.nn as nn
import torchvision.models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torchvision.utils
import torch
import pandas as pd
from torchinfo import summary
from PIL import Image
from torchvision.transforms import ToTensor
from glob import glob
from torch.utils.data import Dataset, DataLoader, random_split
from copy import copy
from collections import defaultdict
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import time
from sklearn.metrics import classification_report
from tqdm import tqdm
import math
from torcheval.metrics import BinaryAccuracy
import os
import timm
import segmentation_models_pytorch as smp
import random
from sklearn.model_selection import train_test_split
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
import pandas as pd
from shapely.geometry import Polygon, MultiPolygon
import cv2
import xml.etree.ElementTree as ET

device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
batch_size=1
img_size=1024
class_list=['NT_epithelial','NT_immune','NT_stroma','TP_in_situ','TP_invasive']
csv_path ="../../data/BRIL 2차 정제 완료 리스트.csv"

tf = ToTensor()
def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result
    

def binary_mask_to_polygon(binary_mask):
    # binary_mask는 2차원 numpy array여야 합니다.
    # Contours를 찾습니다.
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    polygons = []
    for contour in contours:
        # 각 contour를 polygon으로 변환
        if len(contour) >= 3:  # 유효한 polygon을 만들기 위해서 최소한 3개의 점이 필요합니다.
            poly = Polygon(shell=[(point[0][0], point[0][1]) for point in contour])
            polygons.append(poly)
    
    if len(polygons) > 1:
        # 여러 개의 polygon이 있을 경우 MultiPolygon으로 변환
        return MultiPolygon(polygons)
    elif len(polygons) == 1:
        return MultiPolygon(polygons)
    else:
        return None
    
def mask2polygon(mask):
    TP_in_situ_poly=binary_mask_to_polygon(mask[...,2])
    NT_stroma_poly=binary_mask_to_polygon(mask[...,1])
    NT_immune_poly=binary_mask_to_polygon(mask[...,0])

    TP_in_situ_polygon_arrays = []
    NT_stroma_polygon_arrays = []
    NT_immune_polygon_arrays = []

    if TP_in_situ_poly!=None:
        for polygon in TP_in_situ_poly.geoms:
            exterior_coords = np.array(polygon.exterior.coords)
            TP_in_situ_polygon_arrays.append(exterior_coords)

    if NT_stroma_poly!=None:
        for polygon in NT_stroma_poly.geoms:
            exterior_coords = np.array(polygon.exterior.coords)
            NT_stroma_polygon_arrays.append(exterior_coords)

    if NT_immune_poly!=None:
        for polygon in NT_immune_poly.geoms:
            exterior_coords = np.array(polygon.exterior.coords)
            NT_immune_polygon_arrays.append(exterior_coords)
    return  NT_immune_polygon_arrays,NT_stroma_polygon_arrays,TP_in_situ_polygon_arrays

def polygon2asap(label_polygon,class_list,save_path):
    # 루트 엘리먼트 생성
    root = ET.Element("ASAP_Annotations")
    # Annotations 엘리먼트 생성 및 루트에 추가
    annotations = ET.SubElement(root, "Annotations")
    for i in range(len(label_polygon)):
        
        for j in range(len(label_polygon[i])):
            annotation = ET.SubElement(annotations, "Annotation", Name=class_list[i], Type="Polygon", PartOfGroup="None", Color="#F4FA58")
            coordinates = ET.SubElement(annotation, "Coordinates")
            for k in range(len(label_polygon[i][j])):
                ET.SubElement(coordinates, "Coordinate", Order=str(k), X=str(float(label_polygon[i][j][k,0])), Y=str(float(label_polygon[i][j][k,1])))
            
            
    tree = ET.ElementTree(root)
    tree.write(save_path)

In [9]:
pd_data=pd.read_csv(csv_path)
img_path='../../data/NIA/BRNT/'
xml_path='../../result/area_segmentation/BRNT/'
image_list=[]
category_list=[]

for i in range(len(pd_data)):
    image_list.append(img_path+pd_data['file_name'][i])
    category_list.append(pd_data['category'][i])

    
class CustomDataset(Dataset):
    def __init__(self, image_list, label_list):
        self.img_path = image_list
        self.label = label_list
        self.tf= ToTensor()
    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        path=self.img_path[idx]
        image=self.tf(Image.open(self.img_path[idx]))
        label=self.label[idx]
        return image,label,path

dataset = CustomDataset(image_list, category_list)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [10]:
BRNT_model = smp.MAnet(
        encoder_name="efficientnet-b7",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=len(class_list)+1,                      # model output channels (number of classes in your dataset) 
    ).to(device)

def dice_loss(pred, target, num_classes=len(class_list)+1):
    smooth = 1e-6
    dice_per_class = torch.zeros((len(pred),num_classes)).to(pred.device)
    pred=F.softmax(pred,dim=1)
    for i in range(len(pred)):
        for class_id in range(num_classes):
            pred_class = pred[i, class_id, ...]
            target_class = target[i, class_id, ...]
            
            intersection = torch.sum(pred_class * target_class)
            A_sum = torch.sum(pred_class * pred_class)
            B_sum = torch.sum(target_class * target_class)
            dice_per_class[i,class_id] =(2. * intersection + smooth) / (A_sum + B_sum + smooth)

    return 1-dice_per_class

BRNT_model.load_state_dict(torch.load('../../model/areaSeg/BRNT_callback.pt'))


<All keys matched successfully>

In [15]:
import cv2
k = cv2.getStructuringElement(cv2.MORPH_RECT, (15,15))



topilimage = torchvision.transforms.ToPILImage()

train_loss_list=[]
val_loss_list=[]
train_acc_list=[]
val_acc_list=[]
MIN_loss=5000
metrics = defaultdict(float)
BRNT_model.eval()

count=0
val_running_loss=0.0
acc_loss=0
with torch.no_grad():
    for x, label,path in tqdm(dataloader):
        label= label[0]
        path=path[0]
        count+=1
        x=x.to(device).float()
        BRNT_predict = BRNT_model(x).to(device)
        x=x.to('cpu')
        BRNT_pred_softmax = F.softmax(BRNT_predict, dim=1).cpu()
        mask=np.zeros((img_size,img_size,3),dtype=np.uint8)
        NT_epithelial_mask=cv2.morphologyEx(np.array(torch.where(BRNT_pred_softmax.argmax(dim=1)==1,255,0).squeeze()).astype('uint8'), cv2.MORPH_OPEN, k)
        NT_stroma_mask=cv2.morphologyEx(np.array(torch.where(BRNT_pred_softmax.argmax(dim=1)==2,255,0).squeeze()).astype('uint8'), cv2.MORPH_OPEN, k)
        mask[...,0]=NT_epithelial_mask
        mask[...,1]=NT_stroma_mask
        image=x.squeeze().permute(1,2,0).numpy()
        image=image*255
        image=image*0.7+mask*0.3
        image=image.astype(np.uint8)
        image=Image.fromarray(image).save(xml_path+'overlay/'+os.path.basename(path))
        NT_epithelial_polygons,NT_stroma_polygons,NT_immune_polygons=mask2polygon(mask)
        label_polygon=[NT_epithelial_polygons,NT_stroma_polygons]
        save_path=xml_path+label+'/'+os.path.basename(path).split('.')[0]+'.xml'
        polygon2asap(label_polygon,class_list,save_path)

100%|██████████| 2561/2561 [19:14<00:00,  2.22it/s] 
