In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import os, shutil, glob, sys, math, cv2
import pydicom

import segmentation_models_pytorch as smp
import albumentations as albu
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
import pickle

In [2]:
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torch import Tensor
from torch.jit.annotations import List
from torchsummary import summary
from torchvision import models

In [3]:
train_dicom_folder = '/work/tsung1271232/ncku-dataset/train/dicom'
train_label_folder = '/work/tsung1271232/train/label'
valid_dicom_folder = '/work/tsung1271232/ncku-dataset/test/dicom'
valid_label_folder = '/work/tsung1271232/ncku-dataset/test/label'
normal_folder = '/work/tsung1271232/ncku-dataset/normal'

In [4]:
train_dicom_fp = []
train_label_fp = []
valid_dicom_fp = []
valid_label_fp = []
for i in sorted(os.listdir(train_dicom_folder)):
    name, extension = os.path.splitext(i)
    train_dicom_fp.append(os.path.join(train_dicom_folder, i))
    if os.path.exists(os.path.join(train_label_folder, name, 'label.png')):
        train_label_fp.append(os.path.join(train_label_folder, name, 'label.png'))
    else:
        train_label_fp.append(None)
        
for i in sorted(os.listdir(valid_dicom_folder)):
    name, extension = os.path.splitext(i)
    valid_dicom_fp.append(os.path.join(valid_dicom_folder, i))
    if os.path.exists(os.path.join(valid_label_folder, name, 'label.png')):
        valid_label_fp.append(os.path.join(valid_label_folder, name, 'label.png'))
    else:
        valid_label_fp.append(None)
print(len(train_dicom_fp), len(train_label_fp), len(valid_dicom_fp), len(valid_label_fp) )

1656 1656 415 415


In [5]:
from sklearn.model_selection import train_test_split
normal = sorted(os.listdir(normal_folder))
train_normal, valid_normal = train_test_split(normal, test_size=0.33, random_state=42)
print(len(train_normal), len(valid_normal))
for i in train_normal:
    train_dicom_fp.append(os.path.join(normal_folder, i))
    train_label_fp.append(None)

for i in valid_normal:
    valid_dicom_fp.append(os.path.join(normal_folder, i))
    valid_label_fp.append(None)

1340 661


In [6]:
train_dicom_fp = np.array(train_dicom_fp)
train_label_fp = np.array(train_label_fp)
valid_dicom_fp = np.array(valid_dicom_fp)
valid_label_fp = np.array(valid_label_fp)
print(len(np.where(train_label_fp == None)[0]), len(np.where(valid_label_fp == None)[0]))
print(len(train_dicom_fp), len(train_label_fp), len(valid_dicom_fp), len(valid_label_fp) )

2996 729
2996 2996 1076 1076


# DataLoader

In [7]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

def image_mask_preprocessing(image, mask, height = 512, width = 512, **kwargs):
    larger_side = max(image.shape[0], image.shape[1])
    
    aug = albu.Compose([
        albu.PadIfNeeded(min_height=larger_side, min_width=larger_side, always_apply=True, border_mode=0),
        albu.Resize(height=height, width=width , always_apply=True,)
    ])
    
    sample = aug(image=image, mask=mask)
    image, mask = sample['image'], sample['mask']
    
    # normalize
    if 'is_norm' in kwargs and kwargs['is_norm'] == False:
        pass
    else:
        image = (image - image.min()) / (image.max() - image.min()) * (255 - 0) + 0
        image = image.astype('uint8')
    
    mask = np.where(mask > 0, 1, 0)
    # convert to 3 channel
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image, mask

In [8]:
class Dataset(BaseDataset):
    def __init__(
            self, 
            images_fps, 
            masks_fps, 
            augmentation=None, 
            preprocessing=None,
            **kwargs,
    ):
        self.images_fps = images_fps
        self.masks_fps = masks_fps
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.kwargs = kwargs
    
    def __getitem__(self, i):
        # dicom
        dcm = pydicom.dcmread(self.images_fps[i])
        image = dcm.pixel_array
        
        if(self.masks_fps[i] == None):
            mask = np.zeros_like(image)
        else:
            mask = cv2.imread(self.masks_fps[i], cv2.IMREAD_GRAYSCALE)
        
        image, mask = image_mask_preprocessing(image, mask, **self.kwargs)
        mask = np.expand_dims(mask, axis=-1).astype('float')
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        return image, mask
        
    def __len__(self):
        return len(self.images_fps)

