In [None]:
import torch
import torchvision
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
import json
from whuvid_dataset import WhuvidDataset
import torchvision.transforms as transforms
import time

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import tqdm.notebook as tqdm

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax, color='green'):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=2)) 

# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

# sam_checkpoint = "/home/thiago/Workspace/motion-segmentation/models/sam_vit_h_4b8939.pth"
# model_type = "vit_h"
# sam_checkpoint = "/home/thiago/Workspace/motion-segmentation/models/sam_vit_l_0b3195.pth"
# model_type = "vit_l"
sam_checkpoint = "/home/thiago/Workspace/motion-segmentation/models/sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

In [None]:
# sequences = ["01", "17", "18", "19", "20", "22", "25", "30", "31", "32"]
sequences = ["03"]
# 16 missing
# sequence = sequences[0]
root_path = "/home/thiago/Workspace/motion-segmentation/datasets/WHUVID"

In [None]:
# width = int(1280 * 0.5)
# height = int(720 * 0.5)
width = 1280
height = 720

In [None]:
def to_list_of_boxes(bb_raw):
    global width, height
    bb_of_cars = [x['relative_coordinates'] for x in bb_raw if x['name'] == "car"]
    # to pixels
    bb_of_cars_pixel = [{'center_x': x['center_x'] * width, 'center_y': x['center_y'] * height,
                        'width': x['width'] * width, 'height': x['height'] * height} for x in bb_of_cars]
    boxes = []
    for bb in bb_of_cars_pixel:
        x0, y0 = bb['center_x'] - bb['width'] / 2, bb['center_y'] - bb['height'] / 2
        x1, y1 = x0 + bb['width'], y0 + bb['height']
        boxes.append([x0, y0, x1, y1])
    return boxes

In [None]:
def get_data(sequence):
    whuvid_base_path = "/home/thiago/Workspace/motion-segmentation/datasets/WHUVID"
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    whuvid_dataset = WhuvidDataset(whuvid_base_path, [sequence], transform, segmentation=False, flow=False)
    # cam_path = os.path.join(path, "cam0")
    # images = []
    objects_inferred_path = f"/home/thiago/Workspace/motion-segmentation/datasets/WHUVID/{sequence}/other_files/objects_inferred.json"
    with open(objects_inferred_path) as f:
        objects_inferred = json.load(f)
    bounding_boxes = {}
    for obj_id, frames in objects_inferred.items():
        for frame_id, ann in frames.items():
            frame_id = int(frame_id)
            if frame_id not in bounding_boxes:
                bounding_boxes[frame_id] = []
            bounding_boxes[frame_id].append(ann)
    return whuvid_dataset.images, bounding_boxes

In [None]:
def gen_mask(image, boxes):
    if len(boxes) == 0:
        return torch.zeros(1, image.shape[0], image.shape[1], device=predictor.device)
    input_boxes = torch.tensor(boxes, device=predictor.device)
    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
    predictor.set_image(image)
    with torch.no_grad():
        masks, _, _ = predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,
        )
    # mask has shape (n, 1, h, w), where n is the number of masks
    # values are boolean
    # change to (1, h, w) by joining all masks
    mask = masks.sum(dim=0, keepdim=True)
    # limit to 0 or 1
    mask = mask.clamp(0, 1)
    return mask

def to_mask(img_path, boxes_and_labels):
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    width, height = image.shape[1], image.shape[0]
    image = cv2.resize(image, (width, height))
    boxes1 = [x[0] for x in boxes_and_labels if x[1] == True]
    boxes2 = [x[0] for x in boxes_and_labels if x[1] == False]
    # set any value lower than 0 to 0
    boxes1 = [[max(0, x) for x in box] for box in boxes1]
    boxes2 = [[max(0, x) for x in box] for box in boxes2]
    # labels = [x[1] for x in boxes_and_labels]
    mask1 = gen_mask(image, boxes1)
    mask2 = gen_mask(image, boxes2)
    # if they overlap, the last one will be shown
    mask = 255 * mask1 + 128 * mask2
    # clamp to 0-255
    mask = mask.clamp(0, 255)
    return mask

