# Image Segmentation for Lung CT

In [None]:
import matplotlib.pyplot as plt
import os, glob

from torchvision.utils import make_grid

from torch.utils.data import Dataset, DataLoader
from skimage.io import imread
from skimage.transform import pyramid_reduce, resize
from torch.utils.data import random_split

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import NeptuneLogger
from pytorch_lightning import Trainer
from torchvision import transforms

import torch
import torch.nn.functional as F
from torch import nn
import numpy as np

import random
random.seed(17)

### Build dataset

In [None]:
import numpy as np

class LungCTscan(Dataset):
    def __init__(self, data_dir, transform=None):
        self.img_list = sorted(glob.glob(data_dir + '/2d_images/*.tif'))
        self.mask_list = sorted(glob.glob(data_dir + '/2d_masks/*.tif'))
        self.transform = transform
        self.image_size = 256
        
    def __len__(self):
        return len(self.img_list)
        
    def __getitem__(self, idx):
        image_path = self.img_list[idx]
        mask_path = self.mask_list[idx]

        # load image
        image = imread(image_path) / 255.0
        # resize image with 1 channel
        image = resize(image, output_shape=(self.image_size, self.image_size), preserve_range=True)

        # load image
        mask = imread(mask_path) / 255.0
        # resize mask with 1 channel
        mask = resize(mask, output_shape=(self.image_size, self.image_size), preserve_range=True)
        image, mask = np.array(image[..., np.newaxis], dtype=np.float32), np.array(mask[..., np.newaxis], dtype=np.float32)
        
        if self.transform is not None:
            image = self.transform(image)
            mask = self.transform(mask)
        return image, mask

In [None]:
test_data = LungCTscan(data_dir="CT-scan-dataset")

train_data, test_data = random_split(test_data, [200, 67])

image, mask = test_data[0]
fig, ax = plt.subplots(1, 2)
ax[0].imshow(image.squeeze())
ax[1].imshow(mask.squeeze())

In [None]:
(mask == 0).any()

# Build model

In [None]:
def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.LeakyReLU(inplace=True),
    )

def deconvrelu(in_channels, out_channels, kernel, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel, stride=stride, padding=padding, output_padding=1),
        nn.LeakyReLU(inplace=True),
    )

class UNet(nn.Module):
    """
    Network description
    """
    def __init__(self, n_filters = 16):
        super(UNet, self).__init__()

        # Contracting Path
        self.conv1 = convrelu(1, n_filters*1, 3, 1)
        self.conv2 = convrelu(n_filters*1, n_filters*2, 3, 1)
        self.conv3 = convrelu(n_filters*2, n_filters*4, 3, 1)
        self.conv4 = convrelu(n_filters*4, n_filters*8, 3, 1)
        self.conv5 = convrelu(n_filters*8, n_filters*16, 3, 1)

        # Expansive Path
        self.conv_up6 = deconvrelu(n_filters*16, n_filters*8, 3, 2, 1)
        self.conv6 = convrelu(2*n_filters*8, n_filters*8, 3, 1)
        self.conv_up7 = deconvrelu(n_filters*8, n_filters*4, 3, 2, 1)
        self.conv7 = convrelu(2*n_filters*4, n_filters*4, 3, 1)
        self.conv_up8 = deconvrelu(n_filters*4, n_filters*2, 3, 2, 1)
        self.conv8 = convrelu(2*n_filters*2, n_filters*2, 3, 1)
        self.conv_up9 = deconvrelu(n_filters*2, n_filters*1, 3, 2, 1)
        self.conv9 = convrelu(2*n_filters*1, n_filters*1, 3, 1)

        self.out = nn.Conv2d(n_filters*1, 1, kernel_size=1, padding=0)

    def forward(self, x, y=None):
        c1 = self.conv1(x)
        p1 = nn.MaxPool2d(2)(c1)

        c2 = self.conv2(p1)
        p2 = nn.MaxPool2d(2)(c2)

        c3 = self.conv3(p2)
        p3 = nn.MaxPool2d(2)(c3)

        c4 = self.conv4(p3)
        p4 = nn.MaxPool2d(2)(c4)
        
        c5 = self.conv5(p4)

        u6 = self.conv_up6(c5)
        cat6 = torch.cat([u6, c4], dim=1)
        c6 = self.conv6(cat6)

        u7 = self.conv_up7(c6)
        cat7 = torch.cat([u7, c3], dim=1)
        c7 = self.conv7(cat7)

        u8 = self.conv_up8(c7)
        cat8 = torch.cat([u8, c2], dim=1)
        c8 = self.conv8(cat8)

        u9 = self.conv_up9(c8)
        cat9 = torch.cat([u9, c1], dim=1)
        c9 = self.conv9(cat9)

        output = self.out(c9)
        # output = torch.sigmoid(output)
    
        return output

