In [1]:
modelNames = ["convnext-unetplusplus-espcn-modified"]

In [2]:
import os
import torch
import cv2
import tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import PIL.Image as pil
from torchvision import transforms
from collections import OrderedDict
from Models.EncoderModel import EncoderModelResNet, EncoderModelConvNeXt
from Models.DecoderModel import DepthDecoderModelUNET, DepthDecoderModelUNETPlusPlus, PoseDecoderModel
from Models.CameraNet import CameraNet
from utils import dispToDepth, transformParameters

In [3]:
device = "cpu"#torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
encoder_models = []
decoder_models = []
for modelName in modelNames:
    path = os.path.join("models/{}".format(modelName), "weights_19")
    if "resnet" in modelName:
        enc = EncoderModelResNet(50)
    elif "convnext" in modelName:
        enc = EncoderModelConvNeXt()
    else:
        raise Exception("Encoder not found.")
    encoderDict = torch.load(os.path.join(path, "encoder.pth"), map_location=device)
    try:
        height = encoderDict.pop("height")
        width = encoderDict.pop("width")
        use_stereo = encoderDict.pop("use_stereo")
    except:
        pass
    enc.load_state_dict(encoderDict)
    enc.to(device)
    z = enc.eval()
    espcn = False
    if "espcn" in modelName:
        espcn = True
    if "unetplusplus" in modelName:
        depthDecoder = DepthDecoderModelUNETPlusPlus(enc.numChannels, espcn)
    elif "unet" in modelName:
        depthDecoder = DepthDecoderModelUNET(enc.numChannels, espcn)
    else:
        raise Exception("Decoder not found.")
    try:
        depthDecoder.load_state_dict(torch.load(os.path.join(path, "decoder.pth"), map_location=device))
    except:
        odict = torch.load(os.path.join(path, "depth.pth"), map_location=device)
        odict_compat = OrderedDict([(key.replace("conv.conv", "conv"), value) for key, value in odict.items()])
        del odict
        depthDecoder.load_state_dict(odict_compat)
    depthDecoder.to(device)
    z = depthDecoder.eval()
    """if "camnet" in modelName:
        poseDecoder = CameraNet(enc.numChannels[-1], 192//8, 640//8)
    else:
        poseDecoder = PoseDecoderModel(enc.numChannels, 2, 1)
    poseDecoder.load_state_dict(torch.load(os.path.join(path, "pose.pth"), map_location=device))
    poseDecoder.to(device)
    z = poseDecoder.eval()"""
    encoder_models.append(enc)
    decoder_models.append(depthDecoder)

In [5]:
def getFrame(video_in, frame_idx):
    video_in.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
    countWait = 0
    ok = True
    frame = None
    while ok:
        ok, frame = video_in.read()
        if not ok:
            print("Waiting ... ")
            sleep(1)
            countWait += 1
            if countWait < 5:
                ok = True
            else:
                ok = False
        else:
            ok = False
    return frame

In [6]:
video_in = "predictions/testVideo.mp4"
video_in = cv2.VideoCapture(video_in)
fps = video_in.get(cv2.CAP_PROP_FPS)
width = int(video_in.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video_in.get(cv2.CAP_PROP_FRAME_HEIGHT))
length = int(video_in.get(cv2.CAP_PROP_FRAME_COUNT))
print("Total frames : {}".format(length))
video_outs = []
for modelName in modelNames:
    video_outs.append(cv2.VideoWriter("predictions/testVideo-{}.mp4".format(modelName), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)))    
checks = 0
fc = 0
while video_in.isOpened():
    with torch.no_grad():
        status, frame_T = video_in.read()
        if not status:
            checks += 1
            if checks == 10:
                video_in.release()
            continue
        else:
            checks = 0
        frame_T = pil.fromarray(cv2.cvtColor(frame_T, cv2.COLOR_BGR2RGB))
        original_frame = frame_T.copy()
        frame_T = frame_T.resize((640, 192), pil.LANCZOS)
        frame_T = transforms.ToTensor()(frame_T).unsqueeze(0).to(device)
        for enc, depthDecoder, video_out in zip(encoder_models, decoder_models, video_outs):
            main_features = enc(frame_T)
            outputs = depthDecoder(main_features)
            disp = outputs[("disp", 0)]
            disp_resized_np = torch.nn.functional.interpolate(disp, (height, width), mode="bilinear", align_corners=True).squeeze().cpu().numpy()
            normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=np.percentile(disp_resized_np, 95))
            mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
            im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
            video_out.write(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
        fc += 1
        if fc % 50 == 0:
            print("Frames completed: {}".format(fc))
for video_out in video_outs:
    video_out.release()

Total frames : 601
Frames completed: 50
Frames completed: 100
Frames completed: 150
Frames completed: 200
Frames completed: 250
Frames completed: 300
Frames completed: 350
Frames completed: 400
Frames completed: 450
Frames completed: 500
Frames completed: 550
