### Import Packages

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import copy
import random
import torch.nn.functional as F
import time
import argparse
import datetime
import sys
from torch.utils.data import DataLoader
import math
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
import shutil
import hashlib
from torch import nn as nn
from torch.utils.data import Dataset
import os
from PIL import Image, ImageDraw

In [2]:
import sys
sys.path.append('./')
from dsetsFullCT import TrainingLuna2dSegmentationDataset, Luna2dSegmentationDataset, PrepcacheLunaDataset, getCt
from util import logging, enumerateWithEstimate
from UDet_3layer import UDet

# Model

In [3]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, n_classes=2, depth=5, wf=6, padding=False,  #原論文channel數是64，為2^6
                 batch_norm=False, up_mode='upconv'):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597

        Using the default arguments will yield the exact version used
        in the original paper

        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
                                                padding, batch_norm))
            prev_channels = 2**(wf+i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
                                            padding, batch_norm))
            prev_channels = 2**(wf+i) #channel數會隨著down sampling增加，以2的倍數增加

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path)-1:
                blocks.append(x)  #put the result in blocks, and to be a bridge to upsampleing
                x = F.avg_pool2d(x, 2)  #做一次的avarage pooling, stride為2, kernel size 為2(大小砍半)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i-1])

        return self.last(x)

class UNetConvBlock(nn.Module): #每一層都會做2次的convolution，kenrel size 都是3
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3,
                               padding=int(padding)))
        block.append(nn.ReLU())
        # block.append(nn.LeakyReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3,
                               padding=int(padding)))
        block.append(nn.ReLU())
        # block.append(nn.LeakyReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out

class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2,
                                         stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2),
                                    nn.Conv2d(in_size, out_size, kernel_size=1))

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

In [4]:
class UNetWrapper(nn.Module):
    def __init__(self, **kwargs): #kwarg is a dictionary containing all keyword arguments passed to the constructor
        super().__init__()

        # we will do batchnormalization first 
        self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])  #in kwarg, we have in_channels params to give the input channel
        self.unet = UNet(**kwargs)
        self.final = nn.Sigmoid() #use sigmoid to limit the output to 0,1

        self._init_weights()

    def _init_weights(self):
        init_set = {
            nn.Conv2d,
            nn.Conv3d,
            nn.ConvTranspose2d,
            nn.ConvTranspose3d,
            nn.Linear,
        }
        for m in self.modules():
            if type(m) in init_set:
                nn.init.kaiming_normal_(
                    m.weight.data, mode='fan_out', nonlinearity='relu', a=0
                )
                if m.bias is not None:
                    fan_in, fan_out = \
                        nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                    bound = 1 / math.sqrt(fan_out)
                    nn.init.normal_(m.bias, -bound, bound)

        # nn.init.constant_(self.unet.last.bias, -4)
        # nn.init.constant_(self.unet.last.bias, 4)


    def forward(self, input_batch):
        bn_output = self.input_batchnorm(input_batch)
        un_output = self.unet(bn_output)
        fn_output = self.final(un_output)
        return fn_output
    
class UDetWrapper(nn.Module):
    def __init__(self, **kwargs): #kwarg is a dictionary containing all keyword arguments passed to the constructor
        super().__init__()

        # we will do batchnormalization first 
        self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])  #in kwarg, we have in_channels params to give the input channel
        self.udet = UDet(**kwargs)
        self.final = nn.Sigmoid() #use sigmoid to limit the output to 0,1

        self._init_weights()

    def _init_weights(self):
        init_set = {
            nn.Conv2d,
            nn.Conv3d,
            nn.ConvTranspose2d,
            nn.ConvTranspose3d,
            nn.Linear,
        }
        for m in self.modules():
            if type(m) in init_set:
                nn.init.kaiming_normal_(
                    m.weight.data, mode='fan_out', nonlinearity='relu', a=0
                )
                if m.bias is not None:
                    fan_in, fan_out = \
                        nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                    bound = 1 / math.sqrt(fan_out)
                    nn.init.normal_(m.bias, -bound, bound)

        # nn.init.constant_(self.unet.last.bias, -4)
        # nn.init.constant_(self.unet.last.bias, 4)


    def forward(self, input_batch):
        bn_output = self.input_batchnorm(input_batch)
        un_output = self.udet(bn_output)
        fn_output = self.final(un_output)
        return fn_output

