In [None]:
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn
from torchinfo import summary
import torch
import numpy as np
import torch.optim as optim
import torchvision.transforms as T
import torch.nn.functional as F
from collections import defaultdict
import cv2
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
from glob import glob
import pandas as pd
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
tf=T.ToTensor()


In [None]:
params={'image_size':512,
        'lr':2e-3,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':1,
        'epochs':500,}

In [None]:
image_path='../../data/external/ori/*.png'
mask1_path='../../data/external/mask/class1/*.png'

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_list, label_list):
        self.img_path = image_list
        self.label = label_list

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        image_path = Image.open(self.img_path[idx])
        image_path=tf(image_path)
        file_path=os.path.basename(self.img_path[idx])
        label1 = np.array(Image.open(self.label[idx]))
        label1=label1[:,:,0,np.newaxis]
        label2=np.array(Image.open(self.label[idx].replace('/class1', '/class2')))
        label2=label2[:,:,0,np.newaxis]
        label3=np.array(Image.open(self.label[idx].replace('/class1', '/class3')))
        label3=label3[:,:,0,np.newaxis]

        label=np.concatenate((label1,label2,label3),axis=2)
        label_path = tf(cv2.resize(label, (512, 512)))
       
        return image_path, label_path,file_path

test_dataset = CustomDataset(glob(image_path), glob(mask1_path))
test_dataloader = DataLoader(
    test_dataset, batch_size=params['batch_size'],shuffle=False, drop_last=True)


In [None]:
def dice_loss(pred, target, num_classes=3):
    smooth = 1e-6
    dice_per_class = torch.zeros(num_classes).to(pred.device)

    for class_id in range(num_classes):
        pred_class = pred[:, class_id, ...]
        target_class = target[:, class_id, ...]

        intersection = torch.sum(pred_class * target_class)
        A_sum = torch.sum(pred_class * pred_class)
        B_sum = torch.sum(target_class * target_class)

        dice_per_class[class_id] =(2. * intersection + smooth) / (A_sum + B_sum + smooth)

    return dice_per_class

def compute_iou(pred_mask, true_mask, threshold=0.5, num_classes=3):
    """
    IoU를 계산하는 함수

    :param pred_mask: 모델이 예측한 마스크 (torch.Tensor)
    :param true_mask: 실제 마스크 (torch.Tensor)
    :param threshold: 이진화를 위한 임계값
    :return: IoU 값
    """
    iou_per_class = torch.zeros(num_classes).to(device)
    for class_id in range(num_classes):
    # 예측된 마스크 이진화
        pred_mask1 = (pred_mask[:,class_id, ...] > threshold).float()
        
        # 실제 마스크 이진화
        true_mask1 = (true_mask[:,class_id, ...] > threshold).float()
        
        # 교차 계산
        intersection = torch.sum(pred_mask1 * true_mask1)
        
        # 합집합 계산
        union = torch.sum(pred_mask1) + torch.sum(true_mask1) - intersection
        
        # IoU 계산
        iou_per_class[class_id]= intersection / union
    
    return iou_per_class

def compute_f1(pred_mask, true_mask, threshold=0.5, num_classes=3, device='cpu'):
    """
    F1 점수를 계산하는 함수

    :param pred_mask: 모델이 예측한 마스크 (torch.Tensor)
    :param true_mask: 실제 마스크 (torch.Tensor)
    :param threshold: 이진화를 위한 임계값
    :param num_classes: 클래스의 수
    :param device: 연산에 사용할 디바이스 (기본값: 'cpu')
    :return: 각 클래스별 F1 점수, Precision, Recall, Specificity, Accuracy (torch.Tensor)
    """
    f1_per_class = torch.zeros(num_classes).to(device)
    precision1 = torch.zeros(num_classes).to(device)
    recall1 = torch.zeros(num_classes).to(device)
    specificity1 = torch.zeros(num_classes).to(device)
    accuracy1 = torch.zeros(num_classes).to(device)
    
    for class_id in range(num_classes):
        # 예측된 마스크 이진화
        pred_binary_mask = (pred_mask[:, class_id, ...] > threshold).float()
        
        # 실제 마스크 이진화
        true_binary_mask = (true_mask[:, class_id, ...] > threshold).float()
        
        # True Positive (TP), False Positive (FP), False Negative (FN), True Negative (TN) 계산
        TP = torch.sum(pred_binary_mask * true_binary_mask)
        FP = torch.sum(pred_binary_mask * (1 - true_binary_mask))
        FN = torch.sum((1 - pred_binary_mask) * true_binary_mask)
        TN = torch.sum((1 - pred_binary_mask) * (1 - true_binary_mask))
        
        # 정밀도 (Precision) 계산
        precision = TP / (TP + FP + 1e-8)  # 분모가 0이 되는 것을 방지하기 위해 작은 값 추가
        
        # 재현율 (Recall) 계산
        recall = TP / (TP + FN + 1e-8)  # 분모가 0이 되는 것을 방지하기 위해 작은 값 추가
        
        # 특이도 (Specificity) 계산
        specificity = TN / (TN + FP + 1e-8)  # 분모가 0이 되는 것을 방지하기 위해 작은 값 추가
        
        # 정확도 (Accuracy) 계산
        accuracy = (TP + TN) / (TP + TN + FP + FN + 1e-8)  # 분모가 0이 되는 것을 방지하기 위해 작은 값 추가
        
        # F1 점수 계산
        f1_per_class[class_id] = 2 * (precision * recall) / (precision + recall + 1e-8)  # 분모가 0이 되는 것을 방지하기 위해 작은 값 추가
        precision1[class_id] = precision.item()
        recall1[class_id] = recall.item()
        specificity1[class_id] = specificity.item()
        accuracy1[class_id] = accuracy.item()
    
    return f1_per_class, precision1, recall1, specificity1, accuracy1


