## Set up

In [None]:
qwPYTORCH_NO_CUDA_MEMORY_CACHING=1

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import cv2
from torch import cuda
import os
import torch
import numpy as np
import random
np.set_printoptions(precision=15)

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

!nvidia-smi

In [None]:
# import wandb
# # wandb.login()
# !wandb login --relogin
# try:
#     run = wandb.init(project="OM_AI_v1", name="fine-tuning2")
# except wandb.CommError as e:
#     print(f"Error: {e}")

In [None]:
# masks_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_code/one_image_mask_test1/'

This is a 1 image training process.

The annotations are imported from Roboflow. Each mask is a dict of keys:

* `id` - annotation id
* `image_id` - corresponding image id
* `category_id` - class id
* `bbox` - XYHW coordinates
* `area` - area of the bbox
* `segmentation` `[List[float]]` - polygon coordinates of the mask
* `iscrowd`

In [None]:
import json 
import cv2
import numpy as np
import plotly.express as px
from matplotlib.path import Path
import math

input_dir = '/workspace/raid/OM_DeepLearning/XMM_OM_code/OM_sky_images-6/valid/'
json_file_path = input_dir+'_annotations.coco.json'

with open(json_file_path, 'r') as f:
    data = json.load(f)

#take one image
# image = cv2.imread(input_dir+data['images'][1]['file_name'], cv2.IMREAD_GRAYSCALE)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# print(np.max(image))
image = np.ones((256, 256, 3))
image_id = data['images'][1]['id']
bbox_coords = {}

masks = [data['annotations'][a] for a in range(len(data['annotations'])) if data['annotations'][a]['image_id'] == image_id]
ground_truth_masks = {}

# Roboflow segmentations are polygon points, and should be converted to masks
def create_mask(points, image_size):
    polygon = [(points[i], points[i+1]) for i in range(0, len(points), 2)]
    mask = np.zeros(image_size, dtype=np.uint8)
    
    cv2.fillPoly(mask, [np.array(polygon, dtype=np.int32)], 1)
    return mask

for i in range(len(masks)):
    xyhw = masks[i]['bbox']
    points = masks[i]['segmentation'][0]
    mask = create_mask(points, image.shape[:2])
    ground_truth_masks[masks[i]['id']] = np.ones((256, 256))
    bbox_coords[masks[i]['id']] = [0, 0, 255,255]

In [None]:
ground_truth_masks = {0:ground_truth_masks[0]}
bbox_coords =  {0:bbox_coords[0]}

In [None]:
ground_truth_masks

In [None]:
import matplotlib.colors as mcolors
import numpy.ma as ma

def display_masks(masks):
    cmap = mcolors.ListedColormap(['lightblue']) 
    
    for mask in masks.values():
        masked_data = ma.masked_where(mask == 0, mask)
        plt.imshow(masked_data, alpha=0.4, cmap=cmap)
        plt.contour(mask, colors='darkviolet', linewidths=0.1)  # contour color

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(image)

plt.subplot(1, 2, 2)
plt.imshow(image)
display_masks(ground_truth_masks)
plt.show()

In [None]:
# if xyhw[2] >2 or xyhw[3] >2:
#         fig = px.imshow(mask)
#         fig.update_layout(
#             title_text=f'{image_id}<br>bbox (x, y, h, w):({xyhw[0]}, {xyhw[1]}, {xyhw[2]}, {xyhw[3]})',
#             title_x=0.5, 
#             autosize=False,
#             width=700,
#             height=500
#         )        
#         fig.show()
#         # cv2.imwrite(masks_dir+f'{k}_mask{a}.png',  np.array(mask_).astype(int))
#         # # x1, y1, x2, y2 = xyhw[0]-2, xyhw[1]-2, xyhw[2]+ xyhw[0]+2, xyhw[3]+xyhw[1]+2
#         # x1, y1, x2, y2 = xyhw[0], xyhw[1], xyhw[2]+ xyhw[0], xyhw[3]+xyhw[1]
        
#         # bbox_coords[f'{k}_mask{a}'] = np.array([x1, y1, x2, y2])
#         # keys.append(f'{k}_mask{a}')

## Preprocess data

In [None]:
# ground_truth_masks = {}
# for k in bbox_coords.keys():
#   gt_grayscale = cv2.imread(f'{masks_dir}{k}.png', cv2.IMREAD_GRAYSCALE)
#   ground_truth_masks[k] = (gt_grayscale != 0) # was ==0

In [None]:
bbox_coords

In [None]:
np.unique(ground_truth_masks[0])

In [None]:
# test the masks
from PIL import Image
# for keyy in ground_truth_masks.keys():
#     img = Image.fromarray(ground_truth_masks[keyy])
#     print(keyy)
#     img.show() 

# masks 4,5,6,9 and 16 are bad. 
# out them on 0 and check the prediction.
# bad_masks = [4,5,6,9,16]

# for i in bad_masks:
#     ground_truth_masks[f"S0720251301_L_mask{i}"][:] = 0

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

In [None]:
for name in range(len(ground_truth_masks)):
        # print(f'{masks_dir}{name}.png')
        # image = cv2.imread(f'{masks_dir}{name}.png')
        plt.figure(figsize=(5,5))
        plt.imshow(image)
        show_box(bbox_coords[name], plt.gca())
        show_mask(ground_truth_masks[name], plt.gca())
        plt.axis('off')
        plt.show()

## 🚀 Prepare Mobile SAM Fine Tuning

In [None]:
import sys
import PIL
from PIL import Image

sys.path.append('/workspace/raid/OM_DeepLearning/XMM_OM_code/MobileSAM/')
import mobile_sam
from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

model_type = "vit_t" # tiny version
mobile_sam_checkpoint = "/workspace/raid/OM_DeepLearning/XMM_OM_code/MobileSAM/weights/mobile_sam.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

mobile_sam_model = sam_model_registry[model_type](checkpoint=mobile_sam_checkpoint)
mobile_sam_model.to(device)
mobile_sam_model.train();

**The mean and std should be changed, OM images don't have the same mean/std as normal images**

In [None]:
image_T = np.transpose(image, (2, 1, 0))
pixel_mean = torch.as_tensor([np.mean(image_T[0]), np.mean(image_T[1]),np.mean(image_T[2])], dtype=torch.float, device=device)
pixel_std = torch.as_tensor([np.std(image_T[0]), np.std(image_T[1]),np.std(image_T[2])], dtype=torch.float, device=device)

mobile_sam_model.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
mobile_sam_model.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

In [None]:
# np.mean(image_T[0]), np.std(image_T[0])

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_parameters = count_parameters(mobile_sam_model)
print(f"The model has {num_parameters} trainable parameters")

In [None]:
# wandb.watch(mobile_sam_model)

In [None]:
# def get_bounding_box(ground_truth_map):
#   # get bounding box from mask
#   y_indices, x_indices = np.where(ground_truth_map > 0)
#   x_min, x_max = np.min(x_indices), np.max(x_indices)
#   y_min, y_max = np.min(y_indices), np.max(y_indices)
#   # add perturbation to bounding box coordinates
#   H, W = ground_truth_map.shape
#   x_min = max(0, x_min - np.random.randint(0, 20))
#   # x_max = min(W, x_max + np.random.randint(0, 20))
#   # y_min = max(0, y_min - np.random.randint(0, 20))
#   # y_max = min(H, y_max + np.random.randint(0, 20))
#   w = x_max - x_min
#   h = y_max - y_min
#   bbox = [x_min, y_min, w, h]

#   return np.array(bbox)


Convert the input images into a format SAM's internal functions expect.

In [None]:
image.shape

In [None]:
# Preprocess the images
import os
from collections import defaultdict
import torch
from segment_anything.utils.transforms import ResizeLongestSide
from torchvision.transforms.functional import resize

images_dir = "/workspace/raid/OM_DeepLearning/XMM_OM_dataset/scaled_raw/"
transformed_data = defaultdict(dict)
image = image.astype(np.uint8)

transform = ResizeLongestSide(mobile_sam_model.image_encoder.img_size)
k = 0
negative_mask = np.where(image > 0, True, False)
negative_mask = torch.from_numpy(negative_mask)  
negative_mask = negative_mask.permute(2, 0, 1)
negative_mask = resize(negative_mask, [1024, 1024], antialias=True) 
negative_mask = negative_mask.unsqueeze(0)

