<a href="https://colab.research.google.com/github/Seongjin1225/AI_School_9th_Final_Project_TEAM_3/blob/main/Unet_%EC%BD%94%EB%93%9C_ver2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q segmentation_models_pytorch
# !pip install -qU wandb
!pip install -q scikit-learn==1.0

# 📚 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
import random
from glob import glob
import os, shutil
from tqdm import tqdm
tqdm.pandas()
import time
import copy
import joblib
from collections import defaultdict
import gc
from IPython import display as ipd
from PIL import Image

# visualization
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Sklearn
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold
from sklearn.model_selection import train_test_split

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

import rasterio
from joblib import Parallel, delayed

# For colored terminal text
from colorama import Fore, Back, Style
c_  = Fore.GREEN
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# 📖 Data

In [None]:
# Train Data
directory_path_train = '/content/drive/MyDrive/Final_Project/blood-vessel-segmentation/train'

# 파일 경로를 리스트 형태로 저장하는 함수
def list_files_in_directory(directory_path = directory_path_train):
    direc = []
    for root, dirs, files in os.walk(directory_path):
        for dire in dirs:
            if dire in ["labels","images"]:
                continue
            file_path = os.path.join(root, dire)
            direc.append(file_path)
    return direc

train_folders = list_files_in_directory(directory_path_train)
train_folders

def count_total_img(folders = train_folders):
    sub_f = ["images","labels"]
    path = []
    total_files = []
    for dire in folders:
        for subf in sub_f:
            if (dire == "/content/drive/MyDrive/Final_Project/blood-vessel-segmentation/train/kidney_3_dense") & (subf != "labels"):
                continue
            _dir = dire + "/" + subf
            total_sample = len(os.listdir(_dir))
            path.append(_dir)
            total_files.append(total_sample)
    obj = {
        "path": path,
        "total_files":total_files
    }
    return obj

train_file_dir = count_total_img()

folders = train_file_dir
_paths = list(zip(folders['path'], folders['total_files']))


train_images = []
train_labels = []

for path, total_files in _paths:
    split_text = path.split("/")

    if 'labels' in split_text:
        label_path = os.path.join(*split_text)
        train_labels.extend(glob(f'/{label_path}/*.tif'))

    if 'images' in split_text:
        image_path = os.path.join(*split_text)
        train_images.extend(glob(f'/{image_path}/*.tif'))

    if 'kidney_3_dense' in split_text:
        image_path = os.path.join(*split_text).replace('kidney_3_dense', 'kidney_3_sparse').replace('labels','images')
        train_images.extend(glob(f'/{image_path}/*.tif')[:501])

print(len(train_images))  # 7429
print(len(train_labels))  # 7429


# Test Data
test_directory = '/content/drive/MyDrive/Final_Project/blood-vessel-segmentation/tetst'

def list_files_in_directorys(directory_path = test_directory):
    direc = []
    for root, dirs, files in os.walk(directory_path):
        for dire in dirs:
            if dire in ["images"]:
                continue
            file_path = os.path.join(root, dire)
            direc.append(file_path)
    return direc

test_folders = list_files_in_directorys(test_directory)
test_folders

def count_total_imgs(folders = test_folders):
    sub_f = ["images"]
    path = []
    total_files = []
    for dire in folders:
        for subf in sub_f:
            _dir = dire + "/" + subf
            total_sample = len(os.listdir(_dir))
            path.append(_dir)
            total_files.append(total_sample)
    obj = {
        "path": path,
        "total_files":total_files
    }
    return obj

test_file_dir = count_total_imgs()
folders = test_file_dir
_paths = list(zip(folders['path'], folders['total_files']))

test_images = []

for path, total_files in _paths:
    split_text = path.split("/")

    if 'images' in split_text:
        image_path = os.path.join(*split_text)
        test_images.extend(glob(f'/{image_path}/*.tif'))

print(len(test_images))

df = pd.DataFrame(data={"images": train_images, 'masks' : train_labels})

# ⚙️ Configuration