# encoder/resnet

In [9]:
# class BasicBlock(nn.Module):
#     expansion = 1

#     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
#                  base_width=64, dilation=1, norm_layer=None):
#         super(BasicBlock, self).__init__()
#         if norm_layer is None:
#             norm_layer = nn.BatchNorm2d
#         if groups != 1 or base_width != 64:
#             raise ValueError('BasicBlock only supports groups=1 and base_width=64')
#         if dilation > 1:
#             raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
#         # Both self.conv1 and self.downsample layers downsample the input when stride != 1
#         self.conv1 = conv3x3(inplanes, planes, stride)
#         self.bn1 = norm_layer(planes)
#         self.relu = nn.ReLU(inplace=True)
#         self.conv2 = conv3x3(planes, planes)
#         self.bn2 = norm_layer(planes)
#         self.downsample = downsample
#         self.stride = stride
        
#     def forward(self, x):
#         identity = x

#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.relu(out)

#         out = self.conv2(out)
#         out = self.bn2(out)

#         if self.downsample is not None:
#             identity = self.downsample(x)

#         out += identity
#         out = self.relu(out)

#         return out

In [10]:
# import torch.nn as nn

# from torchvision.models.resnet import ResNet
# from torchvision.models.resnet import Bottleneck
# from pretrainedmodels.models.torchvision_models import pretrained_settings

# from segmentation_models_pytorch.encoders._base import EncoderMixin

# class ResNetEncoder(ResNet, EncoderMixin):
#     def __init__(self, out_channels, depth=5, **kwargs):
#         super().__init__(**kwargs)
#         self._depth = depth
#         self._out_channels = out_channels
#         self._in_channels = 3

#         del self.fc
#         del self.avgpool

#     def get_stages(self):
#         return [
#             nn.Identity(),
#             nn.Sequential(self.conv1, self.bn1, self.relu),
#             nn.Sequential(self.maxpool, self.layer1),
#             self.layer2,
#             self.layer3,
#             self.layer4,
#         ]

#     def forward(self, x):
#         stages = self.get_stages()

#         features = []
#         for i in range(self._depth + 1):
#             x = stages[i](x)
#             features.append(x)

#         return features

#     def load_state_dict(self, state_dict, **kwargs):
#         state_dict.pop("fc.bias")
#         state_dict.pop("fc.weight")
#         super().load_state_dict(state_dict, **kwargs)


# resnet_encoders = {
#     "resnet34": {
#         "encoder": ResNetEncoder,
#         "pretrained_settings": pretrained_settings["resnet34"],
#         "params": {
#             "out_channels": (3, 64, 64, 128, 256, 512),
#             "block": BasicBlock,
#             "layers": [3, 4, 6, 3],
#         },
#     },
# }

# PSP

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from segmentation_models_pytorch.base import modules

class PSPBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True):
        super().__init__()
        if pool_size == 1:
            use_bathcnorm = False  # PyTorch does not support BatchNorm for 1x1 shape
        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)),
            modules.Conv2dReLU(in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm)
        )

    def forward(self, x):
        h, w = x.size(2), x.size(3)
        x = self.pool(x)
        x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
        return x


class PSPModule(nn.Module):
    def __init__(self, in_channels, out_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True):
        super().__init__()

        self.blocks = nn.ModuleList([
            PSPBlock(in_channels, in_channels // len(sizes), size, use_bathcnorm=use_bathcnorm) for size in sizes
        ])
        self.conv = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1)
        
    def forward(self, x):
        xs = [block(x) for block in self.blocks] + [x]
        x = torch.cat(xs, dim=1)
        x = self.conv(x)
        return x

In [12]:
class FAMBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True):
        super().__init__()
        if pool_size == 1:
            use_bathcnorm = False  # PyTorch does not support BatchNorm for 1x1 shape
        self.pool = nn.Sequential(
            nn.AvgPool2d(kernel_size=(pool_size, pool_size), stride=(pool_size, pool_size)),
            modules.Conv2dReLU(in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm)
        )

    def forward(self, x):
        h, w = x.size(2), x.size(3)
        x = self.pool(x)
        x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
        return x

