# 演示使用直接计算出的光流，对原始图像做扭曲，对部分特征做追踪

In [None]:
%load_ext autoreload

# 构造数据集

In [None]:
%autoreload
import sys
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import math
from tianmoucv.data import TianmoucDataReader

train='/data/lyh/tianmoucData/tianmoucReconDataset/train/'
dirlist = os.listdir(train)
traindata = [train + e for e in dirlist]

val='/data/lyh/tianmoucData/tianmoucReconDataset/test/'
vallist = os.listdir(val)
valdata = [val + e for e in vallist]
key_list = []

print('---------------------------------------------------')
for sampleset in traindata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e)
        key_list.append(e)
print('---------------------------------------------------')
for sampleset in valdata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e)
        key_list.append(e)
        
all_data = valdata + traindata

# 光流计算

In [None]:
%autoreload
sys.path.append("../demo")
from tianmoucv.proc.opticalflow import interpolate_image,flow_to_image
from tianmoucv.proc.opticalflow.estimator import recurrentMultiScaleOF
from tianmoucv.isp import fourdirection2xy

imlist = []
accumTime = 5
noiseThresh = 8
lambda_of_HS = 25 #bigger->smoother
#(输入是0~255时lambda要>1,否则千万不能太大)

W = 640
H = 320
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
import time

show_list = []

key_list = ['test_exam_fan4']

for key in key_list:
    pathList = all_data
    dataset = TianmoucDataReader(pathList,showList=True,
                                 matchkey = key,
                                 MAXLEN=-1,
                                 speedUpRate=1,
                                 print_info=True)
    show_list = []
    for index in range(len(dataset)):
        if index <= 45:
            continue
        elif index > 75:
            break
        else:
            print('rpogress:',index,'/',len(dataset))
            sample = dataset[index]
            F0 = sample['F0']
            F1 = sample['F1']
            tsdiff = sample['rawDiff']
            F0show = F0.copy()
            show_img = F0show.copy()
            for b in range(25//accumTime):
                sd = 0
                td = 0
                TD = 0
                #积累几帧diff
                for t in range(accumTime):
                    threshed_tsdiff = tsdiff[:,b*accumTime+t,...].permute(1,2,0)
                    threshed_tsdiff[abs(threshed_tsdiff)<noiseThresh] = 0
                    SD = threshed_tsdiff[...,1:]
                    TD = threshed_tsdiff[...,0]
                    Ix,Iy= fourdirection2xy(SD)
                    sd += torch.FloatTensor(np.stack([Ix,Iy],axis=0))
                    td += -(TD)

                # AOP预处理
                sd = sd/accumTime
                td = td.unsqueeze(0)

                # 计算OF
                #rawflow = recurrentOF(sd,td,ifInterploted = True)
                rawflow = recurrentMultiScaleOF(sd,td,ifInterploted = True,epsilon = 1e-8,maxIteration = 50,labmda=lambda_of_HS,scales = 3)
                u = rawflow[0,:, :].numpy()
                v = rawflow[1,:, :].numpy()
                u = torch.Tensor(cv2.resize(u,(640,320))).unsqueeze(0)
                v = torch.Tensor(cv2.resize(v,(640,320))).unsqueeze(0)
                flow_show = flow_to_image(rawflow.permute(1,2,0).numpy())
                flow_show = torch.Tensor(cv2.resize(flow_show,(640,320)))/255.0
                flow_show = (flow_show*255).numpy().astype(np.uint8)

                tdshow = TD.unsqueeze(0).unsqueeze(0)
                tdshow = F.interpolate(tdshow,(320,640),mode='bilinear')


                mask = np.mean(flow_show,axis=-1) > 225
                flow_show[np.stack([mask]*3,axis=-1)]=0

                show_img = interpolate_image(show_img,u,v)
                tdiff_show = np.stack([tdshow[0,0,...].cpu()*255]*3,axis=2).astype(np.uint8)
                sparsity = 8
                scale = 10
                for w in range(640//sparsity):
                    for h in range(320//sparsity):
                        x = int(w*sparsity)
                        y = int(h*sparsity)
                        u_ij = -u[0,y,x]
                        v_ij = -v[0,y,x]
                        color = flow_show[y,x,:]
                        color = tuple([int(e+20) for e in color])
                        if (u_ij**2+v_ij**2)>5:
                            cv2.arrowedLine(flow_show, (x,y), (int(x+u_ij*scale),int(y+v_ij*scale)), color,2, tipLength=0.15)

                tdiff_show_tensor = torch.Tensor(tdiff_show.copy())
                flow_show_tensor = torch.Tensor(flow_show)
                mask = torch.stack([torch.mean(flow_show_tensor,dim=-1)>0]*3,dim=-1)
                tdiff_show_tensor[mask] = flow_show_tensor[mask]/255.0
                tdiff_show_merge = tdiff_show_tensor.numpy()
                imshow = np.concatenate([flow_show/255.0,tdiff_show,tdiff_show_merge],axis=0)
                show_list.append(imshow)

                if b %10 ==0:
                    plt.figure(figsize=(18,10))
                    plt.axis('off') 
                    plt.subplot(2,3,1)
                    plt.imshow(Ix,cmap='gray')
                    plt.subplot(2,3,2)
                    plt.imshow(TD,cmap='gray')
                    plt.axis('off') 
                    plt.subplot(2,3,4)
                    plt.imshow(F0show)
                    plt.subplot(2,3,5)
                    plt.imshow(flow_show)
                    plt.subplot(2,3,6)
                    plt.imshow(imshow)
                    plt.show()

In [None]:
# 导出视频

In [None]:
    
def images_to_video(frame_list,name,Val_size=(512,256),Flip=False):
    fps = 30         
    size = (Val_size[0], Val_size[1]) # 需要转为视频的图片的尺寸
    out = cv2.VideoWriter(name,0x7634706d , fps, size)
    count = 0
    for frame in frame_list:
        count += 1
        frame *= 255
        frame = frame.astype(np.uint8)
        if count % 20 ==0:
            plt.figure(figsize=(16,8))
            plt.imshow(frame[:,:,[2,1,0]])
            plt.axis('off')
            plt.show()
        out.write(frame)
    out.release()
    
images_to_video(show_list,'./opticalFlow/tianmouc_OF_direct_HS.mp4',Val_size=(640,320*3),Flip=False)