In [3]:
import os
os.environ['KAGGLE_CONFIG_DIR'] = '/content'

In [4]:
!chmod 600 /content/kaggle.json

In [5]:
!kaggle kernels output aritra20datta/patch-ggo -p /content

Output file downloaded to /content/_output_.zip
Kernel log downloaded to /content/patch-ggo.log 


In [None]:
!unzip \*.zip && rm *.zip

In [7]:
import pickle

with open('ggo_mask.pkl', 'rb') as fp:
    img_mask = pickle.load(fp)
    print(img_mask)

{'LIDC-IDRI-0386_386_img.png': 0, 'LIDC-IDRI-0770_102_img.png': 1, 'LIDC-IDRI-0153_59_img.png': 2, 'LIDC-IDRI-0039_175_img.png': 3, 'LIDC-IDRI-0508_86_img.png': 4, 'LIDC-IDRI-0939_389_img.png': 5, 'LIDC-IDRI-0884_189_img.png': 6, 'LIDC-IDRI-0496_79_img.png': 7, 'LIDC-IDRI-0510_135_img.png': 8, 'LIDC-IDRI-1003_113_img.png': 9, 'LIDC-IDRI-0019_265_img.png': 10, 'LIDC-IDRI-0192_165_img.png': 11, 'LIDC-IDRI-0908_406_img.png': 12, 'LIDC-IDRI-1003_95_img.png': 13, 'LIDC-IDRI-0138_186_img.png': 14, 'LIDC-IDRI-0641_71_img.png': 15, 'LIDC-IDRI-0663_239_img.png': 16, 'LIDC-IDRI-0491_166_img.png': 17, 'LIDC-IDRI-0340_53_img.png': 18, 'LIDC-IDRI-0138_197_img.png': 19, 'LIDC-IDRI-0842_90_img.png': 20, 'LIDC-IDRI-0215_147_img.png': 21, 'LIDC-IDRI-0377_173_img.png': 22, 'LIDC-IDRI-0415_160_img.png': 23, 'LIDC-IDRI-0078_28_img.png': 24, 'LIDC-IDRI-0969_177_img.png': 25, 'LIDC-IDRI-0466_73_img.png': 26, 'LIDC-IDRI-0972_73_img.png': 27, 'LIDC-IDRI-0500_77_img.png': 28, 'LIDC-IDRI-0433_190_img.png': 29, 

In [None]:
!pip install torchmetrics

In [None]:
!pip install segmentation_models_pytorch

In [10]:
import torch
from torch import nn
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchmetrics.classification import Dice, BinaryJaccardIndex, BinaryAccuracy, BinaryF1Score, BinaryPrecision, BinaryRecall
from google.colab.patches import cv2_imshow
from tqdm import tqdm
import segmentation_models_pytorch as smp
import re
import time

In [11]:
class SegmentationDataset(Dataset):

    def __init__(self, mask_paths, img_mask_map, transform=None) -> None:
        super().__init__()
        self.transform = transform
        self.img_mask_map = img_mask_map
        self.masks = []

        for path in mask_paths:
            img_list = os.listdir(path)
            imgs = [(p,path) for p in img_list]
            self.masks.extend(imgs)

    def __len__(self):
        return len(self.masks)

    def __getitem__(self, index):
        mask_name = self.masks[index][0]
        mask_path = self.masks[index][1]

        base_img, patch = self.extract_base_filename(mask_name)
        ggo_name = f"{self.img_mask_map[base_img]}{patch}.jpg"
        ggo_path = self.extract_base_path(mask_path)

        image = cv2.imread( os.path.join(ggo_path,ggo_name),0).astype(np.float32)
        mask =  cv2.imread(os.path.join(mask_path,mask_name), 0).astype(np.float32)
        image = image/255.0
        mask[mask == 255] = 1.0

        if self.transform:
          augmentations = self.transform(image=image,mask=mask)
          image = augmentations["image"]
          mask = augmentations["mask"]

        return image, mask

    def extract_base_filename(self, full_filename):
      # Define the pattern to match filenames like 'LIDC-IDRI-0001_86_mask_patch_i_j.png'
      pattern = re.compile(r'^(.+)_((?:mask|img)_patch_\d+_\d+)\.(.+)$')

      # Use the pattern to extract the base filename and _patch_i_j
      match = pattern.match(full_filename)
      if match:
          base_filename = f"{match.group(1)}_img.{match.group(3)}"
          patch_info = match.group(2).replace('_img', '').replace('mask', '')
          return base_filename, patch_info
      else:
          return None, None  # No match

    def extract_base_path(self, path):
      directory, filename = os.path.split(path)
      new_filename = filename.replace('mask', 'ggo')
      new_path = os.path.join(directory, new_filename)
      return new_path


In [12]:
train_transform = A.Compose([
    A.Rotate(limit=50,p=0.5,border_mode=cv2.BORDER_CONSTANT),
    A.VerticalFlip(p=0.10),
    ToTensorV2(),
])

test_transform = A.Compose([
    ToTensorV2(),
])

In [22]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LR = 0.001
EPOCHS = 10
BATCH_SIZE = 16
PIN_MEMORY = True
NUM_WORKERS = 2

In [14]:
class MixedBCELoss(nn.Module):
    def __init__(self, weight=None,size_average=True):
        super(MixedBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        loss = nn.BCEWithLogitsLoss()
        BCE = loss(inputs, targets)

        inputs = torch.sigmoid(inputs)
        intersection = (inputs * targets).sum()
        union = inputs.sum() + targets.sum()
        dice_loss = 1 - (2.*intersection + smooth)/(union + smooth)
        jaccard_loss = 1 - ((intersection + smooth)/(union - intersection + smooth))

        Mixed_BCE = BCE + jaccard_loss
        return Mixed_BCE

In [15]:
train_paths = ['/content/mask-patches-1','/content/mask-patches-2','/content/mask-patches-3']
test_paths = ['/content/mask-patches-8','/content/mask-patches-9']

In [16]:
# train and test datasets
train_dataset = SegmentationDataset(mask_paths = train_paths,
                                    img_mask_map = img_mask, transform = train_transform)
test_dataset = SegmentationDataset(mask_paths = test_paths,
                                    img_mask_map = img_mask, transform = test_transform)

# training and test data loaders
trainLoader = DataLoader(train_dataset, shuffle=False, batch_size=BATCH_SIZE,
                         pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS)
testLoader = DataLoader(test_dataset, shuffle=False, batch_size=BATCH_SIZE,
                        pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS)

In [17]:
model = smp.DeepLabV3('resnet101', encoder_weights=None,in_channels=1, classes=1)
model.load_state_dict(torch.load('MODEL_epochs_20__state_dict.pth'))
model = model.to(DEVICE)

In [18]:
loss = MixedBCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler()

In [19]:
def train_model():
  running_train_loss = 0
  for (images, masks) in trainLoader:

      images = images.to(DEVICE)
      masks = masks.to(DEVICE).unsqueeze(1)
      masks = (masks > 0.5).type('torch.ByteTensor').to(DEVICE)

      with torch.cuda.amp.autocast():
        preds = model(images)
        train_loss = loss(preds, masks.float())

      optimizer.zero_grad()
      scaler.scale(train_loss).backward()
      scaler.step(optimizer)
      scaler.update()

      running_train_loss += train_loss.item()


  return running_train_loss / len(trainLoader)

def eval_model():
  running_test_loss = 0
  intersection_over_unions, dice_scores, accuracy, f1_Score, precision, recall = [], [], [], [], [], []
  iou = BinaryJaccardIndex().to(DEVICE)
  dice = Dice().to(DEVICE)
  for (images, masks) in testLoader:

    images = images.to(DEVICE)
    masks = masks.to(DEVICE).unsqueeze(1)
    masks = (masks > 0.5).type('torch.ByteTensor').to(DEVICE)

    preds = model(images)
    if preds.isnan().any():
      print(f"Nan Batch")
      continue

    test_loss = loss(preds, masks.float())

    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).type('torch.ByteTensor').to(DEVICE)

    dice_scores.append(dice(preds,masks).item())
    if not torch.any(preds) and not torch.any(masks):
      intersection_over_unions.append(1.0)
    else:
      intersection_over_unions.append(iou(preds,masks).item())

    running_test_loss += test_loss.item()

  return running_test_loss / len(testLoader), intersection_over_unions, dice_scores

def run_model():
  train_losses, test_losses = [], []
  for e in tqdm(range(EPOCHS)):
    model.train()
    start_time = time.time()
    training_loss = train_model()
    test_loss = 0
    train_losses.append(training_loss)

    with torch.no_grad():
      model.eval()
      test_loss, iou, dice = eval_model()
      test_losses.append(test_loss)

    end_time = time.time()
    epoch_time = end_time - start_time

    print(f"\nEpoch : [{e}] | Train Loss: {training_loss} | Test Loss: {test_loss}")
    print(f"Mean Dice: {np.mean(dice)} | Mean Iou: {np.mean(iou)}")
    print(f"Epoch finished in {epoch_time:.2f} seconds")

In [20]:
torch.cuda.empty_cache()

In [23]:
run_model()

 10%|█         | 1/10 [09:40<1:27:06, 580.69s/it]


Epoch : [0] | Train Loss: 0.6393178832134904 | Test Loss: 0.6639186459623145
Mean Dice: 0.9995806257106519 | Mean Iou: 0.30816535702084047
Epoch finished in 580.69 seconds


 20%|██        | 2/10 [19:33<1:18:22, 587.75s/it]


Epoch : [1] | Train Loss: 0.6342800960194086 | Test Loss: 0.6609387835361246
Mean Dice: 0.9995806257106519 | Mean Iou: 0.30816535702084047
Epoch finished in 592.69 seconds


 30%|███       | 3/10 [29:16<1:08:21, 585.86s/it]


Epoch : [2] | Train Loss: 0.6348896148180478 | Test Loss: 0.6737388292799866
Mean Dice: 0.9995806257106519 | Mean Iou: 0.30816535702084047
Epoch finished in 583.60 seconds


 40%|████      | 4/10 [39:09<58:51, 588.53s/it]  


Epoch : [3] | Train Loss: 0.6364140277746425 | Test Loss: 0.6942668088303251
Mean Dice: 0.9995806257106519 | Mean Iou: 0.30816535702084047
Epoch finished in 592.61 seconds


 50%|█████     | 5/10 [48:45<48:40, 584.04s/it]


Epoch : [4] | Train Loss: 0.6358907262731442 | Test Loss: 0.660001094374008
Mean Dice: 0.9995806257106519 | Mean Iou: 0.30816535702084047
Epoch finished in 576.06 seconds


 60%|██████    | 6/10 [58:23<38:48, 582.06s/it]


Epoch : [5] | Train Loss: 0.632556507738412 | Test Loss: 0.6602957538342746
Mean Dice: 0.9995806257106519 | Mean Iou: 0.30816535702084047
Epoch finished in 578.20 seconds


 70%|███████   | 7/10 [1:07:57<28:58, 579.38s/it]


Epoch : [6] | Train Loss: 0.6371364741339729 | Test Loss: 0.6597894235545286
Mean Dice: 0.9995806257106519 | Mean Iou: 0.30816535702084047
Epoch finished in 573.87 seconds


 80%|████████  | 8/10 [1:17:53<19:28, 584.44s/it]


Epoch : [7] | Train Loss: 0.6351169007696122 | Test Loss: 0.6863750044022903
Mean Dice: 0.9995806257106519 | Mean Iou: 0.30816535702084047
Epoch finished in 595.26 seconds


 90%|█████████ | 9/10 [1:27:25<09:40, 580.64s/it]


Epoch : [8] | Train Loss: 0.6370537846830334 | Test Loss: 0.6607532270692513
Mean Dice: 0.9995806257106519 | Mean Iou: 0.30816535702084047
Epoch finished in 572.29 seconds


100%|██████████| 10/10 [1:36:59<00:00, 581.98s/it]


Epoch : [9] | Train Loss: 0.6342383473331 | Test Loss: 0.6643908500678466
Mean Dice: 0.9995806257106519 | Mean Iou: 0.30816535702084047
Epoch finished in 574.41 seconds





In [24]:
torch.save(model.state_dict(), f'MODEL_epochs_{EPOCHS}__state_dict.pth')