## Load Library

In [75]:
## 라이브러리 불러오기
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn.functional as F

## 라이브러리 불러오기
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import glob
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms as transforms
from torchvision import models

## Data Preprocessing

In [76]:
#경로 설정
train_mask_paths = glob.glob("../Attention_Unet_gray/Comprehensive data/train/mask/*.png")
train_original_paths = glob.glob('../Attention_Unet_gray/Comprehensive data/train/original/*.png')
train_mask_list = []
train_original_list = []
dir_save_train_np = '../Attention_Unet_gray/Comprehensive data/train_np'

#지정한 폴더가 없으면 생성
if not os.path.exists(dir_save_train_np):
    os.makedirs(dir_save_train_np)

#이미지 데이터를 256*256 크기로 전처리 후 npy 배열 형태로 지정폴더에 저장.
for train_mask_path, train_original_path in zip(train_mask_paths, train_original_paths):
    train_mask_list.append(np.array(Image.open(train_mask_path).resize((256, 256)).convert('L')))
    train_original_list.append(np.array(Image.open(train_original_path).resize((256, 256)).convert('L')))

for i, (train_mask, train_original) in enumerate(zip(train_mask_list, train_original_list)):

    label_ = np.asarray(train_mask)
    input_ = np.asarray(train_original)

    np.save(os.path.join(dir_save_train_np, 'label_%03d.npy' % i), label_)
    np.save(os.path.join(dir_save_train_np, 'input_%03d.npy' % i), input_)
    
val_mask_paths = glob.glob('../Attention_Unet_gray/Comprehensive data/val/mask/*.png')
val_original_paths = glob.glob('../Attention_Unet_gray/Comprehensive data/val/original/*.png')
val_mask_list = []
val_original_list = []
dir_save_val_np = '../Attention_Unet_gray/Comprehensive data/val_np'
if not os.path.exists(dir_save_val_np):
    os.makedirs(dir_save_val_np)

for val_mask_path, val_original_path in zip(val_mask_paths, val_original_paths):
    val_mask_list.append(np.array(Image.open(val_mask_path).resize((256, 256)).convert('L')))
    val_original_list.append(np.array(Image.open(val_original_path).resize((256, 256)).convert('L')))

for i, (val_mask, val_original) in enumerate(zip(val_mask_list, val_original_list)):

    label_ = np.asarray(val_mask)
    input_ = np.asarray(val_original)

    np.save(os.path.join(dir_save_val_np, 'label_%03d.npy' % i), label_)
    np.save(os.path.join(dir_save_val_np, 'input_%03d.npy' % i), input_)

## Network

In [77]:
class AttentionGate(nn.Module):
    def __init__(self, in_c, out_c):
        super(AttentionGate, self).__init__()

        self.Wg = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.Ws = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)
        self.output = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, g, s):
        Wg = self.Wg(g)
        Ws = self.Ws(s)
        out = self.relu(Wg + Ws)
        out = self.output(out)
        return out * s

class ConvBNRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBNRelu, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.layers(x)

