In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast

import cv2
import numpy as np
from pathlib import Path
import logging
import time
from typing import Dict, List, Tuple
import wandb
from tqdm import tqdm
import random

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import kornia
from torchvision.transforms.functional import to_tensor
from sklearn.metrics import precision_recall_fscore_support

### 第一部分：模型定义

In [None]:
class PatchEmbed(nn.Module):
    """
    图像patch嵌入模块
    将输入图像划分为固定大小的patch，并进行特征提取
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # 使用卷积层进行patch划分和特征提取
        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )
        
        # 可学习的位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.shape[0]
        x = self.proj(x)
        
        # 添加分类token和位置编码
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embed
        return x

class MultiHeadAttention(nn.Module):
    """
    多头自注意力机制
    用于捕捉图像patch之间的长程依赖关系
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        # 计算注意力权重
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # 注意力加权求和
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class TransformerBlock(nn.Module):
    """
    Transformer编码器块
    包含多头自注意力和前馈神经网络
    """
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 
                                     attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(drop)
        )

    def forward(self, x):
        # 残差连接
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class ViTFeatureExtractor(nn.Module):
    """
    Vision Transformer特征提取器
    用于提取图像的层级特征表示
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.blocks(x)
        x = self.norm(x)
        return x

class CrossAttentionMatcher(nn.Module):
    """
    交叉注意力匹配模块
    用于计算两张图像特征之间的相似度
    """
    def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x1, x2):
        B, N, C = x1.shape
        # 计算查询、键和值
        q = self.q_proj(x1).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k_proj(x2).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v_proj(x2).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # 计算注意力分数
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # 注意力加权求和
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x, attn

class ImageStitchingTransformer(nn.Module):
    """
    完整的图像拼接Transformer模型
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        # 特征提取器
        self.feature_extractor = ViTFeatureExtractor(
            img_size, patch_size, in_channels, embed_dim, depth, 
            num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate
        )
        # 特征匹配器
        self.matcher = CrossAttentionMatcher(embed_dim, num_heads, attn_drop_rate, drop_rate)
        
        # 匹配得分头
        self.match_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Linear(embed_dim // 2, 1)
        )

    def forward(self, img1, img2):
        # 提取特征
        feat1 = self.feature_extractor(img1)
        feat2 = self.feature_extractor(img2)
        
        # 交叉注意力匹配
        matched_features, attention_weights = self.matcher(feat1, feat2)
        
        # 生成匹配分数
        matching_scores = self.match_head(matched_features).squeeze(-1)
        
        return {
            'features1': feat1,
            'features2': feat2,
            'matched_features': matched_features,
            'attention_weights': attention_weights,
            'matching_scores': matching_scores
        }

### 第二部分：数据处理

In [None]:
class VideoFrameExtractor:
    """
    视频帧提取器
    从输入视频中提取帧用于训练
    """
    def __init__(self, overlap_ratio=0.3, frame_interval=5):
        self.overlap_ratio = overlap_ratio
        self.frame_interval = frame_interval
        
    def extract_frames(self, video_path, output_dir):
        """提取视频帧并保存"""
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise ValueError(f"无法打开视频文件: {video_path}")
            
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        frame_pairs = []
        frame_count = 0
        last_saved_frame = None
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
                
            if frame_count % self.frame_interval == 0:
                frame_path = output_dir / f"frame_{frame_count:06d}.jpg"
                cv2.imwrite(str(frame_path), frame)
                
                if last_saved_frame is not None:
                    frame_pairs.append((last_saved_frame, frame_path))
                last_saved_frame = frame_path
                
            frame_count += 1
            
        cap.release()
        return frame_pairs
    class StitchingDataset(Dataset):
    """
    图像拼接数据集
    用于训练图像拼接模型
    """
    def __init__(self, frame_pairs, img_size=224, is_train=True):
        self.frame_pairs = frame_pairs
        self.img_size = img_size
        self.is_train = is_train
        
        # 基础图像变换
        self.basic_transform = A.Compose([
            A.Resize(img_size, img_size),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2()
        ])
        
        # 训练时的数据增强
        if is_train:
            self.train_transform = A.Compose([
                A.RandomBrightnessContrast(p=0.5),
                A.HueSaturationValue(p=0.3),
                A.GaussNoise(p=0.2),
                A.RandomRotate90(p=0.2),
                A.HorizontalFlip(p=0.5),
            ])
            
    def __len__(self):
        return len(self.frame_pairs)
        
    def __getitem__(self, idx):
        img1_path, img2_path = self.frame_pairs[idx]
        
        # 读取图像
        img1 = cv2.imread(str(img1_path))
        img2 = cv2.imread(str(img2_path))
        
        # BGR转RGB
        img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
        img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
        
        # 训练时进行数据增强
        if self.is_train:
            seed = random.randint(0, 2**32)
            
            # 对两张图像应用相同的随机变换
            random.seed(seed)
            img1 = self.train_transform(image=img1)["image"]
            random.seed(seed)
            img2 = self.train_transform(image=img2)["image"]
        
        # 应用基础变换
        img1 = self.basic_transform(image=img1)["image"]
        img2 = self.basic_transform(image=img2)["image"]
        
        return {
            'image1': img1,
            'image2': img2,
            'path1': str(img1_path),
            'path2': str(img2_path)
        }

def create_dataloaders(video_path, output_dir, batch_size=8, img_size=224, 
                      num_workers=4, frame_interval=5):
    """
    创建训练和验证数据加载器
    
    参数:
        video_path: 输入视频路径
        output_dir: 帧保存目录
        batch_size: 批次大小
        img_size: 图像大小
        num_workers: 数据加载进程数
        frame_interval: 帧采样间隔
    """
    # 提取视频帧
    extractor = VideoFrameExtractor(frame_interval=frame_interval)
    frame_pairs = extractor.extract_frames(video_path, output_dir)
    
    # 划分训练集和验证集
    train_pairs = frame_pairs[:-len(frame_pairs)//5]  # 后20%用于验证
    val_pairs = frame_pairs[-len(frame_pairs)//5:]
    
    # 创建数据集
    train_dataset = StitchingDataset(train_pairs, img_size=img_size, is_train=True)
    val_dataset = StitchingDataset(val_pairs, img_size=img_size, is_train=False)
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

class Trainer:
    """
    模型训练器
    管理整个训练过程，包括日志记录、检查点保存等
    """
    def __init__(
        self,
        model: nn.Module,
        train_loader: torch.utils.data.DataLoader,
        val_loader: torch.utils.data.DataLoader,
        config: Dict,
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # 设置设备
        self.device = torch.device(config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu'))
        self.model = self.model.to(self.device)
        
        # 设置优化器
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config.get('weight_decay', 0.01)
        )
        
        # 设置学习率调度器
        self.scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=config['epochs'],
            eta_min=config.get('min_lr', 1e-6)
        )
        
        # 设置混合精度训练
        self.scaler = GradScaler()
        
        # 设置日志
        self.setup_logging()
        
        # 设置wandb
        if config.get('use_wandb', False):
            wandb.init(
                project=config.get('wandb_project', 'image-stitching'),
                name=config.get('wandb_run_name', time.strftime('%Y%m%d_%H%M%S')),
                config=config
            )