class SegmentationAugmentation(nn.Module):
    def __init__(
            self, flip=None, offset=None, scale=None, rotate=None, noise=None
    ):
        super().__init__()

        self.flip = flip
        self.offset = offset
        self.scale = scale
        self.rotate = rotate
        self.noise = noise

    def forward(self, input_g, label_g):
        transform_t = self._build2dTransformMatrix()
        transform_t = transform_t.expand(input_g.shape[0], -1, -1)  #對Index複製
        transform_t = transform_t.to(input_g.device, torch.float32) #transform前兩行有關伸縮旋轉，最後一行有關平移
        affine_t = F.affine_grid(transform_t[:,:2],
                input_g.size(), align_corners=False)  #使用affine grid的原因是因為用grid可以減少對整張圖的運算量
                                     #而且如果用原圖，可能會造成座標落在非整數格上，如此會讓圖型產生矩齒狀

        augmented_input_g = F.grid_sample(input_g,
                affine_t, padding_mode='border',
                align_corners=False)
        augmented_label_g = F.grid_sample(label_g.to(torch.float32),#grid sample只吃float，所以這裡轉float，但用同一個affine grid
                affine_t, padding_mode='border',
                align_corners=False)

        if self.noise:
            noise_t = torch.randn_like(augmented_input_g)
            noise_t *= self.noise

            augmented_input_g += noise_t

        return augmented_input_g, augmented_label_g > 0.5 #把label改回成0,1

    def _build2dTransformMatrix(self):
        transform_t = torch.eye(3)  #建立一個3*3單位矩陣

        for i in range(2):  #我們只有2D
            if self.flip:
                if random.random() > 0.5:
                    transform_t[i,i] *= -1

            if self.offset:
                offset_float = self.offset
                random_float = (random.random() * 2 - 1)
                transform_t[2,i] = offset_float * random_float

            if self.scale:
                scale_float = self.scale
                random_float = (random.random() * 2 - 1)
                transform_t[i,i] *= 1.0 + scale_float * random_float

        if self.rotate:
            angle_rad = random.random() * math.pi * 2 #隨機弧度
            s = math.sin(angle_rad) #轉角度
            c = math.cos(angle_rad)

            rotation_t = torch.tensor([
                [c, -s, 0],
                [s, c, 0],
                [0, 0, 1]])

            transform_t @= rotation_t #矩陣乘法

        return transform_t

# Prepcache

