In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q datasets
!pip install -q monai

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.6/536.6 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.3/38.3 MB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ibis-framework 7.1.0 requires pyarrow<15,>=2, but you have pyarrow 15.0.0 which is

In [None]:
from google.colab import drive
drive.mount('/content/drive')
! unzip /content/drive/MyDrive/YOSAW/data_sam_wound.zip

In [None]:
import torch
import numpy as np
import os
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import SamProcessor, SamModel
from torch.optim import Adam
import monai
from monai.losses import DiceCELoss
from torchvision.transforms import functional as F

In [None]:
def get_bounding_box(ground_truth_map):
    y_indices, x_indices = np.where(ground_truth_map > 0)

    if len(x_indices) == 0 or len(y_indices) == 0:
        return [0, 0, ground_truth_map.shape[1], ground_truth_map.shape[0]]

    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)

    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 10))
    x_max = min(W, x_max + np.random.randint(0, 10))
    y_min = max(0, y_min - np.random.randint(0, 10))
    y_max = min(H, y_max + np.random.randint(0, 10))

    bbox = [x_min, y_min, x_max, y_max]

    return bbox

In [None]:
class CustomTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, image, mask):
        seed = 50

        if image is not None:
            torch.manual_seed(seed)
            image = self.transform(image)

        if mask is not None:
            torch.manual_seed(seed)
            mask = self.transform(mask)

        return image, mask

class SAMDataset(Dataset):
    def __init__(self, images_folder, masks_folder, processor, img_size, train=True):
        self.images_folder = images_folder
        self.masks_folder = masks_folder
        self.processor = processor
        self.train = train
        self.img_size = img_size

        self.transform = CustomTransform(transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.2),
            transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(0.8, 1.2)),
            #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            #transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.1),
            #transforms.RandomApply([transforms.RandomPerspective(distortion_scale=0.5, p=1)], p=0.1),
            transforms.ToTensor(),
        ]))

        self.image_files = os.listdir(images_folder)
        self.mask_files = os.listdir(masks_folder)

        assert set(self.image_files) == set(self.mask_files), "Images - masks names issues"

        self.image_files.sort()
        self.mask_files.sort()

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        mask_file = self.mask_files[idx]

        image_path = os.path.join(self.images_folder, image_file)
        mask_path = os.path.join(self.masks_folder, mask_file)

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")


        image, mask = self.transform(image, mask)
        mask = (mask > 0.5).float()

        # Get bounding box from the transformed mask
        bbox = get_bounding_box(mask.squeeze().numpy())

        prompt = bbox
        inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["ground_truth_mask"] = mask.squeeze()

        return inputs

In [None]:
images_folder = "/content/data/train_images/"
masks_folder = "/content/data/train_masks/"
val_images_folder = "/content/data/val_images/"
val_masks_folder = "/content/data/val_masks/"

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel.from_pretrained("facebook/sam-vit-base")

train_dataset = SAMDataset(images_folder, masks_folder, processor,256,True)
val_dataset = SAMDataset(val_images_folder, val_masks_folder, processor,256,False)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, drop_last=False)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, drop_last=False)

# Freeze encoder
for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

# Initialize optimizer and loss function
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)

# Try DiceFocalLoss, DiceCELoss
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
# Define a function to calculate IoU
def calculate_iou(predicted_masks, ground_truth_masks):
    intersection = torch.logical_and(predicted_masks, ground_truth_masks).sum().item()
    union = torch.logical_or(predicted_masks, ground_truth_masks).sum().item()
    iou = intersection / union if union > 0 else 0
    return iou

In [None]:
num_epochs = 15
best_model_path = "best_model.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

best_iou = 0.0
train_losses = []
val_losses = []
train_ious = []
val_ious = []

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    epoch_losses = []
    epoch_ious = []
    for batch in tqdm(train_dataloader):
        # forward pass
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=False)

        # compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        # backward pass
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())

        # calculate IoU
        iou = calculate_iou(predicted_masks > 0.5, ground_truth_masks > 0.5)
        epoch_ious.append(iou)

    # Store training loss and IoU
    train_losses.append(np.mean(epoch_losses))
    train_ious.append(np.mean(epoch_ious))

    avg_train_loss = np.mean(epoch_losses)
    avg_train_iou = np.mean(epoch_ious)
    print(f"Training Loss: {avg_train_loss:.3f}, Training IoU: {avg_train_iou:.3f}")

    # Validation
    model.eval()
    val_losses_epoch = []
    val_ious_epoch = []
    with torch.no_grad():
        for val_batch in tqdm(val_dataloader):
            # forward pass
            val_outputs = model(pixel_values=val_batch["pixel_values"].to(device),
                                input_boxes=val_batch["input_boxes"].to(device),
                                multimask_output=False)

            # compute loss
            val_predicted_masks = val_outputs.pred_masks.squeeze(1)
            val_ground_truth_masks = val_batch["ground_truth_mask"].float().to(device)
            val_loss = seg_loss(val_predicted_masks, val_ground_truth_masks.unsqueeze(1))
            val_losses_epoch.append(val_loss.item())

            # calculate IoU
            iou = calculate_iou(val_predicted_masks > 0.5, val_ground_truth_masks > 0.5)
            val_ious_epoch.append(iou)

    # Store validation loss and IoU
    val_losses.append(np.mean(val_losses_epoch))
    val_ious.append(np.mean(val_ious_epoch))

    avg_val_loss = np.mean(val_losses_epoch)
    avg_val_iou = np.mean(val_ious_epoch)
    print(f"Validation Loss: {avg_val_loss:.3f}, Validation IoU: {avg_val_iou:.3f}")

    # Save the model if it has the best IoU
    if avg_val_iou > best_iou:
        best_iou = avg_val_iou
        torch.save(model.state_dict(), best_model_path)
        print("Best model saved.")