## Training/Fine Tuning Segment Anything Model

Based on Code from https://encord.com/blog/learn-how-to-fine-tune-the-segment-anything-model-sam/

### Link Colab to Drive and also import any relevant libraries

In [1]:
from google.colab import drive
drive.mount('/content/drive')
current_directory = '/content/drive/MyDrive/Colab Notebooks/FibreAnalysis/Data'

from google.colab.patches import cv2_imshow

Mounted at /content/drive


In [2]:
using_colab = True
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

    !mkdir images
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg

    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

PyTorch version: 2.0.1+cu118
Torchvision version: 0.15.2+cu118
CUDA is available: True
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-qayi155l
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-qayi155l
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment-anything
  Building wheel for segment-anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment-anything: filename=segment_anything-1.0-py3-none-any.whl size=36588 sha256=82f4903ad9a86259d8a5b340e22942495f7213cfa5a40785d06b657c9a6c8597
  Stored in directory: /tmp/pip-ephem-wheel-cache-vz5zkoe5/wheels/10/cf/59/9ccb2f0a1bcc81d4fbd0e501680b5d088d690c6cfbc02dc99d
Successful

In [3]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import torch.nn.functional as F
from statistics import mean
from tqdm import tqdm
from torch.nn.functional import threshold, normalize
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import torch
from matplotlib.patches import Rectangle
from collections import defaultdict
import torch
from segment_anything.utils.transforms import ResizeLongestSide
import random

In [4]:
#current_directory = os.getcwd()
current_directory = r'/content/drive/MyDrive/Colab Notebooks/FibreAnalysis'
OutPreparedImages = os.path.join(current_directory, 'Data', 'Prepared', 'Set', 'images', 'train', '')
OutPreparedMasks = os.path.join(current_directory, 'Data', 'Prepared', 'Set', 'masks', 'train', '')

In [5]:
# Helper functions provided in https://github.com/facebookresearch/segment-anything/blob/9e8f1309c94f1128a6e5c047a10fdcb02fc8d651/notebooks/predictor_example.ipynb

def show_anns(anns): # From https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)


def show_masks(masks, ax, random_color=False):
    for mask in masks:
        if random_color:
            color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
        else:
            color = np.array([30/255, 144/255, 255/255, 0.6])
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        ax.imshow(mask_image)

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def show_boxes(bboxes, ax):
    for bbox in bboxes:
        x_min, y_min, x_max, y_max = bbox
        rect = plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
                             fill=False, edgecolor='red', linewidth=2)
        ax.add_patch(rect)

### Prepare Model

In [8]:
# Model Checkpoints can be downloaded from https://github.com/facebookresearch/segment-anything#model-checkpoints
# Options are base, Large and Huge Vit_b, Vit_L and Vit_H

# Perhaps use a bigger model here????
model_type = 'vit_h'
#checkpoint = r'/content/drive/MyDrive/Colab Notebooks/FibreAnalysis/Data/sam_vit_h_4b8939.pth'
checkpoint = r'/content/sam_vit_h_4b8939.pth'
   # train on the GPU or on the CPU, if a GPU is not available - CPU will be very slow
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

sam_model = sam_model_registry[model_type](checkpoint=checkpoint)
sam_model.to(device)
sam_model.train()

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-31): 32 x Block(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=1280, out_features=5120, bias=True)
          (lin2): Linear(in_features=5120, out_features=1280, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d

### Build ground truth bounding boxes and object masks from the image masks

In [11]:

bbox_coords = {}
ground_truth_masks = {}
#for f in sorted(Path(OutPreparedMasks).iterdir()):
for f in random.sample(sorted(Path(OutPreparedMasks).iterdir()), 100):
    k = f.stem[:]

    mask = cv2.imread(f.as_posix(), cv2.IMREAD_GRAYSCALE)
    _, mask1 = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)

    H, W = mask1.shape
    contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    bounding_boxes = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        height, width = mask.shape
        bounding_boxes.append(np.array([x, y, x + w, y + h]))
    if len(bounding_boxes) > 0:
        bbox_coords[k] = bounding_boxes

        ground_truth_masks[k] = [~(mask == 0)] * len(bbox_coords[k])

    # # Assuming 'bbox_coords' is a dictionary with bounding box coordinates for each key 'k'
    # for k, bbox_list in bbox_coords[k].items():
    #     masks_for_k = []
    #     for bbox_coord in bbox_list:
    #         # Assuming 'mask' is a 2D NumPy array or a PyTorch tensor
    #         mask_for_bbox = ~(mask == 0)
    #         masks_for_k.append(mask_for_bbox)
    #     ground_truth_masks[k] = masks_for_k


In [12]:
############################################################################################
# Preprocess the images for use in the SAM Model
############################################################################################


