# Install Library

In [None]:
!pip install -q tqdm
!pip install -q tensorboard
!pip install -q segmentation-models-pytorch
!pip install -q torchinfo

In [None]:
from __future__ import division, print_function

import collections
import logging
import os
import random

import numpy as np
import segmentation_models_pytorch as smp
import torch
import torchvision.transforms.functional as F  # TODO Fにするのは, transforms.functionalなのか、nn.functionalなのか、両方しないのか
from PIL import Image
from matplotlib import pyplot as plt
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from torchvision import transforms
from torchvision.utils import draw_segmentation_masks
from tqdm import tqdm

# Logging Configuration

In [1]:
root_logger = logging.getLogger(__name__)
root_logger.setLevel(logging.INFO)

for handler in list(root_logger.handlers):
    root_logger.removeHandler(handler)

logfmt_str = "%(asctime)s %(levelname)-8s pid:%(process)d %(name)s:%(lineno)03d:%(funcName)s %(message)s"
formatter = logging.Formatter(logfmt_str)

streameHandler = logging.StreamHandler()
streameHandler.setFormatter(formatter)
streameHandler.setLevel(logging.DEBUG)

root_logger.addHandler(streameHandler)

log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

# Configuration

In [2]:
# Deviceの設定
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEBUG_MODE = None
if torch.cuda.is_available():
    print(f"Using {device}; {torch.cuda.device_count()} devices.")
else:
    DEBUG_MODE = True
    print(f"Using {device}")

# Modelの設定

# MODEL_STR は以下から選ぶ
# "debug", "Init", "UNet", "UNet_Pad", "UNet_BN", "UNet_with_library", "MANet"
MODEL_STR = "MANet"
if DEBUG_MODE:
    MODEL_STR = "debug"

RESIZE = True

EPOCHS = 30
BATCH_SIZE = 8
NUM_WORKERS = 1

# CUDAのメモリの割当の設定
# !env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:


Using cuda; 1 devices.


# Tensorboard Writer

In [3]:
def init_tensor_board_writers(train_writer=None):
    if train_writer is None:
        model_name = type(model).__name__
        base_dir = os.path.dirname(os.path.abspath('.'))
        time_str = datetime.datetime.now().strftime('%Y_%m_%d_%H.%M.%S')
        log_dir = os.path.join(base_dir, 'working', 'log', '{model}_{epochs}epochs_{batch_size}batches'.format(model=model_name, epochs=EPOCHS, batch_size=BATCH_SIZE), time_str)
        train_writer = SummaryWriter(log_dir=log_dir + '_train_seg_')
        val_writer = SummaryWriter(log_dir=log_dir + '_val_seg_')
        return train_writer, val_writer

# データの読み込み

In [4]:
def get_subject_image_idx(path: str) -> tuple:
    filename = os.path.splitext(os.path.basename(path))[0]
    subject, image_idx = map(int, filename.split('_')[:2])
    return subject, image_idx


def get_data_path():
    train_dir = '../input/ultrasound-nerve-segmentation/train'
    input_img_paths = []
    target_paths = []

    for filename in os.listdir(train_dir):
        if filename.endswith("mask.tif"):
            target_paths.append(os.path.join(train_dir, filename))
        elif filename.endswith(".tif"):
            input_img_paths.append(os.path.join(train_dir, filename))

    input_img_paths.sort(key=lambda x: get_subject_image_idx(x))
    target_paths.sort(key=lambda x: get_subject_image_idx(x))
    data_paths = [(input_img, target) for input_img, target in zip(input_img_paths, target_paths)]

    return data_paths

## データの確認

In [5]:
def count_num_imgs_for_subject():
    data_paths = get_data_path()
    subject_l = [get_subject_image_idx(input_img_path)[0] for input_img_path, target_path in data_paths]
    sample_counter = collections.Counter(subject_l)
    return sample_counter

