In [None]:
%pip install segment-anything-py

In [45]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import torch
import torchvision
import sys
from tqdm import  tqdm
import torchvision.transforms as transforms

In [3]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

In [47]:
import os

## organizing the data in a format that SAM expects
path_to_labels=['C:\\Users\\itsjo\\Documents\\repos\\assembly_glovebox_dataset\\data\\Labels\\Test_Subject_1\\id\\J\\Side_View', 
                                          'C:\\Users\\itsjo\\Documents\\repos\\assembly_glovebox_dataset\\data\\Labels\\Test_Subject_1\\id\\J\\Top_View'] 
path_to_images=['C:\\Users\\itsjo\\Documents\\repos\\assembly_glovebox_dataset\\data\\images\\Test_Subject_1\\id\\J\\Side_View', 
                'C:\\Users\\itsjo\\Documents\\repos\\assembly_glovebox_dataset\\data\\images\\Test_Subject_1\\id\\J\\Top_View']


images = [os.path.join(path, file) for path in path_to_images for file in os.listdir(path) if file.endswith('.png')]
masks = [os.path.join(path, file) for path in path_to_labels for file in os.listdir(path)]

In [None]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

# here is the model
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

In [15]:
predictor = SamPredictor(sam)

In [52]:
from PIL import Image

image = cv2.imread("J_GL_0.0.png")
image.shape[0]

1080

In [23]:
# not using the predictor because that doesn't allow us to compute gradients
def get_embeds(predictor, points=None, mask_input=None, return_logits=False, input_boxes=None):
  # setting the embeddings and encoders
  image_embedding = predictor.features

  transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])

  # Embed prompts
  sparse_embeddings, dense_embeddings = predictor.model.prompt_encoder(
      points=points,
      boxes=transformed_boxes,
      masks=mask_input,
  )

  return sparse_embeddings, dense_embeddings

def get_prediction(predictor, sparse_embeddings, dense_embeddings, multimask_output=False, return_logits=False):
    # Predice Masks
    low_res_masks, iou_predictions = predictor.model.mask_decoder(
        image_embeddings=predictor.features,
        image_pe=predictor.model.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=multimask_output,
    )

    # Upscale the masks to the original image resolution
    masks = predictor.model.postprocess_masks(low_res_masks, predictor.input_size, predictor.original_size)

    if not return_logits:
          masks = masks > predictor.model.mask_threshold

    return masks

In [24]:
num_epochs = 100
optimizer = torch.optim.Adam(sam.mask_decoder.parameters())
loss_fn = torch.nn.CrossEntropyLoss() # MSE was used in the example
device = "cuda"

In [51]:
losses = []
for epoch in range(num_epochs):
  epoch_losses  = []
  for (x, y) in tqdm(zip(images, masks)):

    # you'll have to adjust this method for allowing multiclass
    # one box per class prediction?
    # here they did backprop on each image - https://colab.research.google.com/drive/1F6uRommb3GswcRlPZWpkAQRMVNdVH7Ww?usp=sharing
    # how did MedSAM do it on the batch - https://github.com/bowang-lab/MedSAM/blob/main/finetune_and_inference_tutorial_auto_seg.ipynb
    # in the code of SAM they allow you to compute on the batch - https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L129

    # here, we see that SAM (even when batching) has to run an individual prediction per image anyways - https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/sam.py#L101-L131

    with torch.no_grad():
      # define the boxes here
      input_boxes = torch.from_numpy(
                        np.array(
                            [
                              [0, 0, image.shape[0], input.shape[1]],
                              [0, 0, image.shape[0], input.shape[1]]
                            ]
                          )                 
                    ).float().to(device)
      
      image = cv2.imread(x)
      predictor.set_image(image)

      sparse_embeddings, dense_embeddings = get_embeds(predictor, input_boxes=input_boxes)

    mask_preds = get_prediction(predictor, sparse_embeddings, dense_embeddings)

    mask = Image.open(y)
    transform = transforms.toTensor()
    mask = transform(mask)
    mask = mask[0, :, :] + mask[1, :, :] + torch.mul(mask[2, :, :], 2)

    loss = loss(mask_preds, mask.long())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_losses.append(loss.item())
losses.append(epoch_losses)
print(f"mean loss is {np.average(epoch_losses)}") 
print(f'EPOCH: {epoch}')



0it [00:00, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 6.00 GiB total capacity; 3.60 GiB already allocated; 80.38 MiB free; 4.66 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF