In [None]:
!pip install datasets

In [None]:
!pip install monai

In [5]:
import random
import os
import glob
import time
import warnings
import io

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
# import opendatasets as od
import datasets as dts

from PIL import Image
from IPython.display import clear_output
from tqdm.notebook import tqdm
from typing import Dict, List, Tuple
from statistics import mean


import torch
import monai
import cv2
import torchvision
import torch.optim.lr_scheduler as lr_scheduler
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision import transforms
from monai.transforms import Compose, NormalizeIntensityd
from monai.metrics import compute_iou
from sklearn.model_selection import train_test_split

from transformers import SamProcessor
from transformers import SamModel

# CONFIGURATION

In [6]:
plt.style.use("dark_background")
warnings.filterwarnings("ignore", "is_categorical_dtype")
warnings.filterwarnings("ignore", "use_inf_as_na")

In [7]:
class CFG:
    # define paths
    DATASET_PATH = "/kaggle/input/lgg-mri-segmentation/"
    TRAIN_PATH = "/kaggle/input/lgg-mri-segmentation/kaggle_3m/"
    
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    TRAIN_BATCH_SIZE = 2
    TEST_BATCH_SIZE = 1
    LEARNING_RATE = 1e-3
    WEIGHT_DECAY = 0
    EPOCH = 10

# DATA PREPARATION

In [8]:
dataset_images = glob.glob(f"{CFG.TRAIN_PATH}**/*.tif")
dataset_images[:5]

[]

In [9]:
# TRIAL CELL

dataset_images = glob.glob("t0001_Channel 3.tif")
dataset_images[:5]

['t0001_Channel 3.tif']

In [10]:
def get_sample_patient_id(image_paths):
    return [(_.split('/')[-2:][0]) for _ in image_paths]

def get_sample_number(image_paths):
    sample_numbers = []
    is_mask = []

    for path in image_paths:
        path_list = path.split('/')[-2:][1].split('_')

        if 'mask.tif' in path_list:
            sample_numbers.append(int(path_list[-2]))
            is_mask.append(1)
        else:
            sample_numbers.append(int(path_list[-1].replace('.tif', '')))
            is_mask.append(0)

    return sample_numbers, is_mask

def build_df(image_paths):
    sample_numbers, mask_label = get_sample_number(image_paths)
    # create dataframe
    df = pd.DataFrame({
        'id'        : sample_numbers,
        'patient'   : get_sample_patient_id(image_paths),
        'image_path': image_paths,
        'is_mask'   : mask_label
    })

    # return df
    return df

In [11]:
# 0 = image
# 1 = mask

dataset_df = (
    build_df(dataset_images)
    .sort_values(by=['id', 'patient', 'image_path'])
    .reset_index(drop=True)
)

dataset_df

IndexError: list index out of range

In [None]:
# images_df: for images
# mask_df: for masks

grouped_df = dataset_df.groupby(by='is_mask')
images_df, mask_df = (
    grouped_df.get_group(0).drop('is_mask', axis=1).reset_index(drop=True),
    grouped_df.get_group(1).drop('is_mask', axis=1).reset_index(drop=True)
)

mask_df = mask_df.rename({'image_path': 'mask_path'}, axis=1)

mask_df.head()

In [None]:
def _load(image_path, as_tensor=True):
    image = Image.open(image_path)
    return np.array(image).astype(np.float32) / 255.

def generate_label(mask_path, load_fn):
    mask = load_fn(mask_path)
    if mask.max() > 0:
        return 1 # Brain Tumor Present
    return 0 # Normal

In [None]:
# merge images dataframe and masks dataframe
ds = images_df.merge(
    mask_df,
    on=['id', 'patient'],
    how='left'
)

# generate MRI Label
ds['diagnosis'] = [generate_label(_, _load) for _ in tqdm(ds['mask_path'])]
ds.head()

In [None]:
# filter valid masks and choose only 1360
ds = ds[ds['diagnosis']==1]
ds = ds.head(1360)

In [None]:
image_train, image_test, mask_train, mask_test = train_test_split(
ds['image_path'], ds['mask_path'], test_size = 0.10)

In [None]:
train_df = pd.concat([image_train, mask_train], axis=1).reset_index(drop=True)
train_dataset = dts.Dataset.from_pandas(train_df)

In [None]:
test_df = pd.concat([image_test, mask_test], axis=1).reset_index(drop=True)
test_dataset = dts.Dataset.from_pandas(test_df)

In [None]:
def transform(data):
    # Load the image
    with open(data['image_path'], 'rb') as f:
        image = Image.open(io.BytesIO(f.read())).convert('RGB')
    data['image'] = image

    with open(data['mask_path'], 'rb') as f:
        mask = Image.open(io.BytesIO(f.read())).convert('L') # to grayscale
    data['mask'] = mask

    return data

