In [1]:
import torch
import torchvision.transforms as transforms
import os
import torch.nn.functional as F
# 학습 중 진행상황 print(진행률 막대)
import tqdm
import random
import cv2
from ipywidgets import interact
# 파일 및 디렉토리 작업
import shutil
import numpy as np
from PIL import Image
import re
# 다차원 이미지 처리
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
# 파일경로 다루기
from pathlib import Path
from torch import nn
# 텐서를 감싸기
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, random_split

# 0. 0_환경세팅

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
# # zip압축해제
# !unzip '/content/drive/MyDrive/03.Breast Cancer Segmentation/02. Unet_Dataset.zip' -d '/content/02. Unet_Dataset/'

# 1. 1_ 데이터셋 생성

In [4]:
# 데이터셋 클래스 정의
class BreastDatasets(Dataset):
    def __init__(self, img_dir, mask_dir, img_transform, mask_transform):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_transform = img_transform
        self.mask_transform = mask_transform
        self.img_files = []
        self.mask_files = []
        self.seed = np.random.randint(2024)

        for x in ['benign','malignant','normal']:
            file_img_dir = os.path.join(self.img_dir,x)
            file_mask_dir = os.path.join(self.mask_dir,x)
            for img_name in os.listdir(file_img_dir):
                if img_name.split('.')[-1] in ('png', 'jpg'):
                    self.img_files.append(os.path.join(file_img_dir, img_name))
                    self.mask_files.append(os.path.join(file_mask_dir, img_name))

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, i):
        img = Image.open(self.img_files[i]).convert('RGB')
        if self.img_transform is not None:
            random.seed(self.seed)
            img = self.img_transform(img)

        mask = Image.open(self.mask_files[i]).convert('RGB')
        if self.mask_transform is not None:
            mask = self.mask_transform(mask)

        return img, mask

In [5]:
# # 데이터셋 클래스 정의
# class BreastDatasets(Dataset):
#     def __init__(self, img_dir, mask_dir, img_transform, mask_transform):
#         self.img_dir = img_dir
#         self.mask_dir = mask_dir
#         self.img_transform = img_transform
#         self.mask_transform = mask_transform
#         self.img_files = []
#         self.mask_files = []
#         self.seed = np.random.randint(2024)

#         for img_name in os.listdir(self.img_dir):
#             if img_name.split('.')[-1] in ('png', 'jpg'):
#                 self.img_files.append(os.path.join(self.img_dir, img_name))
#                 self.mask_files.append(os.path.join(self.mask_dir, img_name))

#     def __len__(self):
#         return len(self.img_files)

#     def __getitem__(self, i):
#         img = Image.open(self.img_files[i])
#         if self.img_transform is not None:
#             random.seed(self.seed)
#             img = self.img_transform(img)

#         mask = Image.open(self.mask_files[i]).convert('L') # 그레이스케일로 변환
#         if self.mask_transform is not None:
#             mask = self.mask_transform(mask)

#         return img, mask

# 2. 2_Unet 구현

