In [11]:
from Net import *
from VideoUtils import *
import torch
from torch.autograd.variable import Variable
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as co
import cv2
from torchvision import transforms
import gc
gc.collect()
torch.cuda.empty_cache()

In [12]:
model = HydraNet()
model.mobilenet_encoder()
model.refinenet_decoder()
model.initialize_weights()

In [13]:
ckpt = torch.load('KITTI.ckpt')
model.load_state_dict(ckpt['state_dict'])

<All keys matched successfully>

In [14]:
CMAP = np.load('cmap.npy')
NUM_CLASSES = 6

In [15]:
# write_video('video_input\daytraffic.mp4')
# write_video('video_input\\rainy.mp4')
write_video('video_input\\night.mp4')

In [16]:
if torch.cuda.is_available():
    _ = model.cuda() # gpu
_ = model.eval() # eval mode

In [17]:
IMG_SCALE  = 1./255
IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))

def prepare_img(img):
    return (img * IMG_SCALE - IMG_MEAN) / IMG_STD

In [18]:
def pipeline(img):
    with torch.no_grad():
        img_var = Variable(torch.from_numpy(prepare_img(img).transpose(2, 0, 1)[None]), requires_grad=False).float()
        if torch.cuda.is_available():
            img_var = img_var.cuda()
        depth, segm = model(img_var)
        segm = cv2.resize(segm[0, :NUM_CLASSES].cpu().data.numpy().transpose(1, 2, 0),
                        img.shape[:2][::-1],
                        interpolation=cv2.INTER_LANCZOS4)
        depth = cv2.resize(depth[0, 0].cpu().data.numpy(),
                        img.shape[:2][::-1],
                        interpolation=cv2.INTER_LANCZOS4)
        segm = CMAP[segm.argmax(axis=2)].astype(np.uint8)
        depth = np.abs(depth)
        return depth, segm

def depth_to_rgb(depth):
    normalizer = co.Normalize(vmin=0, vmax=80)
    mapper = cm.ScalarMappable(norm=normalizer, cmap='plasma')
    colormapped_im = (mapper.to_rgba(depth)[:,:,:3] * 255).astype(np.uint8)
    return colormapped_im

In [19]:
gc.collect()
torch.cuda.empty_cache()
video_files = sorted(glob.glob('video_output/*png'))

result_video = []
for idx, img_path in enumerate(video_files):
    image = np.array(Image.open(img_path))
    h,w,_ = image.shape
    depth, seg = pipeline(image)
    result_video.append(cv2.cvtColor(cv2.hconcat([image, seg, depth_to_rgb(depth)]), cv2.COLOR_BGR2RGB))

out = cv2.VideoWriter('predictions/out.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 15, (3 * w, h))

for i in range(len(result_video)):
    out.write(result_video[i])
out.release()

In [20]:
play_video('predictions\out.mp4')