# scales the image to 1024x1024 by the longest side (it doesn't matter in my case because images are square)
input_image = transform.apply_image(image)
print(input_image.shape)
input_image_torch = torch.as_tensor(input_image, dtype=torch.float32, device=device)
transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]

# normalization and padding
input_image = mobile_sam_model.preprocess(transformed_image)
input_image = torch.ones_like(input_image)
print(input_image)

original_image_size = image.shape[:2]
input_size = tuple(transformed_image.shape[-2:])
input_image[~negative_mask] = 0
input_image[:] = 1
plt.imshow(input_image[0][0].cpu(), cmap='gray')
plt.show()
transformed_data[k]['image'] = input_image
transformed_data[k]['input_size'] = input_size
transformed_data[k]['original_image_size'] = original_image_size

# apparently, this doesn't free the memory of x since y still points to x
# del input_image_torch
# del transformed_image
# del input_image

In [None]:
transformed_data

In [None]:
torch.cuda.empty_cache() 

print(torch.cuda.memory_summary(device=device, abbreviated=False))

In [None]:
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

lr = 3e-4
wd = 0.0
optimizer = torch.optim.Adam(mobile_sam_model.mask_decoder.parameters(), lr=lr, weight_decay=wd)

scheduler = CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-8) # not very helpful

def dice_loss(pred, target, area, smooth = 1): # smooth is added to avoid division by 0
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))

    # loss = loss - 1.0/area
    # loss = loss.clamp(min=0.0)
    return loss.mean()

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated()/(1024**2)) #MB
print(torch.cuda.memory_reserved()/(1024**2))

## Print model weights before tuning

In [None]:
weights_before = {}
for name, param in mobile_sam_model.state_dict().items():
    weights_before[name] = param.clone()

In [None]:
for name, param in mobile_sam_model.named_parameters():
    if "image_encoder" in name or "prompt_encoder" in name:
        param.requires_grad = False
    else:
        param.requires_grad = True

In [None]:
def check_requires_grad(model, show=False):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print("✅ Param", name, " requires grad.")
        elif param.requires_grad == False:
            print("❌ Param", name, " doesn't require grad.")

In [None]:
check_requires_grad(mobile_sam_model)

## Run fine tuning

In [None]:
from statistics import mean
from tqdm import tqdm
from torch.nn.functional import threshold, normalize
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
import copy

num_epochs = 45
losses = []
mask_epoch_losses = {}
mask_loss = {}

predictor = SamPredictor(mobile_sam_model)

input_image = transformed_data[0]['image'].clone().to(device)
# plt.imshow(input_image[0].permute(1,2,0).detach().cpu().numpy())
# plt.show()
input_size = transformed_data[0]['input_size']
original_image_size = transformed_data[0]['original_image_size']
negative_mask = np.where(image > 0, True, False)
negative_mask = torch.from_numpy(negative_mask)  
negative_mask = negative_mask.permute(2, 0, 1)
negative_mask = negative_mask[0]
negative_mask = negative_mask.unsqueeze(0).unsqueeze(0)
negative_mask = negative_mask.to(device)

