<a href="https://colab.research.google.com/github/PeterJackson61/U-net_study/blob/main/Liver_tumor_3D_segmentation_study_eval_added.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Download the dataset

In [1]:
!pip install --upgrade --no-cache-dir gdown

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gdown
  Downloading gdown-4.6.0-py3-none-any.whl (14 kB)
Installing collected packages: gdown
  Attempting uninstall: gdown
    Found existing installation: gdown 4.4.0
    Uninstalling gdown-4.4.0:
      Successfully uninstalled gdown-4.4.0
Successfully installed gdown-4.6.0


In [2]:
!gdown https://drive.google.com/u/0/uc?id=1LmYkBOG7TyW894asLF6OSqdaB-cfA4go&export=download


Downloading...
From: https://drive.google.com/u/0/uc?id=1LmYkBOG7TyW894asLF6OSqdaB-cfA4go
To: /content/08-3D-Liver-Tumor-Segmentation.zip
100% 2.91G/2.91G [00:27<00:00, 104MB/s]


In [3]:
!unzip /content/08-3D-Liver-Tumor-Segmentation.zip -d /content/my_data

Archive:  /content/08-3D-Liver-Tumor-Segmentation.zip
   creating: /content/my_data/08-3D-Liver-Tumor-Segmentation/
  inflating: /content/my_data/08-3D-Liver-Tumor-Segmentation/01-Data.ipynb  
  inflating: /content/my_data/08-3D-Liver-Tumor-Segmentation/02-Model.ipynb  
  inflating: /content/my_data/08-3D-Liver-Tumor-Segmentation/03-Train.ipynb  
  inflating: /content/my_data/08-3D-Liver-Tumor-Segmentation/model.py  
   creating: /content/my_data/08-3D-Liver-Tumor-Segmentation/.ipynb_checkpoints/
  inflating: /content/my_data/08-3D-Liver-Tumor-Segmentation/.ipynb_checkpoints/03-Train-checkpoint.ipynb  
   creating: /content/my_data/08-3D-Liver-Tumor-Segmentation/Task03_Liver_rs/
  inflating: /content/my_data/08-3D-Liver-Tumor-Segmentation/Task03_Liver_rs/LICENSE.txt  
   creating: /content/my_data/08-3D-Liver-Tumor-Segmentation/Task03_Liver_rs/.ipynb_checkpoints/
   creating: /content/my_data/08-3D-Liver-Tumor-Segmentation/Task03_Liver_rs/imagesTr/
  inflating: /content/my_data/08-3D-L

In [4]:
!pip install celluloid

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting celluloid
  Downloading celluloid-0.2.0-py3-none-any.whl (5.4 kB)
Installing collected packages: celluloid
Successfully installed celluloid-0.2.0


### Display the data

In [5]:
%matplotlib notebook
from pathlib import Path
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
from celluloid import Camera
from IPython.display import HTML

In [6]:
root = Path("/content/my_data/08-3D-Liver-Tumor-Segmentation/Task03_Liver_rs/imagesTr/")
label = Path("/content/my_data/08-3D-Liver-Tumor-Segmentation/Task03_Liver_rsrs/labelsTr/")

In [7]:
def change_img_to_label_path(path):
    """
    Replaces imagesTr with labelsTr
    """
    parts = list(path.parts)  # get all directories within the path
    parts[parts.index("imagesTr")] = "labelsTr"  # Replace imagesTr with labelsTr
    return Path(*parts)  # Combine list back into a Path object

In [8]:
sample_path = list(root.glob("liver*"))[0]  # Choose a subject
sample_path_label = change_img_to_label_path(sample_path)


In [9]:
data = nib.load(sample_path)
label = nib.load(sample_path_label)

ct = data.get_fdata()
mask = label.get_fdata().astype(int)

In [10]:
nib.aff2axcodes(data.affine)

('R', 'A', 'S')

In [11]:
fig = plt.figure()
camera = Camera(fig)  # Create the camera object from celluloid

