# Install Library

In [1]:
# !pip install -q tqdm
!pip install -q tensorboard
!pip install -Uq segmentation-models-pytorch
!pip install -q pytorch-lightning
!pip install -q monai
!pip install -q timm
!pip install -q transformers
!pip install -q torchmetrics
!pip install -q kornia
!pip install -q opencv-python
# !pip install -Uq openmim
# !mim install -q mmcv-full
# !pip install -q mmsegmentation

In [74]:
from __future__ import division, print_function

import datetime
import os
import random

import kornia as K
import kornia.augmentation as KA
import numpy as np
import cv2
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import torch
import torchmetrics
import torchvision.transforms.functional as TF
from PIL import Image
from matplotlib import pyplot as plt
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchmetrics
from torchvision import transforms
from torchvision.utils import draw_segmentation_masks
from tqdm import tqdm
from transformers import ViTModel
import pdb
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
from pytorch_lightning.loggers import TensorBoardLogger

# データの読み込み

In [75]:
def get_subject_image_idx(path: str) -> tuple:
    filename = os.path.splitext(os.path.basename(path))[0]
    subject, image_idx = map(int, filename.split('_')[:2])
    return subject, image_idx


def get_data_path():
    train_dir = '../input/ultrasound-nerve-segmentation/train'
    data_paths = []

    for filename in os.listdir(train_dir):
        if filename.endswith("_mask.tif"):
            continue
        elif filename.endswith(".tif"):
            mask_file_name = os.path.splitext(filename)[0] + '_mask.tif'
            input_img_path = os.path.join(train_dir, filename)
            mask_path = os.path.join(train_dir, mask_file_name)
            data_paths.append((input_img_path, mask_path))

    return data_paths

## データの確認

In [162]:
class UltrasoundNerveDataset(Dataset):
    """Ultrasound image and Nerve structure dataset return the PIL image."""

    def __init__(self, is_val: bool = None, val_stride: int = 0, transform=None, preprocess=None):

        self.data_paths = get_data_path()
        random.Random(111).shuffle(self.data_paths)

        if is_val:
            assert val_stride > 0, val_stride
            self.data_paths = self.data_paths[::val_stride]
            assert self.data_paths
        elif val_stride > 0:
            del self.data_paths[::val_stride]
            assert self.data_paths

        print("{} {} samples".format(len(self.data_paths), "validation" if is_val else "training"))

        self.transform = transform
        self.preprocess = preprocess

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        input_img_path, mask_path = self.data_paths[idx]
        img = cv2.imread(input_img_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        if self.transform:
            img, mask = self.transform(img, mask)
            
        if self.preprocess:
            sample = self.preprocess(image=img, mask=mask)
            img, mask = sample['image'], sample['mask']

        try:
            assert img.shape == mask.shape
        except AssertionError:
            print("img.shape: {}, mask.shape: {}".format(img.shape, mask.shape))
            raise
        

        try:
            assert mask.max().item() <= 1.
        except AssertionError:
            print("mask.max().item(): {}".format(mask.max().item()))
            raise
            
        sample = {'img': img, 'mask': mask}
        return sample


# PyTorch Lightning Data Module

In [163]:
class UltrasoundDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 16, val_stride: int = 5, num_workers: int = None, transform=None, preprocess=None):
        super().__init__()
#         self.save_hyperparameters()
        self.batch_size = batch_size
        self.val_stride = val_stride
        self.transform = transform
        self.preprocess = preprocess

        if num_workers is None:
            self.num_workers = os.cpu_count()
        else:
            self.num_workers = num_workers

    def prepare_data(self) -> None:
        UltrasoundNerveDataset(is_val=False, val_stride=self.val_stride, transform=self.transform, preprocess=self.preprocess)
        UltrasoundNerveDataset(is_val=True, val_stride=self.val_stride, transform=self.transform, preprocess=self.preprocess)

    def setup(self, stage=None):
        self.train_dataset = UltrasoundNerveDataset(is_val=False, val_stride=self.val_stride, transform=self.transform, preprocess=self.preprocess)
        self.val_dataset = UltrasoundNerveDataset(is_val=True, val_stride=self.val_stride, transform=self.transform, preprocess=self.preprocess)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)


