In [None]:
import cv2
import torch
import torch.nn as nn

import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, set = 'train'):
        """
        Args:
            root_dir (string): Directory with all the images and masks.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.images_dir = os.path.join(root_dir, 'images')
        self.masks_dir = os.path.join(root_dir, 'masks')
        self.image_names = [f for f in os.listdir(self.images_dir)]
        # order
        self.image_names.sort()
        np.random.seed(0)
        np.random.shuffle(self.image_names)
        self.set = set
        # 60% train 20% val and 20% test
        if set == 'train':
            self.image_names = self.image_names[:int(len(self.image_names)*0.6)]
        elif set == 'val':
            self.image_names = self.image_names[int(len(self.image_names)*0.6):int(len(self.image_names)*0.8)]
        elif set == 'test':
            self.image_names = self.image_names[int(len(self.image_names)*0.8):]
        else:
            raise ValueError('set must be "train", "val" or "test"')

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.image_names[idx])
        mask_name = os.path.join(self.masks_dir, f'mask_{self.image_names[idx].split("_")[1]}')  # Adjust based on your naming convention
        image = cv2.imread(img_name)
        mask = cv2.imread(mask_name)
        idx = self.image_names[idx].split('_')[1]
        idx = int(idx.split('.')[0])
        
        # # resize both image and mask to 256x256
        image = cv2.resize(image, (256, 256))[:, :, :3]
        mask = cv2.resize(mask, (256, 256))
        # mask of size 256x256
        mask = mask[:, :, 0]
        # Convert mask to binary

        sample = {'image': image, 'mask': mask, 'idx': idx}

        return sample

In [None]:
dataset = ImageDataset('new_dataset', set='test')

In [None]:
print(dataset[0]['image'].shape)
print(dataset[0]['mask'].shape)
print(dataset[0]['idx'])

In [None]:
def get_bounding_box(ground_truth_mask):
  # get bounding box from mask
  y_indices, x_indices = np.where(ground_truth_mask > 0)
  x_min, x_max = np.min(x_indices), np.max(x_indices)
  y_min, y_max = np.min(y_indices), np.max(y_indices)
  # add perturbation to bounding box coordinates
  # H, W = ground_truth_mask.shape
  # x_min = max(0, x_min - np.random.randint(0, 20))
  # x_max = min(W, x_max + np.random.randint(0, 20))
  # y_min = max(0, y_min - np.random.randint(0, 20))
  # y_max = min(H, y_max + np.random.randint(0, 20))
  bbox = [x_min, y_min, x_max, y_max]

  return bbox

In [None]:
def generate_mask(dataset, idx, processor, model, device):
    image = dataset[idx]['image']
    mask = dataset[idx]['mask']
    
    box = get_bounding_box(mask)
    
    inputs = processor(image, input_boxes=[[box]], return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
    # choose mask with the highest confidence
    mask = masks[0].cpu().numpy()
    mask = np.moveaxis(mask.squeeze(), 0, -1).astype(np.float32)
    return mask

In [None]:
def iou(mask1, mask2):
    """
    Calculate the Intersection over Union (IoU) of two binary masks.

    Parameters:
        mask1 (np.array): First binary mask.
        mask2 (np.array): Second binary mask.

    Returns:
        float: IoU score.
    """
    # Ensure that the masks are boolean arrays
    mask1 = mask1.astype(bool)
    mask2 = mask2.astype(bool)

    # Intersection and Union calculations
    intersection = np.logical_and(mask1, mask2)
    union = np.logical_or(mask1, mask2)
    iou_score = np.sum(intersection) / np.sum(union)

    return iou_score

In [None]:
from transformers import SamModel, SamProcessor
import torch
model = SamModel.from_pretrained("finetuned_sam_model")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
print(len(dataset))

In [None]:
# calculate the IoU for all the images
path = 'finetuned_sam_masks/'
ious = []
sam_masks = []

for i in range(len(dataset)):
    with torch.no_grad():
        mask = generate_mask(dataset, i, processor, model, device)
        idx = dataset[i]['idx']
        sam_masks.append(mask[:,:,0])
        ious.append(iou(mask[:,:,0], dataset[i]['mask']))
        plt.imsave(f'{path}sam_mask_{idx}.png', mask[:,:,0], cmap='gray')

In [None]:
len(ious)

In [None]:
np.mean(ious)

In [None]:
# make a csv with idx and iou
import pandas as pd
df = pd.DataFrame({'idx': [dataset[i]['idx'] for i in range(len(dataset))], 'iou': ious})
df.to_csv('finetuned_sam_iou.csv', index=False)