In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# 1. Patch Embedding and Position Encoding
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )
        
        # Learnable position embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.shape[0]
        x = self.proj(x)
        
        # Add cls token and position embeddings
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embed
        return x

# 2. Multi-Head Self Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# 3. Transformer Encoder Block
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(drop)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# 4. Vision Transformer Feature Extractor
class ViTFeatureExtractor(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.blocks(x)
        x = self.norm(x)
        return x

# 5. Cross Attention Matching Module
class CrossAttentionMatcher(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        assert dim % num_heads == 0
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x1, x2):
        B, N, C = x1.shape
        q = self.q_proj(x1).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k_proj(x2).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v_proj(x2).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        # Return both attention weights and values
        return x, attn

# 6. Complete Image Stitching Transformer
class ImageStitchingTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        self.feature_extractor = ViTFeatureExtractor(
            img_size, patch_size, in_channels, embed_dim, depth, 
            num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate
        )
        self.matcher = CrossAttentionMatcher(embed_dim, num_heads, attn_drop_rate, drop_rate)
        
        # Matching head
        self.match_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Linear(embed_dim // 2, 1)
        )

    def forward(self, img1, img2):
        # Extract features
        feat1 = self.feature_extractor(img1)
        feat2 = self.feature_extractor(img2)
        
        # Cross attention matching
        matched_features, attention_weights = self.matcher(feat1, feat2)
        
        # Generate matching scores
        matching_scores = self.match_head(matched_features).squeeze(-1)
        
        return {
            'features1': feat1,
            'features2': feat2,
            'matched_features': matched_features,
            'attention_weights': attention_weights,
            'matching_scores': matching_scores
        }

# 7. Training utilities
def matching_loss(pred_scores, gt_matches, temperature=0.1):
    """
    Compute matching loss using cross entropy
    pred_scores: predicted matching scores [B, N]
    gt_matches: ground truth matching indices [B, N]
    """
    return F.cross_entropy(pred_scores / temperature, gt_matches)

# Example usage
def train_step(model, optimizer, img1, img2, gt_matches):
    optimizer.zero_grad()
    outputs = model(img1, img2)
    loss = matching_loss(outputs['matching_scores'], gt_matches)
    loss.backward()
    optimizer.step()
    return loss.item()

# Testing/inference function
@torch.no_grad()
def find_matches(model, img1, img2, threshold=0.5):
    """
    Find matches between two images
    Returns: matched point pairs and confidence scores
    """
    outputs = model(img1, img2)
    scores = torch.sigmoid(outputs['matching_scores'])
    matches = scores > threshold
    
    # Get patch coordinates for matched points
    matched_indices = torch.nonzero(matches).cpu().numpy()
    confidence_scores = scores[matches].cpu().numpy()
    
    return matched_indices, confidence_scores

In [None]:
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pathlib import Path
import random

class VideoFrameExtractor:
    """Extract frames from video for stitching"""
    def __init__(self, overlap_ratio=0.3, frame_interval=5):
        """
        Args:
            overlap_ratio (float): Expected overlap between consecutive frames
            frame_interval (int): Number of frames to skip
        """
        self.overlap_ratio = overlap_ratio
        self.frame_interval = frame_interval
        
    def extract_frames(self, video_path, output_dir):
        """
        Extract frames from video and save to directory
        
        Args:
            video_path (str): Path to input video
            output_dir (str): Directory to save extracted frames
        
        Returns:
            list: Paths to extracted frame pairs
        """
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise ValueError(f"Error opening video file: {video_path}")
            
        # Create output directory
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        frame_pairs = []
        frame_count = 0
        last_saved_frame = None
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
                
            if frame_count % self.frame_interval == 0:
                frame_path = output_dir / f"frame_{frame_count:06d}.jpg"
                cv2.imwrite(str(frame_path), frame)
                
                if last_saved_frame is not None:
                    frame_pairs.append((last_saved_frame, frame_path))
                last_saved_frame = frame_path
                
            frame_count += 1
            
        cap.release()
        return frame_pairs

class StitchingDataset(Dataset):
    """Dataset for training image stitching model"""
    def __init__(self, frame_pairs, img_size=224, is_train=True):
        """
        Args:
            frame_pairs (list): List of (frame1_path, frame2_path) tuples
            img_size (int): Size of input images
            is_train (bool): Whether to use training augmentations
        """
        self.frame_pairs = frame_pairs
        self.img_size = img_size
        self.is_train = is_train
        
        # Basic augmentations for both train and test
        self.basic_transform = A.Compose([
            A.Resize(img_size, img_size),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2()
        ])
        
        # Additional augmentations for training
        if is_train:
            self.train_transform = A.Compose([
                A.RandomBrightnessContrast(p=0.5),
                A.HueSaturationValue(p=0.3),
                A.GaussNoise(p=0.2),
                A.RandomRotate90(p=0.2),
                A.HorizontalFlip(p=0.5),
            ])
            
    def __len__(self):
        return len(self.frame_pairs)
        
    def __getitem__(self, idx):
        img1_path, img2_path = self.frame_pairs[idx]
        
        # Read images
        img1 = cv2.imread(str(img1_path))
        img2 = cv2.imread(str(img2_path))
        
        # Convert BGR to RGB
        img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
        img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
        
        # Apply training augmentations
        if self.is_train:
            seed = random.randint(0, 2**32)
            
            # Apply same random transforms to both images
            random.seed(seed)
            img1 = self.train_transform(image=img1)["image"]
            random.seed(seed)
            img2 = self.train_transform(image=img2)["image"]
        
        # Apply basic transforms
        img1 = self.basic_transform(image=img1)["image"]
        img2 = self.basic_transform(image=img2)["image"]
        
        return {
            'image1': img1,
            'image2': img2,
            'path1': str(img1_path),
            'path2': str(img2_path)
        }

def create_dataloaders(video_path, output_dir, batch_size=8, img_size=224, 
                      num_workers=4, frame_interval=5):
    """
    Create training and validation dataloaders from video
    
    Args:
        video_path (str): Path to input video
        output_dir (str): Directory to save extracted frames
        batch_size (int): Batch size for dataloaders
        img_size (int): Size of input images
        num_workers (int): Number of workers for dataloaders
        frame_interval (int): Number of frames to skip during extraction
    
    Returns:
        tuple: (train_loader, val_loader)
    """
    # Extract frames
    extractor = VideoFrameExtractor(frame_interval=frame_interval)
    frame_pairs = extractor.extract_frames(video_path, output_dir)
    
    # Split into train/val
    train_pairs = frame_pairs[:-len(frame_pairs)//5]  # Last 20% for validation
    val_pairs = frame_pairs[-len(frame_pairs)//5:]
    
    # Create datasets
    train_dataset = StitchingDataset(train_pairs, img_size=img_size, is_train=True)
    val_dataset = StitchingDataset(val_pairs, img_size=img_size, is_train=False)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

# Utility functions for visualization
def visualize_pair(img1, img2, matched_points=None):
    """
    Visualize image pair with optional matched points
    
    Args:
        img1, img2 (torch.Tensor): Input image tensors
        matched_points (numpy.ndarray): Matched point pairs
    
    Returns:
        numpy.ndarray: Visualization image
    """
    # Convert tensors to numpy arrays
    img1 = (img1.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
    img2 = (img2.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
    
    # Create side-by-side visualization
    h, w = img1.shape[:2]
    vis_img = np.zeros((h, w*2, 3), dtype=np.uint8)
    vis_img[:, :w] = img1
    vis_img[:, w:] = img2
    
    # Draw matches if provided
    if matched_points is not None:
        for pt1, pt2 in matched_points:
            pt2 = (pt2[0] + w, pt2[1])  # Adjust x-coordinate for second image
            cv2.line(vis_img, tuple(map(int, pt1)), tuple(map(int, pt2)), 
                    (0, 255, 0), 1)
            cv2.circle(vis_img, tuple(map(int, pt1)), 3, (255, 0, 0), -1)
            cv2.circle(vis_img, tuple(map(int, pt2)), 3, (255, 0, 0), -1)
            
    return vis_img

# Example usage
if __name__ == "__main__":
    # Create dataloaders from video
    video_path = "input_video.mp4"
    output_dir = "extracted_frames"
    
    train_loader, val_loader = create_dataloaders(
        video_path=video_path,
        output_dir=output_dir,
        batch_size=8,
        img_size=224
    )
    
    # Test dataloader
    for batch in train_loader:
        img1 = batch['image1'][0]  # Get first image from batch
        img2 = batch['image2'][0]
        
        # Visualize pair
        vis_img = visualize_pair(img1, img2)
        cv2.imshow('Image Pair', vis_img)
        if cv2.waitKey(0) & 0xFF == ord('q'):
            break
            
    cv2.destroyAllWindows()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast
import wandb
import numpy as np
from tqdm import tqdm
from pathlib import Path
import time
import logging
from typing import Dict, Tuple

class Trainer:
    def __init__(
        self,
        model: nn.Module,
        train_loader: torch.utils.data.DataLoader,
        val_loader: torch.utils.data.DataLoader,
        config: Dict,
    ):
        """
        Args:
            model: Model to train
            train_loader: Training data loader
            val_loader: Validation data loader
            config: Training configuration
        """
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # Setup device
        self.device = torch.device(config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu'))
        self.model = self.model.to(self.device)
        
        # Setup optimizer
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config.get('weight_decay', 0.01)
        )
        
        # Setup scheduler
        self.scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=config['epochs'],
            eta_min=config.get('min_lr', 1e-6)
        )
        
        # Setup mixed precision training
        self.scaler = GradScaler()
        
        # Setup logging
        self.setup_logging()
        
        # Setup wandb
        if config.get('use_wandb', False):
            wandb.init(
                project=config.get('wandb_project', 'image-stitching'),
                name=config.get('wandb_run_name', time.strftime('%Y%m%d_%H%M%S')),
                config=config
            )
    
    def setup_logging(self):
        """Setup logging configuration"""
        log_dir = Path(self.config.get('log_dir', 'logs'))
        log_dir.mkdir(exist_ok=True)
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_dir / f'training_{time.strftime("%Y%m%d_%H%M%S")}.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
    
    def save_checkpoint(
        self,
        epoch: int,
        model_state: Dict,
        optimizer_state: Dict,
        scheduler_state: Dict,
        best_metric: float,
        is_best: bool
    ):
        """Save model checkpoint"""
        checkpoint_dir = Path(self.config.get('checkpoint_dir', 'checkpoints'))
        checkpoint_dir.mkdir(exist_ok=True)
        
        checkpoint = {
            'epoch': epoch,
            'model_state': model_state,
            'optimizer_state': optimizer_state,
            'scheduler_state': scheduler_state,
            'best_metric': best_metric,
            'config': self.config
        }
        
        # Save latest checkpoint
        torch.save(
            checkpoint,
            checkpoint_dir / 'latest_checkpoint.pth'
        )
        
        # Save best checkpoint
        if is_best:
            torch.save(
                checkpoint,
                checkpoint_dir / 'best_checkpoint.pth'
            )
    
    def train_epoch(self) -> Tuple[float, float]:
        """Train one epoch"""
        self.model.train()
        total_loss = 0
        total_accuracy = 0
        
        pbar = tqdm(self.train_loader, desc='Training')
        for batch in pbar:
            # Move data to device
            img1 = batch['image1'].to(self.device)
            img2 = batch['image2'].to(self.device)
            
            # Forward pass with mixed precision
            with autocast():
                outputs = self.model(img1, img2)
                loss = self.compute_loss(outputs)
            
            # Backward pass with gradient scaling
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            # Compute metrics
            accuracy = self.compute_accuracy(outputs)
            
            # Update progress bar
            total_loss += loss.item()
            total_accuracy += accuracy
            pbar.set_postfix({
                'loss': total_loss / (pbar.n + 1),
                'acc': total_accuracy / (pbar.n + 1)
            })
        
        return total_loss / len(self.train_loader), total_accuracy / len(self.train_loader)
    
    @torch.no_grad()
    def validate(self) -> Tuple[float, float]:
        """Validate model"""
        self.model.eval()
        total_loss = 0
        total_accuracy = 0
        
        pbar = tqdm(self.val_loader, desc='Validation')
        for batch in pbar:
            # Move data to device
            img1 = batch['image1'].to(self.device)
            img2 = batch['image2'].to(self.device)
            
            # Forward pass
            outputs = self.model(img1, img2)
            loss = self.compute_loss(outputs)
            
            # Compute metrics
            accuracy = self.compute_accuracy(outputs)
            
            # Update metrics
            total_loss += loss.item()
            total_accuracy += accuracy
            pbar.set_postfix({
                'loss': total_loss / (pbar.n + 1),
                'acc': total_accuracy / (pbar.n + 1)
            })
        
        return total_loss / len(self.val_loader), total_accuracy / len(self.val_loader)
    
    def compute_loss(self, outputs: Dict) -> torch.Tensor:
        """
        Compute training loss
        可以根据需要自定义损失函数
        """
        # Example loss calculation
        matching_scores = outputs['matching_scores']
        attention_weights = outputs['attention_weights']
        
        # Matching loss
        matching_loss = nn.functional.binary_cross_entropy_with_logits(
            matching_scores,
            torch.ones_like(matching_scores)  # Assume all pairs should match
        )
        
        # Attention regularization
        attention_loss = -torch.mean(
            torch.sum(attention_weights * torch.log(attention_weights + 1e-10), dim=-1)
        )
        
        return matching_loss + 0.1 * attention_loss
    
    def compute_accuracy(self, outputs: Dict) -> float:
        """
        Compute accuracy metric
        可以根据需要自定义准确率计算方法
        """
        matching_scores = outputs['matching_scores']
        predictions = (torch.sigmoid(matching_scores) > 0.5).float()
        return predictions.mean().item()
    
    def train(self):
        """Training loop"""
        best_metric = float('inf')
        
        for epoch in range(self.config['epochs']):
            self.logger.info(f"Epoch {epoch+1}/{self.config['epochs']}")
            
            # Training
            train_loss, train_acc = self.train_epoch()
            self.logger.info(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
            
            # Validation
            val_loss, val_acc = self.validate()
            self.logger.info(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
            
            # Update learning rate
            self.scheduler.step()
            
            # Log metrics
            if self.config.get('use_wandb', False):
                wandb.log({
                    'epoch': epoch + 1,
                    'train_loss': train_loss,
                    'train_acc': train_acc,
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                    'learning_rate': self.scheduler.get_last_lr()[0]
                })
            
            # Save checkpoint
            is_best = val_loss < best_metric
            if is_best:
                best_metric = val_loss
            
            self.save_checkpoint(
                epoch=epoch,
                model_state=self.model.state_dict(),
                optimizer_state=self.optimizer.state_dict(),
                scheduler_state=self.scheduler.state_dict(),
                best_metric=best_metric,
                is_best=is_best
            )

# Training configuration
def get_default_config():
    """Get default training configuration"""
    return {
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'epochs': 100,
        'learning_rate': 1e-4,
        'weight_decay': 0.01,
        'min_lr': 1e-6,
        'use_wandb': False,
        'wandb_project': 'image-stitching',
        'checkpoint_dir': 'checkpoints',
        'log_dir': 'logs'
    }

# Example usage
if __name__ == "__main__":
    # Get configuration
    config = get_default_config()
    
    # Create model and data loaders
    model = ImageStitchingTransformer()  # Your model
    train_loader, val_loader = create_dataloaders(...)  # Your data loaders
    
    # Create trainer
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config
    )
    
    # Start training
    trainer.train()

In [None]:
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
from torchvision.transforms.functional import to_tensor
from typing import Dict, List, Tuple
import kornia

class StitchingEvaluator:
    """评估图像拼接模型的性能"""
    def __init__(self, device='cuda'):
        self.device = device

    def compute_metrics(self, outputs: Dict, batch: Dict) -> Dict[str, float]:
        """计算多个评估指标"""
        metrics = {}
        
        # 1. 匹配准确率指标
        matching_metrics = self._compute_matching_metrics(
            outputs['matching_scores'],
            batch['gt_matches']
        )
        metrics.update(matching_metrics)
        
        # 2. 几何一致性指标
        geometric_metrics = self._compute_geometric_consistency(
            outputs['matched_features'],
            outputs['attention_weights']
        )
        metrics.update(geometric_metrics)
        
        # 3. 重叠区域质量指标
        if 'warped_image' in outputs:
            quality_metrics = self._compute_image_quality_metrics(
                outputs['warped_image'],
                batch['image2']
            )
            metrics.update(quality_metrics)
            
        return metrics
    
    def _compute_matching_metrics(self, pred_scores: torch.Tensor, 
                                gt_matches: torch.Tensor) -> Dict[str, float]:
        """计算特征匹配的准确率指标"""
        pred_matches = (torch.sigmoid(pred_scores) > 0.5).cpu().numpy()
        gt_matches = gt_matches.cpu().numpy()
        
        precision, recall, f1, _ = precision_recall_fscore_support(
            gt_matches.ravel(),
            pred_matches.ravel(),
            average='binary'
        )
        
        return {
            'matching_precision': precision,
            'matching_recall': recall,
            'matching_f1': f1
        }
    
    def _compute_geometric_consistency(self, matched_features: torch.Tensor,
                                    attention_weights: torch.Tensor) -> Dict[str, float]:
        """计算几何一致性指标"""
        # 1. 局部一致性评分
        local_consistency = self._compute_local_consistency(matched_features)
        
        # 2. 全局变换评分
        global_consistency = self._compute_global_consistency(matched_features)
        
        return {
            'local_geometric_consistency': local_consistency,
            'global_geometric_consistency': global_consistency
        }
    
    def _compute_image_quality_metrics(self, warped_image: torch.Tensor,
                                     target_image: torch.Tensor) -> Dict[str, float]:
        """计算图像质量指标"""
        # 计算PSNR
        psnr = kornia.metrics.psnr(warped_image, target_image, max_val=1.0)
        
        # 计算SSIM
        ssim = kornia.metrics.ssim(warped_image, target_image, window_size=11)
        
        return {
            'psnr': psnr.item(),
            'ssim': ssim.mean().item()
        }

class ImageStitcher:
    """图像拼接推理类"""
    def __init__(self, model: torch.nn.Module, config: Dict):
        self.model = model
        self.config = config
        self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)
        self.model.eval()
        
    def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
        """预处理输入图像"""
        # 调整图像大小
        image = cv2.resize(image, (self.config['img_size'], self.config['img_size']))
        
        # 转换为RGB
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        elif image.shape[2] == 4:
            image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
        elif image.shape[2] == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
        # 归一化和转换为tensor
        image = to_tensor(image)
        image = F.normalize(image, 
                          mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
        
        return image
    
    @torch.no_grad()
    def stitch_images(self, img1: np.ndarray, img2: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """拼接两张图像"""
        # 预处理
        tensor1 = self.preprocess_image(img1).unsqueeze(0).to(self.device)
        tensor2 = self.preprocess_image(img2).unsqueeze(0).to(self.device)
        
        # 模型推理
        outputs = self.model(tensor1, tensor2)
        
        # 获取匹配点
        matched_points = self._get_matched_points(outputs)
        
        # 估计变换矩阵
        H = self._estimate_transform(matched_points)
        
        # 图像融合
        stitched_image = self._blend_images(img1, img2, H)
        
        return stitched_image, {
            'matched_points': matched_points,
            'homography': H,
            'confidence_scores': outputs['matching_scores'].cpu().numpy()
        }
    
    def _get_matched_points(self, outputs: Dict) -> np.ndarray:
        """从模型输出中获取匹配点对"""
        scores = torch.sigmoid(outputs['matching_scores'])
        matches = scores > self.config.get('matching_threshold', 0.5)
        
        # 转换为图像坐标
        matched_indices = torch.nonzero(matches).cpu().numpy()
        matched_points = []
        
        for idx1, idx2 in matched_indices:
            pt1 = self._index_to_coordinate(idx1)
            pt2 = self._index_to_coordinate(idx2)
            matched_points.append((pt1, pt2))
            
        return np.array(matched_points)
    
    def _estimate_transform(self, matched_points: np.ndarray) -> np.ndarray:
        """估计变换矩阵"""
        if len(matched_points) < 4:
            raise ValueError("Not enough matched points for homography estimation")
            
        H, mask = cv2.findHomography(
            matched_points[:, 0],
            matched_points[:, 1],
            cv2.RANSAC,
            5.0
        )
        
        return H
    
    def _blend_images(self, img1: np.ndarray, img2: np.ndarray, 
                     H: np.ndarray) -> np.ndarray:
        """融合图像"""
        # 计算变换后图像的大小
        h1, w1 = img1.shape[:2]
        h2, w2 = img2.shape[:2]
        corners1 = np.float32([[0, 0], [0, h1], [w1, h1], [w1, 0]]).reshape(-1, 1, 2)
        corners2 = np.float32([[0, 0], [0, h2], [w2, h2], [w2, 0]]).reshape(-1, 1, 2)
        corners2_warped = cv2.perspectiveTransform(corners2, H)
        corners = np.concatenate([corners1, corners2_warped], axis=0)
        
        [xmin, ymin] = np.int32(corners.min(axis=0).ravel() - 0.5)
        [xmax, ymax] = np.int32(corners.max(axis=0).ravel() + 0.5)
        t = [-xmin, -ymin]
        
        # 创建输出图像
        Ht = np.array([[1, 0, t[0]], [0, 1, t[1]], [0, 0, 1]])
        img1_warped = cv2.warpPerspective(img1, Ht, (xmax-xmin, ymax-ymin))
        img2_warped = cv2.warpPerspective(img2, Ht.dot(H), (xmax-xmin, ymax-ymin))
        
        # 简单的加权融合
        mask1 = (img1_warped != 0).all(axis=2)
        mask2 = (img2_warped != 0).all(axis=2)
        overlap = mask1 & mask2
        
        # 在重叠区域进行加权融合
        result = img1_warped.copy()
        result[overlap] = (img1_warped[overlap] * 0.5 + img2_warped[overlap] * 0.5).astype(np.uint8)
        result[~mask1 & mask2] = img2_warped[~mask1 & mask2]
        
        return result

# Example usage
if __name__ == "__main__":
    # Load model
    model = ImageStitchingTransformer()  # Your model
    checkpoint = torch.load('best_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state'])
    
    # Create stitcher
    config = {
        'device': 'cuda',
        'img_size': 224,
        'matching_threshold': 0.5
    }
    stitcher = ImageStitcher(model, config)
    
    # Read images
    img1 = cv2.imread('image1.jpg')
    img2 = cv2.imread('image2.jpg')
    
    # Stitch images
    result, info = stitcher.stitch_images(img1, img2)
    
    # Save result
    cv2.imwrite('stitched_image.jpg', result)
    
    # Create evaluator and compute metrics
    evaluator = StitchingEvaluator()
    metrics = evaluator.compute_metrics(
        outputs={'matching_scores': torch.tensor([...])},  # Model outputs
        batch={'gt_matches': torch.tensor([...])}  # Ground truth
    )
    print("Evaluation metrics:", metrics)