class FAMModule(nn.Module):
    def __init__(self, in_channels, out_channel, sizes=(1, 2, 4, 8), use_bathcnorm=True):
        super().__init__()

        self.blocks = nn.ModuleList([
            FAMBlock(in_channels, in_channels // len(sizes), size, use_bathcnorm=use_bathcnorm) for size in sizes
        ])

        self.conv = modules.Conv2dReLU(in_channels // len(sizes), out_channel, (3, 3), 1)
        

    def forward(self, x):
        xs = [block(x) for block in self.blocks] + [x]
        x_add = torch.add(xs[0], xs[1])
        x_add = torch.add(x_add, xs[2])
        x_add = torch.add(x_add, xs[3])
        
        conv_out = self.conv(x_add)
        return x

# FPN

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv3x3GNReLU(nn.Module):
    def __init__(self, in_channels, out_channels, upsample=False):
        super().__init__()
        self.upsample = upsample
        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
            ),
            nn.GroupNorm(32, out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.block(x)
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
        return x


class FPNBlock(nn.Module):
    def __init__(self, pyramid_channels, skip_channels):
        super().__init__()
        self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1)
        
    def forward(self, x, skip=None, ggf=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        skip = self.skip_conv(skip)
        x = x + skip + ggf
        return x


class SegmentationBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_upsamples=0):
        super().__init__()

        blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))]

        if n_upsamples > 1:
            for _ in range(1, n_upsamples):
                blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True))

        self.block = nn.Sequential(*blocks)

    def forward(self, x):
        return self.block(x)


class MergeBlock(nn.Module):
    def __init__(self, policy):
        super().__init__()
        if policy not in ["add", "cat"]:
            raise ValueError(
                "`merge_policy` must be one of: ['add', 'cat'], got {}".format(
                    policy
                )
            )
        self.policy = policy

    def forward(self, x):
        if self.policy == 'add':
            return sum(x)
        elif self.policy == 'cat':
            return torch.cat(x, dim=1)
        else:
            raise ValueError(
                "`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy)
            )

In [14]:
class FPNDecoder(nn.Module):
    def __init__(
            self,
            encoder_channels,
            encoder_depth=5,
            pyramid_channels=256,
            segmentation_channels=128,
            dropout=0.2,
            merge_policy="add",
    ):
        super().__init__()

        self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4
        if encoder_depth < 3:
            raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth))

        encoder_channels = encoder_channels[::-1]
        encoder_channels = encoder_channels[:encoder_depth + 1]

        self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
        self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
        self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
        self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])

        self.seg_blocks = nn.ModuleList([
            SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples)
            for n_upsamples in [3, 2, 1, 0]
        ])

        self.merge = MergeBlock(merge_policy)
        self.dropout = nn.Dropout2d(p=dropout, inplace=True)

        self.GGM = PSPModule(encoder_channels[0], pyramid_channels, sizes=(1,2,3,6))
        self.f4 = FAMModule(pyramid_channels, pyramid_channels)
        self.f3 = FAMModule(pyramid_channels, pyramid_channels)
        self.f2 = FAMModule(pyramid_channels, pyramid_channels)

    def forward(self, *features):
        c2, c3, c4, c5 = features[-4:]

        GGM = self.GGM(c5)

        p5 = self.p5(c5)

        GGF4 =  F.interpolate(GGM, scale_factor=2, mode="bilinear", align_corners=True)
        p4 = self.p4(p5, c4, GGF4)
        f4 = self.f4(p4)

        GGF3 =  F.interpolate(GGF4, scale_factor=2, mode="bilinear", align_corners=True)
        p3 = self.p3(f4, c3, GGF3)
        f3 = self.f3(p3)

        GGF2 =  F.interpolate(GGF3, scale_factor=2, mode="bilinear", align_corners=True)
        p2 = self.p2(p3, c2, GGF2)
        f2 = self.f3(p2)
        
        feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, f4, f3, f2])]

        x = self.merge(feature_pyramid)
        x = self.dropout(x)
        return x

In [15]:
from typing import Optional, Union
from segmentation_models_pytorch.base.heads import SegmentationHead, ClassificationHead
from segmentation_models_pytorch.base.model import SegmentationModel
from segmentation_models_pytorch.encoders import get_encoder
import torch