In [5]:
class LunaPrepCacheApp:
    @classmethod
    def __init__(self, sys_argv=None):
        if sys_argv is None:
            sys_argv = sys.argv[1:]

        parser = argparse.ArgumentParser()
        parser.add_argument('--batch-size',
            help='Batch size to use for training',
            default=1024,
            type=int,
        )
        parser.add_argument('--num-workers',
            help='Number of worker processes for background data loading',
            default=1, #8
            type=int,
        )
        # parser.add_argument('--scaled',
        #     help="Scale the CT chunks to square voxels.",
        #     default=False,
        #     action='store_true',
        # )

        self.cli_args = parser.parse_args(sys_argv)

    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        self.prep_dl = DataLoader(
             PrepcacheLunaDataset(
#                 # sortby_str='series_uid',
             ),
   
            batch_size=self.cli_args.batch_size,
            num_workers=self.cli_args.num_workers,
        )

        batch_iter = enumerateWithEstimate(
            self.prep_dl,
            "Stuffing cache",
            start_ndx=self.prep_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            pass

# Training

In [6]:
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()
        self.epsilon = 1e-5
    
    def forward(self, predict, target):
        assert predict.size() == target.size(), "the size of predict and target must be equal."
        num = predict.size(0)
        
        pre = torch.sigmoid(predict).view(num, -1)
        tar = target.view(num, -1)
        
        intersection = (pre * tar).sum(-1).sum()  #利用预测值与标签相乘当作交集
        union = (pre + tar).sum(-1).sum()
        
        score = 1 - 2 * (intersection + self.epsilon) / (union + self.epsilon)
        
        return score

In [7]:
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
METRICS_LOSS_NDX = 1
METRICS_TP_NDX = 7
METRICS_FN_NDX = 8
METRICS_FP_NDX = 9

METRICS_SIZE = 10
class SegmentationTrainingApp:
    def __init__(self, sys_argv=None):
        if sys_argv is None:
            sys_argv = sys.argv[1:]

        parser = argparse.ArgumentParser()
        parser.add_argument('--batch-size',
            help='Batch size to use for training',
            default=16,
            type=int,
        )
        parser.add_argument('--num-workers',
            help='Number of worker processes for background data loading',
            default=8,
            type=int,
        )
        parser.add_argument('--epochs',
            help='Number of epochs to train for',
            default=1,
            type=int,
        )

        parser.add_argument('--augmented',
            help="Augment the training data.",
            action='store_true',
            default=False,
        )
        parser.add_argument('--augment-flip',
            help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
            action='store_true',
            default=False,
        )
        parser.add_argument('--augment-offset',
            help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
            action='store_true',
            default=False,
        )
        parser.add_argument('--augment-scale',
            help="Augment the training data by randomly increasing or decreasing the size of the candidate.",
            action='store_true',
            default=False,
        )
        parser.add_argument('--augment-rotate',
            help="Augment the training data by randomly rotating the data around the head-foot axis.",
            action='store_true',
            default=False,
        )
        parser.add_argument('--augment-noise',
            help="Augment the training data by randomly adding noise to the data.",
            action='store_true',
            default=False,
        )

        parser.add_argument('--tb-prefix',
            default='udet',
            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
        )

        parser.add_argument('comment',
            help="Comment suffix for Tensorboard run.",
            nargs='?',
            default='none',
        )

        self.cli_args = parser.parse_args(sys_argv)
        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
        self.totalTrainingSamples_count = 0
        self.trn_writer = None
        self.val_writer = None

        #augumentation設定的值
        self.augmentation_dict = {}
        if self.cli_args.augmented or self.cli_args.augment_flip:
            self.augmentation_dict['flip'] = True
        if self.cli_args.augmented or self.cli_args.augment_offset:
            self.augmentation_dict['offset'] = 0.03
        if self.cli_args.augmented or self.cli_args.augment_scale:
            self.augmentation_dict['scale'] = 0.2
        if self.cli_args.augmented or self.cli_args.augment_rotate:
            self.augmentation_dict['rotate'] = True
        if self.cli_args.augmented or self.cli_args.augment_noise:
            self.augmentation_dict['noise'] = 25.0

        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")

        self.segmentation_model, self.augmentation_model = self.initModel()
        self.optimizer = self.initOptimizer()


    def initModel(self):
        segmentation_model = UDetWrapper(
            in_channels=7,
            n_classes=1,
            depth=2,  #how deep the U go
            wf=6,   #2^4 filter
            padding=True, #padding so that we get the output size as input size
            batch_norm=True,
            up_mode='upconv', #use  nn.ConvTranspose2d
        )

        augmentation_model = SegmentationAugmentation(**self.augmentation_dict)

        if self.use_cuda:
            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
            if torch.cuda.device_count() > 1: #parallel data if we have much device
                segmentation_model = nn.DataParallel(segmentation_model)
                augmentation_model = nn.DataParallel(augmentation_model)
            segmentation_model = segmentation_model.to(self.device)
            augmentation_model = augmentation_model.to(self.device)

        return segmentation_model, augmentation_model #回傳unet wrapper和augmentation

    def initOptimizer(self):
        return Adam(self.segmentation_model.parameters(), lr=0.001, betas=(0.99,0.999), weight_decay=1e-6)
        # return SGD(self.segmentation_model.parameters(), lr=0.001, momentum=0.99)


    def initTrainDl(self):
        train_ds = TrainingLuna2dSegmentationDataset(
            val_stride=10,
            isValSet_bool=False,
            contextSlices_count=3,
        )

        batch_size = self.cli_args.batch_size
        if self.use_cuda:
            batch_size *= torch.cuda.device_count()

        train_dl = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )

        return train_dl

    def initValDl(self):
        val_ds = Luna2dSegmentationDataset(
            val_stride=10,
            isValSet_bool=True,
            contextSlices_count=3,
        )

        batch_size = self.cli_args.batch_size
        if self.use_cuda:
            batch_size *= torch.cuda.device_count()

        val_dl = DataLoader(
            val_ds,
            batch_size=batch_size,
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )

        return val_dl

    def initTensorboardWriters(self):
        if self.trn_writer is None:
            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)

            self.trn_writer = SummaryWriter(
                log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
            self.val_writer = SummaryWriter(
                log_dir=log_dir + '_val_seg_' + self.cli_args.comment)

    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        train_dl = self.initTrainDl()
        val_dl = self.initValDl()

        best_score = 0.0
        self.validation_cadence = 5
        for epoch_ndx in range(1, self.cli_args.epochs + 1):
            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
                epoch_ndx,
                self.cli_args.epochs,
                len(train_dl),
                len(val_dl),
                self.cli_args.batch_size,
                (torch.cuda.device_count() if self.use_cuda else 1),
            ))

            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)

            if epoch_ndx == 1 or epoch_ndx % self.validation_cadence == 0:
                # if validation is wanted
                valMetrics_t = self.doValidation(epoch_ndx, val_dl)
                score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
                best_score = max(score, best_score)

                self.saveModel('seg', epoch_ndx, score == best_score)

                self.logImages(epoch_ndx, 'trn', train_dl)
                self.logImages(epoch_ndx, 'val', val_dl)

        self.trn_writer.close()
        self.val_writer.close()

    def doTraining(self, epoch_ndx, train_dl):
        trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset), device=self.device)
        self.segmentation_model.train()
        train_dl.dataset.shuffleSamples()

        batch_iter = enumerateWithEstimate(
            train_dl,
            "E{} Training".format(epoch_ndx),
            start_ndx=train_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            self.optimizer.zero_grad()

            loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trnMetrics_g)
            loss_var.backward()

            self.optimizer.step()

        self.totalTrainingSamples_count += trnMetrics_g.size(1)

        return trnMetrics_g.to('cpu')

    def doValidation(self, epoch_ndx, val_dl):
        with torch.no_grad():
            valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset), device=self.device)
            self.segmentation_model.eval()
            batch_iter = enumerateWithEstimate(
                val_dl,
                "E{} Validation ".format(epoch_ndx),
                start_ndx=val_dl.num_workers,
            )
            for batch_ndx, batch_tup in batch_iter:
                self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)

        return valMetrics_g.to('cpu')

    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g,
                         classificationThreshold=0.5):
        input_t, label_t, series_list, _slice_ndx_list = batch_tup

        input_g = input_t.to(self.device, non_blocking=True)
        label_g = label_t.to(self.device, non_blocking=True)

        if self.segmentation_model.training and self.augmentation_dict:
            input_g, label_g = self.augmentation_model(input_g, label_g)

        prediction_g = self.segmentation_model(input_g)
        
        pos_weight = torch.tensor([1000.0]).to(self.device, non_blocking=True)
        
        
        # loss = DiceLoss()
        # DLoss = criterion(prediction_g, label_g.to(torch.float))
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        BCELoss = criterion(prediction_g, label_g.to(torch.float))
        # fnLoss_g = criterion(prediction_g * ~label_g, label_g.to(torch.float))
        # print(BCELoss)

        # diceLoss_g = self.diceLoss(prediction_g, label_g)
        # fnLoss_g = self.diceLoss(prediction_g * label_g, label_g) #只關心ground truth為true的部分
