In [1]:
import torch
import torchvision
from torchvision import transforms as T
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import cv2
from collections import defaultdict
import numpy as np
import math
from operator import itemgetter
import copy

In [2]:
# from google.colab.patches import cv2_imshow

In [3]:
# from google.colab import drive
# drive.mount('/content/drive')

In [4]:
model = torchvision.models.detection.maskrcnn_resnet50_fpn()
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features , 2)
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask , hidden_layer , 2)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
model.to(device)

MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=1e-05)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=1e-05)
          (relu):

In [7]:
params = [p for p in model.parameters() if p.requires_grad]

In [8]:
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

In [9]:
PATH = f'./model_120.pth'
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch'] + 1
all_train_losses = checkpoint['all_train_losses']
all_val_losses = checkpoint['all_val_losses']
model.eval()

MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=1e-05)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=1e-05)
          (relu):

In [10]:
def minimum_distance(item, class_items):
    min_dist = 200
    min_item = []
    for class_item in class_items:
        if math.sqrt(math.pow(class_item["x"] - item["x"], 2) + math.pow(class_item["y"] - item["y"], 2)) < min_dist:
            min_dist = math.sqrt(math.pow(class_item["x"] - item["x"], 2) + math.pow(class_item["y"] - item["y"], 2))
            min_item = copy.deepcopy(class_item)
    return min_item

In [11]:
def find_track_id(item, class_id_item, track_history):
    for key, class_items in track_history.items():
        if key == class_id_item:
            min_item = minimum_distance(item, class_items)
            return min_item

In [12]:
video_path = "WIN_20231022_17_21_16_Pro.mp4"

cap = cv2.VideoCapture(video_path)

if cap.isOpened():
    success, frame = cap.read()

    (h, w) = frame.shape[:2]
    writer = cv2.VideoWriter('Video_output_mask.mp4', -1, 15.0, (w, h), True)
    writer1 = cv2.VideoWriter('Video_output_color_mask.mp4', -1, 15.0, (w, h), True)

    track_history = defaultdict(lambda: [])

    count = 0
    track_id_count = 0

    while cap.isOpened():
        success, frame = cap.read()

        if success:
            count +=1
            transform = T.ToTensor()
            ig = transform(frame)
            with torch.no_grad():
                pred = model([ig.to(device)])

            mask = (pred[0]["masks"][0].cpu().detach().numpy() * 255).astype("uint8").squeeze()
            mask[mask >= 150] = 150
            mask[mask < 150] = 0
            mask[mask == 150] = 255
            cv2.imshow("Mask" , mask)
            writer.write(mask)

        
            mask1 = copy.deepcopy(mask)
            mask1 = np.dstack((mask1, mask1, mask1))
            mask1[mask1 < 255] = 0
            mask1[mask1 >= 255] = 1
            color_mask = np.multiply(frame, mask1)
            cv2.imshow("Color mask" , color_mask)
            writer1.write(color_mask)

            if cv2.waitKey(30) & 0xFF == ord("q"):
                break
        else:
            break

cap.release()
cv2.destroyAllWindows()
writer.release()
writer1.release()

In [13]:
# mask2 = mask1.reshape((mask1.shape[0], mask1.shape[1], 1)).shape
# cv2.imshow("Color mask" , mask2)

In [14]:
# print(mask.shape)
# mask2 = np.dstack((mask, mask, mask))
# print(mask2.shape)

In [15]:
# mask2 = np.expand_dims(mask, axis=-1)
# print(mask2.shape)
# mask2[mask2 < 255] = 0
# mask2[mask2 >= 255] = 1
# color_mask = np.multiply(frame, mask2)
# cv2.imshow("Color mask" , color_mask)
# cv2.waitKey(0)
# cv2.destroyAllWindows()

In [16]:
# import matplotlib.pyplot as plt

# plt.hist(mask, bins=2)
# plt.show()

In [17]:
# pred[0]['boxes'], pred[0]['labels'], pred[0]['masks']