# Notebook: fine-tune SAM (segment anything) on a custom dataset

## Load dataset

<!-- Here we load a small dataset of 130 (image, ground truth mask) pairs.

To load your own images and masks, refer to the bottom of my [SAM inference notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SAM/Run_inference_with_MedSAM_using_HuggingFace_Transformers.ipynb).

See also [this guide](https://huggingface.co/docs/datasets/image_dataset). -->

In [None]:
import os
from torch.utils.data import Dataset, DataLoader
import skimage.io as io
import matplotlib.pyplot as plt
import cv2
import numpy as np

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, set = 'train'):
        """
        Args:
            root_dir (string): Directory with all the images and masks.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.images_dir = os.path.join(root_dir, 'images')
        self.masks_dir = os.path.join(root_dir, 'masks')
        self.image_names = [f for f in os.listdir(self.images_dir)]
        # order
        self.image_names.sort()
        np.random.seed(0)
        np.random.shuffle(self.image_names)
        self.set = set
        # 60% train 20% val and 20% test
        if set == 'train':
            self.image_names = self.image_names[:int(len(self.image_names)*0.6)]
        elif set == 'val':
            self.image_names = self.image_names[int(len(self.image_names)*0.6):int(len(self.image_names)*0.8)]
        elif set == 'test':
            self.image_names = self.image_names[int(len(self.image_names)*0.8):]
        else:
            raise ValueError('set must be "train", "val" or "test"')

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.image_names[idx])
        mask_name = os.path.join(self.masks_dir, f'mask_{self.image_names[idx].split("_")[1]}')  # Adjust based on your naming convention
        image = cv2.imread(img_name)
        mask = cv2.imread(mask_name)
        idx = self.image_names[idx].split('_')[1]
        idx = int(idx.split('.')[0])
        
        # # resize both image and mask to 256x256
        image = cv2.resize(image, (256, 256))[:, :, :3]
        mask = cv2.resize(mask, (256, 256))
        # mask of size 256x256
        mask = mask[:, :, 0]
        # Convert mask to binary

        sample = {'image': image, 'mask': mask, 'idx': idx}

        return sample

In [None]:
image_dataset = ImageDataset('new_dataset')

In [None]:
image = image_dataset[2]['image']
mask = image_dataset[2]['mask']

plt.imshow(image)
plt.imshow(mask, alpha=0.5)

In [None]:
import numpy as np

def get_bounding_box(ground_truth_map):
  # get bounding box from mask
  y_indices, x_indices = np.where(ground_truth_map > 0)
  x_min, x_max = np.min(x_indices), np.max(x_indices)
  y_min, y_max = np.min(y_indices), np.max(y_indices)
  # add perturbation to bounding box coordinates
  H, W = ground_truth_map.shape
  x_min = max(0, x_min - np.random.randint(0, 20))
  x_max = min(W, x_max + np.random.randint(0, 20))
  y_min = max(0, y_min - np.random.randint(0, 20))
  y_max = min(H, y_max + np.random.randint(0, 20))
  bbox = [x_min, y_min, x_max, y_max]

  return bbox

In [None]:
from torch.utils.data import Dataset
import os
import rasterio
from PIL import Image  # For handling image file operations
from torchvision.transforms import Resize  # For resizing images

class SAMDataset(Dataset):
    def __init__(self, root_dir, processor):
        """
        Args:
            root_dir (string): Directory with all the images and masks.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.images_dir = os.path.join(root_dir, 'images')
        self.masks_dir = os.path.join(root_dir, 'masks')
        self.image_names = [f for f in os.listdir(self.images_dir)]
        self.processor = processor

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.image_names[idx])
        mask_name = os.path.join(self.masks_dir, f'mask_{self.image_names[idx].split("_")[1]}')  # Adjust based on your naming convention
        image = plt.imread(img_name)
        mask = plt.imread(mask_name)

        # # resize both image and mask to 256x256
        image = cv2.resize(image, (256, 256))[:, :, :3]
        mask = cv2.resize(mask, (256, 256))
        # mask of size 256x256
        mask = mask[:, :, 0]

        prompt = get_bounding_box(mask)
        inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
        inputs = {k:v.squeeze(0) for k,v in inputs.items()}
        inputs["ground_truth_mask"] = mask


        return inputs

In [None]:
from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
train_dataset = SAMDataset(root_dir="new_dataset", processor=processor)

In [None]:
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)

## Create PyTorch DataLoader

Next we define a PyTorch Dataloader, which allows us to get batches from the dataset.



In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

In [None]:
batch["ground_truth_mask"].shape

## Load the model

In [None]:
from transformers import SamModel

model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

## Train the model

In [None]:
from torch.optim import Adam
import monai

# Note: Hyperparameter tuning could improve performance here
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)

seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

num_epochs = 20

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
      # forward pass
      outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].to(device),
                      multimask_output=False)

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      ground_truth_masks = batch["ground_truth_mask"].float().to(device)
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

In [None]:
# save the model
model.save_pretrained("finetuned_sam_model")

## Inference

Important note here: as we used the Dice loss with `sigmoid=True`, we need to make sure to appropriately apply a sigmoid activation function to the predicted masks. Hence we won't use the processor's `post_process_masks` method here.

In [None]:
# get box prompt based on ground truth segmentation map
ground_truth_mask = np.array(image_dataset[2]["mask"])
prompt = get_bounding_box(ground_truth_mask)

# prepare image + box prompt for the model
inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt").to(device)
for k,v in inputs.items():
  print(k,v.shape)

In [None]:
model.eval()

# forward pass
with torch.no_grad():
  outputs = model(**inputs, multimask_output=False)

In [None]:
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

In [None]:
import matplotlib.pyplot as plt

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

fig, axes = plt.subplots()
# new_image = np.array(image.transpose(1,2,0))
# new_image = (new_image*255).astype(np.uint8)
axes.imshow(image)
show_mask(medsam_seg, axes)
axes.title.set_text(f"Predicted mask")
axes.axis("off")

Compare this to the ground truth segmentation:

In [None]:
fig, axes = plt.subplots()

axes.imshow(image)
show_mask(ground_truth_mask, axes)
axes.title.set_text(f"Ground truth mask")
axes.axis("off")