In [None]:
class CFG:
    seed          = 101   # 랜덤 시드 설정
    debug         = False  # 전체 학습을 위해 False로
    exp_name      = 'Baselinev2'  # 이름
    comment       = 'unet-efficientnet_b1-512x512'  # 코멘트
    model_name    = 'Unet'  # 사용 모델 이름
    backbone      = 'efficientnet-b1'  # 백본(Backbone) 모델의 이름
    train_bs      = 16  # 훈련 배치 사이즈
    valid_bs      = train_bs*2  # 검증 배치 사이즈
    val_split     = 0.2  # 검증 데이터셋 분할 비율
    random_state  = 42
    img_size      = [512, 512]  # 이미지 크기 설정
    epochs        = 5  # 훈련 에폭
    lr            = 2e-3  # learning_rate
    scheduler     = 'CosineAnnealingLR'
    optimizers    ='adam'
    min_lr        = 1e-6  # 최소 학습률
    T_max         = int(30000/train_bs*epochs)+50  # CosineAnnealingLR 스케줄러의 주기 설정
    T_0           = 25   # CosineAnnealingWarmRestarts 스케줄러의 주기 설정
    warmup_epochs = 0
    wd            = 1e-6  # 가중치 감쇠 (Weight Decay) 설정
    n_accumulate  = max(1, 32//train_bs)  # 그래디언트 누적 (Gradient Accumulation) 설정
    n_fold        = 5   # K-Fold Cross Validation의 Fold 수 설정
    num_classes   = 1  # 클래스 수 설정 (이진 분류에서는 1)
    device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    loss_func     = "DiceLoss"


# ❗ Reproducibility

In [None]:
# 랜덤 시드 설정하는 함수

def set_seed(seed = 42):
    '''노트북 전체의 시드를 설정하여 결과가 매번 동일하게 나오도록
    이것은 재현성을 위한 것'''
    np.random.seed(seed)  # NumPy 라이브러리의 난수 시드 설정
    random.seed(seed)  # Python의 기본 random 모듈의 난수 시드 설정
    torch.manual_seed(seed)  # PyTorch의 난수 시드 설정
    torch.cuda.manual_seed(seed)  # PyTorch에서 CUDA를 사용할 때 GPU 난수 시드 설정

    # CuDNN 백엔드에서 실행할 때 두 가지 추가 옵션을 설정
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # 해시 시드를 고정값으로 설정
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')

set_seed(CFG.seed)

# 🔨 Utility

In [None]:
def load_img(path):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    img = np.tile(img[...,None], [1, 1, 3]) # gray to rgb
    img = img.astype('float32') # original is uint16
    mx = np.max(img)
    if mx:
        img/=mx # scale image to [0, 1]
    return img

def load_msk(path):
    msk = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    msk = msk.astype('float32')
    msk/=255.0
    return msk

In [None]:
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction


# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

# 🌈 Augmentations

In [None]:
data_transforms = {
    "train": A.Compose([
        A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.5),
#         A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
        A.OneOf([
            A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
# #             A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
        ], p=0.25),
        A.CoarseDropout(max_holes=8, max_height=CFG.img_size[0]//20, max_width=CFG.img_size[1]//20,
                         min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
        ], p=1.0),

    "valid": A.Compose([
        A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST),
        ], p=1.0)
}

# 🍚 Dataset

In [None]:
df = pd.DataFrame(data={"images": train_images, 'masks' : train_labels})

class BuildDataset(torch.utils.data.Dataset):
    def __init__(self,img_paths, msk_paths=[], transforms=None):
        self.img_paths = img_paths
        self.msk_paths = msk_paths
        self.transforms = transforms

    def __len__(self):
        return len(self.img_paths)

    def load_img(self, img_path):
        return torch.mean(img_path, dim=1, keepdim=True)

    def __getitem__(self, index):
        img_path  = self.img_paths[index]
        img = load_img(img_path)

        if len(self.msk_paths)>0:
            msk_path = self.msk_paths[index]
            msk = load_msk(msk_path)

            if self.transforms:
                data = self.transforms(image=img, mask=msk)
                img  = data['image']
                msk  = data['mask']
            img = np.transpose(img, (2, 0, 1))  # PyTorch에서는 Channels x Height x Width로 표현하므로 변경
            return torch.tensor(img), torch.tensor(msk)
        else:
            orig_size = img.shape
            if self.transforms:
                data = self.transforms(image=img)
                img  = data['image']
            img = np.transpose(img, (2, 0, 1))
            return torch.tensor(img), torch.tensor(np.array([orig_size[0], orig_size[1]]))

# 🍰 DataLoader