In [6]:
# UNET 클래스로 구현
class UNet(nn.Module):
    # 클래스의 갯수 상속
    def __init__(self, num_classes):
        # nn.Module(부모) 클래스의 기능을 상속
        super(UNet, self).__init__()
        self.num_classes = num_classes
        # 3채널을 받아서 64로 반환(1번째 레이어)
        self.contracting_11 = self.conv_block(in_channels=3, out_channels=64)
        self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 64채널을 받아서 128채널로 반환(2번째 레이어)
        self.contracting_21 = self.conv_block(in_channels=64, out_channels=128)
        self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 128채널을 받아서 256채널로 반환(3번째 레이어)
        self.contracting_31 = self.conv_block(in_channels=128, out_channels=256)
        self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 64채널을 받아서 128채널로 반환(4번째 레이어)
        self.contracting_41 = self.conv_block(in_channels=256, out_channels=512)
        self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 512채널을 받아서 1024채널로 반환(5번째 레이어)
        self.middle = self.conv_block(in_channels=512, out_channels=1024)
        self.expansive_11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512)
        # 디코딩(확장)
        self.expansive_21 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_22 = self.conv_block(in_channels=512, out_channels=256)
        # 3
        self.expansive_31 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_32 = self.conv_block(in_channels=256, out_channels=128)
        # 4
        self.expansive_41 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_42 = self.conv_block(in_channels=128, out_channels=64)
        # 결과물
        self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1, stride=1)


    # n개의 채널을 받아서 n개의 채널로 내보내는 함수
    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
                             nn.ReLU(),
                             nn.BatchNorm2d(num_features=out_channels),
                             nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
                             nn.ReLU(),
                             nn.BatchNorm2d(num_features=out_channels))
        return block

    def forward(self, X):
        # 컨볼루션 연산 -> 채널 상승
        contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256] : 배치 / 채널 / 이미지사이즈
        # 맥스풀링 -> 이미지 크기 감소
        contracting_12_out = self.contracting_12(contracting_11_out) #[-1, 64, 128, 128]
        # 2
        contracting_21_out = self.contracting_21(contracting_12_out) #[-1, 128, 128, 128]
        contracting_22_out = self.contracting_22(contracting_21_out) #[-1, 128, 64, 64]
        # 3
        contracting_31_out = self.contracting_31(contracting_22_out) #[-1, 256, 64, 64]
        contracting_32_out = self.contracting_32(contracting_31_out) #[-1, 256, 32, 32]
        # 4
        contracting_41_out = self.contracting_41(contracting_32_out) #[-1, 512, 32, 32]
        contracting_42_out = self.contracting_42(contracting_41_out) #[-1, 512, 16, 16]
        middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16]

    def forward(self, X):
        contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256]
        contracting_12_out = self.contracting_12(contracting_11_out) # [-1, 64, 128, 128]
        contracting_21_out = self.contracting_21(contracting_12_out) # [-1, 128, 128, 128]
        contracting_22_out = self.contracting_22(contracting_21_out) # [-1, 128, 64, 64]
        contracting_31_out = self.contracting_31(contracting_22_out) # [-1, 256, 64, 64]
        contracting_32_out = self.contracting_32(contracting_31_out) # [-1, 256, 32, 32]
        contracting_41_out = self.contracting_41(contracting_32_out) # [-1, 512, 32, 32]
        contracting_42_out = self.contracting_42(contracting_41_out) # [-1, 512, 16, 16]
        middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16]
        expansive_11_out = self.expansive_11(middle_out) # [-1,512, 32, 32]
        expansive_12_out = self.expansive_12(torch.cat((expansive_11_out, contracting_41_out), dim=1)) # [-1, 1024,32,32] -> [-1, 512,32,32]
        expansive_21_out = self.expansive_21(expansive_12_out) # [-1, 256, 64, 64]
        expansive_22_out = self.expansive_22(torch.cat((expansive_21_out, contracting_31_out), dim=1)) # [-1, 512, 64, 64] -> [-1, 256, 64, 64]
        expansive_31_out = self.expansive_31(expansive_22_out) # [-1, 128, 128, 128]
        expansive_32_out = self.expansive_32(torch.cat((expansive_31_out, contracting_21_out), dim=1)) # [-1, 256, 128, 128] -> [-1, 128, 128, 128]
        expansive_41_out = self.expansive_41(expansive_32_out) # [-1, 64, 256, 256]
        expansive_42_out = self.expansive_42(torch.cat((expansive_41_out, contracting_11_out), dim=1)) # [-1, 128, 256, 256] -> [-1, 64, 256, 256]
        output_out = self.output(expansive_42_out) # [-1, num_classes, 256, 256]
        return output_out

