<a href="https://colab.research.google.com/github/arimax32/Lung-CT-GGO-Semantic-segmentation/blob/main/Segment_GGO_Patches.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
os.environ['KAGGLE_CONFIG_DIR'] = '/content'

In [None]:
!chmod 600 /content/kaggle.json

In [None]:
!kaggle kernels output aritra20datta/patch-ggo -p /content

Output file downloaded to /content/_output_.zip
Kernel log downloaded to /content/patch-ggo.log 


In [None]:
!unzip \*.zip && rm *.zip

In [None]:
import pickle

with open('ggo_mask.pkl', 'rb') as fp:
    img_mask = pickle.load(fp)
    print(img_mask)

In [None]:
!pip install torchmetrics

In [None]:
!pip install segmentation_models_pytorch

In [None]:
import torch
from torch import nn
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchmetrics.classification import Dice, BinaryJaccardIndex, BinaryAccuracy, BinaryF1Score, BinaryPrecision, BinaryRecall
from google.colab.patches import cv2_imshow
from tqdm import tqdm
import segmentation_models_pytorch as smp
import re
import time

In [None]:
class SegmentationDataset(Dataset):

    def __init__(self, mask_paths, img_mask_map, transform=None) -> None:
        super().__init__()
        self.transform = transform
        self.img_mask_map = img_mask_map
        self.masks = []

        for path in mask_paths:
            img_list = os.listdir(path)
            imgs = [(p,path) for p in img_list]
            self.masks.extend(imgs)

    def __len__(self):
        return len(self.masks)

    def __getitem__(self, index):
        mask_name = self.masks[index][0]
        mask_path = self.masks[index][1]

        base_img, patch = self.extract_base_filename(mask_name)
        ggo_name = f"{self.img_mask_map[base_img]}{patch}.jpg"
        ggo_path = self.extract_base_path(mask_path)

        image = cv2.imread( os.path.join(ggo_path,ggo_name),0).astype(np.float32)
        mask =  cv2.imread(os.path.join(mask_path,mask_name), 0).astype(np.float32)
        image = image/255.0
        mask[mask == 255] = 1.0

        if self.transform:
          augmentations = self.transform(image=image,mask=mask)
          image = augmentations["image"]
          mask = augmentations["mask"]

        return image, mask

    def extract_base_filename(self, full_filename):
      # Define the pattern to match filenames like 'LIDC-IDRI-0001_86_mask_patch_i_j.png'
      pattern = re.compile(r'^(.+)_((?:mask|img)_patch_\d+_\d+)\.(.+)$')

      # Use the pattern to extract the base filename and _patch_i_j
      match = pattern.match(full_filename)
      if match:
          base_filename = f"{match.group(1)}_img.{match.group(3)}"
          patch_info = match.group(2).replace('_img', '').replace('mask', '')
          return base_filename, patch_info
      else:
          return None, None  # No match

    def extract_base_path(self, path):
      directory, filename = os.path.split(path)
      new_filename = filename.replace('mask', 'ggo')
      new_path = os.path.join(directory, new_filename)
      return new_path


In [None]:
train_transform = A.Compose([
    A.Rotate(limit=50,p=0.5,border_mode=cv2.BORDER_CONSTANT),
    A.VerticalFlip(p=0.10),
    ToTensorV2(),
])

test_transform = A.Compose([
    ToTensorV2(),
])

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LR = 0.001
EPOCHS = 20
BATCH_SIZE = 16
PIN_MEMORY = True
NUM_WORKERS = 2

In [None]:
class MixedBCELoss(nn.Module):
    def __init__(self, weight=None,size_average=True):
        super(MixedBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        loss = nn.BCEWithLogitsLoss()
        BCE = loss(inputs, targets)

        inputs = torch.sigmoid(inputs)
        intersection = (inputs * targets).sum()
        union = inputs.sum() + targets.sum()
        dice_loss = 1 - (2.*intersection + smooth)/(union + smooth)
        jaccard_loss = 1 - ((intersection + smooth)/(union - intersection + smooth))

        Mixed_BCE = BCE + jaccard_loss
        return Mixed_BCE

In [None]:
train_paths = ['/content/mask-patches-7','/content/mask-patches-6']
test_paths = ['/content/mask-patches-5','/content/mask-patches-4']

In [None]:
# train and test datasets
train_dataset = SegmentationDataset(mask_paths = train_paths,
                                    img_mask_map = img_mask, transform = train_transform)
test_dataset = SegmentationDataset(mask_paths = test_paths,
                                    img_mask_map = img_mask, transform = test_transform)

# training and test data loaders
trainLoader = DataLoader(train_dataset, shuffle=False, batch_size=BATCH_SIZE,
                         pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS)
testLoader = DataLoader(test_dataset, shuffle=False, batch_size=BATCH_SIZE,
                        pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS)

In [None]:
model = smp.DeepLabV3('resnet101', encoder_weights='imagenet',in_channels=1, classes=1).to(DEVICE)

Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /root/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
100%|██████████| 170M/170M [00:01<00:00, 147MB/s]


In [None]:
loss = MixedBCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler()

In [None]:
def train_model():
  running_train_loss = 0
  for (images, masks) in trainLoader:

      images = images.to(DEVICE)
      masks = masks.to(DEVICE).unsqueeze(1)
      masks = (masks > 0.5).type('torch.ByteTensor').to(DEVICE)

      with torch.cuda.amp.autocast():
        preds = model(images)
        train_loss = loss(preds, masks.float())

      optimizer.zero_grad()
      scaler.scale(train_loss).backward()
      scaler.step(optimizer)
      scaler.update()

      running_train_loss += train_loss.item()


  return running_train_loss / len(trainLoader)

def eval_model():
  running_test_loss = 0
  intersection_over_unions, dice_scores, accuracy, f1_Score, precision, recall = [], [], [], [], [], []
  iou = BinaryJaccardIndex().to(DEVICE)
  dice = Dice().to(DEVICE)
  for (images, masks) in testLoader:

    images = images.to(DEVICE)
    masks = masks.to(DEVICE).unsqueeze(1)
    masks = (masks > 0.5).type('torch.ByteTensor').to(DEVICE)

    preds = model(images)
    if preds.isnan().any():
      print(f"Nan Batch")
      continue

    test_loss = loss(preds, masks.float())

    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).type('torch.ByteTensor').to(DEVICE)

    dice_scores.append(dice(preds,masks).item())
    if not torch.any(preds) and not torch.any(masks):
      intersection_over_unions.append(1.0)
    else:
      intersection_over_unions.append(iou(preds,masks).item())

    running_test_loss += test_loss.item()

  return running_test_loss / len(testLoader), intersection_over_unions, dice_scores

