In [4]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
#Transformers
!pip install -q git+https://github.com/huggingface/transformers.git
#Datasets to prepare data and monai if you want to use special loss functions
!pip install datasets
!pip install numpy
!pip install -q monai
#Patchify to divide large images into smaller patches for training. (Not necessary for smaller images)
!pip install patchify
!pip install scipy

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /private/var/folders/b2/jmn732l52_bfvbd4bjpcmydm0000gn/T/pip-req-build-4jkuwyyy
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /private/var/folders/b2/jmn732l52_bfvbd4bjpcmydm0000gn/T/pip-req-build-4jkuwyyy
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588
  Preparing metadata (setup.py) ... [?25ldone
Collecting scipy
  Using cached scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl.metadata (60 kB)
Using cached scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl (30.4 MB)
Installing collected packages: scipy
Successfully installed scipy-1.13.1


In [72]:
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image 
from patchify import patchify  #Only to handle large images
import random
from scipy import ndimage

In [73]:
large_images = Image.open("/Users/amin/Desktop/higharc/Datasets/unlabeled/data_pulte/pulte/floorplans/4a89fe58-ffb8-4f01-bbc6-dc22d6b74b86_1_.png")
large_masks = Image.open("/4a89fe58-ffb8-4f01-bbc6-dc22d6b74b86_1_.png")

FileNotFoundError: [Errno 2] No such file or directory: '/4a89fe58-ffb8-4f01-bbc6-dc22d6b74b86_1_.png'

In [183]:
from torch.utils.data import Dataset
import cv2
import json


class SAMDataset(Dataset):
    """
    This class is used to create a dataset that serves input images and masks.
    It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
    """
    def __init__(self, images_base_url, masks_base_url, json_url, processor):
        self.images_base_url = images_base_url
        self.masks_base_url = masks_base_url
        self.processor = processor
        self.json_url = json_url
        with open(self.json_url) as file:
            data = json.load(file)
        self.categories = data['categories']
        self.images = data['images']
        self.annotations = data['annotations']

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        image_file_name = ''
        for image_info in self.images:
            if image_info['id'] == annotation['image_id']:
                image_file_name = image_info['file_name']
                
        
        image = Image.open(os.path.join(self.images_base_url, image_file_name))
        mask = Image.open(os.path.join(self.masks_base_url, image_file_name))

        image = np.array(image)
        ground_truth_mask = np.array(mask)
        
        original_width, original_height, _ = image.shape
        
        new_height = 1024
        new_width = 1024

        # Scaling factors
        scale_x = new_width / original_width
        scale_y = new_height / original_height
        
        scaled_bbox = annotation['bbox']
        scaled_bbox[0] *= scale_x  # Scale x-coordinates
        scaled_bbox[2] *= scale_x  # Scale x-coordinates
        scaled_bbox[1] *= scale_y  # Scale y-coordinates
        scaled_bbox[3] *= scale_y  # Scale y-coordinates
        # get bounding box prompt
        scaled_bbox = np.array(scaled_bbox)

        prompt = scaled_bbox
    
        # prepare image and prompt for the model
        image = cv2.resize(image, (new_width, new_height))
        ground_truth_mask = cv2.resize(ground_truth_mask, (new_width, new_height))

        inputs = self.processor(image, input_boxes=[[[prompt]]], return_tensors="pt")

        # remove batch dimension which the processor adds by default
        # for k,v in inputs.items():
        #     print(k,v.shape)
        inputs = {k:v.squeeze(0) for k,v in inputs.items()}

        # add ground truth segmentation
        inputs["ground_truth_mask"] = ground_truth_mask

        return inputs

In [184]:
from transformers import SamProcessor
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [185]:
images_base_url = "/Users/amin/Desktop/higharc/Datasets/unlabeled/data_pulte/pulte/floorplans"
masks_base_url = "/Users/amin/Desktop/higharc/Datasets/unlabeled/data_pulte/pulte/panoptic_semseg_maskdino_augmented_floorplans"
json_url = "/Users/amin/Desktop/higharc/Datasets/unlabeled/data_pulte/pulte/floorplans/_annotations.coco.json"

In [186]:
train_dataset = SAMDataset(images_base_url=images_base_url, masks_base_url=masks_base_url, json_url=json_url, processor=processor)

In [189]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True, drop_last=False)

In [190]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)
     

pixel_values torch.Size([3, 3, 1024, 1024])
original_sizes torch.Size([3, 2])
reshaped_input_sizes torch.Size([3, 2])
input_boxes torch.Size([3, 1, 4])
ground_truth_mask torch.Size([3, 1024, 1024])


In [191]:
from transformers import SamModel
model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

In [192]:
from torch.optim import AdamW
import monai
# Initialize the optimizer and the loss function
optimizer = AdamW(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
#Try DiceFocalLoss, FocalLoss, DiceCELoss
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [193]:
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

#Training loop
num_epochs = 1

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
        # forward pass
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=False)

        # compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        # backward pass (compute gradients of parameters w.r.t. loss)
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

  0%|          | 0/8688 [00:00<?, ?it/s]