In [1]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw

In [2]:
# To Generate the Random Points and Bounding Boxes using https://www.jasondavies.com/poisson-disc/ algorithm
def generate_random_bounding_boxes(image_width, image_height, number_boxes, min_distance, box_percentage):
    def generate_random_points(image_width, image_height, min_distance, num_points, k=30):
        image_size = (image_width, image_height)
        cell_size = min_distance / np.sqrt(2)
        grid_width = int(np.ceil(image_width / cell_size))
        grid_height = int(np.ceil(image_height / cell_size))
        grid = np.empty((grid_width, grid_height), dtype=np.int32)
        grid.fill(-1)

        points = []
        active_points = []

        def generate_random_point():
            return np.random.uniform(0, image_width), np.random.uniform(0, image_height)

        def get_neighboring_cells(point):
            x, y = point
            x_index = int(x / cell_size)
            y_index = int(y / cell_size)

            cells = []
            for i in range(max(0, x_index - 2), min(grid_width, x_index + 3)):
                for j in range(max(0, y_index - 2), min(grid_height, y_index + 3)):
                    cells.append((i, j))

            return cells

        def is_point_valid(point):
            x, y = point
            if x < 0 or y < 0 or x >= image_width or y >= image_height:
                return False

            x_index = int(x / cell_size)
            y_index = int(y / cell_size)

            cells = get_neighboring_cells(point)
            for cell in cells:
                if grid[cell] != -1:
                    cell_points = points[grid[cell]]
                    if np.any(np.linalg.norm(np.array(cell_points) - np.array(point), axis=None) < min_distance):
                        return False

            return True

        def add_point(point):
            x, y = point
            x_index = int(x / cell_size)
            y_index = int(y / cell_size)

            points.append(point)
            index = len(points) - 1
            grid[x_index, y_index] = index
            active_points.append(point)

        start_point = generate_random_point()
        add_point(start_point)

        while active_points and len(points) < num_points:
            random_index = np.random.randint(len(active_points))
            random_point = active_points[random_index]
            added_new_point = False

            for _ in range(k):
                angle = 2 * np.pi * np.random.random()
                radius = min_distance + min_distance * np.random.random()
                new_point = (random_point[0] + radius * np.cos(angle), random_point[1] + radius * np.sin(angle))
                if is_point_valid(new_point):
                    add_point(new_point)
                    added_new_point = True

            if not added_new_point:
                active_points.pop(random_index)

        return points
    

    points = generate_random_points(image_width, image_height, min_distance, number_boxes)
    
    
    box_width = int(image_width * box_percentage)
    box_height = int(image_height * box_percentage)

    bounding_boxes = []
    for point in points:
        x = int(point[0] - box_width / 2)
        y = int(point[1] - box_height / 2)

        # Adjust the coordinates to keep the bounding box within the image
        x = max(0, min(x, image_width - box_width))
        y = max(0, min(y, image_height - box_height))

        bounding_boxes.append([x, y, x+box_width, y+box_height])

    return bounding_boxes

In [3]:
def generate_bounding_boxes(image_width, image_height, box_size):
    stride = int(0.6 * box_size)
    num_boxes_horizontal = (image_width - box_size) // stride + 1
    num_boxes_vertical = (image_height - box_size) // stride + 1

    bounding_boxes = []
    for i in range(num_boxes_horizontal):
        for j in range(num_boxes_vertical):
            start_x = i * stride
            start_y = j * stride
            end_x = start_x + box_size
            end_y = start_y + box_size

            # Adjust for boxes extending beyond image boundaries
            end_x = min(end_x, image_width)
            end_y = min(end_y, image_height)

            bounding_boxes.append((start_x, start_y, end_x, end_y))

    return bounding_boxes

In [4]:
#Read image using PIL
image = Image.open('4389.png')

In [5]:
# bbox = generate_random_bounding_boxes(image.size[0], image.size[1], 50, image.size[0]/5, 0.25)
bbox = generate_bounding_boxes(image.size[0], image.size[1], int(image.size[0]*0.25))

In [6]:
from modeling.segment_anything.utils.transforms import ResizeLongestSide
from modeling.segment_anything import prepare_sam
from config import cfg  # Import the default config file
import torch
import torch.nn.functional as F
from torch.nn.functional import threshold, normalize
import cv2

In [8]:
cfg.freeze()
print(cfg)

BBOX:
  BOX_LIMITER: 100
  MIN_DISTANCE: 50
  NUMBER: 30
  SIZE_REF: 0.25
DATALOADER:
  NUM_WORKERS: 8
  TRAIN_DATA: 0.8
  VALID_DATA: 0.2
DATASETS:
  ROOT_DIR: /home/aghosh57/Kerner-Lab/all_dataset/
INPUT:
  
LOGGER:
  LEVEL: INFO
LOSS:
  DICE_LOSS_WEIGHT: 1
  FOCAL_LOSS_WEIGHT: 5
MASKS:
  MIN_AREA: 50
MODEL:
  CHECKPOINT: /home/aghosh57/Kerner-Lab/SAM-FineTuning/logs/Jun26_22-55-27/model_checkpoints/sam_checkpoint_1.pth
  DEVICE: cuda
  SAVE_INTERVAL: 1
  TYPE: base
OUTPUT_DIR: ./logs/
SOLVER:
  ITEMS_PER_BATCH: 1
  MAX_EPOCHS: 1
  MIN_LR: 1e-06
  START_LR: 0.01
  WEIGHT_DECAY: 0.0001
TEST:
  ITEMS_PER_BATCH: 6
VALID:
  ITEMS_PER_BATCH: 8