## 前処理

In [164]:
class ImageToTensor(nn.Module):
    """Transform module which convert image to tensor, normalize, and resize to 224, 224."""
    def __init__(self, imsize: tuple=(224, 224)):
        super().__init__()
        self.resize = KA.Resize(imsize, keepdim=True, p=1, resample='nearest')
    
    @torch.no_grad()
    def forward(self, x, y) -> torch.Tensor:

        x_out = K.image_to_tensor(x, keepdim=True)
        x_out = x_out.float()
        x_out = torch.div(x_out, 255.)
#         x_out = self.normalize(x_out)
        x_out = self.resize(x_out)
        x_out = torch.squeeze(x_out, 1)
        
        y_out = K.image_to_tensor(y, keepdim=True)
        y_out = y_out.float()
        y_out = torch.div(y_out, 255, rounding_mode='trunc')
        y_out = self.resize(y_out)
        
        return x_out, y_out

In [165]:
class Augmentation(nn.Module):
    def __init__(self):
        super().__init__()
        self.normalize = KA.Normalize(mean=(0.3890,), std=(0.2223,), p=1)
        
    def forward(self, x):
        out = self.normalize(x)
        return out

## Create Model

### Segmentation Models Pytorch

In [166]:
class UNet(nn.Module):
    def __init__(self, encoder_name='resnet34', encoder_weights='imagenet', classes=1, activation='sigmoid', augmentation=None):
        super().__init__()
        
        self.encoder = smp.Unet(encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=1, classes=classes, activation=activation)
        
        self.augmentation = augmentation

    def forward(self, x):
        if self.augmentation:
            x = self.augmentation(x)
        out = self.encoder(x)
        return out
        


### MONAI

In [167]:
# def vit():
#     model = nets.ViT(in_channels=1, img_size=420*580, patch_size=8)
#     return model

### Timm

In [168]:
# def swin():
#     return timm.create_model('swin_base_patch4_window7_448')

## モデルを選択する

In [169]:
def select_model(model_str: str):
    model = None
    model_str = model_str.lower()
    if model_str == "debug":
        model = ModelForDebug()
    elif model_str == 'unet':
        model = UNet(augmentation=Augmentation())
    elif model_str == 'unet_bn':
        model = UNetBatchNorm(imsize=420*580)
    elif model_str == 'unet_plus_plus':
        model = UNetPlusPlus(imsize=420*580)
    elif model_str == 'manet':
        model = MAnet(imsize=420*580)
    elif model_str == 'pspnet':
        model = PSPNet(imsize=420*580)
    elif model_str == 'linknet':
        model = LinkNet(imsize=420*580)
    elif model_str == 'deep_lab_v3_plus':
        model = DeepLabV3Plus(imsize=420*580)
    # TODO ここから下は未実装
    # elif model_str == 'vit':
    #     model = Vit()
    # elif model_str == 'swin':
    #     model = swin()

    return model

# 損失関数と指標定義

In [170]:
class DiceLoss():
    def __call__(self, prediction, target):
        loss = self.calculate_loss(prediction, target)
        return loss

    @staticmethod
    def calculate_loss(prediction, target):
        epsilon = 1

        dice_prediction = prediction.sum(dim=[1, 2, 3])    # 要素iには、バッチiの、各ピクセルについて陽性である確率の総和
        dice_label = target.sum(dim=[1, 2, 3])    # 要素iには、バッチiの陽性ラベルの数
        dice_correct = (target * prediction).sum(dim=[1, 2, 3])
        
        dice_ratio = (2 * dice_correct + epsilon) / (dice_prediction + dice_label + epsilon)

        return 1 - dice_ratio
    
    
class SegmentationMetrics():
    
    def __init__(self):
        self.f1_cal = torchmetrics.F1Score(task='binary')
        self.dice_coefficient

