# 基于SD的特征追踪

## 这个示例展示一个在AOP上运行的特征点追踪算法

调用接口：
- tianmoucv.proc.features.(HarrisCorner,sift,hog)
- tianmoucv.proc.tracking.(feature_matching,mini_l2_cost_matching,align_images)

In [None]:
%load_ext autoreload

## 必要的包

In [None]:
%autoreload
import sys,os
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from tianmoucv.isp import lyncam_raw_comp,demosaicing_npy,SD2XY
from tianmoucv.proc.features import HarrisCorner,hog
from tianmoucv.proc.tracking import feature_matching,mini_l2_cost_matching,align_images
import cv2
from tianmoucv.data import TianmoucDataReader

In [None]:
val='/data/lyh/tianmoucData/20240930_tobi_sup_exp/data/10klux_5000rpm/tianmouc/'
vallist = os.listdir(val)
valdata = [val]
key_list = []

print('---------------------------------------------------')
for sampleset in valdata:
    print('---->',sampleset,'有：',len(os.listdir(sampleset)),'个样本')
    for e in os.listdir(sampleset):
        print(e,end=" ")
        key_list.append(e)
        
all_data = valdata

def images_to_video(frame_list,name,Val_size=(512,256),Flip=False):
    fps = 30     
    size = (Val_size[0], Val_size[1]*2) # 需要转为视频的图片的尺寸
    out = cv2.VideoWriter(name,0x7634706d , fps, size)
    for frame in frame_list:
        frame2 = frame.copy()
        frame2 = frame2[:,:,[2,1,0]].astype(np.uint8)
        out.write(frame2)
    out.release()


In [None]:
%autoreload
import time
from tianmoucv.isp import SD2XY
from tianmoucv.proc.features import HarrisCorner,steadyHarrisCornerForSIFT
from tianmoucv.proc.tracking import feature_matching,mini_l2_cost_matching,align_images
from tianmoucv.proc.reconstruct import laplacian_blending
from IPython.display import clear_output
import cv2


gap  = 5
#目前的策略比较粗暴，间隔一个固定的帧就直接重新找特征点，防止积累误差
fix_update_frame = 250
time_begin = time.time()
startID = 20
endID = startID + 20

for key in key_list:
    fl_aim = []
    kp_aim = []
    imlist = []
    history = dict([])
    tracking_count = 0
    print(key)
    if key == '.ipynb_checkpoints':
        continue
    dataset = TianmoucDataReader(all_data,showList=True,
                                 matchkey = key,
                                 MAXLEN=-1,
                                 print_info=True)

    for index in range(len(dataset)):
        if index<startID:
            continue
        if index>endID:
            break
        else:
            sample = dataset[index]
            F0 = sample['F0']
            F1 = sample['F1']
            tsdiff = sample['rawDiff']
            threshed_tsdiff = tsdiff[:,...].permute(1,2,3,0)
            for t in range(0,threshed_tsdiff.shape[0],1):
                
                featureList = []
                tsdiff = torch.Tensor(sample['rawDiff'])
                SD = tsdiff[1:,t,...]
                TD = tsdiff[0,t,...]
                Ix,Iy= SD2XY(SD)
                
                image = laplacian_blending(Ix,Iy,iteration=10)
                image = (image-torch.min(image))/(torch.max(image)-torch.min(image)) 
                image *= 255
                hdr_show =F.interpolate(image.unsqueeze(0).unsqueeze(0), size=(320,640), mode='bilinear').squeeze(0).squeeze(0)
                image = image.numpy().astype(np.uint8)
                hdr_show = hdr_show.numpy().astype(np.uint8)
                hdr_show = np.stack([hdr_show]*3,axis=-1)
                F_show = (F1.clone().numpy()*255).astype(np.uint8)
                
                #第1步：计算两张图对应Harris角点检测
                startT = time.time()
                sift = cv2.SIFT_create()
                good_kp, sift_feature_List = sift.detectAndCompute(image, None)
                endT = time.time()
                kp = [(p.pt[1]*2,p.pt[0]*2) for p in good_kp]
                fl = sift_feature_List
                
                #第3步：更新待追踪特征点列表
                if tracking_count % fix_update_frame == 0 or len(fl_aim)==0:
                    print('update tracking target')
                    kp_aim = kp
                    fl_aim = fl
                    history = dict([])
                    for i in range(len(kp_aim)):
                        history[i] = [ kp[i] ]
                else:
                    if len(fl)>0:
                        matches = feature_matching(fl_aim,fl,ratio=0.7)
                        #只要匹配上，就更新待追踪点坐标和对应的特征描述子，以免场景变化过大影响追踪
                        for m in matches:
                            src_pts = kp_aim[m[0].queryIdx]
                            dst_pts = kp[m[0].trainIdx]
                            dist = (src_pts[0]-dst_pts[0])**2 + (src_pts[1]-dst_pts[1])**2
                            if dist < 1600:
                                history[m[0].queryIdx].append(kp[m[0].trainIdx])
                                kp_aim[m[0].queryIdx] = kp[m[0].trainIdx]
                                fl_aim[m[0].queryIdx,:] = fl[m[0].trainIdx,:]

                        #绘制追踪结果
                        for k in history:
                            traj = history[k]
                            y2, x2 = (None,None)
                            for kp_i in traj:
                                y1 , x1 = int(kp_i[0]),int(kp_i[1])
                                if not x2 is None:
                                    cv2.line(F_show,(x1,y1),(x2,y2),(0,255,0))
                                    cv2.line(hdr_show,(x1,y1),(x2,y2),(0,255,0))
                                    cv2.circle(F_show,(x1,y1),2,(0,0,255))
                                    cv2.circle(hdr_show,(x1,y1),2,(0,0,255))
                                y2 = y1
                                x2 = x1
                            cv2.circle(F_show,(x2,y2),2,(255,0,0))
                            cv2.circle(hdr_show,(x2,y2),2,(255,0,0))   

                        imshow = np.concatenate([F_show,hdr_show],axis=0)
                        imlist.append(imshow)
                    else:
                        print('no useable new feature')
                tracking_count += 1
                
            if index%1 ==0:
                time_10 = time.time()
                clear_output()
                print('avg sfps:',10/(time_10-time_begin))
                time_begin = time.time()
                plt.imshow(imshow/255.0)
                plt.show()
    
    images_to_video(imlist,'./realviz/feature_tracking_'+key+'.mp4',Val_size=(640,320),Flip=False)