def show_sample_counter():
    sample_counter = count_num_imgs_for_subject()
    sample_counter = [{'subject': key, 'Image Num': value} for key, value in sample_counter.items()]
    log.info(sample_counter)
    
def in_nerve(target_path):
    target = Image.open(target_path)
    return target.getextrema()[1] > 0

show_sample_counter()

num_data = len(get_data_path())
log.info('{} samples in total.'.format(num_data))

2023-02-16 06:34:52,505 INFO     pid:17999 __main__:028:<cell line: 28> 5635 samples in total.


    subject  Image Num
0         1        120
1         2        120
2         3        119
3         4        120
4         5        120
5         6        120
6         7        119
7         8        120
8         9        120
9        10        120
10       11        120
11       12        120
12       13        120
13       14        120
14       15        120
15       16        120
16       17        119
17       18        120
18       19        120
19       20        120
20       21        120
21       22        120
22       23        120
23       24        120
24       25        120
25       26        120
26       27        120
27       28        120
28       29        120
29       30        120
30       31        120
31       32        120
32       33        120
33       34        119
34       35        120
35       36        120
36       37        120
37       38        120
38       39        120
39       40        120
40       41        120
41       42        120
42       43

## Datasetの作成

In [7]:
class UltrasoundNerveDataset(Dataset):
    """Ultrasound image and Nerve structure dataset return the PIL image."""

    def __init__(self, is_val: bool=None, val_stride: int=0, only_nerve_imgs: bool=False, subject_img_idx: tuple=None, is_random: bool=None, transform=None, data_num: int=None):
        self.data_paths = get_data_path()

        if is_random:
            random.Random(111).shuffle(self.data_paths)

        if only_nerve_imgs:
            self.data_paths = [(input_img, target) for input_img, target in self.data_paths if in_nerve(target)]

        if subject_img_idx:
            self.data_paths = [path_tuple for path_tuple in self.data_paths if get_subject_image_idx(path_tuple[0]) == subject_img_idx]

        if data_num:
            self.data_paths = self.data_paths[:data_num]


        if is_val:
            assert val_stride > 0, val_stride
            self.data_paths = self.data_paths[::val_stride]
            assert self.data_paths
        elif val_stride > 0:
            del self.data_paths[::val_stride]
            assert self.data_paths

        log.info("{!r}: {} {} samples".format(self, len(self.data_paths), "validation" if is_val else "training"))
        self.transform = transform

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        input_img_path, target_path = self.data_paths[idx]
        img = Image.open(input_img_path)
        target = Image.open(target_path)

        sample = {'img': img, 'target': target}

        if self.transform:
            sample = self.transform(sample)

        return sample

## 前処理

In [None]:
### テンソル化する（imageとtargetで処理を変える）
# - Image: Tensorにして、MaxMinScale
# - Target: そのままTensorにする

def select_transform(resize: bool):
    if resize:
        return transforms.Compose([
            PILToTensor(),
            Normalize(mean=[MEAN], std=[STD]),
            Resize32Multiple(),
                ])
    return transforms.Compose([
        PILToTensor(),
        Normalize(mean=[MEAN], std=[STD]),
        ])

TRANSFORM = select_transform(RESIZE)


class PILToTensor:
    def __call__(self, sample):
        img_pil, target = sample['img'], sample['target']
        img = F.to_tensor(img_pil)
        img_vis = F.pil_to_tensor(img_pil)
        target = torch.as_tensor(np.array(target), dtype=torch.float64)
        target = torch.div(target, 255, rounding_mode='floor')
        sample = {'img': img, 'target': target, 'img_vis': img_vis}
        return sample

### 正規化

# メモリを使わないように、計算は繰り返さない

# dataset_to_calculate_mean_std = UltrasoundNerveDataset(transform=PILToTensor())
# imgs = [sample['img'] for sample in dataset_to_calculate_mean_std]
# imgs_t = torch.stack(imgs, dim=3)
# MEAN = imgs_t.view(1, -1).mean(dim=1)
# STD = imgs_t.view(1, -1).std(dim=1)
# log.info(f"Mean: {MEAN}, Standard deviation: {STD}")