In [None]:
train_dataset = train_dataset.map(transform, remove_columns=['image_path','mask_path'])

In [None]:
train_dataset

In [None]:
test_dataset = test_dataset.map(transform, remove_columns=['image_path','mask_path'])

In [None]:
test_dataset

In [None]:
example = train_dataset[0]
img = example['image']
msk = example['mask']

In [None]:
np.array(msk).max(), np.array(msk).min(), np.array(img).max(), np.array(img).min()

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

In [None]:
fig, axes = plt.subplots()

axes.imshow(np.array(img))
ground_truth_seg = np.array(example["mask"])
show_mask(ground_truth_seg, axes)
axes.title.set_text(f"Ground truth mask")
axes.axis("off")

In [None]:
def get_bounding_box(ground_truth_map):
    '''
    This function creates varying bounding box coordinates based on the segmentation contours as prompt for the SAM model
    The padding is random int values between 5 and 20 pixels
    '''

    if len(np.unique(ground_truth_map)) > 1:

        # get bounding box from mask
        y_indices, x_indices = np.where(ground_truth_map > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)

        # add perturbation to bounding box coordinates
        H, W = ground_truth_map.shape
        x_min = max(0, x_min - np.random.randint(5, 20))
        x_max = min(W, x_max + np.random.randint(5, 20))
        y_min = max(0, y_min - np.random.randint(5, 20))
        y_max = min(H, y_max + np.random.randint(5, 20))

        bbox = [x_min, y_min, x_max, y_max]

        return bbox
    else:
        return [0, 0, 256, 256]

In [None]:
class SAMDataset(torch.utils.data.Dataset):
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor
    #self.transforms = transforms = Compose([NormalizeIntensityd(divisor=255, keys=['mask'])])

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["mask"])
    # ground_truth_mask = item["mask"]

    # get bounding box prompt
    prompt = get_bounding_box(ground_truth_mask)

    # prepare image and prompt for the model
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = ground_truth_mask/255

    return inputs

In [None]:
processor = SamProcessor.from_pretrained("facebook/sam-vit-base", do_normalize=False)

In [None]:
train_sam_ds = SAMDataset(dataset=train_dataset, processor=processor)

In [None]:
exmpl = train_sam_ds[10]
for k,v in exmpl.items():
  print(k,v.shape)

In [None]:
train_dataloader = DataLoader(train_sam_ds, batch_size=CFG.TRAIN_BATCH_SIZE, shuffle=False)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

In [None]:
train_df.shape

# Modeling

In [None]:
model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

In [None]:
# Note: Hyperparameter tuning could improve performance here
optimizer = Adam(model.mask_decoder.parameters(), lr=CFG.LEARNING_RATE, weight_decay=CFG.WEIGHT_DECAY)
seg_loss = monai.losses.FocalLoss(reduction='mean')

In [None]:
num_epochs = CFG.EPOCH

device = CFG.DEVICE
model.to(device)

model.train()

epoch_losses = []
epoch_ious = []
# epoch_preds = []
# epoch_actls = []

for epoch in range(num_epochs):
    print(f'EPOCH: {epoch}')
    # menyimpan losses di setiap batch
    batch_losses = []
    # menyimpan iou di setiap batch
    batch_ious = []

    for i, batch in enumerate(tqdm(train_dataloader)):

      # forward pass
      outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].to(device),
                      multimask_output=False)

      # process the output
      predicted_masks = outputs.pred_masks.squeeze(1) # perlu diubah ke sigmoid
      ground_truth_masks = batch["ground_truth_mask"].float().to(device)

      sam_masks_prob = torch.sigmoid(predicted_masks)
      sam_masks_prob = sam_masks_prob.squeeze()
      sam_masks = (sam_masks_prob > 0.5)

      # compute loss
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
      batch_losses.append(loss.item())
      # batch_preds.append(predicted_masks)
      # batch_actls.append(ground_truth_masks.unsqueeze(1))

      # compute iou: return 2 values: 1 for each batch
      ious = compute_iou(sam_masks.unsqueeze(1),
                        ground_truth_masks.unsqueeze(1), ignore_empty=False)
      batch_ious.append(ious.mean())

      '''
      clear_output(wait=True)

      fig, axs = plt.subplots(1, 3)
      xmin, ymin, xmax, ymax = get_bounding_box(batch['ground_truth_mask'][0])
      rect = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none')

      axs[0].set_title('Input Image')
      axs[0].imshow(batch["pixel_values"][0,1], cmap='gray')
      axs[0].axis('off')

      axs[1].set_title('Actual Mask')
      axs[1].imshow(batch['ground_truth_mask'][0], cmap='copper')
      axs[1].add_patch(rect)
      axs[1].axis('off')

      # apply sigmoid
      medsam_seg_prob = torch.sigmoid(outputs.pred_masks[0].squeeze(1))

      # convert soft mask to hard mask
      medsam_seg_prob = medsam_seg_prob.detach().cpu().numpy().squeeze()
      medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

      axs[2].set_title('Predicted Mask')
      # axs[2].imshow(predicted_masks[0].cpu().detach().numpy().reshape(256,256,1)*255, cmap='copper')
      axs[2].imshow(medsam_seg, cmap='copper')
      axs[2].axis('off')

      plt.tight_layout()
      plt.show()
      '''

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()

      if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # epoch_preds.append(batch_preds)
    # epoch_actls.append(batch_actls)

    mean_loss = mean(batch_losses)
    epoch_losses.append(mean_loss)
    print(f'Mean loss: {mean_loss}')

    mean_iou = mean([t.cpu().item() for t in batch_ious])
    epoch_ious.append(mean_iou)
    print(f'Mean IoU: {mean_iou}')