In [None]:
class LossFunction(nn.Module):
    """
    Loss function
    """
    def __init__(self):
        super(LossFunction, self).__init__()
        self.MSE = nn.MSELoss(size_average=True)

    def forward(self, labels, seg, bce_weight=0.5):
        bce = F.binary_cross_entropy_with_logits(seg, labels)
        # loss = self.MSE(labels, seg)
     
        return bce

In [None]:
PARAMS = {
    "batch_size": 16,
    "lr": 0.0001,
    "max_epochs": 100,
    'lr_step': 10,
    'lr_decay': 0.8
}

In [None]:
from neptune.types import File

class UnetModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = UNet()
        self.loss = LossFunction()

        self.training_step_outputs = []
        self.validation_step_outputs = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        # y = y.squeeze()
        y_hat = self(x)
        loss = self.loss(y, y_hat)
        self.log("metrics/batch/loss", loss, prog_bar=False)
        
        y_true = y.cpu().detach()
        y_pred = y_hat.cpu().detach()
        # print(y_pred)
       

        self.training_step_outputs.append({"loss": loss.item(), "y_true": y_true, "y_pred": y_pred})

        return loss

    def on_train_epoch_end(self):
        outputs = self.training_step_outputs
        loss = np.array([])
      
        for results_dict in outputs:
            loss = np.append(loss, results_dict["loss"])
          
        self.log("metrics/epoch/loss", loss.mean())
        self.training_step_outputs.clear()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        # y = y.squeeze()
        y_hat = self(x)
        loss = self.loss(y, y_hat)

        y_true = y.cpu().detach()
        y_pred = torch.sigmoid(y_hat).cpu().detach()
        
        self.validation_step_outputs.append({"loss": loss.item(), "y_true": y_true, "y_pred": y_pred})
    

    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        loss = np.array([])
      
        for results_dict in outputs:
            loss = np.append(loss, results_dict["loss"])
       
        y_true = make_grid(outputs[0]["y_true"], nrow=int(PARAMS["batch_size"] ** 0.5))
        y_pred = make_grid(outputs[0]["y_pred"], nrow=int(PARAMS["batch_size"] ** 0.5))
        y_true = y_true.cpu().numpy().transpose(1, 2, 0)
        y_pred = y_pred.cpu().numpy().transpose(1, 2, 0)
        
        self.log("val/epoch/loss", loss.mean())
        self.logger.experiment["val/epoch/loss"] = loss.mean()
        self.logger.experiment["val/gt_images"].append(File.as_image(y_true))
        self.logger.experiment["val/outputs"].append(File.as_image(y_pred))
        # self.logger.experiment['val/epoch/outputs'] = y_pred
       
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=PARAMS['lr'])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=PARAMS['lr_step'], 
                                                    gamma=PARAMS['lr_decay'])
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [None]:
# %%script false --no-raise-error

neptune_logger = NeptuneLogger(
    project="kaori/Seg",
    api_key="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vbmV3LXVpLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9uZXctdWkubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyZjZiMDA2YS02MDM3LTQxZjQtOTE4YS1jODZkMTJjNGJlMDYifQ==",
    log_model_checkpoints=False,
)

neptune_logger.log_hyperparams(params=PARAMS)

In [None]:
# torch.set_float32_matmul_precision('medium')

unet_model = UnetModule()

Test_transform = transforms.Compose([
            # transforms.ToPILImage(),
            transforms.ToTensor(),# default : range [0, 255] -> [0.0,1.0]
            # transforms.Normalize((0.5), (0.5))
        ])
train_ds = LungCTscan(data_dir="CT-scan-dataset", transform=Test_transform)
train_data, val_data = random_split(train_ds, [200, 67])

train_loader = DataLoader(train_data, batch_size=PARAMS["batch_size"])
val_loader = DataLoader(val_data, batch_size=PARAMS["batch_size"])


checkpoint_callback = ModelCheckpoint(dirpath=os.path.join("model", "test"), save_top_k=1, monitor='metrics/batch/loss', mode="min")

    # (neptune) initialize a trainer and pass neptune_logger
trainer = Trainer(
    logger=neptune_logger,
    max_epochs=PARAMS["max_epochs"],
    accelerator="gpu",
    devices=[0],
    callbacks=[checkpoint_callback]
    )

#Training and save model
trainer.fit(unet_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# Validation
trainer.validate(unet_model, val_loader)