#         # print(diceLoss_g.mean())
#         # print(fnLoss_g)
#         # print(diceLoss_g + fnLoss_g * 8)
#         # print()

        start_ndx = batch_ndx * batch_size
        end_ndx = start_ndx + input_t.size(0)

        with torch.no_grad():
            predictionBool_g = (prediction_g[:, 0:1]
                                > classificationThreshold).to(torch.float32)

            tp = (     predictionBool_g *  label_g).sum(dim=[1,2,3])
            fn = ((1 - predictionBool_g) *  label_g).sum(dim=[1,2,3])
            fp = (     predictionBool_g * (~label_g)).sum(dim=[1,2,3])

            metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = BCELoss
            # metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g + fnLoss_g * 8
            metrics_g[METRICS_TP_NDX, start_ndx:end_ndx] = tp
            metrics_g[METRICS_FN_NDX, start_ndx:end_ndx] = fn
            metrics_g[METRICS_FP_NDX, start_ndx:end_ndx] = fp
        return BCELoss
        # return BCELoss
            
        # return diceLoss_g.mean()

        # return diceLoss_g.mean() + fnLoss_g.mean() * 8  #loss 的權重:positive pixel是negative的8倍, we should expect a large number of false positives in general
        

    def diceLoss(self, prediction_g, label_g, epsilon=1): #如果大部分的pixel是false，用dice會比較精準
        diceLabel_g = label_g.sum(dim=[1,2,3])  #將所有mask裡計為nodule的點加起來(我們的dataset是4維, 後3維是index, row, column)
        dicePrediction_g = prediction_g.sum(dim=[1,2,3])
        diceCorrect_g = (prediction_g * label_g).sum(dim=[1,2,3]) #預測正確的總量

        diceRatio_g = (2 * diceCorrect_g + epsilon) \
            / (dicePrediction_g + diceLabel_g + epsilon)  #epsilon避免其值為0

        return 1 - diceRatio_g  #為了最小化，要用1去扣


    def logImages(self, epoch_ndx, mode_str, dl):
        self.segmentation_model.eval()

        images = sorted(dl.dataset.series_list)[:12]
        for series_ndx, series_uid in enumerate(images):
            ct = getCt(series_uid)

            for slice_ndx in range(6):
                ct_ndx = slice_ndx * (ct.hu_a.shape[0] - 1) // 5
                sample_tup = dl.dataset.getitem_fullSlice(series_uid, ct_ndx)

                ct_t, label_t, series_uid, ct_ndx = sample_tup

                input_g = ct_t.to(self.device).unsqueeze(0)
                label_g = pos_g = label_t.to(self.device).unsqueeze(0)

                prediction_g = self.segmentation_model(input_g)[0]
                prediction_a = prediction_g.to('cpu').detach().numpy()[0] > 0.5
                label_a = label_g.cpu().numpy()[0][0] > 0.5

                ct_t[:-1,:,:] /= 2000
                ct_t[:-1,:,:] += 0.5

                ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()

                image_a = np.zeros((512, 512, 3), dtype=np.float32)
                image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
                image_a[:,:,0] += prediction_a & (1 - label_a)
                image_a[:,:,0] += (1 - prediction_a) & label_a
                image_a[:,:,1] += ((1 - prediction_a) & label_a) * 0.5

                image_a[:,:,1] += prediction_a & label_a
                image_a *= 0.5
                image_a.clip(0, 1, image_a)

                writer = getattr(self, mode_str + '_writer')
                writer.add_image(
                    f'{mode_str}/{series_ndx}_prediction_{slice_ndx}',
                    image_a,
                    self.totalTrainingSamples_count,
                    dataformats='HWC',
                )

                if epoch_ndx == 1:
                    image_a = np.zeros((512, 512, 3), dtype=np.float32)
                    image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
                    # image_a[:,:,0] += (1 - label_a) & lung_a # Red
                    image_a[:,:,1] += label_a  # Green
                    # image_a[:,:,2] += neg_a  # Blue

                    image_a *= 0.5
                    image_a[image_a < 0] = 0
                    image_a[image_a > 1] = 1
                    writer.add_image(
                        '{}/{}_label_{}'.format(
                            mode_str,
                            series_ndx,
                            slice_ndx,
                        ),
                        image_a,
                        self.totalTrainingSamples_count,
                        dataformats='HWC',
                    )
                # This flush prevents TB from getting confused about which
                # data item belongs where.
                writer.flush()

    def logMetrics(self, epoch_ndx, mode_str, metrics_t):
        log.info("E{} {}".format(
            epoch_ndx,
            type(self).__name__,
        ))

        metrics_a = metrics_t.detach().numpy()
        sum_a = metrics_a.sum(axis=1)
        assert np.isfinite(metrics_a).all()

        allLabel_count = sum_a[METRICS_TP_NDX] + sum_a[METRICS_FN_NDX]

        metrics_dict = {}
        metrics_dict['loss/all'] = metrics_a[METRICS_LOSS_NDX].mean()

        metrics_dict['percent_all/tp'] = \
            sum_a[METRICS_TP_NDX] / (allLabel_count or 1) * 100
        metrics_dict['percent_all/fn'] = \
            sum_a[METRICS_FN_NDX] / (allLabel_count or 1) * 100
        metrics_dict['percent_all/fp'] = \
            sum_a[METRICS_FP_NDX] / (allLabel_count or 1) * 100


        precision = metrics_dict['pr/precision'] = sum_a[METRICS_TP_NDX] \
            / ((sum_a[METRICS_TP_NDX] + sum_a[METRICS_FP_NDX]) or 1)
        recall    = metrics_dict['pr/recall']    = sum_a[METRICS_TP_NDX] \
            / ((sum_a[METRICS_TP_NDX] + sum_a[METRICS_FN_NDX]) or 1)

        metrics_dict['pr/f1_score'] = 2 * (precision * recall) \
            / ((precision + recall) or 1)

        log.info(("E{} {:8} "
                 + "{loss/all:.4f} loss, "
                 + "{pr/precision:.4f} precision, "
                 + "{pr/recall:.4f} recall, "
                 + "{pr/f1_score:.4f} f1 score"
                  ).format(
            epoch_ndx,
            mode_str,
            **metrics_dict,
        ))
        log.info(("E{} {:8} "
                  + "{loss/all:.4f} loss, "
                  + "{percent_all/tp:-5.1f}% tp, {percent_all/fn:-5.1f}% fn, {percent_all/fp:-9.1f}% fp"
        ).format(
            epoch_ndx,
            mode_str + '_all',
            **metrics_dict,
        ))

        self.initTensorboardWriters()
        writer = getattr(self, mode_str + '_writer')

        prefix_str = 'seg_'

        for key, value in metrics_dict.items():
            writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)

        writer.flush()

        score = metrics_dict['pr/recall']

        return score

    def saveModel(self, type_str, epoch_ndx, isBest=False):
        file_path = os.path.join(
            'models',
            self.cli_args.tb_prefix,
            '{}_{}_{}.{}.state'.format(
                type_str,
                self.time_str,
                self.cli_args.comment,
                self.totalTrainingSamples_count,
            )
        )

        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)

        model = self.segmentation_model
        if isinstance(model, torch.nn.DataParallel):
            model = model.module

        state = {
            'sys_argv': sys.argv,
            'time': str(datetime.datetime.now()),
            'model_state': model.state_dict(),
            'model_name': type(model).__name__,
            'optimizer_state' : self.optimizer.state_dict(),
            'optimizer_name': type(self.optimizer).__name__,
            'epoch': epoch_ndx,
            'totalTrainingSamples_count': self.totalTrainingSamples_count,
        }
        torch.save(state, file_path)

        log.info("Saved model params to {}".format(file_path))

        if isBest:
            best_path = os.path.join(
               'models',
                self.cli_args.tb_prefix,
                f'{type_str}_{self.time_str}_{self.cli_args.comment}.best.state')
            shutil.copyfile(file_path, best_path)

            log.info("Saved model params to {}".format(best_path))

        with open(file_path, 'rb') as f:
            log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())