for epoch in range(num_epochs):
  epoch_losses = []
  for k in tqdm(ground_truth_masks):

        # with torch.no_grad(): # this doesn't seem to work. I set explicitly the params before training.
        if True:
              image_embedding = mobile_sam_model.image_encoder(input_image)
              prompt_box = np.array(bbox_coords[k])
                
              box = predictor.transform.apply_boxes(prompt_box, original_image_size)
              box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
              box_torch = box_torch[None, :]

              mask_input_torch = torch.as_tensor(ground_truth_masks[k], dtype=torch.float, device=device).unsqueeze(0)
              # print(mask_input_torch.shape)
              # mask_input_torch = mask_input_torch[None, :, :, :]

              sparse_embeddings, dense_embeddings = mobile_sam_model.prompt_encoder(
                  points=None,
                  boxes=box_torch,
                  masks=mask_input_torch,
              )
            
        low_res_masks, iou_predictions = mobile_sam_model.mask_decoder(
          image_embeddings=image_embedding,
          image_pe=mobile_sam_model.prompt_encoder.get_dense_pe(), #  Returns the positional encoding used to encode point prompts,
                                                                   #  applied to a dense set of points the shape of the image encoding.
          sparse_prompt_embeddings=sparse_embeddings,
          dense_prompt_embeddings=dense_embeddings,
          multimask_output=True,
        )
        
        # for mask_ in low_res_masks[0]:
        #     mask = (mask_ > mobile_sam_model.mask_threshold).int()
        #     plt.figure(figsize=(10,10))
        #     plt.imshow(image)
        #     show_mask(mask.detach().cpu().numpy(), plt.gca())
        #     show_box(prompt_box, plt.gca())
        #     plt.axis('off')
        #     plt.show()  
  
        downscaled_masks = mobile_sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
        '''          
        return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask.
        '''
        # return_logits = False
        # if not return_logits:
        #     binary_mask = downscaled_masks > mobile_sam_model.mask_threshold
        #     binary_mask = binary_mask.int()
            
        # binary_mask = normalize(threshold(downscaled_masks, 0.0, 0))
        binary_mask = torch.sigmoid(downscaled_masks - mobile_sam_model.mask_threshold)
        # binary_mask[~negative_mask]= 0

        numpy_binary_mask = binary_mask.detach().cpu().numpy()
        # plt.imshow(numpy_binary_mask[0][0], cmap='viridis')
        # plt.show()
        gt_mask_resized = torch.from_numpy(np.resize(ground_truth_masks[k], (1, 1, ground_truth_masks[k].shape[0], ground_truth_masks[k].shape[1]))).to(device)
        gt_binary_mask = torch.as_tensor(gt_mask_resized >0, dtype=torch.float32) #was >0
        numpy_gt_binary_mask = gt_binary_mask.contiguous().detach().cpu().numpy()
        area = (prompt_box[2]-prompt_box[0])*(prompt_box[3]-prompt_box[1])
        loss = dice_loss(binary_mask, gt_binary_mask, area)
        if k in mask_loss.keys():
            mask_loss[k].append(loss.item())
        else:
            mask_loss[k] = [loss.item()]
            
        epoch_losses.append(loss)
        torch.cuda.empty_cache()
      
  one_loss = sum(epoch_losses)*1.0/len(epoch_losses)
  losses.append(one_loss.item())

  optimizer.zero_grad()
  one_loss.backward()
  optimizer.step()

  print(f'EPOCH: {epoch}. Mean loss: {np.mean(losses)}')
  torch.cuda.empty_cache()
  print("Torch cuda memory allocated:" , torch.cuda.memory_allocated()/(1024**2)) #MB
  print("Torch cuda memory reserved:" , torch.cuda.memory_reserved()/(1024**2))
  print("_________________________")
    
  #or print(f'EPOCH: {epoch}. Mean loss: {losses[-1]}')

In [None]:
binary_mask[0][0].min(), binary_mask[0][0].max()

In [None]:
# all values should  be 1
np.unique(binary_mask.detach().cpu().numpy())

In [None]:
# check the GPU memory after training

torch.cuda.empty_cache()
print(torch.cuda.memory_allocated()/(1024**2)) #MB
print(torch.cuda.memory_reserved()/(1024**2))

In [None]:
#save the model to checkpoint
checkpoint = {
    'model_state_dict': mobile_sam_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}

torch.save(checkpoint, 'mobile_sam_model_checkpoint.pth')

In [None]:
# del mobile_sam_model

## Print model weights after tuning

In [None]:
# run.finish()

# After fine-tuning
# print("After fine-tuning:")
# for name, param in mobile_sam_model.state_dict().items():
#     if not torch.all(torch.eq(weights_before[name], param)):
#         print(f'{name} has changed')
#         print('Old weights:', weights_before[name])
#         print('New weights:', param)

In [None]:
# bad_masks = []

# for keyy in mask_loss.keys():
#     bad_masks.append(mask_loss[keyy])

# items = list(mask_loss.items())