In [7]:
# # UNET 클래스로 구현
# class UNet(nn.Module):
#     # 클래스의 갯수 상속
#     def __init__(self, num_classes):
#         # nn.Module(부모) 클래스의 기능을 상속
#         super(UNet, self).__init__()
#         self.num_classes = num_classes
#         # 3채널을 받아서 64로 반환(1번째 레이어)
#         self.contracting_11 = self.conv_block(in_channels=3, out_channels=64)
#         self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
#         # 64채널을 받아서 128채널로 반환(2번째 레이어)
#         self.contracting_21 = self.conv_block(in_channels=64, out_channels=128)
#         self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
#         # 128채널을 받아서 256채널로 반환(3번째 레이어)
#         self.contracting_31 = self.conv_block(in_channels=128, out_channels=256)
#         self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
#         # 64채널을 받아서 128채널로 반환(4번째 레이어)
#         self.contracting_41 = self.conv_block(in_channels=256, out_channels=512)
#         self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2)
#         # 512채널을 받아서 1024채널로 반환(5번째 레이어)
#         self.middle = self.conv_block(in_channels=512, out_channels=1024)
#         self.expansive_11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
#         self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512)
#         # 디코딩(확장)
#         self.expansive_21 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
#         self.expansive_22 = self.conv_block(in_channels=512, out_channels=256)
#         # 3
#         self.expansive_31 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
#         self.expansive_32 = self.conv_block(in_channels=256, out_channels=128)
#         # 4
#         self.expansive_41 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
#         self.expansive_42 = self.conv_block(in_channels=128, out_channels=64)
#         # 결과물
#         self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, stride=2)


#     # n개의 채널을 받아서 n개의 채널로 내보내는 함수
#     def conv_block(self, in_channels, out_channels):
#         block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
#                              nn.ReLU(),
#                              nn.BatchNorm2d(num_features=out_channels),
#                              nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
#                              nn.ReLU(),
#                              nn.BatchNorm2d(num_features=out_channels))
#         return block

#     def forward(self, X):
#         # 컨볼루션 연산 -> 채널 상승
#         contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256] : 배치 / 채널 / 이미지사이즈
#         # 맥스풀링 -> 이미지 크기 감소
#         contracting_12_out = self.contracting_12(contracting_11_out) #[-1, 64, 128, 128]
#         # 2
#         contracting_21_out = self.contracting_21(contracting_12_out) #[-1, 128, 128, 128]
#         contracting_22_out = self.contracting_22(contracting_21_out) #[-1, 128, 64, 64]
#         # 3
#         contracting_31_out = self.contracting_31(contracting_22_out) #[-1, 256, 64, 64]
#         contracting_32_out = self.contracting_32(contracting_31_out) #[-1, 256, 32, 32]
#         # 4
#         contracting_41_out = self.contracting_41(contracting_32_out) #[-1, 512, 32, 32]
#         contracting_42_out = self.contracting_42(contracting_41_out) #[-1, 512, 16, 16]
#         middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16]

#     def forward(self, X):
#         contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256]
#         contracting_12_out = self.contracting_12(contracting_11_out) # [-1, 64, 128, 128]
#         contracting_21_out = self.contracting_21(contracting_12_out) # [-1, 128, 128, 128]
#         contracting_22_out = self.contracting_22(contracting_21_out) # [-1, 128, 64, 64]
#         contracting_31_out = self.contracting_31(contracting_22_out) # [-1, 256, 64, 64]
#         contracting_32_out = self.contracting_32(contracting_31_out) # [-1, 256, 32, 32]
#         contracting_41_out = self.contracting_41(contracting_32_out) # [-1, 512, 32, 32]
#         contracting_42_out = self.contracting_42(contracting_41_out) # [-1, 512, 16, 16]
#         middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16]
#         expansive_11_out = self.expansive_11(middle_out) # [-1,512, 32, 32]
#         expansive_12_out = self.expansive_12(torch.cat((expansive_11_out, contracting_41_out), dim=1)) # [-1, 1024,32,32] -> [-1, 512,32,32]
#         expansive_21_out = self.expansive_21(expansive_12_out) # [-1, 256, 64, 64]
#         expansive_22_out = self.expansive_22(torch.cat((expansive_21_out, contracting_31_out), dim=1)) # [-1, 512, 64, 64] -> [-1, 256, 64, 64]
#         expansive_31_out = self.expansive_31(expansive_22_out) # [-1, 128, 128, 128]
#         expansive_32_out = self.expansive_32(torch.cat((expansive_31_out, contracting_21_out), dim=1)) # [-1, 256, 128, 128] -> [-1, 128, 128, 128]
#         expansive_41_out = self.expansive_41(expansive_32_out) # [-1, 64, 256, 256]
#         expansive_42_out = self.expansive_42(torch.cat((expansive_41_out, contracting_11_out), dim=1)) # [-1, 128, 256, 256] -> [-1, 64, 256, 256]
#         output_out = self.output(expansive_42_out) # [-1, num_classes, 256, 256]
#         return output_out

