# **1. Sementic Segmentation**

### 균열 데이터셋을 활용한 Segmentation

In [1]:
import torch
import torchvision.transforms as transforms
import os
import torch.nn.functional as F
import tqdm
import random
import shutil
import numpy as np
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt

from PIL import Image
from pathlib import Path
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, random_split

In [2]:
# 재현성을 위한 랜덤시드 고정
random_seed = 2024
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.mbenchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

In [3]:
# 데이터셋 클래스 정의
class CrackDatasets(Dataset):
    def __init__(self, img_dir, mask_dir, img_transform, mask_transform):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_transform = img_transform
        self.mask_transform = mask_transform
        self.img_files = []
        self.mask_file = []
        self.seed = np.random.randint(2024)
        for img_name in os.listdir(self.img_dir):
            if img_name.split('.')[-1] in ('png', 'jpg'):
                self.img_files.append(os.path.join(self.img_dir, img_name))
                self.mask_files.append(os.path.join(self.mask_dir, img_name))
    def __len__(self):
        return len(self.img_files)
    def __getitem__(self, i):
        img = Image.open(self.img_files[i])
        if self.img_transform is not None:
            random.seed(self.seed)
            img = self.img_transform(img)
        mask = Image.open(self.mask_files[i]).convert('L') # 그레이스케일로 변환
        if self.mask_transform is not None:
            mask = self.mask_transform(mask)
        return img, mask

# **2. UNET**
* 의료 영상 분석에 자주 사용되는 딥러닝 아키텍쳐

<img src='https://miro.medium.com/v2/resize:fit:1400/1*qNdglJ1ORP3Gq77MmBLhHQ.png' width=600>

* 인코딩(축소)
  * 일반적인 컨볼루션 신경망(CNN)
  * 컨볼루션과 풀링을 통해 이미지가 점점 작아짐
  * 고수준의 특징을 추출
* 디코딩(확장)
  * 축소 부분에서 얻은 특징을 사용하여 이미지를 원래 크기로 다시 확장
  * 업샘플링 연산을 사용하여 이미지의 크기를 증가
  * 확장 과정에서 축소 부분에서의 특징을 연결하여 정보를 유지할 수 있도록 함
* Skip Connections
  * U-Net의 핵심 특징
  * 축소 부분의 각 레이어에서의 출력은 확장 부분의 해당 레이어와 연결
  * 고수준의 정보와 저수준의 정보가 결합되어 세부적인 부분까지 정확한 이미지 분할이 가능

In [4]:
class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()
        self.num_classes = num_classes
        self.contracting_11 = self.conv_block(in_channels=3, out_channels=64)
        self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_21 = self.conv_block(in_channels=64, out_channels=128)
        self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_31 = self.conv_block(in_channels=128, out_channels=256)
        self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_41 = self.conv_block(in_channels=256, out_channels=512)
        self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.middle = self.conv_block(in_channels=512, out_channels=1024)
        self.expansive_11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1) 
        self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512)
        self.expansive_21 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1) 
        self.expansive_22 = self.conv_block(in_channels=512, out_channels=256)
        self.expansive_31 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1) 
        self.expansive_32 = self.conv_block(in_channels=256, out_channels=128)
        self.expansive_41 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1) 
        self.expansive_42 = self.conv_block(in_channels=128, out_channels=64)
        self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, stride=1, padding=1)
        
    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
                              nn.ReLU(),
                              nn.BatchNorm2d(num_features=out_channels),
                              nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
                              nn.ReLU(),
                              nn.BatchNorm2d(num_features=out_channels))
        return block

    def forward(self, X):
        contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256]
        contracting_12_out = self.contracting_12(contracting_11_out) # [-1, 64, 128, 128]
        contracting_21_out = self.contracting_21(contracting_12_out) # [-1, 128, 128, 128]
        contracting_22_out = self.contracting_22(contracting_21_out) # [-1, 128, 64, 64]
        contracting_31_out = self.contracting_31(contracting_22_out) # [-1, 256, 64, 64]
        contracting_32_out = self.contracting_32(contracting_31_out) # [-1, 256, 32, 32]
        contracting_41_out = self.contracting_41(contracting_32_out) # [-1, 512, 32, 32]
        contracting_42_out = self.contracting_42(contracting_41_out) # [-1, 512, 16, 16]
        middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16]
        expansive_11_out = self.expansive_11(middle_out) # [-1, 512, 32, 32]
        expansive_12_out = self.expansive_12(torch.cat((expansive_11_out, contracting_41_out), dim=1)) # [-1, 1024, 32, 32] -> [-1, 512, 32, 32]
        expansive_21_out = self.expansive_21(expansive_12_out) # [-1, 256, 64, 64]
        expansive_22_out = self.expansive_22(torch.cat((expansive_21_out, contracting_31_out), dim=1)) # [-1, 512, 64, 64] -> [-1, 256, 64, 64]
        expansive_31_out = self.expansive_31(expansive_22_out) # [-1, 128, 128, 128]
        expansive_32_out = self.expansive_32(torch.cat((expansive_31_out, contracting_21_out), dim=1)) # [-1, 256, 128, 128] -> [-1, 128, 128, 128]
        expansive_41_out = self.expansive_41(expansive_32_out) # [-1, 64, 256, 256]
        expansive_42_out = self.expansive_42(torch.cat((expansive_41_out, contracting_11_out), dim=1)) # [-1, 128, 256, 256] -> [-1, 64, 256, 256]
        output_out = self.output(expansive_42_out) # [-1, num_classes, 256, 256]
        return output_out