In [8]:
# train_ds = TrainingLuna2dSegmentationDataset(
#             val_stride=10,
#             isValSet_bool=False,
#             contextSlices_count=2,
#         )
# # print(len(train_ds))
# total_rate = 0
# total = (512 * 512)
# i = 0
# average_rate = 0
# for i in range(8001):
#     # if (i % 10 == 0):
#     #     # print("i = ", i)
#     #     average_rate += (total_rate) / 10
#     #     # print("avg = ", average_rate)
#     #     total_rate = 0
#     if (i % 1000 == 0):
#         print("i = ", i)
#         print("avg = ", average_rate)
#     i += 1
#     csum = (train_ds[i][1].sum())
#     # print("sum = ", csum)
#     total_rate += csum / total
# average_rate = (total_rate) / 8000
# print(average_rate)

In [9]:
# LunaPrepCacheApp(sys_argv=["--num-workers=4"]).main()

In [10]:
# torch.autograd.set_detect_anomaly(False)

In [11]:
# import pandas as pd
# import glob
# import SimpleITK as sitk
# import numpy as np
# import collections
# from PIL import Image, ImageDraw

# train_ds = TrainingLuna2dSegmentationDataset(
#             series_uid="1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860"
#         )
# hu_a = train_ds[0][0].numpy()
# hu_mask = train_ds[0][1].numpy().astype(int)
# print(hu_a.shape)
# print(hu_mask.shape)
# min_value = np.min(hu_a[3])
# max_value = np.max(hu_a[3])
# scaled_hu_a = (hu_a[3] - min_value) / (max_value - min_value) * 255
# scaled_hu_a = scaled_hu_a.astype(np.uint8)
# slice_ori = Image.fromarray(scaled_hu_a, mode='L')
# slice_ori.save("origin.png")
# min_value_mask = np.min(hu_mask[0])
# max_value_mask = np.max(hu_mask[0])
# # print(min_value_mask)
# # print(max_value_mask)
# scaled_mask = hu_mask[0] * 255
# slice_mask = Image.fromarray(scaled_mask, mode='L')
# slice_mask.save("mask.png")