MEAN = 0.3898
STD = 0.2219

class Normalize:
    def __init__(self, mean: list, std: list):
        self.mean = mean
        self.std = std

    def __call__(self, sample):
        img = sample['img']
        img = F.normalize(img, self.mean, self.std)
        sample['img'] = img
        return sample

### Resize

class Resize32Multiple:
    def __init__(self):
        self.pad = nn.ConstantPad2d(14, 0)

    def __call__(self, sample):
        img, target, img_vis = sample['img'], sample['target'], sample['img_vis']
        img = self.pad(img)
        target = self.pad(target)
        img_vis = self.pad(img_vis)
        sample = {'img': img, 'target': target, 'img_vis': img_vis}
        return sample

# Data loaderの作成

In [15]:
def init_dataloader(is_val: bool=None, only_nerve_imgs: bool=None, transform=TRANSFORM, data_num=None):
    batch_size = BATCH_SIZE
    num_workers = NUM_WORKERS
    if torch.cuda.is_available():
        batch_size *= torch.cuda.device_count()

    dataset = UltrasoundNerveDataset(is_val=is_val, val_stride=3, is_random=True, only_nerve_imgs=only_nerve_imgs, transform=TRANSFORM, data_num=data_num)

    data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=torch.cuda.is_available())

    return data_loader

# 損失関数定義

## 指標を定義

In [13]:
METRICS_DICE_LOSS_NDX = 0
METRICS_DICE_COEFFICIENT_NDX = 1
METRICS_TP_NDX = 2
METRICS_FN_NDX = 3
METRICS_FP_NDX = 4
METRICS_SIZE = 5

## 損失関数とバッチ損失

In [14]:
def calculate_dice_loss(prediction, target):
    epsilon = 1

    dice_prediction = prediction.sum(dim=[1, 2, 3])    # 要素iには、バッチiの、各ピクセルについて陽性である確率の総和
    dice_label = target.sum(dim=[1, 2, 3])    # 要素iには、バッチiの陽性ラベルの数
    dice_correct = (target * prediction).sum(dim=[1, 2, 3])

    dice_ratio = (2 * dice_correct + epsilon) / (dice_prediction + dice_label + epsilon)

    return 1 - dice_ratio


def compute_batch_loss(model, batch_ndx, batch_sample: dict, batch_size, metrics, classification_threshold=0.5):
    img, target = batch_sample['img'], batch_sample['target']

    img = img.to(device=device)
    target = target.view(img.shape)
    target = target.to(device=device)

    prediction = model(img)    # N x 1 x H x W

    dice_loss = calculate_dice_loss(prediction, target)

    # Metricsに指標を保存する。

    start_ndx = batch_ndx * batch_size
    end_ndx = start_ndx + img.size(0)

    with torch.no_grad():
        prediction_bool = (prediction[:, 0:1] > classification_threshold).to(torch.float32)

        target = target.to(torch.uint8)

        dice_coefficient = 1 - calculate_dice_loss(prediction_bool, target)
        true_positive = (prediction_bool * target).sum(dim=[1, 2, 3])
        false_negative = ((1 - prediction_bool) * target).sum(dim=[1, 2, 3])
        false_positive = (prediction_bool * (~target)).sum(dim=[1, 2, 3])

        metrics[METRICS_DICE_LOSS_NDX, start_ndx:end_ndx] = dice_loss
        metrics[METRICS_DICE_COEFFICIENT_NDX, start_ndx:end_ndx] = dice_coefficient
        metrics[METRICS_TP_NDX, start_ndx:end_ndx] = true_positive
        metrics[METRICS_FN_NDX, start_ndx:end_ndx] = false_negative
        metrics[METRICS_FP_NDX, start_ndx:end_ndx] = false_positive

    return dice_loss.mean()


## 損失関数のDebug

