In [2]:
import matplotlib.pyplot as plt
import numpy as np
import helper

import torch.nn as nn
import torchvision.models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torchvision.utils
import torch
import pandas as pd
from torchinfo import summary
from PIL import Image
from torchvision.transforms import ToTensor
from glob import glob
from torch.utils.data import Dataset, DataLoader, random_split
from copy import copy
from collections import defaultdict
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import time
from sklearn.metrics import classification_report
from tqdm.notebook import tqdm
import math
from torcheval.metrics import BinaryAccuracy
import os
import torchmetrics
import timm
import segmentation_models_pytorch as smp
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size=8
image_count=50
img_size=512
tf = ToTensor()
def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

In [6]:
test_image_list=glob('../../data/1-cycle_30%_중간데이터/segmentation/test/image/*.tiff')
test_tumor_mask_list=[f.replace('/image/', '/polygon/TP_tumor/') for f in test_image_list]
test_normal_mask_list=[f.replace('/image/', '/polygon/NT_normal/') for f in test_image_list]
train_image_list=glob('../../data/1-cycle_30%_중간데이터/segmentation/train/image/*.tiff')
train_tumor_mask_list=[f.replace('/image/', '/polygon/TP_tumor/') for f in train_image_list]
train_normal_mask_list=[f.replace('/image/', '/polygon/NT_normal/') for f in train_image_list]

class CustomDataset(Dataset):
    def __init__(self, image_list, label_list):
        self.img_path = image_list
        self.label = label_list
    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        image_tensor = self.img_path[idx]
        label_tensor =  self.label[idx]
        return image_tensor, label_tensor
train_image=torch.zeros((len(train_image_list),3,img_size,img_size))
train_mask=torch.zeros((len(train_image_list),2,img_size,img_size))
test_image=torch.zeros((len(test_image_list),3,img_size,img_size))
test_mask=torch.zeros((len(test_image_list),2,img_size,img_size))

for i in tqdm(range(len(train_image_list))):
    img=1-tf(np.array(expand2square(Image.open(train_image_list[i]),(255,255,255)).resize((img_size,img_size))))
    msk_tumor=np.array((expand2square(Image.open(train_tumor_mask_list[i]),(0,0,0)).convert('L')).resize((img_size,img_size)))
    msk_normal=np.array((expand2square(Image.open(train_normal_mask_list[i]),(0,0,0)).convert('L')).resize((img_size,img_size)))
    msk_back=np.where((msk_tumor+msk_normal)==0,255,0)
    train_image[i]=img
    train_mask[i,0]=tf(msk_back)
    train_mask[i,1]=tf(msk_tumor)
    train_mask[i,2]=tf(msk_normal)
    
for i in tqdm(range(len(test_image_list))):
    img=1-tf(np.array(expand2square(Image.open(test_image_list[i]),(255,255,255)).resize((img_size,img_size))))
    msk_tumor=np.array((expand2square(Image.open(test_tumor_mask_list[i]),(0,0,0)).convert('L')).resize((img_size,img_size)))
    msk_normal=np.array((expand2square(Image.open(test_normal_mask_list[i]),(0,0,0)).convert('L')).resize((img_size,img_size)))
    msk_back=np.where((msk_tumor+msk_normal)==0,255,0)
    test_image[i]=img
    test_mask[i,0]=tf(msk_back)
    test_mask[i,1]=tf(msk_tumor)
    test_mask[i,2]=tf(msk_normal)
    
train_dataset = CustomDataset(train_image, train_mask)

test_dataset = CustomDataset(test_image, test_mask)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
validation_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

KeyboardInterrupt: 

In [5]:
model=smp.UnetPlusPlus('efficientnet-b6', in_channels=3, classes=3).to(device)

def diceloss(pred, target):
    smooth = 1.
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()
    A_sum = torch.sum(iflat * iflat)
    B_sum = torch.sum(tflat * tflat)
    return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )
def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = torch.sigmoid(pred)
    dice = diceloss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss
summary(model,(batch_size,3,img_size,img_size))

Layer (type:depth-idx)                                  Output Shape              Param #
UnetPlusPlus                                            [8, 2, 512, 512]          --
├─EfficientNetEncoder: 1-1                              [8, 3, 512, 512]          1,331,712
│    └─Conv2dStaticSamePadding: 2-1                     [8, 56, 256, 256]         1,512
│    │    └─ZeroPad2d: 3-1                              [8, 3, 513, 513]          --
│    └─BatchNorm2d: 2-2                                 [8, 56, 256, 256]         112
│    └─MemoryEfficientSwish: 2-3                        [8, 56, 256, 256]         --
│    └─ModuleList: 2-4                                  --                        --
│    │    └─MBConvBlock: 3-2                            [8, 32, 256, 256]         4,110
│    │    └─MBConvBlock: 3-3                            [8, 32, 256, 256]         1,992
│    │    └─MBConvBlock: 3-4                            [8, 32, 256, 256]         1,992
│    │    └─MBConvBlock: 3-5            

In [None]:
train_loss_list=[]
val_loss_list=[]
train_acc_list=[]
val_acc_list=[]
MIN_loss=5000
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)
metrics = defaultdict(float)
for epoch in range(100):
    train=tqdm(train_dataloader)
    count=0
    running_loss = 0.0
    acc_loss=0
    for x, y in train:
        model.train()
        y = y.to(device).float()
        count+=1
        x=x.to(device).float()
        optimizer.zero_grad()  # optimizer zero 로 초기화
        predict = model(x).to(device)
        cost = calc_loss(predict, y,metrics) # cost 구함
        acc=1-calc_loss(predict, y,metrics)
        cost.backward() # cost에 대한 backward 구함
        optimizer.step() 
        running_loss += cost.item()
        acc_loss+=acc
        train.set_description(f"epoch: {epoch+1}/{100} Step: {count+1} dice_loss : {running_loss/count:.4f} dice_score: {acc_loss/count:.4f}")
    train_loss_list.append((running_loss/count))
    train_acc_list.append((acc_loss/count).cpu().detach().numpy())
#validation
    val=tqdm(validation_dataloader)
    model.eval()
    count=0
    val_running_loss=0.0
    acc_loss=0
    with torch.no_grad():
        for x, y in val:
            y = y.to(device).float()
            count+=1
            x=x.to(device).float()
            
            predict = model(x).to(device)
            cost = calc_loss(predict, y,metrics) # cost 구함
            acc=1-calc_loss(predict, y,metrics)
            val_running_loss+=cost.item()
            acc_loss+=acc
            val.set_description(f"Validation epoch: {epoch+1}/{100} Step: {count+1} dice_loss : {val_running_loss/count:.4f}  dice_score: {acc_loss/count:.4f}")
        val_loss_list.append((val_running_loss/count))
        val_acc_list.append((acc_loss/count).cpu().detach().numpy())
        
    if MIN_loss>(val_running_loss/count):
        torch.save(model.state_dict(), '../../model/NestedUNet_callback.pt')
        MIN_loss=(val_running_loss/count)
        
    if epoch%20==1:
        plt.figure(figsize=(10,5))
        plt.subplot(1, 2, 1) 
        plt.title('loss_graph')
        plt.plot(np.arange(epoch+1),train_loss_list,label='train_loss')
        plt.plot(np.arange(epoch+1),val_loss_list,label='validation_loss')
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.ylim([0, 1]) 
        plt.legend()
        plt.subplot(1, 2, 2)  
        plt.title('acc_graph')
        plt.plot(np.arange(epoch+1),train_acc_list,label='train_acc')
        plt.plot(np.arange(epoch+1),val_acc_list,label='validation_acc')
        plt.xlabel('epoch')
        plt.ylabel('accuracy')
        plt.ylim([0, 1]) 
        plt.legend()
        plt.show()

In [None]:
plt.imshow(y[5,0])