In [12]:
# train_ds = TrainingLuna2dSegmentationDataset(
#             series_uid="1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860"
#         )
# print(train_ds[0][2])

In [13]:
# ct_slice = Image.fromarray(scaled_hu_a, mode='L')

# # Create a drawing context on the image
# draw = ImageDraw.Draw(ct_slice)

# # Define the coordinates to mark (row 212, column 45) as a red rectangle
# x1, y1, x2, y2 = 44, 211, 46, 213  # Adjust these coordinates as needed

# # Define the outline color as "red"
# outline_color = (0, 0, 255)  # Use grayscale value 255 for white outline

# # Draw a red rectangle on the image to mark the specific row and column
# draw.rectangle([x1, y1, x2, y2], outline="white", width=3)  # Increase width for better visibility

# # Save the marked slice as a PNG
# ct_slice.save("marked_slice.png")

In [14]:
def np2Png(np_arr, target_name):
    min_value = np.min(np_arr[0])
    max_value = np.max(np_arr[0])
    scaled_np_arr = (np_arr[0] - min_value) / (max_value - min_value) * 255
    scaled_np_arr = scaled_np_arr.astype(np.uint8)
    slice_ori = Image.fromarray(scaled_np_arr, mode='L')
    slice_ori.save(target_name)

In [15]:
# segmentation_model = UNetWrapper(
#             in_channels=7,
#             n_classes=1,
#             depth=2,  #how deep the U go
#             wf=6,   #2^4 filter
#             padding=True, #padding so that we get the output size as input size
#             batch_norm=True,
#             up_mode='upconv', #use  nn.ConvTranspose2d
#         )
# # model_state
# # torch.load("F:\\udet\\models\\udet\\seg_2023-10-19_08.28.18_final-cls.best.state")["model_state"]
# segmentation_model.load_state_dict(torch.load("F:\\udet\\models\\udet\\u_net_depth2_200epcoch_f1score0.2.state")["model_state"])
# device = torch.device("cuda")
# segmentation_model.to(device)
# segmentation_model.eval()
# val_ds = Luna2dSegmentationDataset(
#             val_stride=10,
#             isValSet_bool=True,
#             contextSlices_count=3,
#         )