class FPN(SegmentationModel):
    def __init__(self,
        encoder_name: str = "resnet34",
        in_channels: int = 3,
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_pyramid_channels: int = 256,
        decoder_segmentation_channels: int = 128,
        decoder_merge_policy: str = "add",
        decoder_dropout: float = 0.2,
        classes: int = 1,
        activation: Optional[str] = None,
        upsampling: int = 4,
        aux_params: Optional[dict] = None
    ):
        super(FPN, self).__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = FPNDecoder(
            encoder_channels=self.encoder.out_channels,
            encoder_depth=encoder_depth,
            pyramid_channels=decoder_pyramid_channels,
            segmentation_channels=decoder_segmentation_channels,
            dropout=decoder_dropout,
            merge_policy=decoder_merge_policy,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=self.decoder.out_channels,
            out_channels=classes,
            activation=activation,
            kernel_size=1,
            upsampling=upsampling,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(
                in_channels=self.encoder.out_channels[-1], **aux_params
            )
        else:
            self.classification_head = None

        self.name = "fpn-{}".format(encoder_name)
        self.initialize()

# Epoch - cls loss & seg loss

In [16]:
from tqdm import tqdm as tqdm
from segmentation_models_pytorch.utils.meter import AverageValueMeter
class Epoch:
    def __init__(self, model, loss, metrics, stage_name, device='cpu', verbose=True):
        self.model = model
        self.loss = loss
        self.metrics = metrics
        self.stage_name = stage_name
        self.verbose = verbose
        self.device = device
        self._to_device()
    def _to_device(self):
        self.model.to(self.device)
        for loss in self.loss:
            loss.to(self.device)
        for metric in self.metrics:
            metric.to(self.device)
    def _format_logs(self, logs):
        str_logs = ['{} - {:.4}'.format(k, v) for k, v in logs.items()]
        s = ', '.join(str_logs)
        return s
    def batch_update(self, x, y, cls):
        raise NotImplementedError
    def on_epoch_start(self):
        pass
    def run(self, dataloader):
        self.on_epoch_start()
        logs = {}
        seg_loss_meter = AverageValueMeter()
        cls_loss_meter = AverageValueMeter()
        seg_metrics_meters = AverageValueMeter()
        cls_metrics_meters = AverageValueMeter()
        with tqdm(dataloader, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as iterator:
            for x, y, cls in iterator:
                x, y, cls = x.to(self.device), y.to(self.device), cls.to(self.device)
                seg_loss, cls_loss, seg_pred, cls_pred = self.batch_update(x, y, cls)
                # update loss logs
                seg_loss_value = seg_loss.cpu().detach().numpy()
                seg_loss_meter.add(seg_loss_value)
                cls_loss_value = cls_loss.cpu().detach().numpy()
                cls_loss_meter.add(cls_loss_value)
                loss_logs = {self.loss[0].__name__: seg_loss_meter.mean, 'cls_bce_loss': cls_loss_meter.mean}
                logs.update(loss_logs)
                # update metrics logs
                for metric_fn in self.metrics:
                    seg_metric_value = metric_fn(seg_pred, y).cpu().detach().numpy()
                    seg_metrics_meters.add(seg_metric_value)
                _, cls_pred = torch.max(cls_pred.data, 1)
                cls_metric_value = (cls_pred == cls).sum().float() / cls.shape[0]
                cls_metric_value = cls_metric_value.detach().cpu().numpy()
                cls_metrics_meters.add(cls_metric_value)
                seg_metrics_logs = {'dice' : seg_metrics_meters.mean}
                cls_metrics_logs = {'acc' : cls_metrics_meters.mean}
                logs.update(seg_metrics_logs)
                logs.update(cls_metrics_logs)
                if self.verbose:
                    s = self._format_logs(logs)
                    iterator.set_postfix_str(s)
        return logs
    
class TrainEpoch(Epoch):
    def __init__(self, model, loss, metrics, optimizer, device='cpu', verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name='train',
            device=device,
            verbose=verbose,
        )
        self.optimizer = optimizer
    def on_epoch_start(self):
        self.model.train()
    def batch_update(self, x, y, cls):
        self.optimizer.zero_grad()
        seg_prediction, cls_prediction = self.model.forward(x)
        seg_loss = self.loss[0](seg_prediction, y)
        cls_loss = self.loss[1](cls_prediction, cls)
        loss = seg_loss + cls_loss
        loss.backward()
        self.optimizer.step()
        return seg_loss, cls_loss, seg_prediction, cls_prediction
    
class ValidEpoch(Epoch):
    def __init__(self, model, loss, metrics, device='cpu', verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name='valid',
            device=device,
            verbose=verbose,
        )
    def on_epoch_start(self):
        self.model.eval()
    def batch_update(self, x, y, cls):
        with torch.no_grad():
            seg_prediction, cls_prediction = self.model.forward(x)
            seg_loss = self.loss[0](seg_prediction, y)
            cls_loss = self.loss[1](cls_prediction, cls)
        return seg_loss, cls_loss, seg_prediction, cls_prediction

# augmentation

In [17]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),

        albu.RandomCrop(height = 768, width = 768, always_apply = True),
        
        albu.ShiftScaleRotate(scale_limit=0, rotate_limit=10, shift_limit=0.1, p=1, border_mode=0),

        albu.IAAAdditiveGaussianNoise(p=0.2),
        albu.IAAPerspective(p=0.5),

        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.IAASharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing():
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
#         albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

# model setting

In [18]:
ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['penu']
ACTIVATION = None # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = 'cuda'
kwargs = {'classes': 2}

# create segmentation model with pretrained encoder
model = FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
#     aux_params= kwargs
)
print(model)