class Attention_Unet(nn.Module):
    def __init__(self):
        super(Attention_Unet, self).__init__()

        self.encoders = nn.ModuleList([
            ConvBNRelu(1, 64), ConvBNRelu(64, 64), nn.MaxPool2d(kernel_size=2),
            ConvBNRelu(64, 128), ConvBNRelu(128, 128), nn.MaxPool2d(kernel_size=2),
            ConvBNRelu(128, 256), ConvBNRelu(256, 256), nn.MaxPool2d(kernel_size=2),
            ConvBNRelu(256, 512), ConvBNRelu(512, 512), nn.MaxPool2d(kernel_size=2),
            ConvBNRelu(512, 1024)
        ])

        self.decoders = nn.ModuleList([
            ConvBNRelu(1024, 512), nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2), AttentionGate(in_c=512, out_c=512),
            ConvBNRelu(2 * 512, 512), ConvBNRelu(512, 256), nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2), AttentionGate(in_c=256, out_c=256),
            ConvBNRelu(2 * 256, 256), ConvBNRelu(256, 128), nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2), AttentionGate(in_c=128, out_c=128),
            ConvBNRelu(2 * 128, 128), ConvBNRelu(128, 64), nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2), AttentionGate(in_c=64, out_c=64),
            ConvBNRelu(2 * 64, 64), nn.Conv2d(64, 1, kernel_size=1)
        ])

    def forward(self, x):
        encoder_outs = []
        for encoder in self.encoders:
            x = encoder(x)
            encoder_outs.append(x)

        for idx, decoder in enumerate(self.decoders):
            if idx % 3 == 2:
                x = decoder(x, encoder_outs[4 - idx // 3])
            else:
                x = decoder(x)

        return x


## State_dict

In [78]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# net = Attention_Unet().to(device)
# print("Model's state_dict:")
# for param_tensor in net.state_dict():
#     print(param_tensor, "\t", net.state_dict()[param_tensor].size())

## DataLoader

In [79]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        lst_data = os.listdir(self.data_dir)

        lst_label = [f for f in lst_data if f.startswith('label')]
        lst_input = [f for f in lst_data if f.startswith('input')]

        lst_label.sort()
        lst_input.sort()

        self.lst_label = lst_label
        self.lst_input = lst_input

    def __len__(self):
        return len(self.lst_label)

    def __getitem__(self, index):
        label = np.load(os.path.join(self.data_dir, self.lst_label[index]))
        input = np.load(os.path.join(self.data_dir, self.lst_input[index]))

        # 정규화
        label = label/255.0
        input = input/255.0

        # 이미지와 레이블의 차원 = 2일 경우(채널이 없을 경우, 흑백 이미지), 새로운 채널(축) 생성
        if label.ndim == 2:
            label = label[:, :, np.newaxis]
        if input.ndim == 2:
            input = input[:, :, np.newaxis]

        data = {'input': input, 'label': label}

        # transform이 정의되어 있다면 transform을 거친 데이터를 불러옴
        if self.transform:
            data = self.transform(data)

        return data  

## Transform

In [80]:
class ToTensor(object):
    def __call__(self, data):
        label, input = data['label'], data['input']

        label = label.transpose((2, 0, 1)).astype(np.float32)
        input = input.transpose((2, 0, 1)).astype(np.float32)

        data = {'label': torch.from_numpy(label), 'input': torch.from_numpy(input)}

        return data

class Normalization(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, data):
        label, input = data['label'], data['input']

        input = (input - self.mean) / self.std

        data = {'label': label, 'input': input}

        return data

class RandomFlip(object):
    def __call__(self, data):
        label, input = data['label'], data['input']

        if np.random.rand() > 0.5:
            label = np.fliplr(label)
            input = np.fliplr(input)

        if np.random.rand() > 0.5:
            label = np.flipud(label)
            input = np.flipud(input)

        data = {'label': label, 'input': input}

        return data
    
class RandomHorizontalFlip(object):
    def __call__(self, data):
        label, input = data['label'], data['input']

        if np.random.rand() > 0.5:
            label = np.fliplr(label)
            input = np.fliplr(input)

        data = {'label': label, 'input': input}

        return data

class RandomVerticalFlip(object):
    def __call__(self, data):
        label, input = data['label'], data['input']

        if np.random.rand() > 0.5:
            label = np.flipud(label)
            input = np.flipud(input)

        data = {'label': label, 'input': input}

        return data
    
class RandomRotation(object):
    def __call__(self, data):
        label, input = data['label'], data['input']

        # 무작위로 회전 각도 생성
        angle = transforms.RandomRotation.get_params((-45, 45))
        
        # 넘파이 배열로 변환
        label = label.numpy()
        input = input.numpy()

        # 무작위 회전 적용
        input_tensor_input = torch.from_numpy(input)
        rotated_input_tensor = transforms.functional.rotate(input_tensor_input, angle)

        # 무작위 회전 적용
        input_tensor_label = torch.from_numpy(label)
        rotated_label_tensor = transforms.functional.rotate(input_tensor_label, angle)

        # 넘파이 배열로 변환
        rotated_input = rotated_input_tensor.numpy()
        rotated_label = rotated_label_tensor.numpy()

        data = {'label': rotated_label, 'input': rotated_input}

        return data

#밝기, 대비, 채도, 색조 무작위 변환
class RandomColorJitter(object):
    def __init__(self):
        self.color_jitter = transforms.ColorJitter(
            brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5
        )

    def __call__(self, data):
        label, input = data['label'], data['input']
        
        input_tensor = torch.from_numpy(input)
        jittered_input_tensor = self.color_jitter(input_tensor)

        # 라벨은 원본 유지
        data = {'label': label, 'input': jittered_input_tensor.numpy()}

        return data

## Network save & load

In [81]:
## 네트워크 저장하기
def save(ckpt_dir, net, optim, epoch, best=False):
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    if best:
        torch.save({'net': net.state_dict(), 'optim': optim.state_dict()},
                   "%s/best_Attention_Unet_gray.pth" % ckpt_dir)
    #주석 해제하면 best epoch마다 가중치 새로 저장함.
    # else:
    #     torch.save({'net': net.state_dict(), 'optim': optim.state_dict()},
    #                "%s/best_Attention_Unet_gray_epoch%d.pth" % (ckpt_dir, epoch))

def load(ckpt_dir, net, optim):
    if not os.path.exists(ckpt_dir):
        epoch = 0
        return net, optim, epoch

    ckpt_lst = os.listdir(ckpt_dir)
    ckpt_lst = [f for f in ckpt_lst if f.endswith(".pth")]
    ckpt_lst = sorted(ckpt_lst, key=lambda f: int(''.join(filter(str.isdigit, f.split(".")[0]))) if f.split(".")[0].isdigit() else -1)

    if not ckpt_lst:
        epoch = 0
        return net, optim, epoch

    dict_model = torch.load(os.path.join(ckpt_dir, ckpt_lst[-1]))

    net.load_state_dict(dict_model['net'])
    optim.load_state_dict(dict_model['optim'])
    
    # 파일명에서 epoch 추출 시 예외 처리
    try:
        epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])
    except IndexError:
        epoch = 0

    return net, optim, epoch

