In [None]:
import argparse
from utils import add_logging_arguments
import logging
import sys
import torch
import torchvision
import torch.nn.functional as F
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
import imageio
import shutil
import cv2
import time

import data
import models
import utils
import os
import warnings
warnings.filterwarnings("ignore")
    

In [None]:
args = lambda: None
args.gpu = 0 # gpu_id
args.angular = 7 # number of angular views to generate; changing this will affect quantitative evaluation
args.model = 'unet_icip'
args.disp_model = 'dispnetC'
args.flow_model = 'flownetC'
args.display = 'multilayer'

args.layers = 3
args.rank = 12
args.restore_file = 'weights/checkpoint.pt'


args.dataset = 'test_st' # dataloader specifically for case without GT LF images
args.h5_file = 'stereo_data.h5'
args.data_path = 'data'
args.batch_size = 1
args.save_dir = 'results'

args.inph = 270 # please specify the input height and width values of your data
args.inpw = 470
args.seq_len = 4

args.seed = 42
args.dry_run = True

In [None]:
def run_batch(inputs, outputs):
    inputs['video'] = inputs['video'].to(device)
    inputs['lf_states'] = None
    inputs['prev_idx'] = 0
    inputs['curr_step'] = 0
    outputs['pred_lf'] = []

    for t in range(1, inputs['video'].size(1)):
        # iterate over the t \in T frames
        if t > 0:
            inputs['flow_loss'] = True
            inputs['prev_idx'] = int(t - 1)
        inputs['curr_step'] = t
        run_instance(inputs, outputs)

def run_instance(inputs, outputs):
    # [N,2,2,3,H,W]: targets
    # [batch, time, view, rgb, height, width]
    targets = inputs['video'][:, inputs['prev_idx']:inputs['curr_step'] + 1, ...]
    curr_gt_lf_frame = targets[:, -1, ...]
    curr_stereo_frame = curr_gt_lf_frame#[:, lf_view_idx, ...]
    instance_loss = 0.

    # from the same function return the disparity map
    decomposition = lf_model(curr_stereo_frame, inputs)
    # decomposition is of size [N,layers,rank,3,h,w]
    curr_lf = tensor_display(decomposition)
    outputs['pred_lf'].append(curr_lf)


    

In [None]:
device = torch.device(
    f'cuda:{args.gpu}') if torch.cuda.is_available() else torch.device('cpu')
utils.setup_experiment(args)
utils.init_logging(args)

# indices of the input stereo frame in the output LF
left_view_idx = 0
right_view_idx = 1

# this is the view indices for the network to predict LF
# lf_view_idx = [left_view_idx, right_view_idx]
# print(f'using stereo view indices as {lf_view_idx}')

In [None]:
models_list = []
trainable_params = []
# ============== Initialize all network models ===================
# initialize the lf prediction network V
lf_model = models.build_model(
    args.model,
    n_channels= 2 * 3,
    args=args).to(device)
trainable_params.extend(list(lf_model.parameters()))
models_list.append(lf_model)

# initialize the optical flow prediction network O
flow_model = models.build_model(
    args.flow_model,
    n_channels=3,
    args=args).to(device)
trainable_params.extend(list(flow_model.parameters()))
models_list.append(flow_model)

# initialize the disparity map prediction network D
disp_model = models.build_model(
    args.disp_model,
    n_channels=3,
    args=args).to(device)
trainable_params.extend(list(disp_model.parameters()))
models_list.append(disp_model)

optimizer = None
scheduler = None
logging.info(
    f"Built {len(models_list)} models consisting of {sum(p.numel() for p in trainable_params):,} parameters")

# ========== Initialize the low-rank display model ==============
if args.display == 'multilayer':
    tensor_display = models.multilayer(
        args.angular,
        args.layers,
        args.inph,
        args.inpw,
        args=args).to(device)
else:
    print('No valid display type chosen')
    print('exiting')
    exit(0)
logging.info(
    f"Using the {args.display} display with {args.layers} layers and {args.rank} rank")

In [None]:
state_dict = utils.load_checkpoint(
    args, models_list, optimizer, scheduler)
global_step = state_dict['last_step']
start_epoch = state_dict['epoch']

In [None]:
test_loader = data.build_dataset(
    args.dataset,
    args.data_path,
    args,
    batch_size=args.batch_size,
    num_workers=16)

In [None]:
save_path = args.save_dir#os.path.join('/', save_path, expt_dir)
if os.path.exists(save_path):
    shutil.rmtree(save_path)
    print(f'removing the directory tree {save_path}')
os.makedirs(save_path, exist_ok=True)
print(save_path)

for model in models_list:
    model.eval()
tensor_display.eval()

test_bar = utils.ProgressBar(test_loader)
save_every = 1
for sample_id, inputs in enumerate(test_bar):
    ssim_vid = []
    lpips_vid = []
    with torch.no_grad():
        outputs = {}
        run_batch(inputs, outputs)
        # outputs["pred_lf"] will be a sequence of LF frames
        # you just have to save them
        # and also compute psnr; ssim and lpips
        # which you can do in the run_batch fn itself
        # test only with a batch size of 1
        assert inputs['video'].size(0) == 1
        if sample_id % save_every == 0:
            # each video sequence will be saved in a separate directory
            # each frame of the video sequence will be saved in a separate
            # sub-directory
            seq_save_path = f'{save_path}/seq_{sample_id:03d}'
            os.makedirs(f'{save_path}/seq_{sample_id:03d}', exist_ok=True)

            # then save the predicted light field
            for t in range(len(outputs['pred_lf'])):
                pred_lf_np = outputs['pred_lf'][t].data.cpu(
                ).numpy().squeeze()
                pred_lf_np = np.transpose(pred_lf_np, [0, 2, 3, 1])
                pred_lf_np = (pred_lf_np - np.amin(pred_lf_np))/(np.amax(pred_lf_np) - np.amin(pred_lf_np))
                save_lf_path = os.path.join(
                    seq_save_path, f'pred_lf_{sample_id:02d}_{t:02d}.avi')
                fourcc = cv2.VideoWriter_fourcc(*'mp4v')
                out = cv2.VideoWriter(
                    save_lf_path, fourcc, 7, (args.inpw, args.inph))
                for k in range(len(pred_lf_np)):
                    out.write(np.uint8(pred_lf_np[k, ..., ::-1] * 255))