# 3. 3_Dice Score와 IoU 메트릭 함수 정의하기

In [8]:
# Dice coefficient: 두 개의 집합 간의 유사성을 측정하는 데 사용되는 통계적인 지표
# 2×∣A∩B∣ / ∣A∣+∣B∣
# 0에서 1까지의 값을 가지며, 1에 가까울수록 두 집합이 유사
# 1e-6: 0.000001
def dice_coeff(input, target, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # assert 조건, "에러 메세지"
    assert input.size() == target.size()
    assert input.dim() == 3 or not reduce_batch_first

    sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
    inter = 2*(input * target).sum(dim=sum_dim)
    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
    # torch.where(): 조건에 따른 연산함수
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    return dice.mean()


# IoU(Intersection over Union): 객체 검출 및 객체 분할과 같은 컴퓨터 비전 작업에서 사용되는 평가 지표
# 두 개의 영역 또는 객체가 주어졌을 때, IoU는 교집합을 합집합으로 나눈 것을 나타냄. 0에서 1사이의 값을 갖음
def iou(y_true, y_pred, epsilon: float = 1e-6):
    intersection = (y_true * y_pred).sum()
    union = y_true.sum() + y_pred.sum() - intersection
    return (intersection + epsilon) / (union + epsilon)

In [9]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def find_latest_model_path(dir):
    model_paths = []
    epochs = []
    for path in Path(dir).glob('*.pth'):
        if 'epoch' not in path.stem:
            continue
        model_paths.append(path)
        parts = path.stem.split('_')
        epoch = int(parts[-1])
        epochs.append(epoch)

    if len(epochs) > 0:
        epochs = np.array(epochs)
        max_idx = np.argmax(epochs)
        return model_paths[max_idx]
    else:
        return None

# param_groups: 옵티마이저 객체 속성. 요소는 딕셔너리이며 매개변수 그룹에 대한 정보를 저장
def adjust_learning_rate(optimizer, epoch, lr):
    lr = lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# 4. 4_train/valid 함수정의

In [10]:
def visualize_images(input_img, target_mask, pred_mask, alpha):
    input_img = input_img.cpu().numpy().transpose(1, 2, 0)
    target_mask = target_mask.cpu().numpy().transpose(1, 2, 0)
    pred_mask = pred_mask.cpu().numpy().squeeze()

    fig, axs = plt.subplots(1, 2, figsize=(15, 5))
    axs[0].imshow(input_img)
    axs[0].imshow(target_mask,alpha=alpha)
    axs[0].set_title('Origianl Mask')
    axs[0].axis('off')

    axs[1].imshow(input_img)
    axs[1].imshow(pred_mask,alpha=alpha)
    axs[1].set_title('Predicted Mask')
    axs[1].axis('off')

In [11]:
def train(train_loader, model, criterion, optimizer, valid_loader, model_dir, n_epoch, batch_size, lr, device):
    latest_model_path = find_latest_model_path(model_dir)
    best_model_path = os.path.join(*[model_dir, 'model_best.pth'])

    if latest_model_path is not None:
        state = torch.load(latest_model_path)
        epoch = state['epoch']
        model.load_state_dict(state['model'])
        assert Path(best_model_path).exists() == True, f'best model path {best_model_path} does not exist!'
        best_state = torch.load(latest_model_path)
        min_val_los = best_state['valid_loss']

        print(f'Restored model: {epoch}, Min validation loss: {min_val_los}')
        epoch += 1
        print(f'{epoch}')
    else:
        print('epoch: 0')
        epoch = 0
        min_val_los = 9999

    valid_losses = []

    for epoch in range(epoch, n_epoch):
        adjust_learning_rate(optimizer, epoch, lr)
        tq = tqdm.tqdm(total=(len(train_loader) * batch_size), dynamic_ncols=True)
        tq.set_description(f'Epoch {epoch}')

        losses = AverageMeter()
        t_iou = 0
        t_dice = 0

        model.train()
        for i, (input, target) in enumerate(train_loader):
            # Variable: 텐서를 대체하기 위해 사용되던 클래스(과거 버전의 텐서)
            input_var = Variable(input).to(device)
            target_var = Variable(target).to(device)

            masks_pred = model(input_var)
            pred = torch.softmax(masks_pred, dim=1)
            target_mask = target_var.squeeze(1)
            # 예측 활성화 함수
            pred_classes = torch.argmax(pred, dim=1)
            # target_classes = torch.argmax(target_mask, dim=1).unsqueeze(1)

            t_dice += dice_coeff(pred, target_mask)
            t_iou += iou(pred, target_mask)

            masks_probs_flat = masks_pred.view(-1)
            true_masks_flat = target_var.view(-1)

            loss = criterion(masks_probs_flat, true_masks_flat)
            losses.update(loss)
            tq.set_postfix(loss='{:.5f}'.format(losses.avg))
            tq.update(batch_size)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'train miou : {t_iou/len(train_loader):.5f} train dice score : {t_dice/len(train_loader):.5f}')
        valid_metrics = valid(model, valid_loader, criterion)
        valid_loss = valid_metrics['valid_loss']
        valid_dice = valid_metrics['v_dice']
        valid_iou = valid_metrics['v_iou']
        valid_losses.append(valid_loss)
        print(f'valid_loss = {valid_loss:.5f}')
        print(f'valid miou : {valid_iou/len(valid_loader):.5f} valid dice score : {valid_dice/len(valid_loader):.5f}')
        tq.close()

        visualize_images(input[epoch], target[epoch], pred[epoch], 0.4)

        epoch_model_path = os.path.join(*[model_dir, f'model_epoch_{epoch}.pth'])
        torch.save({
            'model': model.state_dict(),
            'epoch': epoch,
            'valid_loss': valid_loss,
            'train_loss': losses.avg
        }, epoch_model_path)

        if valid_loss < min_val_los:
            min_val_los = valid_loss

            torch.save({
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_loss': valid_loss,
                'train_loss': losses.avg
            }, best_model_path)