model = AutoModelForSemanticSegmentation.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing",num_labels=3,ignore_mismatched_sizes=True).to(device)
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=params['lr'], betas=(params['beta1'], params['beta2']))


In [None]:
transform = T.ToPILImage()

for i in range(5):
    model.load_state_dict(torch.load('../../model/segformer/seg_former_'+str(i+1)+'_check.pth',map_location=device))
    df=pd.DataFrame(columns=['file_name','Dice1','Dice2','Dice3','mDice','IoU1','IoU2','IoU3','mIoU','f1','precision','sensitivity','specificity','accuracy'])
    with torch.no_grad():
        test = tqdm(test_dataloader)
        count = 0
        val_running_loss = 0.0
        acc_loss = 0
        for x, y,file_path in test:
            model.eval()
            y = y.to(device).float()
            count += 1
            x = x.to(device).float()
            output =model(x).logits.cpu()
            predict = nn.functional.interpolate(
                    output,
                    size=(512,512),
                    mode="bilinear",
                    align_corners=False,
            ).to(device)
            cost = dice_loss(predict,y)  # cost 구함
            iou=compute_iou(predict, y)
            f1,precision,recall,specificity, accuracy=compute_f1(predict, y)
            val_running_loss+=cost.mean().item()
            df.loc[len(df)]=[file_path[0],cost[0].item(),cost[1].item(),cost[2].item(),cost.mean().item(),iou[0].item(),iou[1].item(),iou[2].item(),iou.mean().item(),f1.mean().item(),precision.mean().item(),recall.mean().item(),specificity.mean().item(),accuracy.mean().item()]
            transform(y[0].cpu()).save('../../data/external/result/segformer/k_'+str(i+1)+'/label/'+file_path[0])
            transform(torch.where(predict[0]>0.5,1,0).cpu().float()).save('../../data/external/result/segformer/k_'+str(i+1)+'/pred/'+file_path[0])
            test.set_description(
                f"val_Step: {count+1} dice_sore : {val_running_loss/count:.4f}")
    df.to_csv('../../data/external/result/segformer/segformer_'+str(i+1)+'_result.csv',index=False)

In [7]:
image1=np.load('../../data/cv0_ori.npy')
image1=image1.astype(np.uint8)
image2=np.load('../../data/cv1_ori.npy')
image2=image2.astype(np.uint8)
image3=np.load('../../data/cv2_ori.npy')
image3=image3.astype(np.uint8)
image4=np.load('../../data/cv3_ori.npy')
image4=image4.astype(np.uint8)
image5=np.load('../../data/cv4_ori.npy')
image5=image5.astype(np.uint8)
mask1=np.load('../../data/cv0_mask.npy')
mask1=(mask1[:,:,:,:3]).astype(np.uint8)
mask2=np.load('../../data/cv1_mask.npy')
mask2=(mask2[:,:,:,:3]).astype(np.uint8)
mask3=np.load('../../data/cv2_mask.npy')
mask3=(mask3[:,:,:,:3]).astype(np.uint8)
mask4=np.load('../../data/cv3_mask.npy')
mask4=(mask4[:,:,:,:3]).astype(np.uint8)
mask5=np.load('../../data/cv4_mask.npy')
mask5=(mask5[:,:,:,:3]).astype(np.uint8)
name1=np.load('../../data/cv0_name.npy')
name2=np.load('../../data/cv1_name.npy')
name3=np.load('../../data/cv2_name.npy')
name4=np.load('../../data/cv3_name.npy')
name5=np.load('../../data/cv4_name.npy')

