In [1]:
import torch
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import cv2
import copy
from IPython import display
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import numpy as np

In [2]:
model_path = r'C:\STUDY\BileDcutSegmentation\0_62_0.4999_0.5777.pth'
video_path = r'C:\STUDY\BileDcutSegmentation\test_video.mp4'
to_video_path = r'inferenced_test_video.mp4'

In [3]:
class CFG:
    ImageSize = 256
    ImageMean = [0.496, 0.280, 0.313]
    ImageStd = [0.282, 0.259, 0.263]
    
    OutputClass = 3
    
    Device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # Device = 'cpu'
    

In [4]:
def GetModel():
    model = smp.DeepLabV3Plus(
         encoder_name='resnet50',
         encoder_weights='imagenet',
         in_channels=3,
         classes=3, 
         activation='softmax')
    return model


def ImgPreprocess(img):
    aug = A.Compose([
        A.Normalize(mean=CFG.ImageMean, std=CFG.ImageStd, max_pixel_value=255.0, p=1.0),
        A.Resize(CFG.ImageSize, CFG.ImageSize),
        ToTensorV2(p=1.0)
    ])
    img = aug(image=img)
    img = img['image'].to(CFG.Device)
    img = torch.unsqueeze(img, dim=0)
    return img


def OutputPreprocess(outputs, img_w, img_h):
    outputs = torch.argmax(outputs, dim=1, keepdim=True)
    outputs = outputs.detach().to('cpu').permute(0,2,3,1).numpy()[0].astype(np.uint8)
    outputs = cv2.resize(outputs, (img_w, img_h), interpolation=cv2.INTER_LINEAR)
    outputs = torch.tensor(outputs).to(torch.int64)
    outputs = F.one_hot(outputs, CFG.OutputClass)
    return outputs


def apply_mask(image, mask, color, alpha=0.4):
    for c in range(3):
        image[:, :, c] = np.where(mask == 1,
                                  image[:, :, c] * (1 - alpha) + alpha * color[c] * 255,
                                  image[:, :, c])
    return image  


def ApplyMask(image, masks, num_class, color_random=False):
    if color_random:
        def random_colors(N, bright=True):
            brightness = 1.0 if bright else 0.7
            hsv = [(i / N, 1, brightness) for i in range(N)]
            colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
            random.shuffle(colors)
            return colors        
        colors = random_colors(num_class)
        
    else:
        colors = [(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0)]

    for i in range(num_class):
        if i == 0:
            continue

        color = colors[i]
        mask = masks[:, :, i]
        masked_image = apply_mask(image, mask, color)        
        
    return masked_image



def BileDuctInference(video_path, to_video_path):
    
    cap = cv2.VideoCapture(video_path)

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frameRate = cap.get(cv2.CAP_PROP_FPS)
    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print('frame_width =', frame_width)
    print('frame_height =', frame_height)
    print('frame_rate =', frameRate)
    print(f'frame_length: {length}')

    # frameRate = 30
    # length = length//2

    # Setting codec
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    # generate video file to save images
    out = cv2.VideoWriter(to_video_path, fourcc, frameRate, (frame_width, frame_height))

    cnt = 0
    while True:
        retval, frame = cap.read()	
        cnt += 1
        if not retval:
            break   

        img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img_h, img_w = int(img.shape[0]), int(img.shape[1])
        ori_img = copy.deepcopy(img)
        # print(img.shape)

        img = ImgPreprocess(img)
        # print(img.shape)

        outputs = model(img)
        outputs = OutputPreprocess(outputs, img_w, img_h)
        # print(outputs.shape, np.unique(outputs))

        masked_img = ApplyMask(ori_img, outputs, CFG.OutputClass)
        out.write(cv2.cvtColor(masked_img, cv2.COLOR_RGB2BGR))	
        print(f'{cnt}/{length}')
        display.clear_output(wait=True)    

    cap.release()	
    out.release()
    
    print(f'Done!')
    print(f"Saved at {to_video_path}")

In [5]:
model = GetModel().to(CFG.Device)
model.load_state_dict(torch.load(model_path, map_location=CFG.Device))
model.eval()

DeepLabV3Plus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequentia

In [6]:
BileDuctInference(video_path, to_video_path)

Done!
Saved at inferenced_test_video.mp4