In [12]:
# def train(train_loader, model, criterion, optimizer, valid_loader, model_dir, n_epoch, batch_size, lr, device):
#     latest_model_path = find_latest_model_path(model_dir)
#     best_model_path = os.path.join(*[model_dir, 'model_best.pth'])

#     if latest_model_path is not None:
#         state = torch.load(latest_model_path)
#         epoch = state['epoch']
#         model.load_state_dict(state['model'])
#         assert Path(best_model_path).exists() == True, f'best model path {best_model_path} does not exist!'
#         best_state = torch.load(latest_model_path)
#         min_val_los = best_state['valid_loss']

#         print(f'Restored model: {epoch}, Min validation loss: {min_val_los}')
#         epoch += 1
#         print(f'{epoch}')
#     else:
#         print('epoch: 0')
#         epoch = 0
#         min_val_los = 9999

#     valid_losses = []

#     for epoch in range(epoch, n_epoch):
#         adjust_learning_rate(optimizer, epoch, lr)
#         tq = tqdm.tqdm(total=(len(train_loader) * batch_size))
#         tq.set_description(f'Epoch {epoch}')

#         losses = AverageMeter()
#         t_iou = 0
#         t_dice = 0

#         model.train()
#         for i, (input, target) in enumerate(train_loader):
#             # Variable: 텐서를 대체하기 위해 사용되던 클래스(과거 버전의 텐서)
#             input_var = Variable(input).to(device)
#             target_var = Variable(target).to(device)

