In [None]:
import numpy as np
import os
join = os.path.join
from tqdm import tqdm
import torch
import monai
import cv2
from PIL import Image
import torchvision
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor


In [None]:
def get_bbox_from_mask(mask):
    '''Returns a bounding box from a mask'''
    # Check if the mask contains any non-zero values
    if np.count_nonzero(mask) == 0:
        # If the mask is entirely black, return a default bounding box
        return np.array([0, 0, 1, 1])  # Default bounding box with width and height of 1
    
    # Get indices of non-zero elements
    y_indices, x_indices = np.where(mask > 0)
    # Compute bounding box
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    # Add perturbation to bounding box coordinates
    H, W = mask.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))

    # Ensure positive width and height
    width = max(1, x_max - x_min)
    height = max(1, y_max - y_min)

    return np.array([x_min, y_min, x_min + width, y_min + height])


In [None]:
class CustomDataset(Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "labels"))))
        print(f'imgs len: {len(self.imgs)}, masks len: {len(self.masks)}')
    
    def __getitem__(self, idx):
        # Load images and masks
        img_path = os.path.join(self.root, "images", self.imgs[idx])
        mask_path = os.path.join(self.root, "labels", self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)
        mask = np.array(mask)
        # Get bounding box coordinates for each mask
        bbox = get_bbox_from_mask(mask)
        # Convert everything into a torch.Tensor
        bbox = torch.as_tensor([bbox], dtype=torch.float32)
        labels = torch.ones((1,), dtype=torch.int64)  # assuming only one class
        
        if self.transforms is not None:
            img = self.transforms(img)
        
        target = {}
        target["boxes"] = bbox
        target["labels"] = labels
        
        return img, target
    
    def __len__(self):
        return len(self.imgs)

# Transformations
def get_transform():
    def transform(img):
        return F.to_tensor(img)
    return transform

# Initialize dataset and dataloader
dataset = CustomDataset(root="data/MedSAMDemo_2D/train", transforms=get_transform())
data_loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

for i, (img, target) in enumerate(data_loader):
    if i==0:
        print(img[0].shape)


In [None]:
def get_model(num_classes):
    # load a model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 

    return model

# Assuming only one class (apart from background), num_classes = 2 (class + background)
model = get_model(num_classes=2)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Move model to the right device
model.to(device)
model.train()


# Parameters
num_epochs = 10
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

lowest_loss = float('inf')
best_model_path = 'bbox_model/best_bbox_model.pth'  # Path to save the best model

for epoch in range(num_epochs):
    for step, (imgs, targets) in enumerate(tqdm(data_loader)):
        imgs = list(img.to(device) for img in imgs)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(imgs, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
    if losses.item() < lowest_loss:
        lowest_loss = losses.item()
        # Save the model's parameters
        torch.save(model.state_dict(), best_model_path)

    print(f"Epoch {epoch}: Loss: {losses.item()}")