FPN(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [19]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(pytorch_total_params)

24322689


In [20]:
import torch.optim.lr_scheduler
epoch = 100
batch_size = 32

# model settings
loss = smp.utils.losses.BCEWithLogitsLoss()
# loss = [
#     smp.utils.losses.BCEWithLogitsLoss(),
#     smp.utils.losses.CrossEntropyLoss()
# ]
metrics = [
    smp.utils.metrics.IoU(),
]
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=1e-4),
])

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max=epoch)

use_cuda = torch.cuda.is_available()
DEVICE = torch.device("cuda:0" if use_cuda else "cpu")
# train_epoch = TrainEpoch(
#     model, 
#     loss=loss, 
#     metrics=metrics, 
#     optimizer=optimizer,
#     device=DEVICE,
#     verbose=True,
# )

# valid_epoch = ValidEpoch(
#     model, 
#     loss=loss, 
#     metrics=metrics, 
#     device=DEVICE,
#     verbose=True,
# )
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [21]:
# data settings
train_dataset = Dataset(
    train_dicom_fp, 
    train_label_fp, 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(),
    height = 1024,
    width = 1024
)
valid_dataset = Dataset(
    valid_dicom_fp, 
    valid_label_fp, 
    preprocessing=get_preprocessing(),
    height = 1024,
    width = 1024
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=20)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=20)


In [None]:
# import time
# import warnings
# warnings.filterwarnings('ignore')

# current_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())

# model_folder_name = '/home/tsung1271232/pneumothorax-segmentation/weight/'
# model_name = model_folder_name + str(current_time) + "_aug:Packagedefault_PoolNet_bs:{}_PoolNet".format(batch_size)

# max_score = 0

# for i in range(0, epoch):
#     print('\nEpoch: {}, batch: {}'.format(i, batch_size))
#     train_logs = train_epoch.run(train_loader)
#     valid_logs = valid_epoch.run(valid_loader)
#     lr_scheduler.step()
#     # do something (save model, change lr, etc.)
#     if max_score < valid_logs['iou_score']:
#         max_score = valid_logs['iou_score']
# #         torch.save(model, model_name+"_epoch:{}-fscore:{:.2f}.pth".format(i, max_score))
#         torch.save(model.state_dict(), model_name+"_epoch:{}-bce:{:.2f}.pth".format(i, max_score))
#         print('Model saved! {}'.format(model_name+"_epoch:{}-bce:{:.2f}.pth".format(i, max_score)))
        
# #     if i == 35:
# #         optimizer.param_groups[0]['lr'] = 5e-5
# #         print('Decrease decoder learning rate to 1e-5!')
        
# #     if i == 75:
# #         optimizer.param_groups[0]['lr'] = 5e-6
# #         print('Decrease decoder learning rate to 1e-6!')

