# Reconstructor UNet in 《A vision chip with complementary pathways or open-world sensing》

## 展示该论文所用的初版重建算法，以及SD增强HDR

调用接口：
- tianmoucv.proc.reconstruct.TianmoucRecon_Original


In [None]:
%load_ext autoreload

## 引入必要的库

In [None]:
%autoreload
import sys,os, math,time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tianmoucv.data import TianmoucDataReader
import torch.nn.functional as F
import cv2

## 准备数据

In [None]:
all_data = '/home/lyh/tunnel9_hdr'
key_list = ['tunnel9_hdr']

## 引入网络

In [None]:
%autoreload
from tianmoucv.proc.reconstruct import TianmoucRecon_Original
device = torch.device('cuda:1')
reconstructor = TianmoucRecon_Original(ckpt_path=None,_optim=False).to(device)#有合适的环境可以开pytorch优化

# 融合图像

In [None]:
%autoreload
import torch.nn as nn
import math,time
from tianmoucv.isp import vizDiff
import torch.nn.functional as F
from IPython.display import clear_output
from tianmoucv.isp import SD2XY
from tianmoucv.proc.reconstruct import poisson_blending

def images_to_video(frame_list,name,size=(640,320),Flip=True):
    fps = 30        
    ftmax = 1
    ftmin = 0
    out = cv2.VideoWriter(name,0x7634706d , fps, size)
    output_folder = name.split('.')[0]
    count = 0
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)
    for ft in frame_list:
        ft = (ft-ftmin)/(ftmax-ftmin)
        ft[ft>1]=1
        ft[ft<0]=0
        ft2 = (ft*255).astype(np.uint8)
        out.write(ft2)
        count += 1
        filename = f"{count:06d}.png"
        file_path = os.path.join(output_folder, filename)
        cv2.imwrite(file_path, ft2)
    out.release()


w = 640
h = 320

for key in key_list:
    dataset = TianmoucDataReader(all_data,MAXLEN=500,matchkey=key)
    dataLoader = torch.utils.data.DataLoader(dataset, batch_size=1,\
                                          num_workers=4, pin_memory=False, drop_last = False)
    img_list = []
    count = 0
    for index,sample in enumerate(dataLoader,0):
        if index<= 10:
            # 用于可视化 提前裁切
            F0 = sample['F0'][0,...].clone()
            biasw = (F0.shape[1]-w)//2
            biash = (F0.shape[0]-h)//2
            tsdiff = sample['tsdiff'][0,...][biash:h+biash,biasw:w+biasw,:]
            #sample['tsdiff'] *= 0
            F0 = F0[biash:h+biash,biasw:w+biasw,:]
            
            #channel放到第1维用于推理
            sample['F0'] = sample['F0'].permute(0,3,1,2)
            sample['F1'] = sample['F1'].permute(0,3,1,2)
            
            '''
            输入简单处理过的数据包
            输出这个数据包重建的所有帧
            F0，F1：0~1
            tsdiff：-1~1
            ifSingleDirection：是否双向重建取平均
            w,h: 感兴趣的区域，设置成F0大小则为全图重建
            bs：推理的批大小，显存不够可以设小一些，显存够推荐用26
            '''
            reconstructed_b = reconstructor(sample, 
                                            bs=26, 
                                            h=h, 
                                            w=w).float()
            
            timelen = reconstructed_b.shape[0]
            #最后一帧可以扔掉，或者跟下一次的重建的第0帧做个平均，降低一些闪烁感
            for t in range(timelen-1):

                tsd_rgb = tsdiff[:,t,...].permute(1,2,0)*64
                td = tsd_rgb.cpu()[:,:,0]
                sd = tsd_rgb.cpu()[:,:,1:]
                rgb_sd = vizDiff(sd[...,0],thresh=1)
                rgb_td = vizDiff(td,thresh=1)

                rawDiff = sample['rawDiff'][0,:,t,...].cpu().permute(1,2,0)
                sd = rawDiff.cpu()[:,:,1:]
                Ix,Iy= SD2XY(sd)
                gray = poisson_blending(Ix,Iy,iteration=20)
                gray = torch.stack([gray]*3,dim=0) 
                gray = F.interpolate(gray.unsqueeze(0), size=(320,640), mode='bilinear').squeeze(0).permute(1,2,0)
                gray = (gray-torch.min(gray))/(torch.max(gray)-torch.min(gray))

                rgb_cat = torch.cat([rgb_sd,gray],dim=0).permute(2,0,1)
                rgb_tsd = F.interpolate(rgb_cat.unsqueeze(0), scale_factor=0.5, mode='bilinear', align_corners=True).squeeze(0).permute(1,2,0)

                reconstructed = reconstructed_b[t,...].cpu().permute(1,2,0)

                Ix = F.interpolate(Ix.unsqueeze(0).unsqueeze(0), size=(320,640), mode='bilinear').squeeze(0).squeeze(0)/128
                Iy = F.interpolate(Iy.unsqueeze(0).unsqueeze(0), size=(320,640), mode='bilinear').squeeze(0).squeeze(0)/128

                reconstructed = poisson_blending(Ix,Iy, srcimg= reconstructed,iteration=20, mask_rgb=True,mask_th=36)

                            
                showim = torch.cat([F0,rgb_tsd,reconstructed],dim=1).numpy()

                cv2.putText(showim,"e-GT:"+str(t),(int(w*1.5)+12,36),cv2.FONT_HERSHEY_SIMPLEX,0.75,(0,0,0),2)
                cv2.putText(showim,"SD:"+str(t),(int(w)+12,24),cv2.FONT_HERSHEY_SIMPLEX,0.6,(0,0,0),2)
                cv2.putText(showim,"gray:"+str(t),(int(w)+12,160+24),cv2.FONT_HERSHEY_SIMPLEX,0.6,(0,0,0),2)
                cv2.putText(showim,"COP id:"+str(index),(12,36),cv2.FONT_HERSHEY_SIMPLEX,0.75,(0,0,0),2)
        
                if t==12:
                    clear_output()
                    plt.figure(figsize=(8,3))
                    plt.subplot(1,1,1)  
                    plt.imshow(showim)
                    plt.show()
                img_list.append(showim[...,[2,1,0]])
        else:
            break

In [None]:
images_to_video(img_list,'has_sd_'+key+'.mp4',size=(640*2+320,320),Flip=True)