# Imports

In [1]:
import os
import cv2
import utils
import pickle
from utils import MetricLogger
import numpy as np
import torch
from torch.optim.lr_scheduler import StepLR
import torchvision
from torchvision.utils import draw_bounding_boxes
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
import torch.utils.data
import transforms as T
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from engine import train_one_epoch, evaluate
from tqdm.notebook import tqdm

#### define dataset

In [2]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, image_path, body_path, transform=True):
        self.transform = transform
        # load data
        self.filenames = [path for path in sorted(os.listdir(image_path))]
        self.image_paths = [os.path.join(image_path, path) for path in sorted(os.listdir(image_path))]
        self.body_paths = [os.path.join(body_path, path) for path in sorted(os.listdir(body_path))]
        
    def __getitem__(self, idx):
        # load filename
        filename = self.filenames[idx]
        
        # load images and masks
        image = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        body = cv2.imread(self.body_paths[idx])
        skin = cv2.cvtColor(body, cv2.COLOR_RGB2GRAY)
        
        # save image copy
        image_no_mask = image.copy()
        
        # apply body mask
        image = Image.fromarray(image*body)

        target = {}

        if self.transform:
            image, target = T.ToTensor()(image, target)

        return image, image_no_mask, filename

    def __len__(self):
        return len(self.image_paths)

#### define model

In [3]:
def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model

#### settings

In [7]:
base_path = os.path.join("..")
data_path = os.path.join(base_path, "data")
model_path = os.path.join(base_path, "model")
checkpoint_file = "MRCNN.pth.tar"

input_path = os.path.join(data_path, "joint-ep-of-thu-ego-heiko", "train")
image_path = os.path.join(input_path, "images")
body_path = os.path.join(input_path, "labels")

batch_size = 8
num_workers = 4

#### Run

In [5]:
def overlay_two_images(image, overlay, ignore_color=[0,0,0]):
    ignore_color = np.asarray(ignore_color)
    mask = (overlay==ignore_color).all(-1, keepdims=True)
    out = np.where(mask,image,(image * 0.5 + overlay * 0.5).astype(image.dtype))
    return out

In [None]:
# dataset
dataset = Dataset(image_path=image_path, body_path=body_path)
torch.manual_seed(1)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=utils.collate_fn)

# target folders
mask_save_path = os.path.join(input_path, "preds")
overlay_save_path = os.path.join(input_path, "overlays")
os.makedirs(mask_save_path, exist_ok=True)
os.makedirs(overlay_save_path, exist_ok=True)

# run
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

for epoch in [5]:
    print("epoch: ", epoch)

    # load model
    checkpoint_path = os.path.join(model_path, f"{epoch}_{checkpoint_file}")
    num_classes= 2
    CLASS_NAMES = ['__background__', 'skin']

    model = get_instance_segmentation_model(num_classes)
    model.to(device)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()

    # predict
    with torch.no_grad():
        pbar = tqdm(data_loader)
        for i, (images, images_no_mask, filenames) in enumerate(pbar):
            images = [image.to(device) for image in images]
            predictions = model(images)
            for j, (image_no_mask, pred, filename) in enumerate(zip(images_no_mask, predictions, filenames)):
                filename = filename.replace(".jpg", ".png")
                                
                # merge pred
                pred_masks = (pred["masks"]).squeeze(dim=1).detach().cpu().numpy()
                merged_pred_mask = np.zeros(image_no_mask.shape[:2], dtype=np.uint8)
                for mask in pred_masks:
                    merged_pred_mask[mask > 0.5] = 1
                
                # create overlay
                mask = cv2.merge([merged_pred_mask*0, merged_pred_mask*255, merged_pred_mask*0]).astype("uint8")     
                overlay = overlay_two_images(image_no_mask, mask)
                
                # save
                cv2.imwrite(os.path.join(mask_save_path, filename), merged_pred_mask)
                cv2.imwrite(os.path.join(overlay_save_path, filename), overlay)

epoch:  5


  0%|          | 0/1001 [00:00<?, ?it/s]