## Set Traning

In [82]:
# Set training parameters
lr = 1e-4
batch_size = 4
num_epoch = 1000

# Early stoping
patience_limit = 15
patience_check = 0

base_dir = "../Attention_Unet_gray"
data_dir = "../Attention_Unet_gray/Comprehensive data"
ckpt_dir = os.path.join(base_dir, "checkpoint")
os.makedirs(ckpt_dir, exist_ok=True)

# Transform
transform = transforms.Compose([
    Normalization(mean=(0.5), std=(0.225)),
    RandomHorizontalFlip(),
    RandomVerticalFlip(),  
    RandomFlip(),
    ToTensor(),
    RandomRotation(),  # 무작위 회전
    RandomColorJitter(),  # 색상 변화
    
])

# DataLoader
dataset_train = Dataset(data_dir=os.path.join(data_dir, 'train_np'), transform=transform)
loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=0)

dataset_val = Dataset(data_dir=os.path.join(data_dir, 'val_np'), transform=transform)
loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=0)

# Create Network
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = Attention_Unet().to(device)

# Define loss function
fn_loss = nn.BCEWithLogitsLoss().to(device)

# Set Optimizer with Elastic Net regularization
l1_lambda = 0  # L1 규제 강도 조절 (하이퍼파라미터)
l2_lambda = 0.0001  # L2 규제 강도 조절 (하이퍼파라미터)
optim = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=l2_lambda)

# ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optim, mode='min', factor=0.5, patience=5, min_lr=1e-6)