#             masks_pred = model(input_var)
#             pred = F.sigmoid(masks_pred)
#             target_mask = target_var

#             pred[pred>0.5] = 1
#             pred[pred<=0.5] = 0
#             target_mask[target_mask>0.5] = 1
#             target_mask[target_mask<=0.5] = 0

#             t_dice += dice_coeff(pred, target_mask)
#             t_iou += iou(pred, target_mask)

#             masks_probs_flat = masks_pred.view(-1)
#             true_masks_flat = target_var.view(-1)

#             loss = criterion(masks_probs_flat, true_masks_flat)
#             losses.update(loss)
#             tq.set_postfix(loss='{:.5f}'.format(losses.avg))
#             tq.update(batch_size)

#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()

#         print(f'train miou : {t_iou/len(train_loader):.5f} train dice score : {t_dice/len(train_loader):.5f}')
#         valid_metrics = valid(model, valid_loader, criterion)
#         valid_loss = valid_metrics['valid_loss']
#         valid_dice = valid_metrics['v_dice']
#         valid_iou = valid_metrics['v_iou']
#         valid_losses.append(valid_loss)
#         print(f'valid_loss = {valid_loss:.5f}')
#         print(f'valid miou : {valid_iou/len(valid_loader):.5f} valid dice score : {valid_dice/len(valid_loader):.5f}')
#         tq.close()

#         epoch_model_path = os.path.join(*[model_dir, f'model_epoch_{epoch}.pth'])
#         torch.save({
#             'model': model.state_dict(),
#             'epoch': epoch,
#             'valid_loss': valid_loss,
#             'train_loss': losses.avg
#         }, epoch_model_path)

#         if valid_loss < min_val_los:
#             min_val_los = valid_loss

#             torch.save({
#                 'model': model.state_dict(),
#                 'epoch': epoch,
#                 'valid_loss': valid_loss,
#                 'train_loss': losses.avg
#             }, best_model_path)

In [13]:
def valid(model, val_loader, criterion):
    losses = AverageMeter()
    v_iou = 0
    v_dice = 0
    model.eval()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            input_var = Variable(input).to(device)
            target_var = Variable(target).to(device)
            output = model(input_var)
            loss = criterion(output, target_var)
            losses.update(loss.item(), input_var.size(0))
            pred = torch.argmax(output, dim=1)
            target_mask = target_var

            # v_dice += dice_coeff(pred, target_mask)
            v_iou += iou(pred, target_mask)

    return {'valid_loss': losses.avg, 'v_dice': v_dice, 'v_iou': v_iou}

In [14]:
# def valid(model, val_loader, criterion):
#     losses = AverageMeter()
#     v_iou = 0
#     v_dice = 0
#     model.eval()
#     with torch.no_grad():
#         for i, (input, target) in enumerate(val_loader):
#             input_var = Variable(input).to(device)
#             target_var = Variable(target).to(device)
#             output = model(input_var)
#             loss = criterion(output, target_var)
#             losses.update(loss.item(), input_var.size(0))
#             pred = F.sigmoid(output)
#             target_mask = target_var

#             pred[pred>0.5] = 1
#             pred[pred<=0.5] = 0
#             target_mask[target_mask>0.5] = 1
#             target_mask[target_mask<=0.5] = 0

#             # v_dice += dice_coeff(pred, target_mask)
#             v_iou += iou(pred, target_mask)

#     return {'valid_loss': losses.avg, 'v_dice': v_dice, 'v_iou': v_iou}

# 5. 5_학습용 파라미터 정의

In [15]:
# 모델 저장 폴더
model_dir = '/content/model_weights2'
os.makedirs(model_dir, exist_ok=True)

# 데이터 저장 폴더
data_dir = '/content/02. Unet_Dataset'
DIR_IMG = os.path.join(data_dir, 'image')
DIR_MASK = os.path.join(data_dir, 'mask')