# **3. Dice Score와 IoU 메트릭 함수 정의하기**

In [5]:
# Dice coefficient: 두 개의 집합 간의 유사성을 측정하는 데 사용되는 통계적인 지표
# 2×∣A∩B∣ / ∣A∣+∣B∣ 
# 0에서 1까지의 값을 가지며, 1에 가까울수록 두 집합이 유사
# 1e-6: 0.000001
def dice_coeff(input, target, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # assert 조건, "에러 메세지"
    assert input.size() == target.size()
    assert input.dim() == 3 or not reduce_batch_first

    sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
    inter = 2*(input * target).sum(dim=sum_dim)
    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
    # torch.where(): 조건에 따른 연산함수
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
    
    dice = (inter + epsilon) / (sets_sum + epsilon)
    return dice.mean() 
    

# IoU(Intersection over Union): 객체 검출 및 객체 분할과 같은 컴퓨터 비전 작업에서 사용되는 평가 지표
# 두 개의 영역 또는 객체가 주어졌을 때, IoU는 교집합을 합집합으로 나눈 것을 나타냄. 0에서 1사이의 값을 갖음
def iou(y_true, y_pred, epsilon: float = 1e-6):
    intersection = (y_true * y_pred).sum()
    union = y_true.sum() + y_pred.sum() - intersection
    return (intersection + epsilon) / (union + epsilon)

# **4. 학습, 검증 함수 정의하기**

In [6]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def find_latest_model_path(dir):
    model_paths = []
    epochs = []
    for path in Path(dir).glob('*.pth'):
        if 'epoch' not in path.stem:
            continue
        model_paths.append(path)
        parts = path.stem.split('_')
        epoch = int(parts[-1])
        epochs.append(epoch)

    if len(epochs) > 0:
        epochs = np.array(epochs)
        max_idx = np.argmax(epochs)
        return model_paths[max_idx]
    else:
        return None

# param_groups: 옵티마이저 객체 속성. 요소는 딕셔너리이며 매개변수 그룹에 대한 정보를 저장
def adjust_learning_rate(optimizer, epoch, lr):
    lr = lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [8]:
def train(train_loader, model, criterion, optimizer, valid_loader, model_dir, n_epoch, batch_size, lr, device):
    latest_model_path = find_latest_model_path(model_dir)
    best_model_path = os.path.join(*[model_dir, 'model_best.pth'])

    if latest_model_path is not None:
        state = torch.load(latest_model_path)
        epoch = state['epoch']
        model.load_state_dict(state['model'])
        assert Path(best_model_path).exists() == True, f'best model path {best_model_path} does not exist!'
        best_state = torch.load(latest_model_path)
        min_val_los = best_state['valid_loss']

        print(f'Restored model: {epoch}, Min validation loss: {min_val_los}')
        epoch += 1
        print(f'{epoch}')
    else:
        print('epoch: 0')
        epoch = 0
        min_val_los = 9999
    
    valid_losses = []

    for epoch in range(epoch, n_epoch):
        adjust_learning_rate(optimizer, epoch, lr)
        tq = tqdm.tqdm(total=(len(train_loader) * batch_size))
        tq.set_description(f'Epoch {epoch}')

        losses = AverageMeter()
        t_iou = 0
        t_dice = 0

        model.train()
        for i, (input, target) in enumerate(train_loader):
            # Variable: 텐서를 대체하기 위해 사용되던 클래스(과거 버전의 텐서)
            input_var = Variable(input).to(device)
            target_var = Variable(target).to(device)

            masks_pred = model(input_var)
            pred = F.sigmoid(masks_pred)
            target_mask = target_var
            
            pred[pred>0.5] = 1
            pred[pred<=0.5] = 0
            target_mask[target_mask>0.5] = 1
            target_mask[target_mask<=0.5] = 0

            t_dice += dice_coeff(pred, target_mask)
            t_iou += iou(pred, target_mask)

            masks_probs_flat = masks_pred.view(-1)
            true_masks_flat = target_var.view(-1)

            loss = criterion(masks_probs_flat, true_masks_flat)
            losses.update(loss)
            tq.set_postfix(loss='{:.5f}'.format(losses.avg))
            tq.update(batch_size)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'train miou : {t_iou/len(train_loader):.5f} train dice score : {t_dice/len(train_loader):.5f}')
        valid_metrics = valid(model, valid_loader, criterion)
        valid_loss = valid_metrics['valid_loss']
        valid_dice = valid_metrics['v_dice']
        valid_iou = valid_metrics['v_iou']
        valid_losses.append(valid_loss)
        print(f'valid_loss = {valid_loss:.5f}')
        print(f'valid miou : {valid_iou/len(valid_loader):.5f} valid dice score : {valid_dice/len(valid_loader):.5f}')
        tq.close()

        epoch_model_path = os.path.join(*[model_dir, f'model_epoch_{epoch}.pth'])
        torch.save({
            'model': model.state_dict(),
            'epoch': epoch,
            'valid_loss': valid_loss,
            'train_loss': losses.avg
        }, epoch_model_path)

        if valid_loss < min_val_los:
            min_val_los = valid_loss

            torch.save({
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_loss': valid_loss,
                'train_loss': losses.avg
            }, best_model_path)