def run_model():
  train_losses, test_losses = [], []
  for e in tqdm(range(EPOCHS)):
    model.train()
    start_time = time.time()
    training_loss = train_model()
    test_loss = 0
    train_losses.append(training_loss)

    with torch.no_grad():
      model.eval()
      test_loss, iou, dice = eval_model()
      test_losses.append(test_loss)

    end_time = time.time()
    epoch_time = end_time - start_time

    print(f"\nEpoch : [{e}] | Train Loss: {training_loss} | Test Loss: {test_loss}")
    print(f"Mean Dice: {np.mean(dice)} | Mean Iou: {np.mean(iou)}")
    print(f"Epoch finished in {epoch_time:.2f} seconds")

In [None]:
torch.cuda.empty_cache()

In [None]:
run_model()

  5%|▌         | 1/20 [07:43<2:26:49, 463.68s/it]


Epoch : [0] | Train Loss: 0.682338795549571 | Test Loss: 0.6618081510749879
Mean Dice: 0.9995633465856457 | Mean Iou: 0.30808773132282385
Epoch finished in 463.67 seconds


 10%|█         | 2/20 [14:58<2:14:04, 446.93s/it]


Epoch : [1] | Train Loss: 0.6338176253258825 | Test Loss: 0.6616888050197755
Mean Dice: 0.9995633465856457 | Mean Iou: 0.30808773132282385
Epoch finished in 435.20 seconds


 15%|█▌        | 3/20 [22:15<2:05:20, 442.38s/it]


Epoch : [2] | Train Loss: 0.6313850963726773 | Test Loss: 0.661690427932165
Mean Dice: 0.9995633465856457 | Mean Iou: 0.30808773132282385
Epoch finished in 436.97 seconds


 20%|██        | 4/20 [29:47<1:58:53, 445.86s/it]


Epoch : [3] | Train Loss: 0.632326867267834 | Test Loss: 0.661725770640645
Mean Dice: 0.9995633465856457 | Mean Iou: 0.30808773132282385
Epoch finished in 451.18 seconds


 25%|██▌       | 5/20 [37:06<1:50:54, 443.62s/it]


Epoch : [4] | Train Loss: 0.6301262211336947 | Test Loss: 0.661681103542796
Mean Dice: 0.9995633465856457 | Mean Iou: 0.30808773132282385
Epoch finished in 439.63 seconds


 30%|███       | 6/20 [44:27<1:43:16, 442.62s/it]


Epoch : [5] | Train Loss: 0.6325392208879673 | Test Loss: 0.6616713943883408
Mean Dice: 0.9995633465856457 | Mean Iou: 0.30808773132282385
Epoch finished in 440.66 seconds


 35%|███▌      | 7/20 [51:57<1:36:25, 445.05s/it]


Epoch : [6] | Train Loss: 0.6318374290468316 | Test Loss: 0.6616054849342553
Mean Dice: 0.9995633465856457 | Mean Iou: 0.30808773132282385
Epoch finished in 450.07 seconds


 40%|████      | 8/20 [59:16<1:28:38, 443.24s/it]


Epoch : [7] | Train Loss: 0.63526669047669 | Test Loss: 0.6617007396862372
Mean Dice: 0.9995633465856457 | Mean Iou: 0.30808773132282385
Epoch finished in 439.34 seconds


 45%|████▌     | 9/20 [1:06:37<1:21:06, 442.44s/it]


Epoch : [8] | Train Loss: 0.6314339806171058 | Test Loss: 0.6616142120633135
Mean Dice: 0.9995633465856457 | Mean Iou: 0.30808773132282385
Epoch finished in 440.69 seconds


 50%|█████     | 10/20 [1:14:08<1:14:12, 445.21s/it]


Epoch : [9] | Train Loss: 0.6302730545322086 | Test Loss: 0.6615946936787577
Mean Dice: 0.9995633465856457 | Mean Iou: 0.30808773132282385
Epoch finished in 451.39 seconds


In [None]:
torch.save(model.state_dict(), f'MODEL_epochs_{EPOCHS}__state_dict.pth')