<a href="https://colab.research.google.com/github/Alisoltan82/Small-Cell-lung-Cancer/blob/main/lung_tumour_unet_(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# !pip install torch==2.1.1

In [None]:
!pip install nibabel
!pip install monai

In [None]:
import os
from glob import glob

In [None]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import monai

from monai.config import print_config
from monai.apps import download_and_extract , DecathlonDataset

from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric , get_confusion_matrix , LabelQualityScore , label_quality_score , LossMetric
from monai.losses import DiceLoss , DiceCELoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch , ArrayDataset

from monai.config import print_config , KeysCollection
from monai.utils import first , set_determinism
from monai.transforms import (
    Compose,
    LoadImage,
    LoadImaged,
    EnsureChannelFirst,
    EnsureChannelFirstd,
    ToTensor,
    ToTensord,
    ScaleIntensityRange,
    ScaleIntensityRanged,
    ThresholdIntensity,
    ThresholdIntensityd,
    SaveImaged,
    Spacingd,
    CropForegroundd,
    Orientationd,
    AsDiscrete,
    RandCropByPosNegLabeld,
    DivisiblePadd,
    Resized,
    RandFlipd,
    RandRotate90d,
    RandShiftIntensityd




)

print_config()

In [None]:
HOME = os.getcwd()

resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task06_Lung.tar"
md5 = "8afd997733c7fc0432f71255ba4e52dc"

compressed_file = os.path.join(HOME, "Task06_Lung.tar")
data_dir = os.path.join(HOME, "Task06_Lung")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, HOME, md5)

In [None]:
train_images = sorted(glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:50], data_dicts[-10:]

In [None]:
len(train_files) , len(val_files)

In [None]:
# from tqdm.auto import tqdm

# for i in tqdm(train_labels):
#     label = nib.load(i).get_fdata()
#     # print(i,len(np.unique(label)))
#     if len(np.unique(label)) > 2 :
#         print(f'default file {i}')
#     elif len(np.unique(label))< 2:
#         print(f'no segment {i}')




In [None]:
# for i in range(len(train_images)):

#     image  = train_images[i]
#     imag  = nib.load(image).get_fdata()
#     print(imag.shape)



In [None]:
# for i in range(len(train_labels)):

#     image  = train_labels[i]
#     imag  = nib.load(image).get_fdata()
#     print(imag.shape)

In [None]:
import random
r = random.randint(0, len(train_images))


r_img = nib.load(train_images[r]).get_fdata()
print(r_img.shape , np.max(r_img) , np.min(r_img))
r_label = nib.load(train_labels[r]).get_fdata()
print(r_label.shape , r_label.min() , r_label.max())


# for i in range(r_label.shape[2]):
#     if r_label[i].max() > 0:
#           k = random.randint(i_list)
#         print(k)


plt.figure(figsize = (8,5))
plt.subplot(121)
plt.imshow(r_img[:,:,20], cmap = 'gray')
plt.colorbar()
plt.subplot(122)
plt.imshow(r_label[:,:,20])
plt.show()

In [None]:
#setting piplines for train and validation

train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-1024.0,
            a_max=3071.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        #DivisiblePadd(keys=["image", "label"], k = 16),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(64, 64, 64),
                pos=1,
                neg=1,
                num_samples=6,
                image_key="image",
                image_threshold=0,
            ),
#         DivisiblePadd(keys = ['image' , 'label'] , k = 32),
#             Resized(
#                 keys=["image", "label"],
#                 spatial_size=(128, 128, 128)
#             ),
            RandFlipd(
                keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        )
          ])


val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-1024.0,
            a_max=3071.0,
            b_min=0.0,
            b_max=1.0,
            clip=True),

        CropForegroundd(keys=["image", "label"], source_key="image"),

        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
       ])
# Resized(keys=["image", "label"], spatial_size = (128,128,128))
#  DivisiblePadd(keys = ['image' , 'label'] , k = 32)

In [None]:
check_ds = Dataset(data=train_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# for i in range(label.shape[2]):
#     if label[i].max() == 1:
#         k = random.randint(0,i)
# print(k)
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :,60 ], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 60])
plt.show()

In [None]:
check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# for i in range(label.shape[2]):
#     if label[i].max() == 1:
#         k = random.randint(0,i)
# print(k)
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 70], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 70])
plt.show()

In [None]:
# Dataloaders - Train , val

#train
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=1 )
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=1 )

#val
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=1)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2 )

In [None]:
data = first(train_loader)
data['image'].shape , data['label'].shape

In [None]:

#Model
import torch

device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceCELoss(to_onehot_y=True, sigmoid=True )
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=True, reduction="mean")


In [None]:
# def dice_metric(y_pred, y):
#     loss = DiceLoss(to_onehot_y=True, sigmoid=True , squared_pred = True)
#     value = 1-loss
#     return value


In [None]:
# #loss / opt

# loss_function = DiceLoss(to_onehot_y=True, sigmoid=True)
# optimizer = torch.optim.Adam(model.parameters(), 1e-4)
# dice_metric = DiceMetric(include_background=True, reduction="mean")

