## Imports

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models 
from torch.utils.data import DataLoader
import os
from torch.optim import Adam
import matplotlib.pyplot as plt
from patchify import patchify
import numpy as np
import random

ModuleNotFoundError: No module named 'patchify'

## Importing Data

In [2]:
#Download Dataset
import opendatasets as od

od.download("https://www.kaggle.com/datasets/aysendegerli/qatacov19-dataset/data")

Skipping, found downloaded files in "./qatacov19-dataset" (use force=True to force download)


In [12]:
#!/usr/bin/python                                                  
from PIL import Image                                              
import os, sys                       

train_img_path = "/home/cahsi/Josh/SAM/work/SAM/qatacov19-dataset/QaTa-COV19/QaTa-COV19-v2/Train Set/Images/*.png"
test_img_path = "/home/cahsi/Josh/SAM/work/SAM/qatacov19-dataset/QaTa-COV19/QaTa-COV19-v2/Test Set/Images/*.png"

train_mask_path = "/home/cahsi/Josh/SAM/work/SAM/qatacov19-dataset/QaTa-COV19/QaTa-COV19-v2/Train Set/Ground-truths/*.png"
test_mask_path = "/home/cahsi/Josh/SAM/work/SAM/qatacov19-dataset/QaTa-COV19/QaTa-COV19-v2/Test Set/Ground-truths/*.png"

In [13]:
from skimage import io

# Load all images in the current folder that end with .png
train_img = io.imread_collection(train_img_path)
test_img = io.imread_collection(test_img_path)
train_mask = io.imread_collection(train_mask_path)
test_mask = io.imread_collection(test_mask_path)

In [14]:
#Images Info
print("Length of training raw images: " + str(len(train_images_raw)) + "      Shape of an training raw image: " + str(train_images_raw[0].shape))
print("Length of training mask images: " + str(len(train_images_masks)) + "     Shape of an training mask image: " + str(train_images_masks[0].shape))

print("\nLength of test raw images: " + str(len(test_images_raw)) + "       Shape of an test raw image: " + str(test_images_raw[0].shape))
print("Length of test mask images: " + str(len(test_images_masks)) + "      Shape of an test mask image: " + str(test_images_masks[0].shape))

Length of training raw images: 7145      Shape of an training raw image: (224, 224)
Length of training mask images: 2113     Shape of an training mask image: (224, 224)

Length of test raw images: 7145       Shape of an test raw image: (224, 224)
Length of test mask images: 2113      Shape of an test mask image: (224, 224)


## Resizing Images


In [15]:
from skimage.transform import resize

def resize_images(images, mask):
    output = []
    if mask:
        for mask in images:
            # Perform resizing with nearest neighbor interpolation to maintain binary values
            resized_mask = (resize(mask, (256, 256), order=0, anti_aliasing=False) > 0.5).astype(np.uint8)
            output.append(resized_mask)
    else:
        for image in images:
            resized_image = resize(image, (256, 256), anti_aliasing=False)
            output.append(resized_image)
    
    return output

# Resize training images to 256x256
train_img = resize_images(train_img, False)

# Resize testing images to 256x256
test_img = resize_images(test_img, False)

# Resize training masks to 256x256
train_mask = resize_images(train_mask, True)

# Resize testing masks to 256x256
test_mask = resize_images(test_mask, True)

#Convert to np array
train_img = np.array(train_img)
test_img = np.array(test_img)
train_mask = np.array(train_mask)
test_mask = np.array(test_mask)

In [16]:
#Print Shape of resized images
print("Shape of resized training raw image: " + str(train_images.shape))
print("Shape of resized testing raw image: " + str(test_images.shape))
print("Shape of resized training mask image: " + str(train_masks.shape))
print("Shape of resized testing mask image: " + str(test_masks.shape))

Shape of resized training raw image: (7145, 256, 256)
Shape of resized testing raw image: (7145, 256, 256)
Shape of resized training mask image: (2113, 256, 256)
Shape of resized testing mask image: (2113, 256, 256)


In [17]:
#Combine images and masks respectively
images = np.concatenate((train_images, test_images))
masks = np.concatenate((train_masks, test_masks))

print("Shape of resized images: " + str(images.shape))
print("Shape of resized masks: " + str(masks.shape))

Shape of resized images: (14290, 256, 256)
Shape of resized masks: (4226, 256, 256)


## Create Datasets

In [18]:
from datasets import Dataset
from PIL import Image

# Convert the NumPy arrays to Pillow images and store them in a dictionary
dataset_dict = {
    "image": [Image.fromarray(img) for img in images],
    "mask": [Image.fromarray(mask) for mask in masks],
}

# Create the dataset using the datasets.Dataset class
dataset = Dataset.from_dict(dataset_dict)

ArrowInvalid: Column 1 named mask expected length 14290 but got length 4226

In [None]:
#Check dataset info
print("Dataset Info:\n" + str(dataset))

## Visualize Data