In [None]:
# Split into training and validation sets
train_df, valid_df = train_test_split(df, test_size=CFG.val_split, random_state=CFG.random_state)

# Convert image and mask paths to absolute paths
# train_df["images"] = train_df["images"].apply(lambda x: os.path.join(CFG.data_root, x))
# train_df["masks"] = train_df["masks"].apply(lambda x: os.path.join(CFG.data_root, x))
# valid_df["images"] = valid_df["images"].apply(lambda x: os.path.join(CFG.data_root, x))
# valid_df["masks"] = valid_df["masks"].apply(lambda x: os.path.join(CFG.data_root, x))

# Convert data frames to lists
train_img_paths = train_df["images"].values.tolist()
train_msk_paths = train_df["masks"].values.tolist()
valid_img_paths = valid_df["images"].values.tolist()
valid_msk_paths = valid_df["masks"].values.tolist()

# Optionally, subsample for debugging
# debug = True인 경우, 제한된 양의 데이터만 사용하도록 데이터 경로들을 잘라내는 코드
if CFG.debug:
    train_img_paths = train_img_paths[:CFG.train_bs * 5]
    train_msk_paths = train_msk_paths[:CFG.train_bs * 5]
    valid_img_paths = valid_img_paths[:CFG.valid_bs * 3]
    valid_msk_paths = valid_msk_paths[:CFG.valid_bs * 3]


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_dataset = BuildDataset(train_img_paths, train_msk_paths, transforms=data_transforms['train'])
valid_dataset = BuildDataset(valid_img_paths, valid_msk_paths, transforms=data_transforms['valid'])

train_loader = DataLoader(train_dataset, batch_size=CFG.train_bs, num_workers=3, shuffle=True, pin_memory=True, drop_last=False)

# 이미지의 채널 수를 1로 변경
for batch_idx, (batch_images, batch_mask) in enumerate(train_loader):
    batch_images = torch.mean(batch_images.to(device), dim=1, keepdim=True)

valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_bs, num_workers=3, shuffle=False, drop_last=False)

# 📈 Visualization

In [None]:
for batch_idx, (batch_images, batch_mask) in enumerate(train_loader):
    batch_images = torch.mean(batch_images, dim=1, keepdim=True)
    print('Batch', batch_idx + 1)
    print('Image batch shape:', batch_images.shape)
    print('Label batch shape:', batch_mask.shape)

    for image, mask, image_path, mask_path in zip(batch_images, batch_mask, train_img_paths, train_msk_paths):
        image = image.permute((1,2,0)).numpy()*255.0
        image = image.astype('uint8')

        image_filename = os.path.basename(image_path)
        mask_filename = os.path.basename(mask_path)

        plt.figure(figsize=(15, 10))

        plt.subplot(2,4,1)
        plt.imshow(image, cmap='gray')
        plt.title(f'Original Image - {image_filename}')

        plt.subplot(2,4,2)
        plt.imshow(mask, cmap='gray')
        plt.title(f'Original Mask - {mask_filename}')
        plt.show()
    break

In [None]:
import gc
gc.collect()

# 📦 Model

In [None]:
import segmentation_models_pytorch as smp

def build_model():
    model = smp.Unet(
        encoder_name=CFG.backbone,      # encoder 선택, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # 인풋 채널 수 (1 for gray-scale images, 3 for RGB, etc.)
        classes=CFG.num_classes,        # 아웃풋 채널 수 (number of classes in your dataset)
        activation=None,
    )
    model.to(CFG.device)
    return model

def load_model(path):
    model = build_model()
    model.load_state_dict(torch.load(path))
    model.eval()  # 모델을 평가모드로 설정
    return model

# 🔧 Loss Function

In [None]:
JaccardLoss = smp.losses.JaccardLoss(mode='binary')
DiceLoss    = smp.losses.DiceLoss(mode='binary')
BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
# LovaszLoss  = smp.losses.LovaszLoss(mode='multilabel', per_image=False)
# verskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)

import torch

smooth = 1e-5

def dice_coef(y_true, y_pred):
    y_true_flat = y_true.view(-1)
    y_pred_flat = y_pred.view(-1)

    intersection = torch.sum(y_true_flat * y_pred_flat)

    return (2 * intersection + smooth) / (torch.sum(y_true_flat) + torch.sum(y_pred_flat) + smooth)