In [9]:
def valid(model, val_loader, criterion):
    losses = AverageMeter()
    v_iou = 0
    v_dice = 0
    model.eval()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            input_var = Variable(input).to(device)
            target_var = Variable(target).to(device)
            output = model(input_var)
            loss = criterion(output, target_var)
            losses.update(loss.item(), input_var.size(0))
            pred = F.sigmoid(output)
            target_mask = target_var

            pred[pred>0.5] = 1
            pred[pred<=0.5] = 0
            target_mask[target_mask>0.5] = 1
            target_mask[target_mask<=0.5] = 0

            v_dice += dice_coeff(pred, target_mask)
            v_iou += iou(pred, target_mask)
            
    return {'valid_loss': losses.avg, 'v_dice': v_dice, 'v_iou': v_iou}

# **5. 파라미터 정의**

In [10]:
# 모델 저장 폴더
model_dir = './model_weights'
os.makedirs(model_dir, exist_ok=True)

# 데이터 저장 폴더
data_dir = './crack_segmentation_dataset/train'
DIR_IMG = os.path.join(data_dir, 'images')
DIR_MASK = os.path.join(data_dir, 'masks')

# Device 할당
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

lr = 0.001
momentum = 0.9
weight_decay = 1e-4
batch_size = 8
num_workers = 8
n_epoch = 10