for i in range(ct.shape[2]):  # Axial view
    plt.imshow(ct[:,:,i], cmap="bone")
    mask_ = np.ma.masked_where(mask[:,:,i]==0, mask[:,:,i])
    plt.imshow(mask_, alpha=0.5)
    # plt.axis("off")
    camera.snap()  # Store the current slice
plt.tight_layout()
animation = camera.animate()

<IPython.core.display.Javascript object>

In [12]:
HTML(animation.to_html5_video())

  dv = np.float64(self.norm.vmax) - np.float64(self.norm.vmin)
  a_min = np.float64(newmin)
  a_max = np.float64(newmax)
  data = np.asarray(value)


In [13]:
import torch

In [14]:
class DoubleConv(torch.nn.Module):
    """
    Helper Class which implements the intermediate Convolutions
    """
    def __init__(self, in_channels, out_channels):
        
        super().__init__()
        self.step = torch.nn.Sequential(torch.nn.Conv3d(in_channels, out_channels, 3, padding=1),
                                        torch.nn.ReLU(),
                                        torch.nn.Conv3d(out_channels, out_channels, 3, padding=1),
                                        torch.nn.ReLU())
        
    def forward(self, X):
        return self.step(X)

class UNet(torch.nn.Module):
    """
    This class implements a UNet for the Segmentation
    We use 3 down- and 3 UpConvolutions and two Convolutions in each step
    """

    def __init__(self):
        """Sets up the U-Net Structure
        """
        super().__init__()
        
        
        ############# DOWN #####################
        self.layer1 = DoubleConv(1, 32)
        self.layer2 = DoubleConv(32, 64)
        self.layer3 = DoubleConv(64, 128)
        self.layer4 = DoubleConv(128, 256)

        #########################################

        ############## UP #######################
        self.layer5 = DoubleConv(256 + 128, 128)
        self.layer6 = DoubleConv(128+64, 64)
        self.layer7 = DoubleConv(64+32, 32)
        self.layer8 = torch.nn.Conv3d(32, 3, 1)  # Output: 3 values -> background, liver, tumor
        #########################################

        self.maxpool = torch.nn.MaxPool3d(2)

    def forward(self, x):
        
        ####### DownConv 1#########
        x1 = self.layer1(x)
        x1m = self.maxpool(x1)
        ###########################
        
        ####### DownConv 2#########        
        x2 = self.layer2(x1m)
        x2m = self.maxpool(x2)
        ###########################

        ####### DownConv 3#########        
        x3 = self.layer3(x2m)
        x3m = self.maxpool(x3)
        ###########################
        
        ##### Intermediate Layer ## 
        x4 = self.layer4(x3m)
        ###########################

        ####### UpCONV 1#########        
        x5 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x4)  # Upsample with a factor of 2
        x5 = torch.cat([x5, x3], dim=1)  # Skip-Connection
        x5 = self.layer5(x5)
        ###########################

        ####### UpCONV 2#########        
        x6 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x5)        
        x6 = torch.cat([x6, x2], dim=1)  # Skip-Connection    
        x6 = self.layer6(x6)
        ###########################
        
        ####### UpCONV 3#########        
        x7 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x6)
        x7 = torch.cat([x7, x1], dim=1)       
        x7 = self.layer7(x7)
        ###########################
        
        ####### Predicted segmentation#########        
        ret = self.layer8(x7)
        return ret

In [15]:
model = UNet()

### Torchio for loading, preprocessing and patching 3D medical images

In [16]:
!pip install torchio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchio
  Downloading torchio-0.18.86-py2.py3-none-any.whl (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.8/172.8 KB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Collecting SimpleITK!=2.0.*,!=2.1.1.1
  Downloading SimpleITK-2.2.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
Collecting Deprecated
  Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)
Collecting colorama<0.5.0,>=0.4.3
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Collecting shellingham<2.0.0,>=1.3.0
  Downloading shellingham-1.5.0.post1-py2.py3-none-any.whl (9.4 kB)