In [10]:
np_data={'image1':image1,'image2':image2,'image3':image3,'image4':image4,'image5':image5,'mask1':mask1,'mask2':mask2,'mask3':mask3,'mask4':mask4,'mask5':mask5,'name1':name1,'name2':name2,'name3':name3,'name4':name4,'name5':name5}
def dice_loss(pred, target, num_classes=3):
    smooth = 1e-6
    dice_per_class = torch.zeros(num_classes).to(pred.device)

    for class_id in range(num_classes):
        pred_class = pred[:, class_id, ...]
        target_class = target[:, class_id, ...]
        
        intersection = torch.sum(pred_class * target_class)
        A_sum = torch.sum(pred_class * pred_class)
        B_sum = torch.sum(target_class * target_class)

        dice_per_class[class_id] =(2. * intersection + smooth) / (A_sum + B_sum + smooth)

    return dice_per_class
class CustomDataset(Dataset):
    def __init__(self, image_list, label_list,name_list):
        self.img_path = image_list
        self.label = label_list
        self.name=name_list

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        image_path = self.img_path[idx]
        image_path=tf(cv2.cvtColor(image_path, cv2.COLOR_GRAY2RGB))
        
        label_path = self.label[idx]
        label_path = tf(cv2.resize(label_path[:,:,:3], (512, 512)))
        
        name=self.name[idx]
       
        return image_path, label_path,name
    

In [12]:
transform = T.ToPILImage()

for i in range(5):
    model.load_state_dict(torch.load('../../model/segformer/seg_former_'+str(i+1)+'_check.pth',map_location=device))
    model.to(device)
    test_image=np_data['image'+str(i+1)]
    test_mask=np_data['mask'+str(i+1)]
    test_name=np_data['name'+str(i+1)]
    test_dataset = CustomDataset(test_image, test_mask,test_name)
    test_dataloader = DataLoader(
    test_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True)
    df=pd.DataFrame(columns=['file_name','Dice1','Dice2','Dice3','mDice','IoU1','IoU2','IoU3','mIoU','f1','precision','sensitivity','specificity','accuracy'])
    with torch.no_grad():
        test = tqdm(test_dataloader)
        count = 0
        val_running_loss = 0.0
        acc_loss = 0
        for x, y,file_path in test:
            model.eval()
            y = y.to(device).float()
            count += 1
            x = x.to(device).float()
            output =model(x).logits.cpu()
            predict = nn.functional.interpolate(
                    output,
                    size=(512,512),
                    mode="bilinear",
                    align_corners=False,
            ).to(device)
            cost = dice_loss(predict,y)  # cost 구함
            iou=compute_iou(predict, y)
            f1,precision,recall,specificity, accuracy=compute_f1(predict, y)
            val_running_loss+=cost.mean().item()
            df.loc[len(df)]=[file_path[0]+'.png',cost[0].item(),cost[1].item(),cost[2].item(),cost.mean().item(),iou[0].item(),iou[1].item(),iou[2].item(),iou.mean().item(),f1.mean().item(),precision.mean().item(),recall.mean().item(),specificity.mean().item(),accuracy.mean().item()]
            transform(y[0].cpu()).save('../../data/internal/result/segformer/k_'+str(i+1)+'/label/'+file_path[0]+'.png')
            transform(torch.where(predict[0]>0.5,1,0).cpu().float()).save('../../data/internal/result/segformer/k_'+str(i+1)+'/pred/'+file_path[0]+'.png')
            test.set_description(
                f"val_Step: {count+1} dice_sore : {val_running_loss/count:.4f}")
    df.to_csv('../../data/internal/result/segformer/segformer_'+str(i+1)+'_result.csv',index=False)

val_Step: 9763 dice_sore : 0.9131: 100%|██████████| 9762/9762 [15:22<00:00, 10.58it/s]
val_Step: 10007 dice_sore : 0.9277: 100%|██████████| 10006/10006 [17:47<00:00,  9.37it/s]
val_Step: 10093 dice_sore : 0.9194: 100%|██████████| 10092/10092 [18:25<00:00,  9.13it/s]
val_Step: 10185 dice_sore : 0.9263: 100%|██████████| 10184/10184 [19:02<00:00,  8.91it/s]
val_Step: 9601 dice_sore : 0.9293: 100%|██████████| 9600/9600 [19:11<00:00,  8.34it/s]


In [None]:
y[0]