# fig, axes = plt.subplots(5, 4, figsize=(30, 30))

# for i, ax in enumerate(axes.flatten()):
#     x_values = np.arange(1, len(bad_masks[i]) + 1)
    
#     ax.plot(x_values, bad_masks[i])
#     # ax.set_title(f'{items[i][0].split("_")[2]}')
#     ax.set_xlabel('Index')
#     ax.set_ylabel('Value')
#     # ax.set_ylim(0, 1)

# plt.tight_layout()
# plt.show()

In [None]:
# mean_losses = mean(x) for x in losses]
# mean_losses

plt.plot(list(range(len(losses))), losses)
plt.title('Mean epoch loss \n mask with sigmoid')
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.savefig('loss_mask_sigmoid.png')
plt.show()

## Compare the tuned model to the Mobile SAM model

In [None]:
# Load up the model with default weights
sam_model_orig = sam_model_registry[model_type](checkpoint=mobile_sam_checkpoint)
sam_model_orig.to(device);
sam_model_orig.eval();

print(torch.cuda.memory_allocated()/(1024**2)) #MB
print(torch.cuda.memory_reserved()/(1024**2))

In [None]:
# Set up predictors for both tuned and original models
# from segment_anything import sam_model_registry, SamPredictor
predictor_tuned = SamPredictor(mobile_sam_model)
predictor_original = SamPredictor(sam_model_orig)

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated()/(1024**2)) #MB
print(torch.cuda.memory_reserved()/(1024**2))

In [None]:
# # remove gradients requirement on the model

# for name, param in mobile_sam_model.named_parameters():
#         param.requires_grad = False

# mobile_sam_model.eval();

In [None]:
# Create a dict and append tuned/original masks them on each image.

torch.cuda.empty_cache()
print('Before inference:')
print(torch.cuda.memory_allocated()/(1024**2)) #MB
print(torch.cuda.memory_reserved()/(1024**2))

images_from_bbox_tuned = []
images_from_bbox_orig = []
k=0

if True:
# for k in bbox_coords.keys():

    predictor_tuned.set_image(image)
    predictor_original.set_image(image)
    
    input_bbox = np.array(bbox_coords[k])
    with torch.no_grad():
    
        masks_tuned, _, _ = predictor_tuned.predict(
            point_coords=None,
            box=input_bbox,
            multimask_output=False,
        )
        
        masks_orig, _, _ = predictor_original.predict(
            point_coords=None,
            box=input_bbox,
            multimask_output=False,
        )
    
    masks_tuned = torch.sigmoid(masks_tuned - mobile_sam_model.mask_threshold)
    masks_orig = torch.sigmoid(masks_orig - mobile_sam_model.mask_threshold)

    print('In loop:')

    torch.cuda.empty_cache()
    print(torch.cuda.memory_allocated()/(1024**2)) #MB
    print(torch.cuda.memory_reserved()/(1024**2))
    
    images_from_bbox_tuned.append(masks_tuned.detach().cpu().numpy())
    images_from_bbox_orig.append(masks_orig.detach().cpu().numpy())

In [None]:
images_from_bbox_tuned

In [None]:
bbox_coords.keys()

In [None]:
# for plotting
# image = 255 - image

In [None]:
%matplotlib inline

import numpy.ma as ma
import matplotlib.colors as mcolors

def show_mask(masks, ax, random_color=False):
    for mask in masks:
        if random_color:
            color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
        else:
            color = np.array([30/255, 144/255, 255/255, 0.6])
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        ax.imshow(mask_image)
    
_, axs = plt.subplots(1, 3, figsize=(25, 25))

negative_mask_img = np.transpose((image>0).astype(int), (2, 0, 1))
masks_tuned = images_from_bbox_tuned
masks_orig = images_from_bbox_orig
axs[0].imshow(image)
show_mask(masks_tuned, axs[0])
axs[0].set_title(f'Mask with Mobile Tuned Model\n #epochs = {num_epochs}, lr={lr}, loss=dice_loss\n mask with sigmoid', fontsize=18)

axs[1].imshow(image)
show_mask(masks_orig, axs[1])
axs[1].set_title('Mask with Mobile Untuned Model', fontsize=18)