def iou_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    union = (y_true + y_pred - y_true*y_pred).sum(dim=dim)
    iou = ((inter+epsilon)/(union+epsilon)).mean(dim=(1,0))
    return iou

def criterion(y_pred, y_true):
    if CFG.loss_func == "DiceLoss":
        return DiceLoss(y_pred, y_true)
    elif CFG.loss_func == "BCELoss":
        y_true = y_true.unsqueeze(1)
        return BCELoss(y_pred, y_true)

# 🚄 Training Function

In [None]:
# 모델 train 위한 함수
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):

    model.train()  # 모델을 훈련 모드로 설정
    scaler = amp.GradScaler()  # 그래디언트 스케일링을 수행하기 위한 GradScaler 객체를 생성

    # 에폭 전체에 대한 총 데이터 크기 및 러닝 손실을 초기화
    dataset_size = 0
    running_loss = 0.0

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, (images, masks) in pbar:
        images = images.to(device, dtype=torch.float)
        masks  = masks.to(device, dtype=torch.float)

        batch_size = images.size(0)

        with amp.autocast(enabled=True):  # 그래디언트 계산을 수행하는 범위를 지정
            # 모델을 통해 예측을 수행하고 손실을 계산 & n_accumulate에 따라 계산된 손실을 나누기.
            y_pred = model(images)
            loss   = criterion(y_pred, masks)
            loss   = loss / CFG.n_accumulate

        scaler.scale(loss).backward()

        if (step + 1) % CFG.n_accumulate == 0:  # 지정된 n_accumulate 배수마다 그래디언트 업데이트를 수행
            scaler.step(optimizer)
            scaler.update()

            # zero the parameter gradients
            optimizer.zero_grad()

            if scheduler is not None:
                scheduler.step()

        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = running_loss / dataset_size  # 에폭 전체에 대한 평균 손실을 계산

        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix( epoch=f'{epoch}',
                          train_loss=f'{epoch_loss:0.4f}',
                          lr=f'{current_lr:0.5f}',
                          gpu_mem=f'{mem:0.2f} GB')
    torch.cuda.empty_cache()
    gc.collect()
    return epoch_loss

# 👀 Validation Function

In [None]:
@torch.no_grad()  # 해당 함수 내에서 그래디언트를 추적하지 않도록 설정 -> 평가 시에는 필요 없기 때문

def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()

    dataset_size = 0
    running_loss = 0.0

    val_scores = []

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid')

    for step, (images, masks) in pbar:
        images  = images.to(device, dtype=torch.float)
        masks   = masks.to(device, dtype=torch.float)

        batch_size = images.size(0)

        y_pred  = model(images)
        loss    = criterion(y_pred, masks)

        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = running_loss / dataset_size

        y_pred = nn.Sigmoid()(y_pred)
        val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
        val_jaccard = iou_coef(masks, y_pred).cpu().detach().numpy()
        val_scores.append([val_dice, val_jaccard])

        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.5f}',
                        gpu_memory=f'{mem:0.2f} GB')
    val_scores  = np.mean(val_scores, axis=0)
    torch.cuda.empty_cache()
    gc.collect()

    return epoch_loss, val_scores

# 🏃 Run Training

In [None]:
def run_training(model, optimizer, scheduler, device, num_epochs):
    # To automatically log gradients
#     wandb.watch(model, log_freq=100)

    if torch.cuda.is_available():
        print("cuda: {}\n".format(torch.cuda.get_device_name()))

    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice      = np.inf
    best_epoch     = -1
    history = defaultdict(list)

    for epoch in range(1, num_epochs + 1):
        gc.collect()
        print(f'Epoch {epoch}/{num_epochs}', end='')
        train_loss = train_one_epoch(model, optimizer, scheduler,
                                           dataloader=train_loader,
                                           device=CFG.device, epoch=epoch)

        val_loss, val_scores = valid_one_epoch(model, valid_loader,
                                                 device=CFG.device,
                                                 epoch=epoch)
        val_dice, val_jaccard = val_scores

        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)
        history['Valid Dice'].append(val_dice)
        history['Valid Jaccard'].append(val_jaccard)

        # Log the metrics