In [9]:
# Get the model and download the checkpoint if needed
model = prepare_sam(checkpoint=cfg.MODEL.CHECKPOINT, model_type = cfg.MODEL.TYPE)
device = cfg.MODEL.DEVICE
model.to(device)

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()
    )


In [10]:
#Set the portion of the model to be trained (We will train only the mask_decoder part)
for name, param in model.named_parameters():
    param.requires_grad = False

In [11]:
# Image needs to be resized to 1024*1024 and necessary preprocessing should be done
# image = cv2.imread('4389.png')
scale_factor = 1024 / max(image.size[0], image.size[1])

sam_transform = ResizeLongestSide(model.image_encoder.img_size)
resize_img = sam_transform.apply_image(np.array(image))
resize_img = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device)
resize_img = model.preprocess(resize_img[None,:,:,:]) # (1, 3, 1024, 1024)

#scale the bbox prompts and point prompts according to the scale factor
bbox_prompts = np.around(np.array(bbox) * scale_factor)

bbox_prompts = torch.as_tensor(bbox_prompts).to(device)

In [12]:
model.eval()

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()
    )


In [13]:
with torch.no_grad():
    image_embeddings = model.image_encoder(resize_img)  # (B,256,64,64)

In [14]:
sparse_embeddings, dense_embeddings = model.prompt_encoder(
    points=None,
    boxes=bbox_prompts,
    masks=None,
)

In [15]:
low_res_masks, iou_predictions = model.mask_decoder(
    image_embeddings=image_embeddings.to(device),  # (B, 256, 64, 64)
    image_pe=model.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
    sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
    dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
    multimask_output=True,
)

In [19]:
iou_predictions

tensor([[0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409],
        [0.5409]], device='cuda:0')

In [18]:
low_res_masks

tensor([[[[-54420.0977, -56006.5117, -32643.9688,  ..., -56006.5117,
           -32643.8242, -37241.1289],
          [-16110.6406, -33661.3164,  -2513.8547,  ..., -33661.3164,
            -2513.8315, -18808.5117],
          [-51992.5039, -56582.9531, -13255.1260,  ..., -56582.7109,
           -13254.9395, -13576.3770],
          ...,
          [-16110.6406, -33661.3164,  -2513.8682,  ..., -33661.2227,
            -2513.6069, -18808.0977],
          [-51992.5000, -56582.9492, -13254.7637,  ..., -56582.5273,
           -13254.4746, -13575.9336],
          [ -3193.0127, -31927.9219,  -9549.2617,  ..., -31927.8945,
            -9549.1865, -16216.1758]]],


        [[[-54420.0234, -56006.4258, -32643.9199,  ..., -56006.4258,
           -32643.7305, -37241.0664],
          [-16110.6172, -33661.2695,  -2513.8540,  ..., -33661.2695,
            -2513.7273, -18808.4629],
          [-51992.4297, -56582.8867, -13255.1289,  ..., -56582.8242,
           -13254.8418, -13576.2676],
          ...,
   

In [16]:
upscaled_masks = model.postprocess_masks(low_res_masks, (1024, 1024), max(image.size[0], image.size[1])).to(device)
high_res_masks = normalize(threshold(upscaled_masks, 0.0, 0)).to(device).float()

In [17]:
upscaled_masks

tensor([[[[-54420.0977, -54986.6758, -55224.9219,  ..., -33721.0078,
           -35599.2852, -37241.1289],
          [-40738.1445, -43340.9883, -46610.0469,  ..., -23445.0781,
           -27524.2402, -30658.0508],
          [-20834.2969, -26418.4453, -34111.7031,  ...,  -8028.7134,
           -15222.4521, -20478.7109],
          ...,
          [-46238.4883, -48848.8477, -51732.9180,  ..., -13834.1641,
           -13485.2227, -13975.3730],
          [-20620.8691, -27803.7812, -37957.8867,  ..., -12104.3496,
           -13701.5957, -15273.2598],
          [ -3193.0127, -13455.4814, -28506.3164,  ..., -10803.3076,
           -13835.1807, -16216.1758]]],


        [[[-54420.0234, -54986.5977, -55224.8359,  ..., -33720.9180,
           -35599.2109, -37241.0664],
          [-40738.0898, -43340.9219, -46609.9805,  ..., -23444.9863,
           -27524.1660, -30657.9922],
          [-20834.2676, -26418.4082, -34111.6523,  ...,  -8028.6167,
           -15222.3809, -20478.6582],
          ...,
   

In [None]:
high_res_masks = np.squeeze(high_res_masks.cpu().numpy())

In [None]:
# Stack the masks to get the final mask
final_mask = np.stack(high_res_masks)
final_mask = np.sum(final_mask, axis=0)

In [None]:
#Plot the bounding boxes on the image
for box in bbox:
    draw = ImageDraw.Draw(image)
    draw.rectangle(box, outline='red')

In [None]:
high_res_masks.shape

In [None]:
fig, ax = plt.subplots(1, 3)
ax[0].imshow(image)
ax[0].set_title("Original Image")
ax[1].imshow(final_mask)
ax[1].set_title("Full Predicted Image")
ax[2].imshow(high_res_masks[21])
ax[2].set_title("Predicted Mask")

In [None]:
# Plot the masks and the image (Only For Visualization Purposes)
for i in range(len(high_res_masks)):
    fig, ax = plt.subplots(1, 3)
    ax[0].imshow(image)
    ax[0].set_title("Original Image")
    ax[1].imshow(final_mask)
    ax[1].set_title("Full Predicted Image")
    ax[2].imshow(high_res_masks[i])
    ax[2].set_title("Predicted Mask")
    plt.show()