In [None]:
import os
import cv2
import monai
import numpy as np
from PIL import Image
from tqdm import tqdm
from statistics import mean
from typing import Dict, List, Tuple

import torch
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.functional import threshold, normalize

from transformers import SamProcessor
from transformers import SamModel

os.environ["CUDA_VISIBLE_DEVICES"] = "2"


In [None]:
from typing import Any


def get_disturbed_bounding_box(ground_truth_map: np.ndarray) -> List:
    # get bounding box from mask
    y_indices, x_indices = np.where(ground_truth_map > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    # add perturbation to bounding box coordinates
    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    bbox = [x_min, y_min, x_max, y_max]
    return bbox


def erode_mask(mask: np.ndarray, kernel_size: int = 3) -> np.ndarray:
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    mask = cv2.erode(mask, kernel, iterations=1)
    return mask


def get_point_prompt_bymask(mask: np.ndarray) -> np.ndarray:
    nb = 100
    mask = erode_mask(mask)
    bg_points = np.where(mask == 0)
    bg_points = [list(i) for i in bg_points]
    bg_points = list(zip(bg_points[1], bg_points[0]))
    assert len(bg_points) > 0
    indices = np.random.choice(
        np.arange(len(bg_points)),
        size=nb, replace=False) \
        if len(bg_points) >= nb else \
        np.random.choice(
            np.arange(len(bg_points)),
            size=nb, replace=True)
    bg_points = np.array(bg_points)[indices]
    bg_label = np.zeros(nb,)

    fg_points = np.where(mask != 0)
    fg_points = [list(i) for i in fg_points]
    fg_points = list(zip(fg_points[1], fg_points[0]))
    if len(fg_points) > 0:
        indices = np.random.choice(
            np.arange(len(fg_points)),
            size=nb, replace=False) \
            if len(fg_points) >= nb else \
            np.random.choice(
                np.arange(len(fg_points)),
                size=nb, replace=True)
        fg_points = np.array(fg_points)[indices]
        fg_label = np.ones(nb,)
        pmt_points = [[np.vstack([fg_points, bg_points]).tolist()]]
        pmt_labels = [[np.hstack([fg_label, bg_label]).tolist()]]
    else:
        fg_points = []
        fg_label = []
        pmt_points = [[np.vstack([bg_points, bg_points]).tolist()]]
        pmt_labels = [[np.hstack([bg_label, bg_label]).tolist()]]
    # pmt_points = np.vstack([fg_points, bg_points])
    # pmt_labels = np.hstack([fg_label, bg_label])
    return {'input_points': pmt_points,
            'input_labels': pmt_labels}

class PSVDataset(Dataset):
    def __init__(self,
                 root='/home/jiashuo/workspace/datasets/parking_slots/PSV dataset/',
                 split='test') -> None:
        super().__init__()
        label_path = os.path.join(root, f'{split}.txt')
        with open(label_path, 'r') as f:
            samples = f.readlines()
        self.samples = [sample.strip() for sample in samples]
        self.image_root = os.path.join(root, 'images', split)
        self.label_root = os.path.join(root, 'labels', split)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index: int) -> Dict[str, np.ndarray]:
        sample_id = self.samples[index]
        image = cv2.imread(os.path.join(self.image_root, f'{sample_id}.jpg'))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = self.get_mask(sample_id)
        # resize
        image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_LINEAR)
        return {'image': image, 'label': mask}

    def get_mask(self, sample_id):
        mask_path = os.path.join(self.label_root, f'{sample_id}.png')
        mask = np.array(Image.open(mask_path)).astype(np.uint8)
        # eliminate class impact
        mask = np.bool_(mask).astype(np.uint8)
        return mask


class SAMDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]
        ground_truth_mask = np.array(item["label"])

        # get prompt by erosion
        prompt = get_point_prompt_bymask(ground_truth_mask)

        # prepare image and prompt for the model
        inputs = self.processor(image, **prompt, return_tensors="pt")

        # remove batch dimension which the processor adds by default
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}

        # add ground truth segmentation
        inputs["ground_truth_mask"] = ground_truth_mask

        return inputs


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SamModel.from_pretrained("facebook/sam-vit-huge", mirror='tuna').to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge", mirror='tuna')

In [None]:
psv_set = PSVDataset(split='train')
train_dataset = SAMDataset(dataset=psv_set,
                           processor=processor)

# for i in range(len(train_dataset)):
#     example = train_dataset[i]
#     # for k, v in example.items():
#     #     print(k, v.shape)
# print(wrong)

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
# for batch in tqdm(train_dataloader):
#   pass
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)
outputs = model(pixel_values=batch["pixel_values"].to(device),
                input_points=batch["input_points"].to(device),
                input_labels =batch["input_labels"].to(device),
                multimask_output=False)
print(outputs.keys())
print(outputs.pred_masks)
print(batch["ground_truth_mask"].float())

In [None]:
import segmentation_models_pytorch as smp

from transformers.models.maskformer.modeling_maskformer import dice_loss, sigmoid_focal_loss


def postprocess_masks(masks: torch.Tensor,
                      input_size: Tuple[int, ...],
                      original_size: Tuple[int, ...], image_size=1024) -> torch.Tensor:
    """
    Remove padding and upscale masks to the original image size.

    Args:
      masks (torch.Tensor):
        Batched masks from the mask_decoder, in BxCxHxW format.
      input_size (tuple(int, int)):
        The size of the image input to the model, in (H', W') format. Used to remove padding.
      original_size (tuple(int, int)):
        The original size of the image before resizing for input to the model, in (H, W) format.

    Returns:
      (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
        is given by original_size.
    """
    masks = F.interpolate(
        masks,
        (image_size, image_size),
        mode="bilinear",
        align_corners=False,
    )
    masks = masks[..., : input_size[0], : input_size[1]]
    masks = F.interpolate(masks, original_size,
                          mode="bilinear", align_corners=False)
    return masks

print(outputs.keys())
# [bt_size, nb_predictions, nb_per_pred, H, W]
# [bt_size, 1, H, W]
low_res_masks = outputs.pred_masks
upscaled_masks = postprocess_masks(
    low_res_masks.squeeze(1),
    batch["reshaped_input_sizes"][0].tolist(),
    batch["original_sizes"][0].tolist()).to(device)

'process upscaled masks'
# compute iou by thresholding
predicted_masks = torch.sigmoid(upscaled_masks)
# predicted_masks = \
#             threshold(upscaled_masks, 0.0, 0)
# predicted_masks = normalize(
#             threshold(upscaled_masks, 0.0, 0))
print(predicted_masks.shape)
gt_mask = batch["ground_truth_mask"].to(device)
print('mask.shape:    ', predicted_masks.shape)
print('gt_mask.shape: ', gt_mask.shape)
batch_tp, batch_fp, batch_fn, batch_tn = smp.metrics.get_stats(
    predicted_masks,
    gt_mask.unsqueeze(1),
    mode='binary',
    threshold=0.5,
)
batch_iou = smp.metrics.iou_score(batch_tp, batch_fp, batch_fn, batch_tn)
print('iou_scores.shape: ', outputs.iou_scores.shape)
print('batch_iou.shape:  ', batch_iou.shape)
loss_iou = F.mse_loss(outputs.iou_scores.squeeze(1), 
                      batch_iou, reduction='mean')
print('batch_tp : ',  batch_tp.data)
print('batch_fp : ',  batch_fp.data)
print('batch_fn : ',  batch_fn.data)
print('batch_tn : ',  batch_tn.data)
print('batch_iou: ', batch_iou.data)
print('loss_iou : ', loss_iou.data)