In [16]:
debug_loss_loader= DataLoader(dataset=UltrasoundNerveDataset(is_val=True, val_stride=1000, only_nerve_imgs=False, transform=TRANSFORM), batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
for sample in debug_loss_loader:
    target = sample['target']
    if RESIZE:
        target = target.view(6, 1, 448, 608)
    else:
        target = target.view(6, 1, 420, 580)
    log.debug(target.shape)
    losses = calculate_dice_loss(target, target)
    assert losses.max().item() == 0.0
    assert losses.min().item() == 0.0

2023-02-16 06:34:52,624 INFO     pid:17999 __main__:028:__init__ <__main__.UltrasoundNerveDataset object at 0x7f018d2098e0>: 6 validation samples
2023-02-16 06:34:52,723 DEBUG    pid:17999 __main__:008:<cell line: 2> torch.Size([6, 1, 448, 608])


## Create Model

### Model for debug

In [22]:
class ModelForDebug(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, padding="same", padding_mode="zeros")
        self.conv2 = nn.Conv2d(64, 1, 3, padding="same", padding_mode="zeros")
        self.batch_norm = nn.BatchNorm2d(64)


    def forward(self, x):
        out = self.batch_norm(self.conv1(x))
        out = torch.sigmoid(self.conv2(out))
        return out


### UNet with Batch Normalization

In [20]:
class UNetBatchNormConvBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=torch.relu):
        super(UNetBatchNormConvBlock, self).__init__()
        self.batch_norm = nn.BatchNorm2d(out_size)
        self.batch_norm2 = nn.BatchNorm2d(out_size)
        self.conv = nn.Conv2d(in_size, out_size, kernel_size)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size)
        self.activation = activation

    def forward(self, x):
        out = self.activation(self.batch_norm(self.conv(x)))
        out = self.activation(self.batch_norm2(self.conv2(out)))

        return out


class UNetBatchNormUpBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=torch.relu, space_dropout=False):
        super(UNetBatchNormUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
        self.conv = nn.Conv2d(in_size, out_size, kernel_size)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size)
        self.batch_norm = nn.BatchNorm2d(out_size)
        self.batch_norm2 = nn.BatchNorm2d(out_size)
        self.activation = activation

    def center_crop(self, layer: torch.Tensor, target_height, target_width):
        batch_size, n_channels, layer_height, layer_width = layer.size()
        height_remove = (layer_height - target_height) // 2
        width_remove = (layer_width - target_width) // 2
        return layer[:, :, height_remove:(height_remove + target_height), width_remove:(width_remove + target_width)]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.size()[2], up.size()[3])
        out = torch.cat([up, crop1], 1)
        out = self.activation(self.batch_norm(self.conv(out)))
        out = self.activation(self.batch_norm2(self.conv2(out)))

        return out


class UNetBatchNorm(nn.Module):
    def __init__(self, imsize):
        super(UNetBatchNorm, self).__init__()
        self.imsize = imsize

        self.activation = torch.relu

        self.pool1 = nn.MaxPool2d(2)
        self.pool2 = nn.MaxPool2d(2)
        self.pool3 = nn.MaxPool2d(2)

        self.conv_block1_64 = UNetBatchNormConvBlock(1, 64)
        self.conv_block64_128 = UNetBatchNormConvBlock(64, 128)
        self.conv_block128_256 = UNetBatchNormConvBlock(128, 256)
        self.conv_block256_512 = UNetBatchNormConvBlock(256, 512)

        self.up_block512_256 = UNetBatchNormUpBlock(512, 256)
        self.up_block256_128 = UNetBatchNormUpBlock(256, 128)
        self.up_block128_64 = UNetBatchNormUpBlock(128, 64)

        self.last = nn.Conv2d(64, 1, 1)
        self.pad = nn.ConstantPad2d(44, 0)

    def forward(self, x):
        block1 = self.conv_block1_64(x)
        pool1 = self.pool1(block1)

        block2 = self.conv_block64_128(pool1)
        pool2 = self.pool2(block2)

        block3 = self.conv_block128_256(pool2)
        pool3 = self.pool3(block3)

        block4 = self.conv_block256_512(pool3)

        up1 = self.up_block512_256(block4, block3)

        up2 = self.up_block256_128(up1, block2)

        up3 = self.up_block128_64(up2, block1)

        return self.pad(torch.sigmoid(self.last(up3)))