# Device 할당
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

lr = 0.001
momentum = 0.9
weight_decay = 1e-4
batch_size = 4
num_workers = 8
n_epoch = 5

In [16]:
device
print(torch.cuda.get_device_name())

Tesla T4


# 6. 6_학습 사전 TEST

In [17]:
# test_size = []
# for i in range(len(dataset)):
#     if dataset[i][0].shape != dataset[i][1].shape:
#         test_size.append(i)
#     else:
#         pass

In [18]:
# test_size

In [19]:
# def tensor_size(dataset, index):
#     print(dataset[index][0].shape)
#     print(dataset[index][1].shape)
# tensor_size(dataset, 647)

In [20]:
# dataiter = iter(train_loader)
# images, labels = next(dataiter)
# images, labels = images.to(device), labels.to(device)
# outputs = model(images)

# outputs.shape

In [21]:
# num_epochs = 2

# for epoch in range(num_epochs):
#     model.train()
#     train_loss = 0.0

#     for images, masks in train_loader:
#         images, masks = images.to(device), masks.to(device)

#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = criterion(outputs, masks)
#         loss.backward()
#         optimizer.step()

#         train_loss += loss.item() * images.size(0)

#     train_loss = train_loss / len(train_loader.dataset)

#     print(f'Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}')

# 7. 7_본 학습

In [22]:
# 모델 할당
model = UNet(num_classes=3)

# 옵티마이저 정의
# weight_decay: L2정규화, 모델이 과적합되지 않도록 패널티 값
optimizer = torch.optim.SGD(model.parameters(), lr,
                            momentum=momentum, weight_decay=weight_decay)

# 손실함수 정의
criterion = nn.CrossEntropyLoss().to(device)

# transforms
channel_mean = [0.485, 0.456, 0.406]
channel_stds = [0.229, 0.224, 0.225]
train_tfms = transforms.Compose([transforms.ToTensor(), transforms.Resize(256),
                                 transforms.Normalize(channel_mean, channel_stds)])
val_tfms = transforms.Compose([transforms.ToTensor(), transforms.Resize(256),
                                 transforms.Normalize(channel_mean, channel_stds)])
mask_tfms = transforms.Compose([transforms.ToTensor(), transforms.Resize(256)])

# 데이터셋
dataset = BreastDatasets(img_dir=DIR_IMG, img_transform=train_tfms, mask_dir=DIR_MASK,
                        mask_transform=mask_tfms)
train_size = int(0.85*len(dataset))
valid_size = len(dataset) - train_size

# train/ val 데이터셋 분할
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

# 데이터로더
train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers)

model.to(device)

train(train_loader, model, criterion, optimizer, valid_loader, model_dir, n_epoch, batch_size, lr, device)



epoch: 0


Epoch 0: 100%|██████████| 664/664 [00:39<00:00, 18.40it/s, loss=nan]

train miou : nan train dice score : nan


RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1

# 7. test

In [None]:
model = UNet(num_classes=3)  # num_classes는 출력 클래스 수에 따라 조정
state = torch.load('/content/model_weights/model_best.pth')
model.load_state_dict(state['model'])
model.to(device)
model.eval()  # 모델을 평가 모드로 전환

In [None]:
state['train_loss']

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

image = Image.open('/content/benign18.png')
plt.imshow(image)
image = transform(image).unsqueeze(0)  # 배치 차원을 추가

In [None]:
#예측수행
with torch.no_grad():
    image = image.to(device)  # GPU 사용 시
    output = model(image)
    output = torch.sigmoid(output)  # 바이너리 세그멘테이션의 경우
    output = (output > 0.5).float()  # 임계값 0.5로 이진화

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.title('Input Image')
plt.imshow(image.squeeze().cpu().numpy().transpose(1, 2, 0))

plt.subplot(1, 2, 2)
plt.title('Segmentation Output')
plt.imshow(output.squeeze().cpu().numpy().transpose(1, 2, 0))

plt.show()