In [None]:
#Get random image from training set and plot
img_num = random.randint(0, images.shape[0]-1)
example_image = dataset[img_num]["image"]
example_mask = dataset[img_num]["mask"]

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# Plot the first image on the left
axes[0].imshow(np.array(example_image), cmap='gray')  # Assuming the first image is grayscale
axes[0].set_title("Image")

# Plot the second image on the right
axes[1].imshow(example_mask, cmap='gray')  # Assuming the second image is grayscale
axes[1].set_title("Mask")

# Hide axis ticks and labels
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

# Display the images side by side
plt.show()

## Get Bounding Boxes

In [None]:
#Function that gets bounding boxes from masks
def get_bounding_box(ground_truth_map):
  # get bounding box from mask
  y_indices, x_indices = np.where(ground_truth_map > 0)
  x_min, x_max = np.min(x_indices), np.max(x_indices)
  y_min, y_max = np.min(y_indices), np.max(y_indices)
  # add perturbation to bounding box coordinates
  H, W = ground_truth_map.shape
  x_min = max(0, x_min - np.random.randint(0, 20))
  x_max = min(W, x_max + np.random.randint(0, 20))
  y_min = max(0, y_min - np.random.randint(0, 20))
  y_max = min(H, y_max + np.random.randint(0, 20))
  bbox = [x_min, y_min, x_max, y_max]

  return bbox

In [None]:
from torch.utils.data import Dataset

class SAMDataset(Dataset):
  """
  This class is used to create a dataset that serves input images and masks.
  It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
  """
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["mask"])

    image = np.array(image.convert("RGB"))

    # get bounding box prompt
    prompt = get_bounding_box(ground_truth_mask)

    # prepare image and prompt for the model
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

In [None]:
# Initialize the processor
from transformers import SamProcessor
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
# Create an instance of the SAMDataset
train_dataset = SAMDataset(dataset=dataset, processor=processor)

In [None]:
print("Training: \n")
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)

In [None]:
# Create a DataLoader instance for the training dataset
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, drop_last=False)

In [None]:
print("Training: \n")
train_batch = next(iter(train_dataloader))
for k,v in train_batch.items():
  print(k,v.shape)

In [None]:
print("Training Batch Shape: " + str(train_batch["ground_truth_mask"].shape))

## Train Model

In [None]:
# Load the model
from transformers import SamModel
model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

In [None]:
from torch.optim import Adam
import monai
# Initialize the optimizer and the loss function
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
#Try DiceFocalLoss, FocalLoss, DiceCELoss
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
from tqdm import tqdm
from statistics import mean
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
from torch.nn.functional import threshold, normalize

#Training loop
num_epochs = 10

# Move your model to GPU devices
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:1")
model = model.to(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
        # forward pass
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=False)

        # compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        # backward pass (compute gradients of parameters w.r.t. loss)
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

In [None]:
# Save the model's state dictionary to a file
torch.save(model.state_dict(), "/home/cahsi/Josh/SAM/sam_model_checkpoint.pth")

## Inference

In [None]:
from transformers import SamModel, SamConfig, SamProcessor
import torch

In [None]:
# Load the model configuration
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# Create an instance of the model architecture with the loaded configuration
basic_sam_model = SamModel(config=model_config)
#Update the model by loading the weights from saved file.
basic_sam_model.load_state_dict(torch.load("/home/cahsi/Josh/SAM/basic_sam_model_checkpoint.pth"))

In [None]:
# set the device to cuda if available, otherwise use cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
basic_sam_model.to(device)

In [None]:
import numpy as np
import random
import torch
import matplotlib.pyplot as plt

# let's take a random training example
idx = random.randint(0, images.shape[0]-1)

# load image
test_image = dataset[idx]["image"]
test_image = np.array(test_image.convert("RGB"))


# get box prompt based on ground truth segmentation map
ground_truth_mask = np.array(dataset[idx]["mask"])
prompt = get_bounding_box(ground_truth_mask)

# prepare image + box prompt for the model
inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt")

# Move the input tensor to the GPU if it's not already there
inputs = {k: v.to(device) for k, v in inputs.items()}

basic_sam_model.eval()

# forward pass
with torch.no_grad():
    outputs = basic_sam_model(**inputs, multimask_output=False)

# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)


fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot the first image on the left
test_image = dataset[idx]["image"]
axes[0].imshow(np.array(test_image), cmap='gray')  # Assuming the first image is grayscale
axes[0].set_title("Image")

# Plot the second image on the right
axes[1].imshow(medsam_seg, cmap='gray')  # Assuming the second image is grayscale
axes[1].set_title("Mask")

# Plot the second image on the right
axes[2].imshow(medsam_seg_prob)  # Assuming the second image is grayscale
axes[2].set_title("Probability Map")

# Hide axis ticks and labels
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

# Display the images side by side
plt.show()

##### Resources
<b> https://github.com/hitachinsk/SAMed

<b> https://github.com/MathieuNlp/Sam_LoRA

<b>https://colab.research.google.com/github/bnsreenu/python_for_microscopists/blob/master/331_fine_tune_SAM_mito.ipynb#scrollTo=aTXUX7xyCEGT