### Segmentation Models Pytorch

In [21]:
def unet_plus_plus():
    model = smp.UnetPlusPlus(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=1,
    classes=1,
    )
    return model

def ma_net():
    model = smp.MAnet(in_channels=1, classes=1, activation="sigmoid")
    return model

def deep_lab_v3_plus(encoder_weights="imagenet"):
    model = smp.DeepLabV3Plus(encoder_weights=encoder_weights, in_channels=1, classes=1, activation="sigmoid")
    return model

## モデルを選択する

In [23]:
def select_model(model_str: str):
    model = None
    model_str = model_str.lower()
    if model_str == "debug":
        model = ModelForDebug()
    elif model_str == "init":
        model =  InitNeuralNetwork()
    elif model_str == 'unet':
        model =  UNet(imsize=420*580)
    elif model_str == 'unet_pad':
        model = UNetPadEach(imsize=420*580)
    elif model_str == 'unet_bn':
        model = UNetBatchNorm(imsize=420*580)
    elif model_str == 'unet_with_library':
        model = unet_plus_plus()
    elif model_str == 'manet':
        model = ma_net()
    elif model_str == 'deep_lab_v3_plus':
        model = deep_lab_v3_plus(None)


    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model = model.to(device=device)

    return model

model = select_model(model_str=MODEL_STR)

## モデルを確認する

In [24]:
summary(model=model, input_size=(BATCH_SIZE, 1, 448, 608))

Layer (type:depth-idx)                             Output Shape              Param #
MAnet                                              [8, 1, 448, 608]          --
├─ResNetEncoder: 1-1                               [8, 1, 448, 608]          --
│    └─Conv2d: 2-1                                 [8, 64, 224, 304]         3,136
│    └─BatchNorm2d: 2-2                            [8, 64, 224, 304]         128
│    └─ReLU: 2-3                                   [8, 64, 224, 304]         --
│    └─MaxPool2d: 2-4                              [8, 64, 112, 152]         --
│    └─Sequential: 2-5                             [8, 64, 112, 152]         --
│    │    └─BasicBlock: 3-1                        [8, 64, 112, 152]         73,984
│    │    └─BasicBlock: 3-2                        [8, 64, 112, 152]         73,984
│    │    └─BasicBlock: 3-3                        [8, 64, 112, 152]         73,984
│    └─Sequential: 2-6                             [8, 128, 56, 76]          --
│    │    └─BasicBl

# 訓練

In [26]:
optimizer = torch.optim.Adam(model.parameters())
train_loader= init_dataloader(is_val=False, only_nerve_imgs=False, transform=TRANSFORM)
val_loader = init_dataloader(is_val=True, only_nerve_imgs=False, transform=TRANSFORM)
train_writer, val_writer = init_tensor_board_writers()

2023-02-16 06:35:03,160 INFO     pid:17999 __main__:028:__init__ <__main__.UltrasoundNerveDataset object at 0x7f0184a8ad30>: 3756 training samples
2023-02-16 06:35:03,207 INFO     pid:17999 __main__:028:__init__ <__main__.UltrasoundNerveDataset object at 0x7f0184a8acd0>: 1879 validation samples


## Training function

In [27]:
def do_train(model, epoch_ndx, train_dataloader, total_training_samples_count):
    torch.autograd.set_detect_anomaly(True)
    model.train()
    # 指標の入力
    train_metrics = torch.zeros(METRICS_SIZE, len(train_dataloader.dataset), device=device)