# # Set other ancillary variablesea
num_data_train = len(dataset_train)
num_data_val = len(dataset_val)

num_batch_train = np.ceil(num_data_train / batch_size)
num_batch_val = np.ceil(num_data_val / batch_size)

# # Set other ancillary functions
fn_tonumpy = lambda x: x.to('cpu').detach().numpy().transpose(0, 2, 3, 1)
fn_denorm = lambda x, mean, std: (x * std) + mean
fn_class = lambda x: 1.0 * (x > 0.5)

# Load the trained model, if any
st_epoch = 0
#net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim) 

best_loss = float('inf')

# 학습률 모니터 함수 정의
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

# IoU 계산 함수 정의
def calculate_iou(prediction, ground_truth):
    prediction = prediction.cpu().numpy()
    ground_truth = ground_truth.cpu().numpy() 
    
    intersection = np.logical_and(prediction, ground_truth)
    union = np.logical_or(prediction, ground_truth)
    iou = np.sum(intersection) / np.sum(union)
    return iou

In [None]:
for epoch in range(st_epoch + 1, num_epoch + 1):
        net.train()
        loss_arr = []
        iou_arr = []

        for batch, data in enumerate(loader_train, 1):
            # forward pass
            label = data['label'].to(device)
            input = data['input'].to(device)

            output = net(input)

            # backward pass
            optim.zero_grad()

            loss = fn_loss(output, label)

            l1_loss = 0
            for param in net.parameters():
                l1_loss += torch.norm(param, p=1)  # 각 가중치의 L1 노름을 더함

            loss = loss + l1_lambda * l1_loss + l2_lambda * loss  # Elastic Net 규제 항 추가

            loss.backward()
            optim.step()
            
            # Calculate the loss function
            loss_arr += [loss.item()]

            # Calculate the IoU
            prediction = fn_class(output)
            iou = calculate_iou(prediction, label)
            iou_arr += [iou]

            current_lr = get_lr(optim)
            
            print("TRAIN: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f | IOU %.4f | Patience %d | lr %.7f" %
                    (epoch, num_epoch, batch, num_batch_train, np.mean(loss_arr), np.mean(iou_arr), patience_check, current_lr))

        with torch.no_grad():
            net.eval()
            val_loss_arr = []
            iou_arr = []

            for batch, data in enumerate(loader_val, 1):
                # forward pass
                label = data['label'].to(device)
                input = data['input'].to(device)

                output = net(input)

                # Calculate the loss function
                loss = fn_loss(output, label)

                l1_loss = 0
                for param in net.parameters():
                    l1_loss += torch.norm(param, p=1)  # 각 가중치의 L1 노름을 더함

                loss = loss + l1_lambda * l1_loss + l2_lambda * loss  # Elastic Net 규제 항 추가

                val_loss_arr += [loss.item()]

                # Calculate the IoU
                prediction = fn_class(output)
                iou = calculate_iou(prediction, label)
                iou_arr += [iou]

                print("VAILD: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f | IOU %.4f | Patience %d | lr %.7f" %
                    (epoch, num_epoch, batch, num_batch_val, np.mean(val_loss_arr), np.mean(iou_arr), patience_check, current_lr))

        mean_loss = np.mean(val_loss_arr)

        # Step the ReduceLROnPlateau scheduler with validation loss
        val_loss = np.mean(val_loss_arr)
        scheduler.step(val_loss)  # 스케줄러에 검증 손실 전달

        # 조기 종료 기능 (Early stopping)
        if mean_loss < best_loss:  # 손실이 개선되면
            best_loss = mean_loss
            patience_check = 0  # 개선이 있을 때마다 참을성 카운터를 리셋
            save(ckpt_dir=ckpt_dir, net=net, optim=optim, epoch=epoch, best=True)
        else:
            patience_check += 1  # 손실이 개선되지 않으면 참을성 카운터 증가

        if patience_check >= patience_limit:  # 조기 종료 조건이 충족되면         
            break
        

## Infernece

