# 轻量化视频重建

## 这个示例展示如何使用一个端到端网络融合两个数据通路重建原始场景

调用接口：
- from tianmoucv.proc.reconstruct.TianmoucRecon_tiny


In [None]:
%load_ext autoreload

## 引入必要的库

In [None]:
%autoreload
import sys,os, math,time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tianmoucv.data import TianmoucDataReader
import torch.nn.functional as F
import cv2

## 准备数据

In [None]:
train='/data/lyh/tianmoucData/tianmoucReconDataset/train/'
dirlist = os.listdir(train)
traindata = [train + e for e in dirlist]

val='/data/lyh/tianmoucData/tianmoucReconDataset/test/'
vallist = os.listdir(val)
valdata = [val + e for e in vallist]
key_list = []
print('---------------------------------------------------')
for sampleset in traindata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e,end=" ")
        key_list.append(e)
print('---------------------------------------------------')
for sampleset in valdata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e,end=" ")
        key_list.append(e)
        
all_data = valdata + traindata
key_list = ['underbridge_hdr_4']

## TinyUNet重建网络调用示例

In [None]:
%autoreload
from tianmoucv.proc.reconstruct import TianmoucRecon_tiny

device = torch.device('cuda:0')
reconstructor = TianmoucRecon_tiny(ckpt_path=None,_optim=False).to(device)#某些版本python和pytorch无法使用_optim

# 视频输出

In [None]:
def images_to_video(frame_list,name,size=(640,320),Flip=True):
    fps = 30        
    ftmax = 1
    ftmin = 0
    out = cv2.VideoWriter(name,0x7634706d , fps, size)
    for ft in frame_list:
        ft = (ft-ftmin)/(ftmax-ftmin)
        ft[ft>1]=1
        ft[ft<0]=0
        ft2 = (ft*255).astype(np.uint8)
        out.write(ft2)
    out.release()

# 融合图像

In [None]:
%autoreload
from IPython.display import clear_output
from tianmoucv.isp import vizDiff

# 预期重建的区域，以中心向外圈扩展（不超过F0大小）                        
w = 640
h = 320

key_list = ['test_driving_night_light1']
for key in key_list:
    dataset = TianmoucDataReader(all_data,MAXLEN=500,matchkey=key)
    dataLoader = torch.utils.data.DataLoader(dataset, batch_size=1,\
                                          num_workers=4, pin_memory=False, drop_last = False)
    img_list = []
    count = 0
    for index,sample in enumerate(dataLoader,0):
        #重建前10帧
        if index<= 10:
            
            # 用于可视化 提前裁切
            F0 = sample['F0'][0,...].clone()
            biasw = (F0.shape[1]-w)//2
            biash = (F0.shape[0]-h)//2
            tsdiff = sample['tsdiff'][0,...][biash:h+biash,biasw:w+biasw,:]
            F0 = F0[biash:h+biash,biasw:w+biasw,:]
            
            #channel放到第1维用于推理
            sample['F0'] =  sample['F0'].permute(0,3,1,2)
            sample['F1'] =  sample['F1'].permute(0,3,1,2)
            
            '''
            输入简单处理过的数据包
            输出这个数据包重建的所有帧
            F0，F1：0~1
            tsdiff：-1~1
            ifSingleDirection：是否双向重建取平均
            w,h: 感兴趣的区域，设置成F0大小则为全图重建
            '''
            reconstructed_b = reconstructor(sample,
                                            w=w,
                                            h=h,
                                            bs=26,
                                            ifSingleDirection=False).float()
            
            
            timelen = tsdiff.shape[1]
            
            #最后一帧可以扔掉，或者跟下一次的重建的第0帧做个平均，降低一些闪烁感
            for t in range(timelen-1):
                tsd_rgb = tsdiff[:,t,...].cpu().permute(1,2,0)*255
                td = tsd_rgb.cpu()[:,:,0]
                sd = tsd_rgb.cpu()[:,:,1]
                rgb_sd = vizDiff(sd,thresh=3).permute(2,0,1)
                rgb_td = vizDiff(td,thresh=3).permute(2,0,1)
                #数据可视化
                rgb_cat = torch.cat([rgb_sd,rgb_td],dim=1)
                rgb_tsd = F.interpolate(rgb_cat.unsqueeze(0), scale_factor=0.5, mode='bilinear', align_corners=True).squeeze(0).permute(1,2,0)
                reconstructed = reconstructed_b[t,...].cpu()
                showim = torch.cat([F0,rgb_tsd,reconstructed.permute(1,2,0)],dim=1).numpy()
                # 标注文字
                cv2.putText(showim,"e-GT:"+str(t),(int(w*1.5)+12,36),cv2.FONT_HERSHEY_SIMPLEX,0.75,(0,0,0),2)
                cv2.putText(showim,"SD:"+str(t),(int(w)+12,24),cv2.FONT_HERSHEY_SIMPLEX,0.6,(0,0,0),2)
                cv2.putText(showim,"TD:"+str(t),(int(w)+12,160+24),cv2.FONT_HERSHEY_SIMPLEX,0.6,(0,0,0),2)
                cv2.putText(showim,"COP:0",(12,36),cv2.FONT_HERSHEY_SIMPLEX,0.75,(0,0,0),2)
        
                if t==12:
                    clear_output(wait=True)
                    plt.figure(figsize=(8,3))
                    plt.subplot(1,1,1)  
                    plt.imshow(showim)
                    plt.show()
                img_list.append(showim[...,[2,1,0]])
        else:
            break
    images_to_video(img_list,'./viz_'+key+'.mp4',size=(640*2+320,320),Flip=True)