#     batch_iter = enumerateWithEstimate(train_dataloader, "E{} Training".format(epoch_ndx), start_ndx=train_dataloader.num_workers)

    with tqdm(train_dataloader) as pbar:

        for batch_ndx, batch_samples in enumerate(pbar):

            batch_loss = compute_batch_loss(model, batch_ndx, batch_samples, train_dataloader.batch_size, train_metrics)
            pbar.set_description("Training #{}".format(epoch_ndx))
            pbar.set_postfix(loss=batch_loss.item())

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

        total_training_samples_count += train_metrics.size(1)

        return train_metrics.cpu(), total_training_samples_count

def do_validation(model, epoch_ndx, val_dataloader):
    with torch.inference_mode():
        model.eval()
        val_metrics = torch.zeros(METRICS_SIZE, len(val_dataloader.dataset), device=device)

        with tqdm(val_dataloader) as pbar_val:
            for batch_ndx, batch_samples in enumerate(pbar_val):
                pbar_val.set_description("Validation #{}".format(epoch_ndx))
                compute_batch_loss(model, batch_ndx, batch_samples, val_dataloader.batch_size, val_metrics)

    return val_metrics.cpu()


def log_metrics(epoch_ndx, mode_str, metrics, tensorboard_writer, total_training_samples_count):

    metrics_numpy = metrics.detach().numpy()
    sum_metrics = metrics_numpy.sum(axis=1)
    assert np.isfinite(metrics).all()

    all_positive_count = sum_metrics[METRICS_TP_NDX] + sum_metrics[METRICS_FN_NDX]

    metrics_dict = {}
    metrics_dict['dice_loss'] = metrics_numpy[METRICS_DICE_LOSS_NDX].mean()
    metrics_dict['dice_coefficient'] = metrics_numpy[METRICS_DICE_COEFFICIENT_NDX].mean()
    metrics_dict['percent_all/true_positive'] = sum_metrics[METRICS_TP_NDX] / (all_positive_count or 1) * 100
    metrics_dict['percent_all/false_negative'] = sum_metrics[METRICS_FN_NDX] / (all_positive_count or 1) * 100
    metrics_dict['percent_all/false_positive'] = sum_metrics[METRICS_FP_NDX] / (all_positive_count or 1) * 100

    precision = metrics_dict['pr/precision'] = sum_metrics[METRICS_TP_NDX] / ((sum_metrics[METRICS_TP_NDX] + sum_metrics[METRICS_FP_NDX]) or 1)
    recall = metrics_dict['pr/recall'] = sum_metrics[METRICS_TP_NDX] / ((sum_metrics[METRICS_TP_NDX] + sum_metrics[METRICS_FN_NDX]) or 1)

    metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)

    log.info(("E{} {:8} "
              + "Dice loss: {dice_loss:.4f}, ").format(
        epoch_ndx, mode_str, **metrics_dict
    ))

    log.info(("E{} {:8} "
              + "Dice coefficient: {dice_coefficient:.4f}, ").format(
        epoch_ndx, mode_str, **metrics_dict
    ))
    
    for key, value in metrics_dict.items():
        tensorboard_writer.add_scalar(key, value, total_training_samples_count, new_style=True)
        tensorboard_writer.flush()
            
    score = metrics_dict['dice_coefficient']

    return score

