# Reconstructor

note: This model is a bit large, which require about 12GB GPU MEM

In [None]:
%load_ext autoreload
! nvidia-smi

# prepare environment and reconstruction model

In [None]:
import time, os, random,sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

sys.path.append("../")
sys.path.append("../datareader")
from tianmoucv.alg import cal_optical_flow,backWarp,flow_to_image,white_balance
from tianmoucv.basic import vizDiff
from tianmoucv.nn import warp_fast,interpolate_preprocess
from tianmoucv.isp import lyncam_raw_comp,demosaicing_npy

import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt


##################### Step1. Env Preparation #####################
local_rank = 0
device = torch.device('cuda:'+str(0))
torch.cuda.set_device(0)
writer = None 
master = False 

###################### Step2. model and data Preparation #############

from  model.reconstructor_new import TianmoucRecon

CHECKPOINT_DIR =  '../data/ckpts/'
CHECKPOINT_PATH_MODEL = '../data/ckpts/unet_reconstruction.ckpt'

VALIDATION_BATCH_SIZE = 1
TRAINING_CONTINUE = True
h = 320
w = 640 
Val_size   = (w,h)
ReconModel = TianmoucRecon(Val_size)
ReconModel.load_model(ckpt=CHECKPOINT_PATH_MODEL)
ReconModel.to(device)
start = time.time()
imlist = []

In [None]:
from tianmoucData import TianmoucDataReader

dataset_top1 = "../data/recon_data"
datasetList = [dataset_top1]

key = 'tunnelbyq9_1332ae'
key = 'openroad_sync1yh'

startID = 50
endID = startID + 20

dataset = TianmoucDataReader(datasetList,matchkey=key)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=VALIDATION_BATCH_SIZE,\
                                         num_workers=4, pin_memory=False, drop_last = False)

# run reconstruction

## mode1: ifsingleDirection = True
cop0 + (aop0 + ... + aopn) -> reconstructed_RGB

## mode1: ifsingleDirection = False
cop0 + (aop0 + ... + aopn + ... + aop_N-1) + cop1 -> reconstructed_RGB

In [None]:
%autoreload

ifsingleDirection = True
with torch.no_grad():
    validationIndex = 0
    for index,sampleRaw in enumerate(dataloader, 0):
        if index < startID:
            continue
        if index > endID:
            break
        startTime  = time.time()
        sample = dict([])
        F0 = sampleRaw['F0']
        F1 = sampleRaw['F1']
        
        tsdiff = sampleRaw['tsdiff']
        sample['F0'] = F0.permute(0,3,1,2).to(device)
        sample['F1'] =  F1.permute(0,3,1,2).to(device)
        sample['tsdiff'] = tsdiff.to(device)
        
        middleTime  = time.time()
        F1t, F0,tsdiff= warp_fast(sample,ReconModel,None,h,w,device,ifsingleDirection=ifsingleDirection)
        endTime  = time.time()
        tsdiff = tsdiff.cpu()
        for t in range(25):
            retImg1 = F0.cpu()[t,:,:,:]
            retImg2 = F1t.cpu()[t,:,:,:]
            imageCanve = torch.zeros([3,w*4,h*3])
            gapw = w//4
            gaoh = h//4
            imageCanve[:,gaoh:gaoh+h,gapw:gapw+w]
            
            sd  = tsdiff[0,1,t,...] * 255      
            rgb_sd = vizDiff(sd,thresh=12)
                             
            td = tsdiff[0,0,t,...] * 255    
            rgb_td = vizDiff(td,thresh=12)

            img_col1 = torch.cat([retImg1,rgb_td],dim=1)
            img_col2 = torch.cat([rgb_sd,retImg2],dim=1) 
            img = torch.cat([img_col1,img_col2],dim=2)
            imlist.append(img)
            if t == 12:
                plt.figure(figsize=(16,8))
                canvas = (img.permute(1,2,0).numpy() * 255).astype(np.uint8).copy()
                canvas[0:h,0:w,...] = white_balance(canvas[0:h,0:w,...])
                canvas[h:2*h,w:2*w,...] = white_balance(canvas[h:2*h,w:2*w,...])
                
                cv2.putText(canvas,"CONE",(10+0*w,20+0*h),cv2.FONT_HERSHEY_SIMPLEX,0.75,(255,255,255),2)
                cv2.putText(canvas,"TD",(10+0*w,20+1*h),cv2.FONT_HERSHEY_SIMPLEX,0.75,(255,255,255),2)
                cv2.putText(canvas,"SD",(10+1*w,20+0*h),cv2.FONT_HERSHEY_SIMPLEX,0.75,(255,255,255),2)
                cv2.putText(canvas,"reconstructed:cone+12aop(~16ms)",(10+1*w,20+1*h),cv2.FONT_HERSHEY_SIMPLEX,0.75,(255,255,255),2)
                plt.imshow(canvas)
                plt.axis('off')
                plt.show()
        
        print(validationIndex,'/',endID-startID, ' cost:',endTime-startTime,'s',' run:',middleTime-startTime,'s')
        validationIndex += 1

# dump a video

In [None]:
def images_to_video(frame_list,name,Val_size=(512,256),Flip=False):
    fps = 15          
    size = (Val_size[0]*2, Val_size[1]*2) # 需要转为视频的图片的尺寸
    #cv2.VideoWriter_fourcc(*'DIVX')
    out = cv2.VideoWriter(name,0x7634706d , fps, size)
    for frame in frame_list:
        frame = (frame[[2,1,0],:,]*255).cpu().permute(1,2,0).numpy() 
        w = Val_size[0]
        h = Val_size[1]
        frame[0:h,0:w,...] = white_balance(frame[0:h,0:w,...])
        frame[h:2*h,w:2*w,...] = white_balance(frame[h:2*h,w:2*w,...])
        if Flip:
            frame[0:h,0:w,:] = frame[h:0:-1,0:w,:]
            frame[h:2*h,0:w,:] = frame[2*h:h-1:-1,0:w,:]
            frame[0:h,w:2*w,:] = frame[h:0:-1,w:2*w,:]
            frame[h:2*h,w:2*w,:] = frame[h*2:h-1:-1,w:2*w,:]
        
        frame = frame.astype(np.uint8)
        out.write(frame)
    out.release()
    
images_to_video(imlist,'../../results/'+key+'_'+str(validationIndex)+'.mp4',Val_size=(w,h),Flip=False)
imlist_fast = []
for i in range(len(imlist)//12):
    imlist_fast.append(imlist[i*12])

images_to_video(imlist_fast,'../../results/'+key+'_'+str(validationIndex)+'_fast.mp4',Val_size=(w,h),Flip=False)