In [None]:
# torch.save(model.state_dict(), model_name+"_epoch:{}-bce:{:.2f}.pth".format(99, 0.9089))

# testing

In [23]:
model = FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
#     aux_params= kwargs
)

model.load_state_dict(torch.load('/home/tsung1271232/pneumothorax-segmentation/weight/2020_06_16_12_33_47_aug:Packagedefault_PoolNet_bs:32_PoolNet_epoch:30-iou:0.88.pth'))

model.cuda()

FPN(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [24]:
torch.save(model, '/home/tsung1271232/Pneumothorax-Detection/stage1_seg.pth')

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [None]:
valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
valid_logs = valid_epoch.run(valid_loader)

In [None]:
cmap_no_background = LinearSegmentedColormap.from_list("", ["none", "blue", 'cyan', 'green', 'orange', 'red'])
cmap_gt_background2 = LinearSegmentedColormap.from_list("", ["none", 'yellow'])

In [None]:
for batch_idx, (data, mask) in enumerate(valid_loader):
    pr_mask = model.predict(data.to('cuda'))
    act = nn.Sigmoid()
    pr_mask = act(pr_mask)
    pr_mask = pr_mask.cpu().numpy()
    pr_mask[pr_mask < 0.1] = None
    
    gt_mask = mask.numpy()    
    image = data.numpy().astype('uint8')

    plt.figure(figsize=(80, 40))
    for idx in range(len(gt_mask)):
        plt.subplot(4, 8, idx+1).set_title("{}-{}".format(batch_idx, idx))
        plt.imshow(image[idx].transpose(1,2,0).reshape(1024, 1024,3), cmap = "gray")
#         plt.imshow(gt_mask[idx].transpose(1,2,0).reshape(1024, 1024), alpha=0.5, cmap = 'jet')
        plt.imshow(pr_mask[idx].transpose(1,2,0).reshape(1024, 1024), alpha=0.5, cmap = 'jet')
        
    break

# external dataset

In [None]:
image_folder = '/work/tsung1271232/ncku-dataset/16 cases'
# image_folder = '/work/tsung1271232/ncku-dataset/0615testdata'

In [None]:
mask_dict = {}
for test_image in sorted(os.listdir(image_folder)):
    dcm = pydicom.dcmread(os.path.join(image_folder, test_image))
    image = dcm.pixel_array
    
    image, mask = image_mask_preprocessing(image, np.zeros_like(image), 768, 768)
    
    data = image.transpose(2, 0, 1).astype('float32')
    data = np.expand_dims(data, axis = 0)
    
    data = torch.from_numpy(data)
    pr_mask = model.predict(data.to('cuda'))
    act = nn.Sigmoid()
    pr_mask = act(pr_mask)
    pr_mask = pr_mask.cpu().numpy()
    
    plt.figure(figsize=(8, 8))
    plt.imshow(image.reshape(768, 768,3), cmap = "gray")
    plt.imshow(pr_mask.squeeze(), alpha=0.5, cmap= 'Reds', vmin=0, vmax=1)
    plt.colorbar()
    mask_dict[test_image] = pr_mask.squeeze()
    

In [None]:
with open('/work/tsung1271232/16_cases_masks.pickle', 'wb') as handle:
    pickle.dump(mask_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import pickle
with open('/work/tsung1271232/0615testdata_masks.pickle', 'wb') as handle:
    pickle.dump(mask_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
train_folder = '/work/tsung1271232/siim/train/512_dicom/dicom'
for name in os.listdir(train_folder):
    image = cv2.imread(os.path.join(train_folder, name))

# kaggle

In [None]:
class test_Dataset(BaseDataset):
    def __init__(
            self, 
            images_fps, 
            masks_fps, 
            augmentation=None, 
            preprocessing=None,
            **kwargs,
    ):
        self.images_fps = images_fps
        self.masks_fps = masks_fps
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.kwargs = kwargs
    
    def __getitem__(self, i):
        # dicom
        image = cv2.imread(self.images_fps[i], cv2.IMREAD_GRAYSCALE)
        
        if(self.masks_fps[i] == None):
            mask = np.zeros_like(image)
        else:
            mask = cv2.imread(self.masks_fps[i], cv2.IMREAD_GRAYSCALE)
        
        image, mask = image_mask_preprocessing(image, mask, **self.kwargs)
        mask = np.expand_dims(mask, axis=-1).astype('float')
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        return image, mask
        
    def __len__(self):
        return len(self.images_fps)

In [None]:
test_image_fp = []
test_mask_fp = []
for test_image in sorted(os.listdir('/work/tsung1271232/siim/train/images/1024/image')):
    if os.path.exists(os.path.join('/work/tsung1271232/siim/train/images/1024/mask', test_image)):
        test_image_fp.append(os.path.join('/work/tsung1271232/siim/train/images/1024/image', test_image))
        test_mask_fp.append(os.path.join('/work/tsung1271232/siim/train/images/1024/mask', test_image))

test_dataset = test_Dataset(
    test_image_fp,
    test_mask_fp, 
    preprocessing=get_preprocessing(),
    height = 1024,
    width = 1024,
    is_norm = False
)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=20)

test_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
batch_size = 8

# model settings
loss = smp.utils.losses.BCEWithLogitsLoss()

metrics = [
    smp.utils.metrics.IoU(),
]
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=1e-4),
])