def log_images(model, mode_str, dataloader, tensorboard_writer, epochs, total_training_samples_count):
    model.eval()

    sample_batch = next(iter(dataloader))
    imgs_model, targets = sample_batch['img'], sample_batch['target']


    # visualize image
    imgs_vis = sample_batch['img_vis']

    # target mask
    targets_masks = targets.to(torch.bool)

    # model
    imgs_g = imgs_model.to(device=device)
    outputs = model(imgs_g)
    targets_for_loss = targets.unsqueeze(1).to(device=device)
    loss = calculate_dice_loss(outputs, targets_for_loss)

    # output mask
    outputs_masks = outputs >= 0.5
    outputs_masks = torch.squeeze(outputs_masks, 1)
    outputs.cpu()

    # input image
    rg = torch.zeros(imgs_vis.shape[0:1] +  (2,) + imgs_vis.shape[2:4], dtype=torch.uint8)
    imgs_rgb = torch.cat((rg, imgs_vis), 1).to(dtype=torch.uint8)

    # mask
    outputs_with_masks = [draw_segmentation_masks(img_rgb, masks=output, alpha=.3, colors="#FFFFFF") for img_rgb, output in zip(imgs_rgb, outputs_masks)]
    targets_with_masks = [draw_segmentation_masks(img_rgb, masks=target, alpha=.3, colors="#FFFFFF") for img_rgb, target in zip(imgs_rgb, targets_masks)]

    for image_ndx, (output_mask, target_mask) in enumerate(zip(outputs_with_masks, targets_with_masks)):

        tensorboard_writer.add_image(f'{mode_str}/E{epochs}_#{image_ndx}_prediction_{loss[image_ndx].item():.3f}loss', output_mask, total_training_samples_count, dataformats='CHW')
        tensorboard_writer.add_image(f'{mode_str}/E{epochs}_#{image_ndx}_label', target_mask, total_training_samples_count, dataformats='CHW')
        tensorboard_writer.flush()


## 実行

### Tensorboard利用

In [28]:
# model_name = type(model).__name__
# base_dir = os.path.dirname(os.path.abspath('.'))
# time_str = datetime.datetime.now().strftime('%Y_%m_%d_%H.%M.%S')
# log_dir = os.path.join(base_dir, 'working', 'log', '{model}_{epochs}epochs_{batch_size}batches'.format(model=model_name, epochs=EPOCHS, batch_size=BATCH_SIZE), time_str)

# %load_ext tensorboard
# %tensorboard --logdir log_dir

In [None]:
log.info("Starting {}".format(type(model).__name__))

best_score = 0.0
validation_cadence = 5
epochs = EPOCHS
total_training_samples_count = 0
for epoch_ndx in (range(1, epochs + 1)):

    log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
        epoch_ndx,
        epochs,
        len(train_loader),
        len(val_loader),
        BATCH_SIZE,
        (torch.cuda.device_count() if torch.cuda.is_available() else 1),
    ))

    train_metrics, total_training_samples_count = do_train(model, epoch_ndx, train_loader, total_training_samples_count)
    log_metrics(epoch_ndx, 'Training', train_metrics, train_writer, total_training_samples_count)

    if epoch_ndx == 1 or epoch_ndx % validation_cadence == 0 or epoch_ndx == epochs:
        val_metrics = do_validation(model, epoch_ndx, val_loader)
        # Score is calculated with Dice Coefficient.
        score = log_metrics(epoch_ndx, 'Validation', val_metrics, val_writer, total_training_samples_count)
        best_score = max(score, best_score)
        
        log_images(model, 'Training', train_loader, train_writer, epoch_ndx, total_training_samples_count)
        log_images(model, 'Validation', val_loader, val_writer, epoch_ndx, total_training_samples_count)

log.info('Best Score: {}'.format(best_score))
train_writer.close()
val_writer.close()

2023-02-16 06:35:03,243 INFO     pid:17999 __main__:001:<cell line: 1> Starting MAnet
2023-02-16 06:35:03,245 INFO     pid:17999 __main__:009:<cell line: 7> Epoch 1 of 30, 470/235 batches of size 8*1
Training #1: 100%|██████████| 470/470 [04:04<00:00,  1.92it/s, loss=0.251]  
2023-02-16 06:39:07,957 INFO     pid:17999 __main__:060:log_metrics E1 Training Dice loss: 0.7241, 
2023-02-16 06:39:07,958 INFO     pid:17999 __main__:065:log_metrics E1 Training Dice coefficient: 0.3147, 
Validation #1: 100%|██████████| 235/235 [00:33<00:00,  7.07it/s]
2023-02-16 06:39:41,254 INFO     pid:17999 __main__:060:log_metrics E1 Validation Dice loss: 0.4204, 
2023-02-16 06:39:41,255 INFO     pid:17999 __main__:065:log_metrics E1 Validation Dice coefficient: 0.5823, 
2023-02-16 06:39:46,041 INFO     pid:17999 __main__:009:<cell line: 7> Epoch 2 of 30, 470/235 batches of size 8*1
Training #2: 100%|██████████| 470/470 [04:16<00:00,  1.83it/s, loss=0.5]     
2023-02-16 06:44:02,300 INFO     pid:17999 __mai