In [None]:
# shift bounding boxes in pixels
def shift_bb(bb, horizontal, vertical):
    return [[x[0] + horizontal, x[1] + vertical, x[2] + horizontal, x[3] + vertical] for x in bb]

In [None]:
def increase_bb(bb, horizontal, vertical):
    horizontal = int(horizontal)
    vertical = int(vertical)
    increased = [[x[0] - horizontal, x[1] - vertical, x[2] + horizontal, x[3] + vertical] for x in bb]
    # limit to image size
    increased = [[max(0, x[0]), max(0, x[1]), min(width, x[2]), min(height, x[3])] for x in increased]
    return increased

In [None]:
i = 30
seq_example = "22"
# example_path = os.path.join(root_path, seq_example)
images, bounding_boxes = get_data(seq_example)
img_path = images[i]
# boxes = bounding_boxes[i]
boxes = bounding_boxes[i]
percent = 0.01
# boxes = increase_bb(bounding_boxes[i], percent * width, percent * height)
# boxes = increase_bb(bounding_boxes[i], 8, 8)
# boxes = shift_bb(boxes, 10, -20)
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (width, height))
mask = to_mask(img_path, boxes)
# in tensor
label1 = torch.tensor(mask >= 255, dtype=torch.uint8)
# lower than 255 and higher than 128
label2 = torch.tensor((mask < 255) & (mask >= 128), dtype=torch.uint8)
# show_mask(mask.cpu().numpy(), plt.gca())
# input_boxes = torch.tensor(boxes, device=predictor.device)
image_bb = image.copy()
for item in boxes:
    box = torch.tensor(item[0], device=predictor.device)
    is_moving = item[1]
    # use cv
    x0, y0, x1, y1 = box.cpu().numpy().astype(int)
    color = (0, 0, 255) if is_moving else (0, 255, 0)
    cv2.rectangle(image_bb, (x0, y0), (x1, y1), color, 3)
# show image_bb, label1, label2
fig, ax = plt.subplots(1, 3, figsize=(20, 10))
# add legend: Image, Moving and Static
ax[0].set_title("Imagem com bounding box")
ax[1].set_title("Mascara de objetos em movimento")
ax[2].set_title("Mascara de objetos estaticos")
ax[0].imshow(image_bb)
ax[1].imshow(label1.cpu().numpy().squeeze(), cmap='gray')
ax[2].imshow(label2.cpu().numpy().squeeze(), cmap='gray')
plt.show()

path = os.path.join(root_path, '22')
cam_path = os.path.join(path, "cam0")
images, bounding_boxes = get_data(path)

bounding_boxes[200]

In [None]:
for seq in tqdm.tqdm(sequences):
    path = os.path.join(root_path, seq)
    cam_path = os.path.join(path, "cam0")
    mask_dir = cam_path + "_masks_ann"
    if not os.path.exists(mask_dir):
        os.makedirs(mask_dir)
    images, bounding_boxes = get_data(seq)
    # generate masks for every image and save them
    # get in batches of 2
    for i in tqdm.tqdm(range(len(images))):
        basename = os.path.basename(images[i])
        mask_path = os.path.join(mask_dir, basename)
        # if os.path.exists(mask_path):
        #     continue
        if i in bounding_boxes:
            try:
                mask = to_mask(images[i], bounding_boxes[i])
                # mask = to_mask(images[i], shift_bb(bounding_boxes[i], 10, -20))
            except Exception as e:
                print("Skipping image ", images[i], " because of error:", e)
                continue
            mask = mask.cpu().numpy()
            mask = mask.astype(np.uint8)
            # mask = mask * 255
            mask = mask.squeeze()
        else:
            mask = np.zeros((height, width), dtype=np.uint8)
        cv2.imwrite(mask_path, mask)