transformed_data = defaultdict(dict)
for k in bbox_coords.keys():

  image_path = os.path.join(OutPreparedImages, f'image{k[4:]}.png')
  image = cv2.imread(image_path)
  #image = cv2.imread(f'{OutPreparedImages}{k}.png')
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  transform = ResizeLongestSide(sam_model.image_encoder.img_size)
  input_image = transform.apply_image(image)
  input_image_torch = torch.as_tensor(input_image, device=device)
  transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]


  input_image = sam_model.preprocess(transformed_image)
  original_image_size = image.shape[:2]
  input_size = tuple(transformed_image.shape[-2:])

  transformed_data[k]['image'] = input_image
  transformed_data[k]['input_size'] = input_size
  transformed_data[k]['original_image_size'] = original_image_size

############################################################################################

############################################################################################

### Set up hyperparameters

In [13]:
# Set up the optimizer, hyperparameter tuning will improve performance here
lr = 1e-2
wd = 0
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=lr, weight_decay=wd)

# loss_fn = torch.nn.MSELoss()
# loss_fn = torch.nn.BCELoss()
loss_fn = loss_fn = torch.nn.BCEWithLogitsLoss()  # More suitable loss function for multi mask pixel level similarities

keys = list(set(bbox_coords.keys()))   # Get unique list of keys

## Fine Tune the model

In [16]:

num_epochs = 100
losses = []

for epoch in range(num_epochs):
  epoch_losses = []
  # Just train on the first 2 examples
  #for k in keys:
  for k in random.sample(keys, 10):
    input_image = transformed_data[k]['image'].to(device)
    input_size = transformed_data[k]['input_size']
    original_image_size = transformed_data[k]['original_image_size']

    # No grad here as we don't want to optimize the encoders
    with torch.no_grad():
      image_embedding = sam_model.image_encoder(input_image)

      prompt_boxes = bbox_coords[k]  # Multiple bounding boxes
      prompt_boxes1 = np.array(prompt_boxes)  # Convert to NumPy array
      boxes = transform.apply_boxes(prompt_boxes1, original_image_size)

######################################################################################

######################################################################################
# Convert the image tensor to a numpy array
    #  image_np = input_image.squeeze(0).permute(1, 2, 0).cpu().numpy()


      boxes_torch = torch.as_tensor(boxes, dtype=torch.float, device=device)

      sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
          points=None,
          boxes=boxes_torch,
          masks=None,
      )

    low_res_masks, iou_predictions = sam_model.mask_decoder(
      image_embeddings=image_embedding,
      image_pe=sam_model.prompt_encoder.get_dense_pe(),
      sparse_prompt_embeddings=sparse_embeddings,
      dense_prompt_embeddings=dense_embeddings,
      multimask_output=True,  # Output multiple masks
    )

    upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
    # Convert RGB image to grayscale
    gray_image = torch.mean(upscaled_masks, dim=1, keepdim=True)
    binary_masks = normalize(threshold(gray_image, 0.0, 0))

    gt_masks_resized = []
    for gt_mask in ground_truth_masks[k]:  # Loop over multiple ground truth masks
        gt_mask_resized = torch.from_numpy(np.resize(gt_mask, (1, gt_mask.shape[0], gt_mask.shape[1]))).to(device)

        gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)
        gt_masks_resized.append(gt_binary_mask)

# Stack the individual gt_binary_mask tensors along the batch dimension
    gt_masks_tensor = torch.stack(gt_masks_resized, dim=0)

 #   loss = loss_fn(binary_masks, torch.stack(gt_masks_resized))
    loss = loss_fn(binary_masks, gt_masks_tensor)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_losses.append(loss.item())

  losses.append(epoch_losses)
  print(f'EPOCH: {epoch}')
  print(f'Loss: {mean(epoch_losses)}')
  if epoch % 10 == 0:
    PATH = r'/content/drive/MyDrive/Colab Notebooks/FibreAnalysis/SAMTraining'
    filename = os.path.join(PATH, f'SAMModel_epoch_{epoch}.pt')
    torch.save(sam_model.state_dict(), filename)


RuntimeError: ignored

In [None]:
  epoch = 20
  if epoch % 10 == 0:
    PATH = r'/content/drive/MyDrive/Colab Notebooks/FibreAnalysis/SAMTraining'
    filename = os.path.join(PATH, f'SAMModel_epoch_test_{epoch}.pt')
    torch.save(sam_model.state_dict(), filename)

In [None]:
PATH = r'/content/drive/MyDrive/Colab Notebooks/FibreAnalysis/SAMModel1'
torch.save(sam_model.state_dict(), PATH)

## Display losses

In [None]:
mean_losses = [mean(x) for x in losses]
# mean_losses

plt.plot(list(range(len(mean_losses))), mean_losses)
plt.title('Mean epoch loss')
plt.xlabel('Epoch Number')
plt.ylabel('Loss')

plt.show()

### Display test image with the predicted masks

In [None]:

img_files = os.listdir(os.path.join(current_directory, 'train/val')) # Get list of files in the directory
for imageFile in random.sample(img_files, 1):  # Take one example

   image = cv2.imread(imageFile)
   mask_generator = SamAutomaticMaskGenerator(sam_model)
   masks = mask_generator.generate(image)

   plt.figure(figsize=(20,20))
   plt.imshow(image)
   show_anns(masks)
   plt.axis('off')
   plt.show()