# batch_size = 8

# val_dl = DataLoader(
#     val_ds,
#     batch_size=batch_size,
#     num_workers=4,
#     pin_memory=True,
# )
# batch_iter = enumerateWithEstimate(
#     val_dl,
#     "E{} Validation ".format(1),
#     start_ndx=val_dl.num_workers,
# )

In [16]:
# for batch_ndx, batch_tup in batch_iter:
#     input_t, label_t, series_list, _slice_ndx_list = batch_tup

#     input_g = input_t.to(device, non_blocking=True)
#     label_g = label_t.to(device, non_blocking=True)

#     prediction_g = segmentation_model(input_g)
#     # np2Png(input_g.cpu().numpy().astype(int), "./test/test.png")
#     # np2Png(label_g.cpu().numpy().astype(int), "./test/label.png")
#     # np2Png(prediction_g.cpu().detach().numpy().astype(float), "./test/predict.png")
#     break

In [17]:
val_ds = Luna2dSegmentationDataset(
            val_stride=1,
            isValSet_bool=True,
            contextSlices_count=0,
        )
print(len(val_ds))

In [18]:
val_ds[2]

In [19]:
print(val_ds[0][0].size())

In [20]:
# for i in range(len(val_ds)):
#     origin_n = val_ds[i][0].numpy()
#     mask_n = val_ds[i][1].numpy()
#     mask_ori_n = ((val_ds[i][0].float() + 1001) * val_ds[i][1]).numpy()
#     # print(type(origin_n[0]))
#     np2Png(origin_n.astype(int), "./origin_new/{}_{}_{}.png".format(i, val_ds[i][2], val_ds[i][3]))
#     np2Png(mask_n.astype(int), "./mask_new/{}_{}_{}.png".format(i, val_ds[i][2], val_ds[i][3]))
#     np2Png(mask_ori_n.astype(int), "./mask_origin_new/{}_{}_{}.png".format(i, val_ds[i][2], val_ds[i][3]))

