In [None]:
import cv2
import numpy as np
from pathlib import Path
import shutil

def create_folders():
    """创建必要的文件夹结构"""
    folders = [
        'temp_frames/up',
        'temp_frames/down',
        'temp_frames/left',
        'temp_frames/right',
        'output/split',
        'output'
    ]
    
    for folder in folders:
        Path(folder).mkdir(parents=True, exist_ok=True)

def check_and_resize_image(img, max_dimension=3000):
    """检查并在必要时调整图像大小"""
    height, width = img.shape[:2]
    if height > max_dimension or width > max_dimension:
        scale = max_dimension / max(height, width)
        new_width = int(width * scale)
        new_height = int(height * scale)
        return cv2.resize(img, (new_width, new_height))
    return img

def extract_and_split_frames(video_path, frame_interval, helicopter_ratio=0.1):
    """
    提取帧并分割存储
    Args:
        video_path: 视频路径
        frame_interval: 帧间隔
        helicopter_ratio: 直升机区域占高度的比例
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"无法打开视频文件: {video_path}")
    
    frame_count = 0
    saved_count = 0
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    print(f"总帧数: {total_frames}")
    print(f"预计处理帧数: {total_frames // frame_interval}")
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        if frame_count % frame_interval == 0:
            # 计算分割区域
            height, width = frame.shape[:2]
            helicopter_height = int(height * helicopter_ratio)
            center_height = height - 2 * helicopter_height  # 中间区域的高度
            
            # 计算各个区域的范围
            up = frame[0:helicopter_height, :]
            down = frame[height-helicopter_height:height, :]
            left = frame[helicopter_height:height-helicopter_height, 0:width//2]
            right = frame[helicopter_height:height-helicopter_height, width//2:width]
            
            # 保存分割后的图片
            cv2.imwrite(f'temp_frames/up/frame_{saved_count}_up.jpg', up)
            cv2.imwrite(f'temp_frames/down/frame_{saved_count}_down.jpg', down)
            cv2.imwrite(f'temp_frames/left/frame_{saved_count}_left.jpg', left)
            cv2.imwrite(f'temp_frames/right/frame_{saved_count}_right.jpg', right)
            
            print(f"\r提取帧进度: {saved_count + 1}/{total_frames // frame_interval}", end="", flush=True)
            saved_count += 1
        
        frame_count += 1
    
    print("\n帧提取完成")
    cap.release()
    return saved_count

def match_features(gray1, gray2):
    """改进的特征匹配函数"""
    # 创建SIFT检测器，增加特征点数量
    sift = cv2.SIFT_create(nfeatures=2000)
    kp1, des1 = sift.detectAndCompute(gray1, None)
    kp2, des2 = sift.detectAndCompute(gray2, None)
    
    if des1 is None or des2 is None or len(kp1) < 2 or len(kp2) < 2:
        return None, None, []
    
    # 使用FLANN匹配器提高效率
    FLANN_INDEX_KDTREE = 1
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    search_params = dict(checks=50)
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    
    try:
        matches = flann.knnMatch(des1, des2, k=2)
    except Exception:
        return None, None, []
    
    # 应用比率测试筛选好的匹配点
    good_matches = []
    try:
        for m, n in matches:
            if m.distance < 0.7 * n.distance:  # 提高匹配点质量要求
                good_matches.append(m)
    except ValueError:
        return None, None, []
    
    return kp1, kp2, good_matches

def stitch_images(img1, img2, max_dimension=3000):
    """改进的图像拼接函数"""
    try:
        # 检查并调整输入图像大小
        img1 = check_and_resize_image(img1, max_dimension)
        img2 = check_and_resize_image(img2, max_dimension)
        
        # 转换为灰度图
        gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
        gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
        
        # 特征匹配
        kp1, kp2, good_matches = match_features(gray1, gray2)
        
        # 检查匹配点数量
        if len(good_matches) < 10:  # 增加最小匹配点要求
            print(f"\nWarning: 特征点匹配数量不足 ({len(good_matches)})")
            return img2  # 返回第二张图片而不是None
        
        # 获取匹配点坐标
        src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
        dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
        
        # 使用RANSAC计算单应性矩阵
        H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
        
        if H is None:
            print("\nWarning: 无法计算单应性矩阵")
            return img2
        
        # 计算变换后的图像范围
        h1, w1 = img1.shape[:2]
        h2, w2 = img2.shape[:2]
        pts = np.float32([[0, 0], [0, h1], [w1, h1], [w1, 0]]).reshape(-1, 1, 2)
        dst = cv2.perspectiveTransform(pts, H)
        
        # 计算输出图像的大小（添加限制）
        xmin = max(int(dst[:, 0, 0].min()), -max_dimension)
        ymin = max(int(dst[:, 0, 1].min()), -max_dimension)
        xmax = min(int(dst[:, 0, 0].max()), w2 + max_dimension)
        ymax = min(int(dst[:, 0, 1].max()), h2 + max_dimension)
        
        # 确保输出尺寸不会太大
        if (xmax - xmin) > max_dimension * 2 or (ymax - ymin) > max_dimension * 2:
            print("\nWarning: 输出图像尺寸过大，跳过拼接")
            return img2
        
        # 调整变换矩阵
        translation_matrix = np.array([
            [1, 0, -xmin],
            [0, 1, -ymin],
            [0, 0, 1]
        ])
        H = translation_matrix.dot(H)
        
        # 创建输出图像
        output_img = cv2.warpPerspective(img1, H, (xmax-xmin, ymax-ymin))
        output_img[-ymin:h2-ymin, -xmin:w2-xmin] = img2
        
        # 创建渐变混合蒙版
        mask = np.zeros((output_img.shape[0], output_img.shape[1]), dtype=np.float32)
        mask[-ymin:h2-ymin, -xmin:w2-xmin] = 1
        mask = cv2.GaussianBlur(mask, (61, 61), 31)
        mask = np.dstack((mask, mask, mask))
        
        # 混合图像
        warped_img = cv2.warpPerspective(img1, H, (xmax-xmin, ymax-ymin))
        img2_placed = np.zeros_like(output_img)
        img2_placed[-ymin:h2-ymin, -xmin:w2-xmin] = img2
        
        output_img = img2_placed * mask + warped_img * (1 - mask)
        
        return output_img.astype(np.uint8)
        
    except Exception as e:
        print(f"\nError in stitch_images: {str(e)}")
        return img2  # 出错时返回第二张图片而不是None

def stitch_region(region_name, total_frames):
    """拼接指定区域的所有帧"""
    result = cv2.imread(f'temp_frames/{region_name}/frame_0_{region_name}.jpg')
    if result is None:
        raise ValueError(f"无法读取{region_name}区域的第一帧")
    
    print(f"\n开始处理{region_name}区域...")
    
    for i in range(1, total_frames):
        next_frame = cv2.imread(f'temp_frames/{region_name}/frame_{i}_{region_name}.jpg')
        if next_frame is None:
            print(f"\nWarning: Could not read frame {i} for {region_name}")
            continue
        
        print(f"\r{region_name}区域进度: {i}/{total_frames-1}", end="", flush=True)
        
        stitched = stitch_images(result, next_frame)
        if stitched is not None:
            result = stitched
    
    print(f"\n{region_name}区域处理完成")
    cv2.imwrite(f'output/split/{region_name}_stitched.jpg', result)

def merge_final_map(helicopter_ratio):
    """合并四个区域图片"""
    # 读取各个区域的拼接结果
    up = cv2.imread('output/split/up_stitched.jpg')
    down = cv2.imread('output/split/down_stitched.jpg')
    left = cv2.imread('output/split/left_stitched.jpg')
    right = cv2.imread('output/split/right_stitched.jpg')
    
    if any(img is None for img in [up, down, left, right]):
        raise ValueError("无法读取某个区域的拼接结果")
    
    # 创建最终地图
    central_height = max(left.shape[0], right.shape[0])
    total_height = central_height + up.shape[0] + down.shape[0]
    total_width = max(up.shape[1], left.shape[1] + right.shape[1], down.shape[1])
    
    final_map = np.zeros((total_height, total_width, 3), dtype=np.uint8)
    
    # 放置上部分
    final_map[0:up.shape[0], :up.shape[1]] = up
    
    # 放置左右部分
    start_y = up.shape[0]
    final_map[start_y:start_y+left.shape[0], :left.shape[1]] = left
    final_map[start_y:start_y+right.shape[0], total_width-right.shape[1]:] = right
    
    # 放置下部分
    start_y = start_y + max(left.shape[0], right.shape[0])
    final_map[start_y:start_y+down.shape[0], :down.shape[1]] = down
    
    # 保存最终结果
    cv2.imwrite('output/final_map.jpg', final_map)

def cleanup():
    """清理临时文件"""
    shutil.rmtree('temp_frames', ignore_errors=True)

def main(video_path, frame_interval=30, helicopter_ratio=0.1):
    """
    主函数
    Args:
        video_path: 视频路径
        frame_interval: 帧间隔
        helicopter_ratio: 直升机区域占高度的比例(0-1之间)
    """
    try:
        # 创建文件夹
        create_folders()
        
        # 提取和分割帧
        print("正在提取和分割帧...")
        total_frames = extract_and_split_frames(video_path, frame_interval, helicopter_ratio)
        
        # 拼接各个区域
        print("正在拼接各个区域...")
        for region in ['up', 'down', 'left', 'right']:
            print(f"处理{region}区域...")
            stitch_region(region, total_frames)
        
        # 合并最终地图
        print("正在合成最终地图...")
        merge_final_map(helicopter_ratio)
        
        print("处理完成！")
    except Exception as e:
        print(f"发生错误: {str(e)}")
    finally:
        # 是否要清理临时文件可以根据需要决定
        # cleanup()
        pass

if __name__ == "__main__":
    video_path = "input_video.mp4"  # 替换为实际的视频路径
    frame_interval = 90  # 每30帧取一帧
    helicopter_ratio = 0.32  # 直升机区域占10%
    main(video_path, frame_interval, helicopter_ratio)