In [171]:
class DiceLoss():
    def __call__(self, prediction, target):
        loss = self.calculate_loss(prediction, target)
        return loss

    @staticmethod
    def calculate_loss(prediction, target):
        epsilon = 1

        dice_prediction = prediction.sum(dim=[1, 2, 3])    # 要素iには、バッチiの、各ピクセルについて陽性である確率の総和
        dice_label = target.sum(dim=[1, 2, 3])    # 要素iには、バッチiの陽性ラベルの数
        dice_correct = (target * prediction).sum(dim=[1, 2, 3])
        
        dice_ratio = (2 * dice_correct + epsilon) / (dice_prediction + dice_label + epsilon)

        return 1 - dice_ratio


# PyTorch Lightning Module

In [172]:
class LitSegmentationModule(pl.LightningModule):
    def __init__(self, model, lr=0.05):
        super().__init__()
        self.model = select_model(model_str=model)
        self.lr = lr
        self.save_hyperparameters()
        self.example_input_array = torch.Tensor(8, 1, 224, 224)
        self.loss_func = self.configure_loss_function()
        self.training_step_outputs = {}
        self.validation_step_outputs = {}
        
        self.f1_cal = torchmetrics.F1Score(task='binary')
        self.recall_cal = torchmetrics.Recall(task='binary')
        self.precision_cal = torchmetrics.Precision(task='binary')
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        img, target = batch['img'], batch['mask']
        target = target.to(torch.uint8)
        prediction = self.model(img)
        print(f"Prediction Max: {prediction.max().item()}, Min: {prediction.min().item()}, Mean: {prediction.mean()}, Std: {prediction.std()}")
        loss = self.loss_func(prediction, target)
#         loss_mean = loss.mean()
#         self.log('train_loss', loss_mean, on_epoch=True)
        self.log('train_loss', loss, on_epoch=True)
        output_dict = self.draw_mask_on_image(batch, prediction, target)
        output_dict['loss'] = loss
        self.training_step_outputs = output_dict

        return loss
    
    def validation_step(self, batch, batch_idx):
        img, target = batch['img'], batch['mask']
        target = target.view(img.shape)
        prediction = self.model(img)
        loss = self.loss_func(prediction, target)
#         loss_mean = loss.mean()
        metrics = self.calculate_metrics(prediction, target)
#         metrics['val_loss'] = loss_mean
        metrics['val_loss'] = loss
        self.log_dict(metrics, on_epoch=True)
        output_dict = self.draw_mask_on_image(batch, prediction, target)
        output_dict['loss'] = loss
        self.validation_step_outputs = output_dict
#         return loss_mean
        return loss

    def validation_epoch_end(self, validation_step_outputs):
        if not self.training_step_outputs:
            return
        training_outputs = self.training_step_outputs
        validation_outputs = self.validation_step_outputs
        logger = self.logger.experiment
        
        epochs = self.current_epoch
        
        for mode_str, output in zip(['training', 'val'], [training_outputs, validation_outputs]):
            self.add_image(mode_str, output, epochs, logger)
    
    @torch.no_grad()
    def calculate_metrics(self, prediction, target):
        f1 = self.f1_cal(preds=prediction, target=target)
        precision = self.precision_cal(preds=prediction, target=target)
        recall = self.recall_cal(preds=prediction, target=target)
        
        metrics_dict = {
            'precision': precision,
            'recall': recall,
            'f1': f1
        }

        return metrics_dict
        
    @torch.no_grad()
    def draw_mask_on_image(self, batch, predictions, target):
        self.model.eval()

        # visualize image
        imgs_vis = batch['img'].cpu()
        imgs_vis *= 255
        imgs_vis = imgs_vis.to(torch.uint8)
        imgs_rgb = K.color.grayscale_to_rgb(imgs_vis)

        # target mask
        targets = torch.squeeze(target, 1)
        targets = targets.to(torch.bool).cpu()

        # output mask
        predictions = predictions >= 0.5
        predictions = torch.squeeze(predictions, 1).cpu()

        # mask
        output_with_mask = [draw_segmentation_masks(img_rgb, masks=output, alpha=.3, colors="#FFFFFF") for img_rgb, output in zip(imgs_rgb, predictions)]
        target_with_mask = [draw_segmentation_masks(img_rgb, masks=target, alpha=.3, colors="#FFFFFF") for img_rgb, target in zip(imgs_rgb, targets)]
        
        output_dict = {'output': output_with_mask, 'target': target_with_mask}
        return output_dict

    @staticmethod
    def add_image(mode_str, output, epochs, logger):