In [21]:
# origin_n = input_g.cpu().numpy()
# np2Png(origin_n[0].astype(int), "./test/test.png")

In [22]:
# prediction_g.size()
# label_g.size()
# label_n = label_g.cpu().numpy()
# np2Png(label_n[0].astype(int), "./test/label.png")

In [23]:
# prediction_n = prediction_g.cpu().detach().numpy()
# np2Png(prediction_n[0].astype(float), "./test/predict.png")

In [24]:
# np.sum(prediction_n)

In [25]:
# print(prediction_n)

In [26]:
# torch.cuda.is_available()

In [27]:
SegmentationTrainingApp(sys_argv=['--epochs=400','--augmented', 'final-cls',"--num-workers=8", "--batch-size=8"]).main()

2023-10-22 20:26:07,000 INFO     pid:6164 __main__:115:initModel Using CUDA; 1 devices.
2023-10-22 20:26:07,073 INFO     pid:6164 __main__:179:main Starting SegmentationTrainingApp, Namespace(batch_size=8, num_workers=8, epochs=400, augmented=True, augment_flip=False, augment_offset=False, augment_scale=False, augment_rotate=False, augment_noise=False, tb_prefix='udet', comment='final-cls')
2023-10-22 20:26:10,023 INFO     pid:6164 dsetsFullCT:333:__init__ <dsetsFullCT.TrainingLuna2dSegmentationDataset object at 0x000001783E4026D0>: 80 training series, 1622 slices, 101 nodules
2023-10-22 20:26:10,047 INFO     pid:6164 dsetsFullCT:333:__init__ <dsetsFullCT.Luna2dSegmentationDataset object at 0x000001783E43E610>: 9 validation series, 131 slices, 11 nodules
2023-10-22 20:26:10,048 INFO     pid:6164 __main__:187:main Epoch 1 of 400, 203/17 batches of size 8*1
2023-10-22 20:26:35,104 INFO     pid:6164 util:126:enumerateWithEstimate E1 Training   64/203, done at 2023-10-22 20:26:55, 0:00:29


KeyboardInterrupt: 

In [None]:
# %tensorboard --logdir E:\LUNA\nodule_detection\runs\p2ch13\2023-08-17_12.51.58_val_seg_final-cls

In [None]:
# import os
# os.getcwd()

In [None]:
# run('p2ch13.training.SegmentationTrainingApp', f'--epochs={final_epochs}', '--augmented', 'final-cls')

In [None]:
# SegmentationTrainingApp(sys_argv=['--epochs=1','--augmented', 'final-cls']).main()