# 基于SpyNet的光流网络

## 这个示例展示一个在AOP上运行的，推理快速的端到端光流网络

调用接口：
- tianmoucv.proc.opticalflow.TianmoucOF_SpyNet


In [None]:
%load_ext autoreload

## 必要的库

In [None]:
%autoreload
import sys,os, math,time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tianmoucv.data import TianmoucDataReader
import torch.nn.functional as F
import cv2

## 准备数据

In [None]:
train='/data/lyh/tianmoucData/tianmoucReconDataset/train/'
dirlist = os.listdir(train)
traindata = [train + e for e in dirlist]

val='/data/lyh/tianmoucData/tianmoucReconDataset/test/'
vallist = os.listdir(val)
valdata = [val + e for e in vallist]
key_list = []
print('---------------------------------------------------')
for sampleset in traindata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e,end=" ")
        key_list.append(e)
print('---------------------------------------------------')
for sampleset in valdata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e,end=" ")
        key_list.append(e)
        
all_data = valdata + traindata
key_list = ['underbridge_hdr_4']

## 光流网络初始化

In [None]:
%autoreload
from tianmoucv.proc.opticalflow import TianmoucOF_SpyNet

local_rank = 0
device = torch.device('cuda:'+str(local_rank))
OFNet = TianmoucOF_SpyNet((320,640),_optim=False)
OFNet.to(device)
OFNet.eval()

# 光流计算

In [None]:
%autoreload
from IPython.display import clear_output
from tianmoucv.isp import *
from tianmoucv.proc.opticalflow import interpolate_image,flow_to_image
import time
import cv2

imlist = []
noiseThresh = 0
W = 640
H = 320
acctime= 1
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))

dataset = TianmoucDataReader(all_data,MAXLEN=400,matchkey = 'test_exam_fan4')

show_list = []
for index in range(len(dataset)):
    if index <=0:
        continue
    elif index > 10:
        break
    else:
        print('rpogress:',index,'/',len(dataset))
        sample = dataset[index]
        F0 = sample['F0']
        F1 = sample['F1']
        tsdiff = sample['tsdiff']
        F0show = F0.copy()
        show_img = F0show.copy()
        for b in range(25//acctime):
            SD0 = 0
            SD1 = 0
            Tdiff = 0
                        
            with torch.no_grad():
                #print(tsdiff.shape)
                rawflow = OFNet.forward_time_range(tsdiff.unsqueeze(0), t1=b*acctime, t2=(b+1)*acctime) #输出值0~1
                rawflow = rawflow.cpu()
                
            SD0 = tsdiff[1:,b*acctime,...].unsqueeze(0).to(device)
            SD1 = tsdiff[1:,(b+1)*acctime,...].unsqueeze(0).to(device)
            Tdiff= tsdiff[0:1,b*acctime:(b+1)*acctime,...].to(device)
            Tdiff = torch.sum(Tdiff,dim=1).unsqueeze(0)
            
            td = -tsdiff[0,(b+1)*acctime,...].to(device)
            tdiff_show = np.stack([td.cpu()*255]*3,axis=2).astype(np.uint8)
            tdiff_show[abs(tdiff_show)<8]=0
            
            Tdiff = F.interpolate(Tdiff,(320,640),mode='bilinear')
            SD0 = F.interpolate(SD0,(320,640),mode='bilinear')
            SD1 = F.interpolate(SD1,(320,640),mode='bilinear')

            
            u = rawflow[0,0:1,:, :]
            v = rawflow[0,1:2,:, :]
            flow_show = flow_to_image(rawflow[0,...].permute(1,2,0).numpy())
            flow_show = torch.Tensor(cv2.resize(flow_show,(640,320)))/255.0
            flow_show = (flow_show*255).numpy().astype(np.uint8)
            
            mask = np.mean(flow_show,axis=-1) > 225
            flow_show[np.stack([mask]*3,axis=-1)]=0
            
            show_img = interpolate_image(show_img,u,v)
            sparsity = 8
            scale = 10
            for w in range(640//sparsity):
                for h in range(320//sparsity):
                    x = int(w*sparsity)
                    y = int(h*sparsity)
                    u_ij = -u[0,y,x]
                    v_ij = -v[0,y,x]
                    color = flow_show[y,x,:]
                    color = tuple([int(e+20) for e in color])
                    if (u_ij**2+v_ij**2)>5:
                        cv2.arrowedLine(flow_show, (x,y), (int(x+u_ij*scale),int(y+v_ij*scale)), color,2, tipLength=0.15)
            
            tdiff_show_tensor = torch.Tensor(tdiff_show.copy())
            flow_show_tensor = torch.Tensor(flow_show)
            mask = torch.stack([torch.mean(flow_show_tensor,dim=-1)>0]*3,dim=-1)
            tdiff_show_tensor[mask] = flow_show_tensor[mask]/255.0
            tdiff_show_merge = tdiff_show_tensor.numpy()
            imshow = np.concatenate([flow_show/255.0,tdiff_show,tdiff_show_merge],axis=0)
            show_list.append(imshow)
            
            if b%10==0:
                clear_output()
                plt.figure(figsize=(9,5))
                plt.axis('off') 
                plt.subplot(2,3,1)
                plt.imshow(SD0[0,0,...].cpu(),cmap='gray')
                plt.subplot(2,3,2)
                plt.imshow(Tdiff[0,0,...].cpu(),cmap='gray')
                plt.axis('off') 
                plt.subplot(2,3,4)
                plt.imshow(F0show)
                plt.subplot(2,3,5)
                plt.imshow(flow_show/255.0)
                plt.subplot(2,3,6)
                plt.imshow(imshow)
                plt.show()

## 导出视频

In [None]:
def images_to_video(frame_list,name,Val_size=(512,256),Flip=False):
    fps = 30         
    size = (Val_size[0], Val_size[1]) # 需要转为视频的图片的尺寸
    out = cv2.VideoWriter(name,0x7634706d , fps, size)
    count = 0
    for frame in frame_list:
        count += 1
        frame *= 255
        frame = frame.astype(np.uint8)
        out.write(frame)
    out.release()
    
images_to_video(show_list,'./spynet_tianmouc_of_multiple_scale_nn.mp4',Val_size=(640,320*3),Flip=False)