In [None]:
torch.manual_seed(42)
max_epochs = 500
val_interval = 5
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([AsDiscrete(argmax=True , to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
#         print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
#     np.save(os.path.join(HOME , 'epoch_loss.npy') , epoch_loss_values)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.inference_mode():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (64, 64, 64)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model )
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
#                 print(val_labels.shape , val_outputs.shape)

                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
#                 np.save(os.path.join(HOME , 'metric_values.npy' , metric_values))
                torch.save(model.state_dict(), os.path.join(HOME, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()

In [None]:
# train_images = sorted(glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
# train_labels = sorted(glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
# data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
# train_files, val_files = data_dicts[20:40], data_dicts[41:46]

In [None]:
len(train_files) , len(val_files)

In [None]:
# # Dataloaders - Train , val

# #train
# train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4 )
# train_loader = DataLoader(train_ds, batch_size=2, shuffle=False, num_workers=4 )

# #val
# val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
# val_loader = DataLoader(val_ds, batch_size=1, num_workers=4 )

In [None]:
# import torch

# device = torch.device("cuda")
# model_2 = UNet(
#     spatial_dims=3,
#     in_channels=1,
#     out_channels=2,
#     channels=(16, 32, 64, 128, 256),
#     strides=(2, 2, 2, 2),
#     num_res_units=2,
#     norm=Norm.BATCH,
# ).to(device)

# #loss / opt

# loss_function = DiceLoss(to_onehot_y=True, sigmoid=True)
# optimizer = torch.optim.Adam(model.parameters(), 1e-4)
# dice_metric = DiceMetric(include_background=True, reduction="mean")

# model.load_state_dict(torch.load('/content/best_metric_model.pth'))

In [None]:
# max_epochs = 120
# val_interval = 2
# best_metric = -1
# best_metric_epoch = -1
# epoch_loss_values = []
# metric_values = []
# post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
# post_label = Compose([AsDiscrete(to_onehot=2)])

# for epoch in range(max_epochs):
#     print("-" * 10)
#     print(f"epoch {epoch + 1}/{max_epochs}")
#     model_2.train()
#     epoch_loss = 0
#     step = 0
#     for batch_data in train_loader:
#         step += 1
#         inputs, labels = (
#             batch_data["image"].to(device),
#             batch_data["label"].to(device),
#         )
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = loss_function(outputs, labels)
#         loss.backward()
#         optimizer.step()
#         epoch_loss += loss.item()
# #         print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
#     epoch_loss /= step
#     epoch_loss_values.append(epoch_loss)
#     print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

#     if (epoch + 1) % val_interval == 0:
#         model_2.eval()
#         with torch.inference_mode():
#             for val_data in val_loader:
#                 val_inputs, val_labels = (
#                     val_data["image"].to(device),
#                     val_data["label"].to(device),
#                 )
#                 roi_size = (64, 64, 64)
#                 sw_batch_size = 1
#                 val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
#                 val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
#                 val_labels = [post_label(i) for i in decollate_batch(val_labels)]
# #                 print(val_labels.shape , val_outputs.shape)

#                 # compute metric for current iteration
#                 dice_metric(y_pred=val_outputs, y=val_labels)

#             # aggregate the final mean dice result
#             metric = dice_metric.aggregate().item()
#             # reset the status for next validation round
#             dice_metric.reset()

#             metric_values.append(metric)
#             if metric > best_metric:
#                 best_metric = metric
#                 best_metric_epoch = epoch + 1
#                 torch.save(model.state_dict(), os.path.join(HOME, "best_metric_model.pth"))
#                 print("saved new best metric model")
#             print(
#                 f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
#                 f"\nbest mean dice: {best_metric:.4f} "
#                 f"at epoch: {best_metric_epoch}"
#             )

In [None]:
import torchvision
from torchvision.models.segmentation import fcn_resnet50 , FCN_ResNet50_Weights

In [None]:
weights = FCN_ResNet50_Weights.DEFAULT
model2 = fcn_resnet50(weights = weights)

In [None]:
max_epochs = 200
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=3)])
post_label = Compose([AsDiscrete(to_onehot=3)])

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model2.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
#         print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model2.eval()
        with torch.inference_mode():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (64, 64, 64)
                sw_batch_size = 1
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
#                 print(val_labels.shape , val_outputs.shape)

                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model2.state_dict(), os.path.join(HOME, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )

In [None]:
print(f"train completed, best_metric: {best_metric:.4f} " f"at epoch: {best_metric_epoch}")

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()

In [None]:
model.load_state_dict(torch.load(os.path.join(HOME, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    for i, val_data in enumerate(val_loader):
        roi_size = (64, 64, 64)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_data["image"].to(device), roi_size, sw_batch_size, model)
#         nib.save(val_data['image'].detach().cpu() , os.path.join(HOME ,f'image_ {i}_.nii.gz'))
#         nib.save(val_data['label'].detach().cpu() ,  os.path.join(HOME, f'label_{i}_.nii.gz'))
        val_outputs = torch.softmax(val_outputs, 1).cpu().numpy()
        val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0]
        nib.save(nib.Nifti1Image(val_outputs, affine), os.path.join(HOME, f'label_{i}_.nii.gz')

        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(val_data["image"][0, 0, :, :,70], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(val_data["label"][0, 0, :, :, 70])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, 70])
        plt.show()

#         nib.save(val_data['image'] , filename = os.path.join('/content/drive/MyDrive/preds',f'image_ {i}_.nii.gz'))
#         nib.save(val_data['label'] , filename = os.path.join('/content/drive/MyDrive/preds', f'label_{i}_.nii.gz'))
#         nib.save(val_data['output'] , filename = os.path.join('/content/drive/MyDrive/preds',f'pred_{i}_.nii.gz'))
        if i == 10:
            break

In [None]:
(val_outputs.shape)