# setiap batch: IoU-nya adalah rata2 dari 2 IoU (ada 2 batch size)
# setiap epoch: rata-rata dari IoU setiap batch

In [None]:
sam_masks.shape

In [None]:
plt.plot(np.arange(1,CFG.EPOCH+1), epoch_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')

In [None]:
plt.plot(np.arange(1,CFG.EPOCH+1), epoch_ious)
plt.xlabel('Epoch')
plt.ylabel('IoU')

# INFERENCE

## Inference with DataLoader

In [None]:
# create test dataloader
test_sam_ds = SAMDataset(test_dataset, processor=processor)
test_dataloader = DataLoader(test_sam_ds, batch_size=CFG.TEST_BATCH_SIZE, shuffle=False)

In [None]:
test_ious = []
model.eval()
# Iteratire through test images
with torch.no_grad():
    for batch in tqdm(test_dataloader):

        # forward pass
        outputs = model(pixel_values=batch["pixel_values"].cuda(),
                      input_boxes=batch["input_boxes"].cuda(),
                      multimask_output=False)

        ground_truth_masks = batch["ground_truth_mask"].float().cuda()

        # apply sigmoid
        sam_mask_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
        sam_mask_prob = sam_mask_prob.cpu().numpy().squeeze()
        sam_mask = (sam_mask_prob > 0.5)

        iou = compute_iou(torch.from_numpy(sam_mask).reshape(1, 256, 256).unsqueeze(1).cpu(),
                          ground_truth_masks.unsqueeze(1).cpu(), ignore_empty=False)

        print(f'IoU: {iou}')
        test_ious.append(iou)

        plt.figure(figsize=(12,4))
        plt.subplot(1,3,1)
        plt.imshow(batch["pixel_values"][0,1], cmap='gray')
        plt.title('MRI Image')
        plt.axis('off')
        
        plt.subplot(1,3,2)
        plt.imshow(batch["ground_truth_mask"][0], cmap='copper')
        plt.title('Actual Mask')
        plt.axis('off')
        
        plt.subplot(1,3,3)
        plt.imshow(sam_mask, cmap='copper')
        plt.title('Predicted Mask')
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()

## Inference with Known Bounding Box

In [None]:
# let's take a random training example
idx = 10

# load image
image = test_dataset[idx]["image"]
plt.imshow(np.array(image))

In [None]:
plt.imshow(np.array(test_dataset[idx]["mask"]))

In [None]:
# get box prompt based on ground truth segmentation map
ground_truth_mask = np.array(test_dataset[idx]["mask"])
prompt = get_bounding_box(ground_truth_mask)

# prepare image + box prompt for the model
inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt").to(device)
for k,v in inputs.items():
  print(k,v.shape)

In [None]:
fig, axes = plt.subplots()

axes.imshow(np.array(image))
show_mask(ground_truth_mask, axes)
axes.title.set_text(f"Ground truth mask")
axes.axis("off")

In [None]:
model.eval()

# forward pass
with torch.no_grad():
  outputs = model(**inputs, multimask_output=False)

In [None]:
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5)

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([255, 255, 255, 255])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

fig, axes = plt.subplots()

axes.imshow(np.array(image))
show_mask(medsam_seg, axes)
axes.title.set_text(f"Predicted mask")
axes.axis("off")

In [None]:
predicted_masks = outputs.pred_masks.squeeze(1) # perlu diubah ke sigmoid
ground_truth_masks = batch["ground_truth_mask"].float().to(device)

sam_masks_prob = torch.sigmoid(predicted_masks)
sam_masks_prob = sam_masks_prob.squeeze()
sam_masks = (sam_masks_prob > 0.5)

In [None]:
pred = torch.from_numpy(medsam_seg).float()

In [None]:
act = torch.from_numpy(ground_truth_mask).float()