Collecting rich<13.0.0,>=10.11.0
  Downloading rich-12.6.0-py3-none-any.whl (237 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m237.5/237.5 KB[0

In [17]:
!pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning
  Downloading pytorch_lightning-1.8.6-py3-none-any.whl (800 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m800.3/800.3 KB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities!=0.4.0,>=0.3.0
  Downloading lightning_utilities-0.5.0-py3-none-any.whl (18 kB)
Collecting tensorboardX>=2.2
  Downloading tensorboardX-2.5.1-py2.py3-none-any.whl (125 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.4/125.4 KB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.11.0-py3-none-any.whl (512 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m512.4/512.4 KB[0m [31m51.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX, torchmetrics, lightning-utilities, pytorch_lightning
Successfully installed lightning-utilities-0.5.0 pyto

In [18]:
from pathlib import Path

import torchio as tio
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
import numpy as np

In [19]:
def change_img_to_label_path(path):
    """
    Replace data with mask to get the masks
    """
    parts = list(path.parts)
    parts[parts.index("imagesTr")] = "labelsTr"
    return Path(*parts)

In [20]:
path = Path("/content/my_data/08-3D-Liver-Tumor-Segmentation/Task03_Liver_rs/imagesTr/")
subjects_paths = list(path.glob("liver_*"))
subjects = []

for subject_path in subjects_paths:
    label_path = change_img_to_label_path(subject_path)
    subject = tio.Subject({"CT":tio.ScalarImage(subject_path), "Label":tio.LabelMap(label_path)})
    subjects.append(subject)

In [21]:
for subject in subjects:
    assert subject["CT"].orientation == ("R", "A", "S")

In [22]:
process = tio.Compose([
            tio.CropOrPad((256, 256, 200)),
            tio.RescaleIntensity((-1, 1))
            ])


augmentation = tio.RandomAffine(scales=(0.9, 1.1), degrees=(-10, 10))


val_transform = process
train_transform = tio.Compose([process, augmentation])

In [23]:
train_dataset = tio.SubjectsDataset(subjects[:105], transform=train_transform)
val_dataset = tio.SubjectsDataset(subjects[105:], transform=val_transform)

sampler = tio.data.LabelSampler(patch_size=96, label_name="Label", label_probabilities={0:0.2, 1:0.3, 2:0.5})

In [24]:
train_patches_queue = tio.Queue(
     train_dataset,
     max_length=40,
     samples_per_volume=5,
     sampler=sampler,
     num_workers=4,
    )

val_patches_queue = tio.Queue(
     val_dataset,
     max_length=40,
     samples_per_volume=5,
     sampler=sampler,
     num_workers=4,
    )



In [25]:
batch_size = 2

train_loader = torch.utils.data.DataLoader(train_patches_queue, batch_size=batch_size, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_patches_queue, batch_size=batch_size, num_workers=0)

In [26]:
class Segmenter(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = UNet()
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.loss_fn = torch.nn.CrossEntropyLoss()
    
    def forward(self, data):
        pred = self.model(data)
        return pred
    
    def training_step(self, batch, batch_idx):
        # You can obtain the raw volume arrays by accessing the data attribute of the subject
        img = batch["CT"]["data"]
        mask = batch["Label"]["data"][:,0]  # Remove single channel as CrossEntropyLoss expects NxHxW
        mask = mask.long()
        
        pred = self(img)
        loss = self.loss_fn(pred, mask)
        
        # Logs
        self.log("Train Loss", loss)
        if batch_idx % 50 == 0:
            self.log_images(img.cpu(), pred.cpu(), mask.cpu(), "Train")
        return loss
    
        
    def validation_step(self, batch, batch_idx):
        # You can obtain the raw volume arrays by accessing the data attribute of the subject
        img = batch["CT"]["data"]
        mask = batch["Label"]["data"][:,0]  # Remove single channel as CrossEntropyLoss expects NxHxW
        mask = mask.long()
        
        pred = self(img)
        loss = self.loss_fn(pred, mask)
        
        # Logs
        self.log("Val Loss", loss)
        self.log_images(img.cpu(), pred.cpu(), mask.cpu(), "Val")
        
        return loss

    
    def log_images(self, img, pred, mask, name):
        
        results = []
        pred = torch.argmax(pred, 1) # Take the output with the highest value
        axial_slice = 50  # Always plot slice 50 of the 96 slices
        
        fig, axis = plt.subplots(1, 2)
        axis[0].imshow(img[0][0][:,:,axial_slice], cmap="bone")
        mask_ = np.ma.masked_where(mask[0][:,:,axial_slice]==0, mask[0][:,:,axial_slice])
        axis[0].imshow(mask_, alpha=0.6)
        axis[0].set_title("Ground Truth")
        
        axis[1].imshow(img[0][0][:,:,axial_slice], cmap="bone")
        mask_ = np.ma.masked_where(pred[0][:,:,axial_slice]==0, pred[0][:,:,axial_slice])
        axis[1].imshow(mask_, alpha=0.6, cmap="autumn")
        axis[1].set_title("Pred")

        self.logger.experiment.add_figure(f"{name} Prediction vs Label", fig, self.global_step)

            
    
    def configure_optimizers(self):
        #Caution! You always need to return a list here (just pack your optimizer into one :))
        return [self.optimizer]

In [27]:
model = Segmenter()

# Create the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='Val Loss',
    save_top_k=10,
    mode='min')

### This part for training will be skip because of the long time, we go straight to the model examination

In [28]:
gpus = 1 #TODO
trainer = pl.Trainer(gpus=gpus, logger=TensorBoardLogger(save_dir="./logs"), log_every_n_steps=1,
                     callbacks=checkpoint_callback,
                     max_epochs=100)

  rank_zero_deprecation(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_loader, val_loader)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params
---------------------------------------------
0 | model   | UNet             | 5.8 M 
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
5.8 M     Trainable params
0         Non-trainable params
5.8 M     Total params
23.344    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



<IPython.core.display.Javascript object>

  dv = np.float64(self.norm.vmax) - np.float64(self.norm.vmin)
  a_min = np.float64(newmin)
  a_max = np.float64(newmax)
  data = np.asarray(value)


<IPython.core.display.Javascript object>

Training: 0it [00:00, ?it/s]

<IPython.core.display.Javascript object>

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [29]:
from IPython.display import HTML
from celluloid import Camera

In [30]:
model = Segmenter.load_from_checkpoint("/content/my_data/08-3D-Liver-Tumor-Segmentation/weights/epoch=97-step=25773.ckpt")

In [31]:
model = model.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device);

In [32]:
IDX = np.random.randint(0,len(val_dataset)+1)
mask = val_dataset[IDX]["Label"]["data"]
imgs = val_dataset[IDX]["CT"]["data"]

In [33]:
grid_sampler = tio.inference.GridSampler(val_dataset[IDX], 96, (8, 8, 8))

In [34]:
aggregator = tio.inference.GridAggregator(grid_sampler)

In [35]:
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)

In [36]:
with torch.no_grad():
    for patches_batch in patch_loader:
        input_tensor = patches_batch['CT']["data"].to(device)  # Get batch of patches
        locations = patches_batch[tio.LOCATION]  # Get locations of patches
        pred = model(input_tensor)  # Compute prediction
        aggregator.add_batch(pred, locations) 

In [37]:
output_tensor = aggregator.get_output_tensor()  

In [44]:
fig = plt.figure(figsize=(10,8))
camera = Camera(fig)  # create the camera object from celluloid
pred = output_tensor.argmax(0)

for i in range(50, output_tensor.shape[3]-50, 1):  # axial view
    plt.subplot(2,2,1)
    plt.imshow(imgs[0,:,:,i], cmap="bone")
    mask_ = np.ma.masked_where(pred[:,:,i]==0, pred[:,:,i])
    plt.imshow(mask_,)
    plt.subplot(2,2,2)
    plt.imshow(imgs[0,:,:,i], cmap="bone")
    label_mask = np.ma.masked_where(mask[0,:,:,i]==0, mask[0,:,:,i])
    plt.imshow(label_mask, alpha=0.5, cmap="jet")  
    # springmment if you want to see the label
    plt.subplot(2,2,3)
    mask_tumor = np.where(pred[:,:,i] > 1, 1, 0)
    plt.imshow(mask_tumor)
    plt.subplot(2,2,4)
    mask_label_tumor = np.where(mask[0,:,:,i] > 1, 1, 0)
    plt.imshow(mask_label_tumor)
    # plt.axis("off")
    camera.snap()  # Store the current slice
animation = camera.animate()

<IPython.core.display.Javascript object>

  plt.subplot(2,2,1)
  plt.subplot(2,2,2)
  plt.subplot(2,2,3)
  plt.subplot(2,2,4)


In [45]:
HTML(animation.to_html5_video())

  dv = np.float64(self.norm.vmax) - np.float64(self.norm.vmin)
  a_min = np.float64(newmin)
  a_max = np.float64(newmax)
  data = np.asarray(value)


In [None]:
from sklearn.metrics import f1_score, jaccard_score, precision_score, recall_score

In [None]:
fig_2 = plt.figure()
camera_2 = Camera(fig_2)  # create the camera object from celluloid
pred = output_tensor.argmax(0)

mask_combined = []
label_mask_combined = []
jac_ind = []
for i in range(0, output_tensor.shape[3], 1):
  plt.subplot(1,2,1)
  mask_ = np.where(pred[:,:,i]>1,1,0)
  mask_sq = mask_.ravel()
  plt.imshow(mask_)
  plt.subplot(1,2,2)
  label_mask = np.where(mask[0,:,:,i]>1,1,0)
  label_mask_sq = label_mask.ravel()
  plt.imshow(label_mask)
  jac_value = jaccard_score(mask_sq, label_mask_sq, labels=[0, 1], average="binary")
  jac_ind.append(jac_value)
  camera_2.snap()  # Store the current slice
animation_2 = camera_2.animate()

  fig_2 = plt.figure()


<IPython.core.display.Javascript object>

  _warn_prf(average, modifier, msg_start, len(result))
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
  _warn_prf(average, modifier, msg_start, len(result))
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
  _warn_prf(average, modifier, msg_start, len(result))
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
  _warn_prf(average, modifier, msg_start, len(result))
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
  _warn_prf(average, modifier, msg_start, len(result))
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
  _warn_prf(average, modifier, msg_start, len(result))
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
  _warn_prf(average, modifier, msg_start, len(result))
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
  _warn_prf(average, modifier, msg_start, len(result))
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
  _warn_prf(average, modifier, msg_start, len(result))
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
  _warn_prf(average, modifier, msg_start, len(result))
  plt.subplot(1,2,1)
  plt.subplot(1,2,2)
  _warn_prf(average, modifier,

In [None]:
print(jac_ind)

[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9998779278248264, 0.9998168246554013, 0.999664040070857, 0.9997250061109754, 0.9993122420907841, 0.9990515091712944, 0.9982868854966502, 0.9985035197825523, 0.9994195814877043, 0.9992205291231717, 0.9990359454620575, 0.9986192507133871, 0.9949778071293637, 0.995134644105375, 0.9943244035225713, 0.9940553393860787, 0.9936613948023437, 0.9941476102741953, 0.9942196531791907, 0.9945855376452519, 0.9955895142253696, 0.9958497839400628, 0.9961117332337937, 0.9965629325494954, 0.9970763416947888, 0.9974190739761809, 0

In [None]:
HTML(animation_2.to_html5_video())