use_cuda = torch.cuda.is_available()
DEVICE = torch.device("cuda:0" if use_cuda else "cpu")

In [None]:
test_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
# test_logs = test_epoch.run(test_loader)

In [None]:
for batch_idx, (data, mask) in enumerate(test_loader):
    if batch_idx == 5:
        pr_mask = model.predict(data.to('cuda'))
        act = nn.Sigmoid()
        pr_mask = act(pr_mask)
        pr_mask = pr_mask.cpu().numpy()
        pr_mask[pr_mask < 0.1] = None

        gt_mask = mask.numpy()    
        image = data.numpy().astype('uint8')

        plt.figure(figsize=(80, 40))
        for idx in range(len(gt_mask)):
            plt.subplot(2, 4, idx+1).set_title("{}-{}".format(batch_idx, idx))
            plt.imshow(image[idx].transpose(1,2,0).reshape(1024, 1024,3), cmap = "gray")
            plt.imshow(gt_mask[idx].transpose(1,2,0).reshape(1024, 1024), alpha=0.5, cmap=ListedColormap(['#ffffff00', 'y']))
            plt.imshow(pr_mask[idx].transpose(1,2,0).reshape(1024, 1024), alpha=0.5, cmap=ListedColormap(['#ffffff00', 'm']))
        break

In [None]:
mask = cv2.imread(test_mask_fp[0])

In [None]:
mask.min()

# TTA

In [None]:
import ttach as tta
transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.Rotate90(angles=[0, 20]),
        tta.Scale(scales=[1, 2, 4]),
        tta.Multiply(factors=[0.9, 1, 1.1]),        
    ]
)

tta_model = tta.SegmentationTTAWrapper(model, transforms)
idx = 0
for test_image in sorted(os.listdir('/work/tsung1271232/ncku-dataset/16 cases')):
    idx += 1
    if(idx != 6):
        continue
    masks = []
    dcm = pydicom.dcmread(os.path.join('/work/tsung1271232/ncku-dataset/16 cases', test_image))
    image = dcm.pixel_array

    image, mask = image_mask_preprocessing(image, np.zeros_like(image), 1024, 1024)
    
    data = image.transpose(2, 0, 1).astype('float32')
    data = np.expand_dims(data, axis = 0)
    
    data = torch.from_numpy(data)
    
    for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 

        # augment image
        augmented_image = transformer.augment_image(data)
        
        # pass to model
        model_output = model.predict(augmented_image.to('cuda'))
        
        act = nn.Sigmoid()
        model_output = act(model_output)
        
        # reverse augmentation for mask and label
        deaug_mask = transformer.deaugment_mask(model_output)
        
        # save results
        masks.append(deaug_mask)
    results = torch.cat(masks, dim=0)
    mask = torch.mean(results, dim=0)
    # reduce results as you want, e.g mean/max/min
    pr_mask = mask.cpu().numpy()
    
    plt.figure(figsize=(6,6))
    plt.title(test_image)
    plt.imshow(image.reshape(1024, 1024,3), cmap = "gray")
    plt.imshow(pr_mask.squeeze(), alpha=0.5, cmap= 'Reds', vmin=0, vmax=1)
    plt.colorbar()