폴더에 있는 모든 이미지파일을 infernece하여 npy, png 파일로 저장

데이터 로더, 트랜스폼 Label만 할 수 있게 수정하여 다시 실행시켜줘야함.

In [83]:
#npy 파일로 전처리

infer_data_paths = glob.glob('../Attention_Unet_gray/Thyroid_image/*png')

dir_save_infer_np = '../Attention_Unet_gray/Infer_npy_data'

if not os.path.exists(dir_save_infer_np):
    os.makedirs(dir_save_infer_np)

for i, infer_data_path in enumerate(infer_data_paths):
    infer_data = np.array(Image.open(infer_data_path).resize((256, 256)).convert('L'))
    input_ = np.asarray(infer_data)
    np.save(os.path.join(dir_save_infer_np, 'input_%03d.npy' % i), input_)


In [84]:
# 데이터 로더를 구현하기
class Dataset_infer(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        lst_data = os.listdir(self.data_dir)

        lst_input = [f for f in lst_data if f.startswith('input')]

        lst_input.sort()

        self.lst_input = lst_input

    def __len__(self):
        return len(self.lst_input)

    def __getitem__(self, index):
        input = np.load(os.path.join(self.data_dir, self.lst_input[index]))

        # 정규화
        input = input/255.0

        # 이미지와 레이블의 차원 = 2일 경우(채널이 없을 경우, 흑백 이미지), 새로운 채널(축) 생성
        if input.ndim == 2:
            input = input[:, :, np.newaxis]

        data = {'input': input}

        # transform이 정의되어 있다면 transform을 거친 데이터를 불러옴
        if self.transform:
            data = self.transform(data)

        return data

In [85]:
# 트렌스폼 구현하기
class ToTensor_infer(object):
    def __call__(self, data):
        input = data['input']

        input = input.transpose((2, 0, 1)).astype(np.float32)

        data = {'input': torch.from_numpy(input)}

        return data

class Normalization_infer:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, data):
        input = data['input']
        input = (input - self.mean) / self.std
        data = {'input': input}
        return data

In [None]:
base_dir = '../Attention_Unet_gray/'
infer_base_dir = '../Attention_Unet_gray/'
infer_data_dir = '../Attention_Unet_gray/'
ckpt_dir = os.path.join(base_dir, "checkpoint")

transform = transforms.Compose([Normalization_infer(mean=0.5, std=0.5), ToTensor_infer()])

dataset_test = Dataset_infer(data_dir=os.path.join(infer_data_dir, 'Infer_npy_data'), transform=transform)
loader_test = DataLoader(dataset_test, batch_size=8, shuffle=False, num_workers=0)

# 결과 디렉토리 생성하기
result_dir = os.path.join(infer_base_dir, 'result')
if not os.path.exists(result_dir):
    os.makedirs(os.path.join(result_dir, 'png'))
    os.makedirs(os.path.join(result_dir, 'numpy'))

net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim)

# 그밖에 부수적인 variables 설정하기
num_data_test = len(dataset_test)
num_batch_test = np.ceil(num_data_test / batch_size)

with torch.no_grad():
    net.eval()

    for batch, data in enumerate(loader_test, 1):
        
        # forward pass
        input = data['input'].to(device)
        output = net(input)
        
        input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))
        output = fn_tonumpy(fn_class(output))
        
        # 테스트 결과 저장하기
        for j in range(input.shape[0]):
            id = num_batch_test * (batch - 1) + j

            #plt.imsave(os.path.join(result_dir, 'png', 'input_%04d.png' % id), input[j].squeeze(), cmap='gray')
            plt.imsave(os.path.join(result_dir, 'png', 'output_%04d.png' % id), output[j].squeeze(), cmap='gray')

            #np.save(os.path.join(result_dir, 'numpy', 'input_%04d.npy' % id), input[j].squeeze())
            np.save(os.path.join(result_dir, 'numpy', 'output_%04d.npy' % id), output[j].squeeze())