#         wandb.log({"Train Loss": train_loss,
#                    "Valid Loss": val_loss,
#                    "Valid Dice": val_dice,
#                    "Valid Jaccard": val_jaccard,
#                    "LR":scheduler.get_last_lr()[0]})

#         print(f'Valid Dice: {val_dice:0.4f} | Valid Jaccard: {val_jaccard:0.4f}')

        # deep copy the model
        if val_dice >= best_dice:
            print(f"{c_}Valid Score Improved ({best_dice:0.4f} ---> {val_dice:0.4f})")
            best_dice    = val_dice
            best_jaccard = val_jaccard
            best_epoch   = epoch
            run.summary["Best Dice"]    = best_dice
            run.summary["Best Jaccard"] = best_jaccard
            run.summary["Best Epoch"]   = best_epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"best_epoch-{fold:02d}.bin"
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
#             wandb.save(PATH)
            print(f"Model Saved{sr_}")

        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = f'/content/drive/MyDrive/Final_Project/blood-vessel-segmentation/model_weights.pth'
        torch.save(model.state_dict(), PATH)

        print(); print()

    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Score: {:.4f}".format(best_jaccard))

    # load best model weights
    model.load_state_dict(best_model_wts)

    return model, history

# 🔍 Optimizer

In [None]:
def fetch_scheduler(optimizer):
    if CFG.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CFG.T_max,
                                                   eta_min=CFG.min_lr)
    elif CFG.scheduler == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CFG.T_0,
                                                             eta_min=CFG.min_lr)
    elif CFG.scheduler == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   mode='min',
                                                   factor=0.1,
                                                   patience=7,
                                                   threshold=0.0001,
                                                   min_lr=CFG.min_lr,)
    elif CFG.scheduer == 'ExponentialLR':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.85)
    elif CFG.scheduler == None:
        return None

    return scheduler

def select_optimizer():
    if CFG.optimizers == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)

    elif CFG.optimizers == 'nadam':
        optimizer = optim.NAdam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)

    elif CFG.optimizers == 'adamW':
        optimizer = optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)

    elif CFG.optimizers == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)

    elif CFG.optimizers ==None:
        return None

    return optimizer

In [None]:
model = build_model()
optimizer = optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
optimizer = select_optimizer()
scheduler = fetch_scheduler(optimizer)

# 🚅 Training

In [None]:
model, history = run_training(model, optimizer, scheduler,
                                device=CFG.device,
                                num_epochs=CFG.epochs)

# 🔭 Prediction

In [None]:
test_df = pd.DataFrame(data={'test':test_images})
test_dataset = BuildDataset(test_df['test'], label=False,
                            transforms=data_transforms['valid'])
test_loader  = DataLoader(test_dataset, batch_size=5,
                          num_workers=4, shuffle=False, pin_memory=True)
imgs = next(iter(test_loader))
imgs = imgs.to(CFG.device, dtype=torch.float)

preds = []
for fold in range(1):
    model = load_model(f"best_epoch-{fold:02d}.bin")
    with torch.no_grad():
        pred = model(imgs)
        pred = (nn.Sigmoid()(pred)>0.5).double()
    preds.append(pred)

imgs  = imgs.cpu().detach()
preds = torch.mean(torch.stack(preds, dim=0), dim=0).cpu().detach()

In [None]:
def show_img(img, mask=None):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    img = clahe.apply(img)
    plt.figure(figsize=(10,10))
    plt.imshow(img, cmap='bone')

    if mask is not None:
        # plt.imshow(np.ma.masked_where(mask!=1, mask), alpha=0.5, cmap='autumn')
        plt.imshow(mask, alpha=0.5)
        handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.667,0.0,0.0), (0.0,0.667,0.0), (0.0,0.0,0.667)]]
        labels = ["Large Bowel", "Small Bowel", "Stomach"]
        plt.legend(handles,labels)
    plt.axis('off')

In [None]:
def plot_batch(imgs, msks, size=3):
    plt.figure(figsize=(5*5, 5))
    for idx in range(size):
        plt.subplot(1, 5, idx+1)
        img = imgs[idx,].permute((1, 2, 0)).numpy()*255.0
        img = img.astype('uint8')
        msk = msks[idx,].permute((1, 2, 0)).numpy()*255.0
        show_img(img, msk)
    plt.tight_layout()
    plt.show()

In [None]:
plot_batch(imgs, preds, size=5)