In [None]:
iou = compute_iou(act/255, pred)

iou

## Inference with Unknown Bounding Box

### Prompt: Grid of Points

In [None]:

"""
input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
Generally yields to much better results. The points can be obtained by passing a
list of list of list to the processor that will create corresponding torch tensors
of dimension 4. The first dimension is the image batch size, the second dimension
is the point batch size (i.e. how many segmentation masks do we want the model to
predict per input point), the third dimension is the number of points per segmentation
mask (it is possible to pass multiple points for a single mask), and the last dimension
is the x (vertical) and y (horizontal) coordinates of the point. If a different number
of points is passed either for each image, or for each mask, the processor will create
“PAD” points that will correspond to the (0, 0) coordinate, and the computation of the
embedding will be skipped for these points using the labels.

"""
# Define the size of your array
array_size = 256

# Define the size of your grid
grid_size = 10

# Generate the grid points
x = np.linspace(0, array_size-1, grid_size)
y = np.linspace(0, array_size-1, grid_size)

# Generate a grid of coordinates
xv, yv = np.meshgrid(x, y)

# Convert the numpy arrays to lists
xv_list = xv.tolist()
yv_list = yv.tolist()

# Combine the x and y coordinates into a list of list of lists
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]

#We need to reshape our nxn grid to the expected shape of the input_points tensor
# (batch_size, point_batch_size, num_points_per_image, 2),
# where the last dimension of 2 represents the x and y coordinates of each point.
#batch_size: The number of images you're processing at once.
#point_batch_size: The number of point sets you have for each image.
#num_points_per_image: The number of points in each set.
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)

In [None]:
image = test_dataset[-1]["image"]
ground_truth_mask = np.array(test_dataset[-1]["mask"])

# prepare image + box prompt for the model
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
for k,v in inputs.items():
  print(k,v.shape)

In [None]:
model.eval()

# forward pass
with torch.no_grad():
  outputs = model(**inputs, multimask_output=False)

In [None]:
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot the first image on the left
axes[0].imshow(np.array(image), cmap='gray')  # Assuming the first image is grayscale
axes[0].set_title("Image")

# Plot the second image on the right
axes[1].imshow(ground_truth_mask)  # Assuming the second image is grayscale
axes[1].set_title("Actual Mask")

# Plot the second image on the right
axes[2].imshow(medsam_seg, cmap='gray')  # Assuming the second image is grayscale
axes[2].set_title("Prediction")

# Hide axis ticks and labels
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

# Display the images side by side
plt.show()

In [None]:
pred = torch.from_numpy(medsam_seg).float()

In [None]:
act = torch.from_numpy(ground_truth_mask).float()

In [None]:
iou = compute_iou(act/255, pred, ignore_empty=False)

iou

### Prompt: Bounding Box in Image

In [None]:
# Function to draw a bounding box around the object in an image
def draw_bounding_box(image):
    # Load the image
    # image = cv2.imread(image_path)

    # Convert the image to gray scale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Apply a blur to the image to reduce noise
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)

    # Threshold the image to get the object in binary
    _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # Find contours in the thresholded image
    contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    # Find the largest contour, assuming the object to bound is the largest one
    largest_contour = max(contours, key=cv2.contourArea)

    # Draw a bounding box around the largest contour
    x, y, w, h = cv2.boundingRect(largest_contour)
    cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 2)

    plt.imshow(image)

    x_min = x
    x_max = x + w
    y_min = y
    y_max = y + h

    return [x_min, y_min, x_max, y_max]


In [None]:
pr = draw_bounding_box(np.array(test_dataset[10]["image"]))

In [None]:
image = test_dataset[10]["image"]
ground_truth_mask = np.array(test_dataset[10]["mask"])

# prepare image + box prompt for the model
inputs = processor(image, input_boxes=[[pr]], return_tensors="pt").to(device)
for k,v in inputs.items():
  print(k,v.shape)

In [None]:
model.eval()

# forward pass
with torch.no_grad():
  outputs = model(**inputs, multimask_output=False)

In [None]:
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot the first image on the left
axes[0].imshow(np.array(image), cmap='gray')  # Assuming the first image is grayscale
axes[0].set_title("Image")

# Plot the second image on the right
axes[1].imshow(ground_truth_mask)  # Assuming the second image is grayscale
axes[1].set_title("Actual Mask")

# Plot the second image on the right
axes[2].imshow(medsam_seg, cmap='gray')  # Assuming the second image is grayscale
axes[2].set_title("Prediction")

# Hide axis ticks and labels
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

# Display the images side by side
plt.show()

In [None]:
pred = torch.from_numpy(medsam_seg).float()
act = torch.from_numpy(ground_truth_mask).float()
iou = compute_iou(act/255, pred, ignore_empty=False)

iou