### モデルの保存

In [None]:
base_dir = os.path.dirname(os.path.abspath('.'))
model_path = os.path.join(base_dir, 'models')
if not os.path.exists(model_path):
    os.mkdir(model_path)
model_name = type(model).__name__
torch.save(model, '{}/{}.pth'.format(model_path, model_name))

# 訓練結果の出力

In [None]:
class Visualize:
    def __init__(self, mean: list, std: list):
        self.mean = mean
        self.std = std
        self.pad = nn.ConstantPad2d(14, 0)

    def __call__(self, sample):
        img, target = sample['img'], sample['target']
        img_for_model = F.to_tensor(img)
        img_vis = F.pil_to_tensor(img)
        target = torch.as_tensor(np.array(target), dtype=torch.int64)
        target = torch.abs(target)

        img_for_model = F.normalize(img_for_model, self.mean, self.std)

        if RESIZE:
            img_for_model = self.pad(img_for_model)
            target= self.pad(target)
            img_vis = self.pad(img_vis)

        sample = {'img': img_for_model, 'target': target, 'img_vis': img_vis}

        return sample
    

In [None]:
def show(imgs):
    fig, axes = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(30, 30))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axes[0, i].imshow(np.asarray(img))
        axes[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        

In [None]:
def show_masks(model, classification_threshold=0.5, num_show_imgs=5):
    val_stride = 5635 // num_show_imgs + 1
    with torch.no_grad():
        model.eval()
        torch.cuda.empty_cache()
        visualize_dataloader = DataLoader(UltrasoundNerveDataset(is_val=True, val_stride=val_stride, is_random=True, transform=Visualize(mean=[MEAN], std=[STD]), only_nerve_imgs=True), batch_size=num_show_imgs)
        for samples in visualize_dataloader:
            imgs_model, targets = samples['img'], samples['target']

            
            # visualize image
            imgs_vis = samples['img_vis']

            # target mask
            targets_masks = targets.to(torch.bool)
            
            # model
            imgs_g = imgs_model.to(device=device)
            outputs = model(imgs_g)
            
            # output mask
            log.info("Output shape: {}, Output min: {}, Output Max: {}".format(outputs.shape, outputs.min().item(), outputs.max().item()))
            outputs = outputs >= classification_threshold
            outputs_masks = outputs.to(torch.bool)
            outputs_masks = torch.squeeze(outputs_masks, 1)
            outputs.cpu()
            
            # input image
            rg = torch.zeros(imgs_vis.shape[0:1] +  (2,) + imgs_vis.shape[2:4], dtype=torch.uint8)
            imgs_rgb = torch.cat((rg, imgs_vis), 1).to(dtype=torch.uint8)
            
            # mask
            outputs_with_masks = [draw_segmentation_masks(img_rgb, masks=output, alpha=.3, colors="#FFFFFF") for img_rgb, output in zip(imgs_rgb, outputs_masks)]
            targets_with_masks = [draw_segmentation_masks(img_rgb, masks=target, alpha=0.3, colors="#FFFFFF") for img_rgb, target in zip(imgs_rgb, targets_masks)]
            show(imgs_vis)
            show(outputs_with_masks)
            show(targets_with_masks)

In [None]:
show_masks(model, num_show_imgs=10)