# segmentation mask

In [None]:
# data settings
train_dataset = Dataset(
    train_dicom_fp, 
    train_label_fp, 
    preprocessing=get_preprocessing(),
    height = 1024,
    width = 1024
)

valid_dataset = Dataset(
    valid_dicom_fp, 
    valid_label_fp, 
    preprocessing=get_preprocessing(),
    height = 1024,
    width = 1024
)
output_folder = '/work/tsung1271232/ncku-dataset/init_seg'

In [None]:
for idx in range(len(train_dataset)):
    print(idx, train_dataset.images_fps[idx])
    fp = train_dataset.images_fps[idx]
    fp = os.path.split(fp)
    name, extension = os.path.splitext(fp[-1])
    
    image, mask = train_dataset[idx]
    
    data = np.expand_dims(image, axis = 0)
    data = torch.from_numpy(data)
    pr_mask = model.predict(data.to('cuda'))
    
    act = nn.Sigmoid()
    pr_mask = act(pr_mask)
    pr_mask = pr_mask.cpu().numpy()
    image = image.astype('uint8')
    
#     plt.figure(figsize=(6,6))
#     plt.subplot(1,2,1).set_title(name)
#     plt.imshow(image.transpose(1,2,0).reshape(1024, 1024,3), cmap = "gray")
#     plt.imshow(pr_mask.squeeze().reshape(1024, 1024), alpha=0.5, cmap = cmap_no_background)
#     plt.subplot(1,2,2)
#     plt.imshow(mask.squeeze().reshape(1024, 1024), alpha=0.5, cmap = cmap_gt_background2)
#     output_seg[name] = pr_mask.squeeze()
    np.save(os.path.join(output_folder, name+'.npy'), pr_mask.squeeze())

In [None]:
for idx in range(len(valid_dataset)):
    fp = valid_dataset.images_fps[idx]
    fp = os.path.split(fp)
    name, extension = os.path.splitext(fp[-1])
    
    image, mask = valid_dataset[idx]
    
    data = np.expand_dims(image, axis = 0)
    data = torch.from_numpy(data)
    pr_mask = model.predict(data.to('cuda'))
    
    act = nn.Sigmoid()
    pr_mask = act(pr_mask)
    pr_mask = pr_mask.cpu().numpy()
    image = image.astype('uint8')
    
#     plt.figure(figsize=(6,6))
#     plt.subplot(1,2,1).set_title(name)
#     plt.imshow(image.transpose(1,2,0).reshape(1024, 1024,3), cmap = "gray")
#     plt.imshow(pr_mask.squeeze().reshape(1024, 1024), alpha=0.5, cmap = cmap_no_background)
#     plt.subplot(1,2,2)
#     plt.imshow(mask.squeeze().reshape(1024, 1024), alpha=0.5, cmap = cmap_gt_background2)
    
    np.save(os.path.join(output_folder, name+'.npy'), pr_mask.squeeze())

# test for py

In [None]:
mask_dict = {}
for test_image in sorted(os.listdir('/work/tsung1271232/ncku-dataset/train/dicom')):
    dcm = pydicom.dcmread(os.path.join('/work/tsung1271232/ncku-dataset/train/dicom', test_image))
    image = dcm.pixel_array
    
    image, mask = image_mask_preprocessing(image, np.zeros_like(image), 512, 512)
    
    data = image.transpose(2, 0, 1).astype('float32')
    data = np.expand_dims(data, axis = 0)
    
    data = torch.from_numpy(data)
    pr_mask = model.predict(data.to('cuda'))
    act = nn.Sigmoid()
    pr_mask = act(pr_mask)
    pr_mask = pr_mask.cpu().numpy()
    
    plt.figure(figsize=(8, 6))
    plt.imshow(image.reshape(512, 512, 3), cmap = "gray")
    plt.axis('off')
    plt.savefig('./original.png', bbox_inches='tight') 
    
    plt.figure(figsize=(8, 6))
    plt.imshow(image.reshape(512, 512, 3), cmap = "gray")
    plt.imshow(pr_mask.squeeze(), alpha=0.5, cmap= cmap_no_background, vmin=0, vmax=1)
    plt.colorbar()
    plt.axis('off')
    plt.savefig('./predict_mask.png', bbox_inches='tight')
    break