# 基于SpyNet的光流网络

## 这个示例展示一个在AOP上运行的，推理快速的端到端光流网络

调用接口：
- tianmoucv.proc.opticalflow.TianmoucOF_SpyNet
- 输入方式1: 输入sample中的tsdiff，指定计算t1~t2
- 输入方式2: 输入ti对应的sd1和t2对应的sd2

In [None]:
%load_ext autoreload

## 必要的库

In [None]:
%autoreload
import sys,os, math,time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tianmoucv.data import TianmoucDataReader
import torch.nn.functional as F
import cv2

## 准备数据

In [None]:
train='/data/lyh/tianmoucData/tianmoucReconDataset/train/'
dirlist = os.listdir(train)
traindata = [train + e for e in dirlist]

val='/data/lyh/tianmoucData/tianmoucReconDataset/test/'
vallist = os.listdir(val)
valdata = [val + e for e in vallist]
key_list = []
print('---------------------------------------------------')
for sampleset in traindata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e,end=" ")
        key_list.append(e)
print('---------------------------------------------------')
for sampleset in valdata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e,end=" ")
        key_list.append(e)
        
all_data = valdata + traindata
key_list = ['underbridge_hdr_4']

## 光流网络初始化

In [None]:
%autoreload
from tianmoucv.proc.opticalflow import TianmoucOF_SpyNet

local_rank = 7
device = torch.device('cuda:'+str(local_rank))
OFNet = TianmoucOF_SpyNet((320,640),_optim=False)
OFNet.to(device)
OFNet.eval()

# 光流计算

In [None]:
%autoreload
from IPython.display import clear_output
from tianmoucv.isp import *
from tianmoucv.proc.opticalflow import interpolate_image,flow_to_image
import time
import cv2

imlist = []
noiseThresh = 0
W = 640
H = 320
acctime= 5
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))


def images_to_video(frame_list,name,size=(640,320),Flip=True):
    fps = 25        
    ftmax = 1
    ftmin = 0
    out = cv2.VideoWriter(name,0x7634706d , fps, size)
    for ft in frame_list:
        ft = (ft-ftmin)/(ftmax-ftmin)
        ft[ft>1]=1
        ft[ft<0]=0
        ft2 = (ft*255).astype(np.uint8)
        out.write(ft2)
    out.release()

show_list = []


for key in key_list:
    dataset = TianmoucDataReader(all_data,MAXLEN=400,matchkey = key,print_info=False)
    for index in range(len(dataset)):
        if index<= 5:
            sample = dataset[index]
            F0 = sample['F0'].numpy()
            F1 = sample['F1'].numpy()
            tsdiff = sample['tsdiff']
            F0show = F0.copy()
            show_img = F0show.copy()

            # you may choose the accumulation time(minimum is 1, uint is frame index)
            for b in range((tsdiff.shape[1]-1)//acctime):

                # conduct optical flow estimation using rafe
                with torch.no_grad():
                    # 输入方式1，如果t1,t2都在 sample的 time range 内部
                    #rawflow = OFNet.forward_time_range(tsdiff.unsqueeze(0), t1=b*acctime, t2=(b+1)*acctime) #输出值0~1
                    
                    # 输入方式2， 通用：
                    td = torch.sum(tsdiff[0:1,(b)*acctime:(b+1)*acctime,...],dim=1)
                    sd0 = tsdiff[1:,(b)*acctime,...]
                    sd1 = tsdiff[1:,(b+1)*acctime,...]
                    rawflow = OFNet(td,sd0,sd1,print_fps = True) #输出值0~1
                    rawflow = rawflow.cpu()
                    
                #visualization
                td = tsdiff[0,(b+1)*acctime,...] * 128
                rgb_td = vizDiff(td.cpu(),thresh=3,bg_color='black')
                tdiff_show = rgb_td.numpy() 

                #optical flow visualization
                u = rawflow[0,0:1,:, :] #x
                v = rawflow[0,1:2,:, :] #y
                flow_show = flow_to_image(rawflow[0,...].permute(1,2,0).numpy())
                flow_show = torch.Tensor(cv2.resize(flow_show,(640,320)))
                
                mask = torch.mean(flow_show,dim=-1) > 225
                flow_show[torch.stack([mask]*3,dim=-1)]=0
                flow_show = flow_show.numpy()
                
                show_img = interpolate_image(show_img,u,v)

                # add arrows to optical flow
                sparsity = 4
                scale = 5
                for w in range(640//sparsity):
                    for h in range(320//sparsity):
                        x = int(w*sparsity)
                        y = int(h*sparsity)
                        u_ij = -u[0,y,x]
                        v_ij = -v[0,y,x]
                        color = flow_show[y,x,:]
                        color = tuple([int(e+20) for e in color])
                        if (u_ij**2+v_ij**2)>5:
                            cv2.arrowedLine(flow_show, (x,y), (int(x+u_ij*scale),int(y+v_ij*scale)), color,2, tipLength=0.15)

                #concate for output
                tdiff_show_tensor = torch.Tensor(tdiff_show.copy())
                flow_show_tensor = torch.Tensor(flow_show)
                mask = torch.stack([torch.mean(flow_show_tensor,dim=-1)>0]*3,dim=-1)
                tdiff_show_tensor[mask] = flow_show_tensor[mask]/255.0
                tdiff_show_merge = tdiff_show_tensor.numpy()
                imshow = np.concatenate([flow_show/255.0,tdiff_show,tdiff_show_merge],axis=0)
                imshow1 = np.concatenate([flow_show/255.0,F0show],axis=1)
                imshow2 = np.concatenate([tdiff_show,tdiff_show_merge],axis=1)
                imshow = np.concatenate([imshow1,imshow2],axis=0)
                show_list.append(imshow)
    
                if b %10 ==0:
                    clear_output()
                    plt.figure(figsize=(6,3))
                    plt.axis('off') 
                    plt.imshow(imshow)
                    plt.show()
        else:
            clear_output()
            output_dir = './output'
            if not os.path.exists(output_dir):
                 os.mkdir(output_dir)
            output_name = os.path.join(output_dir,'OF_RAFT_viz_'+key+'.mp4')
            images_to_video(show_list,output_name,size=(640*2,320*2),Flip=True)
            break