In [2]:
from __future__ import print_function 
import cv2
import torch 
import torch.nn.functional as F
from torchvision import transforms 
import time

In [3]:
model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained = True)
people_class = 15

model.eval()
print("Model has been loaded.")

Downloading: "https://github.com/pytorch/vision/zipball/v0.6.0" to /home/kinlo/.cache/torch/hub/v0.6.0.zip
Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /home/kinlo/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth
100.0%


Model has been loaded.


In [4]:
blur = torch.FloatTensor([[[[1.0, 2.0, 1.0],[2.0, 4.0, 2.0],[1.0, 2.0, 1.0]]]]) / 16.0
if torch.cuda.is_available():
    model.to('cuda')
    blur = blur.to('cuda')

# Apply preprocessing (normalization)
preprocess = transforms.Compose([
	transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to create segmentation mask
def makeSegMask(img):
    # Scale input frame
	frame_data = torch.FloatTensor( img ) / 255.0

	input_tensor = preprocess(frame_data.permute(2, 0, 1))
    
    # Create mini-batch to be used by the model
	input_batch = input_tensor.unsqueeze(0)

    # Use GPU if supported, for better performance
	if torch.cuda.is_available():
		input_batch = input_batch.to('cuda')

	with torch.no_grad():
		output = model(input_batch)['out'][0]

	segmentation = output.argmax(0)

	bgOut = output[0:1][:][:]
	a = (1.0 - F.relu(torch.tanh(bgOut * 0.30 - 1.0))).pow(0.5) * 2.0

	people = segmentation.eq( torch.ones_like(segmentation).long().fill_(people_class) ).float()

	people.unsqueeze_(0).unsqueeze_(0)
	
	for i in range(3):
		people = F.conv2d(people, blur, stride=1, padding=1)

	# Activation function to combine masks - F.hardtanh(a * b)
	combined_mask = F.relu(F.hardtanh(a * (people.squeeze().pow(1.5)) ))
	combined_mask = combined_mask.expand(1, 3, -1, -1)

	res = (combined_mask * 255.0).cpu().squeeze().byte().permute(1, 2, 0).numpy()

	return res



In [6]:
video = cv2.VideoCapture('../Persona8/persona8_9_f.webm')

# Get video file's dimensions
frame_width = int(video.get(3))
frame_height = int(video.get(4))

# Creates output video file
out = cv2.VideoWriter('Persona_fcn.avi',cv2.VideoWriter_fourcc('M','J','P','G'), 30, (frame_width,frame_height))

prev_frame_time = 0
new_frame_time = 0

while (video.isOpened):
    # Read each frame one by one
    success, img = video.read()
    
    # Run if there are still frames left
    if (success):
        img_resized = cv2.resize(img, (256, 256))
        # Apply background subtraction to extract foreground (silhouette)
        mask = makeSegMask(img_resized)
        
        # Apply thresholding to convert mask to binary map
        ret,thresh = cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
        
        # Write processed frame to output file
        
        # cv2.rectangle(mask, (10, 2), (100,20), (255,255,255), -1)
        # cv2.putText(mask, fps, (15, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5 , (0,0,0))
        
        # Show extracted silhouette only, by multiplying mask and input frame
        #final = cv2.bitwise_and(thresh, img_resized)
        
        # Show current frame
        cv2.imshow('Silhouette Mask', mask)
        #cv2.imshow('Extracted Silhouette', final)
        out.write(mask)
        # Allow early termination with Esc key
        key = cv2.waitKey(10)
        if key == 27:
            break
    # Break when there are no more frames  
    else:
        break

# Release resources
cv2.destroyAllWindows()
video.release()
out.release()