#         for image_ndx, (output_mask, target_mask, loss) in enumerate(zip(output['output'], output['target'], output['loss'])):
        for image_ndx, (output_mask, target_mask) in enumerate(zip(output['output'], output['target'])):
            loss = output['loss']
            logger.add_image(f'{mode_str}/E{epochs}_#{image_ndx}_prediction_{loss.item():.3f}loss', output_mask, dataformats='CHW')
            logger.add_image(f'{mode_str}/E{epochs}_#{image_ndx}_label', target_mask, dataformats='CHW')
            logger.flush()

    
    def configure_loss_function(self):
#         loss_func = DiceLoss()
        loss_func = smp.losses.DiceLoss('binary', from_logits=False)
        return loss_func

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [173]:
preds = torch.ones(8, 1, 224, 224)
target = torch.randint(2, (8, 1, 224, 224))
loss = smp.losses.DiceLoss('binary', from_logits=False, smooth=1, eps=0)
loss(preds, target)

tensor(0.3329)

# Execution

In [175]:
# root_dir
default_root_dir = os.path.join(os.path.dirname(os.path.abspath(os.getcwd())), 'working')
log_dir = os.path.join(default_root_dir, 'logs')

# Reproducibility
seed_everything = pl.seed_everything(42, workers=True)

# Dataloader
preprocessing_fn = get_preprocessing_fn("resnet34", pretrained="imagenet")
# Default: batch_size: int = 16, val_stride: int = 5, num_workers: int = None, transform=None, preprocess=None
data_module = UltrasoundDataModule(batch_size=32, val_stride=6, num_workers=5, transform=ImageToTensor(), preprocess=preprocessing_fn)

# Callbacks
early_stopping = EarlyStopping(monitor='val_loss', verbose=True, patience=3)
model_checkpoint = ModelCheckpoint(monitor='val_loss', verbose=True)
model_summary = ModelSummary(max_depth=5)

callbacks = [early_stopping, model_checkpoint, model_summary]

# Model
# model = "debug", "Init", "UNet", "UNet_Pad", "UNet_BN", "UNet_with_library", "MANet", "Vit", "Swin"
model = LitSegmentationModule(model="UNet")

# Logger
model_name = type(model.model).__name__
JST = datetime.timezone(datetime.timedelta(hours=+9), 'JST')
now_str = datetime.datetime.now(JST).strftime('%Y%m%d_%H.%M')
log_name = now_str + model_name + '_logs'
tensorboard_logger = TensorBoardLogger(save_dir=log_dir, name=log_name)

# Training

# TODO When submit the final results, change benchmark=False, and deterministic=True to ensure reproducibility.
trainer = pl.Trainer(fast_dev_run=False, devices='auto', accelerator='auto', default_root_dir=default_root_dir, logger=tensorboard_logger,  callbacks=callbacks, max_epochs=40, check_val_every_n_epoch=3, auto_scale_batch_size=False, auto_lr_find=False, benchmark=True, num_sanity_val_steps=0, enable_progress_bar=True)
trainer.tune(model, data_module)
trainer.fit(model, data_module)

Global seed set to 42
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/ec2-user/SageMaker/working/logs/20230312_02.38UNet_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                                         | Type             | Params | In sizes                                                                                                   | Out sizes                                                                                                 
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

4695 training samples
940 validation samples
4695 training samples
940 validation samples


Training: 0it [00:00, ?it/s]

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_22329/3734063520.py", line 37, in __getitem__
    sample = self.preprocess(image=img, mask=mask)
TypeError: preprocess_input() missing 1 required positional argument: 'x'