axs[2].imshow(negative_mask_img[0], cmap='gray')
axs[2].set_title('Negative pixels map', fontsize=18)
# plt.subplots_adjust(left=2, right=1, bottom=0, top=1)  # remove padding and distance to corners

plt.savefig('MobileSAM_output_one_image_sigmoid.png', dpi=400)

plt.show()
plt.close()

In [None]:
# %matplotlib inline

# import matplotlib.backends.backend_pdf

# pdf = matplotlib.backends.backend_pdf.PdfPages("fine_tuned_mobile_sam_images.pdf")

# i=0
# for k in images_from_bbox_tuned:
#     if i<50:
#         i+=1
#     else:
#         break
#     _, axs = plt.subplots(1, 2, figsize=(25, 25))
    
#     image = cv2.imread(f'{images_dir}{k}.png')
#     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#     masks_tuned = images_from_bbox_tuned[k]
#     masks_orig = images_from_bbox_orig[k]
    
#     axs[0].imshow(image)
#     display_masks(axs[0], masks_tuned)
#     # show_mask(mask, axs[0])
#     # show_box(input_bbox, axs[0])
#     axs[0].set_title('Mask with Tuned Model', fontsize=26)
#     axs[0].axis('off')

#     axs[1].imshow(image)
#     display_masks(axs[1],masks_orig)
#     # show_mask(mask, axs[1])
#     # show_box(input_bbox, axs[1])
#     axs[1].set_title('Mask with Untuned Model', fontsize=26)
#     axs[1].axis('off')
#     pdf.savefig(_)
#     plt.close(_)
# pdf.close()

## Compare the tuned model to the original SAM model

In [None]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
# !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

HOME = os.getcwd()
import os

origSAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(origSAM_CHECKPOINT_PATH, "; exist:", os.path.isfile(origSAM_CHECKPOINT_PATH))

origMODEL_TYPE = "vit_h"
device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[origMODEL_TYPE](checkpoint=origSAM_CHECKPOINT_PATH).to(device=device)
sam.eval();

In [None]:
predictor_origSAM = SamPredictor(sam)

In [None]:
images_from_bbox_orig_SAM = {}

# k=74
if True:
# for k in keys:
    # image_from_key = f'{k.split("_")[0]+"_"+k.split("_")[1]}'
    
    if k not in images_from_bbox_orig_SAM:
        images_from_bbox_orig_SAM[k] = []

    # image = cv2.imread(f'{images_dir}{k.split("_")[0]+"_"+k.split("_")[1]}.png')
    # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    input_bbox = np.array(bbox_coords[k])
    predictor_tuned.set_image(image)
    predictor_origSAM.set_image(image)

    with torch.no_grad():
        masks_orig, _, _ = predictor_origSAM.predict(
            point_coords=None,
            box=input_bbox,
            multimask_output=False,
        )

    images_from_bbox_orig_SAM[k].append(masks_orig)

In [None]:
### for k in images_from_bbox_tuned:
if True:
    _, axs = plt.subplots(1, 3, figsize=(25, 25))
    
    # image = cv2.imread(f'{images_dir}{k}.png')
    # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    negative_mask_img = np.transpose(image>0, (2, 0, 1))
    # image = 255 - image
    
    # masks_tuned = images_from_bbox_tuned[k]
    masks_orig = images_from_bbox_orig_SAM[k]
    axs[0].imshow(image)
    show_mask(masks_tuned, axs[0])
    axs[0].set_title(f'Mask with Mobile Tuned Model\n #epochs = {num_epochs}, lr={lr}, loss=dice_loss\n mask with sigmoid', fontsize=18)
    
    axs[1].imshow(image)
    show_mask(masks_orig, axs[1])
    axs[1].set_title('Mask with SAM Untuned Model', fontsize=26)
    
    axs[2].imshow(negative_mask_img[0], cmap='viridis')
    axs[2].set_title('Negative pixels map', fontsize=26)
    # plt.subplots_adjust(left=0, right=1, bottom=0, top=1)  # remove padding and distance to corners

    plt.savefig('origSAM_output_one_image.png')

    plt.show()
    plt.close()