# 轻量化视频重建(TD)

## 这个示例展示使用灰度图累加TD，轻量级的unet做时间积分

调用接口：
- from tianmoucv.proc.reconstruct.TianmoucRecon_tiny_td


In [None]:
%load_ext autoreload

## 引入必要的库

In [None]:
%autoreload
import sys,os, math,time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tianmoucv.data import TianmoucDataReader
import torch.nn.functional as F
import cv2

## 准备数据

In [None]:
train='/data/lyh/tianmoucData/tianmoucReconDataset/train/'
dirlist = os.listdir(train)
traindata = [train + e for e in dirlist]

val='/data/lyh/tianmoucData/tianmoucReconDataset/test/'
vallist = os.listdir(val)
valdata = [val + e for e in vallist]
key_list = []
print('---------------------------------------------------')
for sampleset in traindata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e,end=" ")
        key_list.append(e)
print('---------------------------------------------------')
for sampleset in valdata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e,end=" ")
        key_list.append(e)
        
all_data = valdata + traindata
key_list = ['train_exam_fan5']

## TD2VID重建网络调用示例

In [None]:
%autoreload
from tianmoucv.proc.reconstruct import TianmoucRecon_recurrent

device = torch.device('cuda:0')
#设置为None会自动下载权重
reconstructor = TianmoucRecon_recurrent(ckpt_path=None,_optim=False).to(device)#某些版本python和pytorch无法使用_optim

print(reconstructor.reconNet)

# 视频输出

In [None]:
def images_to_video(frame_list,name,size=(640,320),Flip=True):
    fps = 30        
    ftmax = 1
    ftmin = 0
    out = cv2.VideoWriter(name,0x7634706d , fps, size)
    for ft in frame_list:
        ft = (ft-ftmin)/(ftmax-ftmin)
        ft[ft>1]=1
        ft[ft<0]=0
        ft2 = (ft*255).astype(np.uint8)
        out.write(ft2)
    out.release()

# 融合图像

In [None]:
%autoreload
%matplotlib inline
from IPython.display import clear_output
from tianmoucv.isp import vizDiff

# 预期重建的区域，以中心向外圈扩展（不超过F0大小）                        
w = 640
h = 320
    
key_list = ['test_man_play_ball3']
for key in key_list:
    dataset = TianmoucDataReader(all_data,MAXLEN=500,matchkey=key)

    frame_list = []
    count = 0
    states = None
    for index in range(len(dataset)-1):
        #重建前10帧
        if index<= 60:
            sample0 = dataset[index]
        
            # 用于可视化 提前裁切
            F0 = sample0['F0'][...].clone()
            biasw = (F0.shape[1]-w)//2
            biash = (F0.shape[0]-h)//2
            F0 = F0[biash:h+biash,biasw:w+biasw,:]

            tsdiff = sample0['tsdiff'][biash:h+biash,biasw:w+biasw,:]

            td = tsdiff[0:1,1:,...].to(device).unsqueeze(0)
            timelen = tsdiff.shape[1]-1
            
            reconstructed_b,states = reconstructor(td, states)

            print(reconstructed_b.shape)

            for t in range(reconstructed_b.size(0)):
                reconstructed = reconstructed_b[t,...].cpu()
                reconstructed = reconstructed.permute(1,2,0)
                reconstructed = torch.cat([reconstructed]*3,dim=2)
            
                tsd_rgb = tsdiff[:,t,...].cpu().permute(1,2,0)*255
                td_ = tsd_rgb.cpu()[:,:,0]
                sd_ = tsd_rgb.cpu()[:,:,1]
                rgb_sd = vizDiff(sd_,thresh=0,bg_color='black',gain=16) 
                rgb_td = vizDiff(td_,thresh=0,bg_color='black',gain=16) 

                reconstructed *= 255
                
                canvas = np.zeros([h*2,w*2,3])
                canvas[:h,:640,:] = F0.cpu().numpy()*255
                canvas[:h,640:,:] = rgb_td
                canvas[1*h:2*h,:w,:] = rgb_sd
                canvas[1*h:2*h,w:,:] = reconstructed 
                
                cv2.putText(canvas,"CONE",(12+0*w,24+0*h),cv2.FONT_HERSHEY_SIMPLEX,0.75,(255,0,0),2)
                cv2.putText(canvas,"TD",(12+1*w,24+0*h),cv2.FONT_HERSHEY_SIMPLEX,0.75,(255,0,0),2)
                cv2.putText(canvas,"SD:"+str(t) ,(12+0*w,24+1*h),cv2.FONT_HERSHEY_SIMPLEX,0.75,(255,0,0),2)
                cv2.putText(canvas,"td-gray+"+str(t)+"",(12+1*w,24+1*h),cv2.FONT_HERSHEY_SIMPLEX,0.75,(255,0,0),2)

                if t in [12]:
                    plt.imshow(canvas/255)
                    plt.show()
                frame_list.append(canvas/255)

        else:
            break
    images_to_video(frame_list,'./viz_'+key+'.mp4',size=(640*2,320*2),Flip=True)