# compute focal and dice loss
mask_logits = upscaled_masks.flatten(1)
gt_mask_logits = gt_mask.flatten(1)
nb_masks = mask_logits.shape[0]
print('mask_logits.shape   : ', mask_logits.shape)
print('gt_mask_logits.shape: ', gt_mask_logits.shape)
loss_focal = sigmoid_focal_loss(mask_logits, gt_mask_logits.float(), nb_masks)
loss_dice = dice_loss(mask_logits, gt_mask_logits.float(), nb_masks)
print('loss_focal: ', loss_focal.data)
print('loss_dice : ', loss_dice.data)

def criterion_mse(outputs, gt_mask, batch):
    low_res_masks = outputs.pred_masks
    upscaled_masks = postprocess_masks(
        low_res_masks.squeeze(1),
        batch["reshaped_input_sizes"][0].tolist(),
        batch["original_sizes"][0].tolist()).to(gt_mask.device)
    predicted_masks = normalize(threshold(upscaled_masks, 0.0, 0))
    loss = torch.nn.MSELoss(reduction='mean')(predicted_masks, gt_mask.unsqueeze(1))
    print(predicted_masks.shape, gt_mask.shape, loss.data)
    return loss


def criterion_mde(outputs, gt_mask, batch):
    low_res_masks = outputs.pred_masks
    upscaled_masks = postprocess_masks(
        low_res_masks.squeeze(1),
        batch["reshaped_input_sizes"][0].tolist(),
        batch["original_sizes"][0].tolist()).to(gt_mask.device)
    seg_loss = monai.losses.DiceCELoss(sigmoid=True,
                                       squared_pred=True,
                                       reduction='mean')
    loss = seg_loss(upscaled_masks,
                    gt_mask.unsqueeze(1))
    return loss


def criterion_sam(outputs, gt_mask, batch):
    low_res_masks = outputs.pred_masks
    upscaled_masks = postprocess_masks(
        low_res_masks.squeeze(1),
        batch["reshaped_input_sizes"][0].tolist(),
        batch["original_sizes"][0].tolist()).to(gt_mask.device)
    'process upscaled masks'
    '''Compute iou by thresholding
    predicted_masks = \
                threshold(upscaled_masks, 0.0, 0)
    predicted_masks = normalize(
                threshold(upscaled_masks, 0.0, 0))
    '''
    predicted_masks = torch.sigmoid(upscaled_masks)
    batch_tp, batch_fp, batch_fn, batch_tn = smp.metrics.get_stats(
        predicted_masks,
        gt_mask.unsqueeze(1),
        mode='binary',
        threshold=0.5,
    )
    batch_iou = smp.metrics.iou_score(batch_tp, batch_fp, batch_fn, batch_tn)
    loss_iou = F.mse_loss(outputs.iou_scores.squeeze(1), 
                        batch_iou, reduction='mean')
    # compute focal and dice loss
    mask_logits = upscaled_masks.flatten(1)
    gt_mask_logits = gt_mask.flatten(1).float()
    nb_masks = mask_logits.shape[0]
    loss_focal = sigmoid_focal_loss(mask_logits, gt_mask_logits, nb_masks)
    loss_dice = dice_loss(mask_logits, gt_mask_logits, nb_masks)
    return loss_iou + loss_focal * 20. + loss_dice

print(criterion_sam(outputs, gt_mask, batch))
print(criterion_mde(outputs, gt_mask, batch))
print(criterion_mse(outputs, gt_mask, batch))


In [None]:
low_res_masks = outputs.pred_masks
upscaled_masks = postprocess_masks(
    low_res_masks.squeeze(1),
    batch["reshaped_input_sizes"][0].tolist(),
    batch["original_sizes"][0].tolist())
mask_prob = torch.sigmoid(upscaled_masks)
# convert soft mask to hard mask
mask_prob = mask_prob.cpu().detach().squeeze(1)
sam_mask = (mask_prob > 0.5).to(torch.uint8)
mask_prob.shape
# iou = mask_iou(sam_mask.numpy(), gt_mask.numpy())
# dsc = DSC(sam_mask.numpy(), gt_mask.numpy())