In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2
from tqdm import tqdm
import time
import numpy as np
from pathlib import Path
import shutil
import gc

In [12]:
def create_folders():
    """创建必要的文件夹结构"""
    folders = [
        'temp_frames',
        'output'
    ]
    
    for folder in folders:
        Path(folder).mkdir(parents=True, exist_ok=True)

def extract_frames(video_path, start_frame, frame_interval):
    """
    从指定帧开始提取帧
    Args:
        video_path: 视频路径
        start_frame: 开始帧的索引
        frame_interval: 帧间隔
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"无法打开视频文件: {video_path}")
    
    try:
        frame_count = 0
        saved_count = 0
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        print(f"总帧数: {total_frames}")
        print(f"从第 {start_frame} 帧开始，间隔 {frame_interval} 帧")
        print(f"预计处理帧数: {(total_frames - start_frame) // frame_interval}")
        
        # 跳到指定的开始帧
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_count % frame_interval == 0:
                cv2.imwrite(f'temp_frames/frame_{saved_count}.jpg', frame)
                print(f"\r提取帧进度: {saved_count + 1}/{(total_frames - start_frame) // frame_interval}", end="", flush=True)
                saved_count += 1
                
            frame_count += 1
            del frame
            gc.collect()
        
        print("\n帧提取完成")
        return saved_count
        
    finally:
        cap.release()
        gc.collect()
        
def cleanup():
    """清理临时文件"""
    try:
        shutil.rmtree('temp_frames', ignore_errors=True)
    except Exception as e:
        print(f"清理临时文件时出错: {str(e)}")       

# 创建必要的文件夹
create_folders()
Path('checkpoints').mkdir(exist_ok=True)

video_path = "video/7.mp4"
start_frame = 3
frame_interval = 60
# 提取帧
print("正在提取帧...")
total_frames = extract_frames(video_path, start_frame, frame_interval)
         

正在提取帧...
总帧数: 1476
从第 3 帧开始，间隔 60 帧
预计处理帧数: 24
提取帧进度: 25/24
帧提取完成


In [8]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=4, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, padding=1)
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # 保持特征图的空间信息
        f1 = self.relu(self.bn1(self.conv1(x)))
        p1 = self.pool(f1)
        
        f2 = self.relu(self.bn2(self.conv2(p1)))
        p2 = self.pool(f2)
        
        f3 = self.relu(self.bn3(self.conv3(p2)))
        
        return [f1, f2, f3]

class CorrelationLayer(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, feat1, feat2):
        b, c, h, w = feat1.size()
        # 计算特征相关性
        feat1_flat = feat1.view(b, c, -1)
        feat2_flat = feat2.view(b, c, -1)
        
        correlation = torch.bmm(feat1_flat.permute(0,2,1), feat2_flat)
        correlation = correlation.view(b, h*w, h, w)
        return F.softmax(correlation, dim=1)

class StitchingNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.correlation = CorrelationLayer()
        
        # 自适应特征融合
        self.fusion = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        
        # 融合权重预测
        self.blend_weights = nn.Conv2d(64, 1, 1)
        
    def forward(self, img1, img2):
        # 提取多尺度特征
        feats1 = self.feature_extractor(img1)
        feats2 = self.feature_extractor(img2)
        
        # 计算特征相关性
        correlations = []
        for f1, f2 in zip(feats1, feats2):
            corr = self.correlation(f1, f2)
            correlations.append(corr)
        
        # 特征融合
        fusion_feats = torch.cat([feats1[-1], feats2[-1]], dim=1)
        fused = self.fusion(fusion_feats)
        
        # 预测融合权重
        weights = torch.sigmoid(self.blend_weights(fused))
        
        # 上采样权重到原始图像大小
        weights = F.interpolate(weights, size=img1.shape[2:], mode='bilinear', align_corners=False)
        
        # 生成最终结果
        result = weights * img1 + (1 - weights) * img2
        
        return result, weights, correlations

class StitchingLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        
    def forward(self, result, img1, img2, weights, correlations):
        # 重建损失
        reconstruct_loss = self.l1_loss(result, img2)
        
        # 相关性一致性损失
        correlation_loss = sum(self.mse_loss(corr, torch.ones_like(corr)/corr.shape[1]) 
                             for corr in correlations)
        
        # 平滑度损失
        smoothness_loss = self.l1_loss(weights[:,:,1:,:], weights[:,:,:-1,:]) + \
                         self.l1_loss(weights[:,:,:,1:], weights[:,:,:,:-1])
        
        total_loss = reconstruct_loss + 0.1 * correlation_loss + 0.01 * smoothness_loss
        return total_loss

def train_step(model, optimizer, img1, img2):
    """单步训练函数"""
    optimizer.zero_grad()
    
    # 前向传播
    result, weights, correlations = model(img1, img2)
    
    # 计算损失
    loss = StitchingLoss()(result, img1, img2, weights, correlations)
    
    # 反向传播
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [15]:
class FramePairsDataset(Dataset):
    def __init__(self, frames_dir, size=(128, 128)):
        self.frames_dir = Path(frames_dir)
        self.frame_pairs = self._get_frame_pairs()
        self.size = size
        
        # 使用PIL Image进行转换
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x[:3] if x.size(0) > 3 else x),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    
    def __getitem__(self, idx):
        try:
            frame1_path, frame2_path = self.frame_pairs[idx]
            
            # 读取并预处理图像
            frame1 = self._safe_read_image(frame1_path)
            frame2 = self._safe_read_image(frame2_path)
            
            # 转换为张量
            frame1_tensor = self.transform(frame1)
            frame2_tensor = self.transform(frame2)
            
            # 打印张量形状用于调试
            print(f"Tensor shapes - frame1: {frame1_tensor.shape}, frame2: {frame2_tensor.shape}")
            
            return frame1_tensor, frame2_tensor
            
        except Exception as e:
            print(f"Error processing item {idx}: {str(e)}")
            # 返回全零张量
            return torch.zeros(3, *self.size), torch.zeros(3, *self.size)
        
    def _get_frame_pairs(self):
        """获取相邻帧对，确保文件存在且可读"""
        frames = []
        for frame in sorted(list(self.frames_dir.glob('frame_*.jpg'))):
            if frame.exists() and frame.stat().st_size > 0:
                frames.append(frame)
                
        if not frames:
            raise RuntimeError(f"No valid frames found in {self.frames_dir}")
            
        return [(frames[i], frames[i+1]) for i in range(len(frames)-1)]
    
    def __len__(self):
        return len(self.frame_pairs)
    
    def _safe_read_image(self, path):
        """安全地读取和处理图像，确保输出为3通道RGB图像"""
        try:
            # 使用IMREAD_COLOR确保读取为3通道
            img = cv2.imread(str(path), cv2.IMREAD_COLOR)
            if img is None:
                raise ValueError(f"Failed to read image: {path}")
            
            # 调整图像大小以确保一致性
            img = cv2.resize(img, self.size, interpolation=cv2.INTER_AREA)
            
            # 确保是3通道RGB图像
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # 确保图像为uint8类型
            img = img.astype(np.uint8)
            
            # 打印图像形状和类型信息（用于调试）
            # print(f"Image shape: {img.shape}, dtype: {img.dtype}")
            
            return img
            
        except Exception as e:
            print(f"Error reading image {path}: {str(e)}")
            # 返回一个空白图像而不是失败
            return np.zeros((*self.size, 3), dtype=np.uint8)
    
    def __getitem__(self, idx):
        try:
            frame1_path, frame2_path = self.frame_pairs[idx]
            
            # 安全读取图像
            frame1 = self._safe_read_image(frame1_path)
            frame2 = self._safe_read_image(frame2_path)
            
            # 应用变换
            if self.transform:
                try:
                    frame1 = self.transform(frame1)
                    frame2 = self.transform(frame2)
                except Exception as e:
                    print(f"Transform error for index {idx}: {str(e)}")
                    # 返回零张量而不是失败
                    frame1 = torch.zeros((3, *self.size))
                    frame2 = torch.zeros((3, *self.size))
            
            return frame1, frame2
            
        except Exception as e:
            print(f"Error processing item {idx}: {str(e)}")
            # 返回零张量而不是失败
            return torch.zeros((3, *self.size)), torch.zeros((3, *self.size))

def train_model(model, train_loader, num_epochs, device):
    """训练模型"""
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = StitchingLoss()

    # 用于提前停止的变量
    best_loss = float('inf')
    patience = 5
    patience_counter = 0
    avg_loss = float('inf')  # 初始化avg_loss

    # 创建进度条
    epoch_pbar = tqdm(total=num_epochs, desc="Training Progress")

    try:
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0
            batch_count = 0

            # 添加梯度累积
            accumulation_steps = 4  # 累积4次更新一次
            optimizer.zero_grad()

            # 创建每个epoch的batch进度条
            batch_pbar = tqdm(total=len(train_loader),
                            desc=f"Epoch {epoch+1}/{num_epochs}",
                            leave=False)

            start_time = time.time()

            for batch_idx, (img1, img2) in enumerate(train_loader):
                try:
                    img1, img2 = img1.to(device), img2.to(device)

                    # 清除GPU缓存
                    if batch_idx % 5 == 0:
                        torch.cuda.empty_cache()

                    # 前向传播和损失计算
                    result, weights, correlations = model(img1, img2)
                    loss = criterion(result, img1, img2, weights, correlations)
                    loss = loss / accumulation_steps  # 缩放loss
                    loss.backward()

                    # 累积梯度
                    if (batch_idx + 1) % accumulation_steps == 0:
                        optimizer.step()
                        optimizer.zero_grad()

                    total_loss += loss.item()
                    batch_count += 1

                    # 更新batch进度条
                    batch_pbar.update(1)
                    batch_pbar.set_postfix({
                        'loss': f'{loss.item():.4f}',
                        'avg_loss': f'{total_loss/batch_count:.4f}'
                    })

                except Exception as e:
                    print(f"\\nError in batch {batch_idx}: {str(e)}")
                    continue

            batch_pbar.close()

            # 计算平均损失
            avg_loss = total_loss / batch_count if batch_count > 0 else float('inf')
            epoch_time = time.time() - start_time

            # 更新epoch进度条
            epoch_pbar.update(1)
            epoch_pbar.set_postfix({
                'avg_loss': f'{avg_loss:.4f}',
                'time': f'{epoch_time:.1f}s'
            })

            # 提前停止检查
            if avg_loss < best_loss:
                best_loss = avg_loss
                patience_counter = 0
                # 保存最佳模型
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_loss,
                }, 'checkpoints/best_model.pt')
            else:
                patience_counter += 1

            # 定期保存检查点
            if (epoch + 1) % 5 == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_loss,
                }, f'checkpoints/model_epoch_{epoch+1}.pt')

            # 如果连续多个epoch没有改善，提前停止
            if patience_counter >= patience:
                print(f"\\nEarly stopping after {epoch+1} epochs")
                break

    except KeyboardInterrupt:
        print("\\nTraining interrupted by user")
    except Exception as e:
        print(f"\\nTraining error: {str(e)}")
    finally:
        epoch_pbar.close()
        # 确保最后一个模型状态被保存
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, 'checkpoints/final_model.pt')


def stitch_with_model(model, img1, img2, device):
    """使用训练好的模型进行图像拼接"""
    model.eval()
    with torch.no_grad():
        img1 = img1.unsqueeze(0).to(device)
        img2 = img2.unsqueeze(0).to(device)
        
        result, _, _ = model(img1, img2)
        
        # 转换回numpy格式
        result = result.squeeze(0).cpu().numpy()
        result = (result * 255).astype(np.uint8)
        result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
        
        return result

In [16]:
def main():
    """主函数"""
    try:
        # 检查CUDA是否可用
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        
        
        # 创建数据集和数据加载器，使用更保守的设置
        try:
            print("初始化数据集...")
            dataset = FramePairsDataset('temp_frames', size=(128, 128))
            
            # 使用单进程模式加载数据
            train_loader = DataLoader(
                dataset, 
                batch_size=2,
                shuffle=True,
                num_workers=0,  # 使用单进程
                pin_memory=False
                # True if torch.cuda.is_available() else False
            )
            print(f"数据集初始化完成，共有 {len(dataset)} 对图像")
            
        except Exception as e:
            print(f"创建数据加载器时出错: {str(e)}")
            raise
        
        # 创建模型
        model = StitchingNet()
        
        # 训练模型
        print("开始训练模型...")
        train_model(model, train_loader, num_epochs=50, device=device)
        
        # 使用训练好的模型进行拼接
        print("开始拼接全景图...")
        result = None
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])
        
        for i in range(total_frames - 1):
            img1 = cv2.imread(f'temp_frames/frame_{i}.jpg')
            img2 = cv2.imread(f'temp_frames/frame_{i+1}.jpg')
            
            if img1 is None or img2 is None:
                continue
                
            # 确保输入图像是3通道
            if img1 is None or img2 is None:
                print("Error: Failed to read images")
                continue
                
            # 转换为RGB并应用预处理
            img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
            img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
            
            # 确保图像形状正确
            print(f"Image shapes before transform - img1: {img1.shape}, img2: {img2.shape}")
            
            try:
                img1 = transform(img1)
                img2 = transform(img2)
                
                # 验证转换后的张量
                assert img1.size(0) == 3, f"Expected 3 channels, got {img1.size(0)} for img1"
                assert img2.size(0) == 3, f"Expected 3 channels, got {img2.size(0)} for img2"
                
            except Exception as e:
                print(f"Error during transform: {str(e)}")
                continue
            
            # 如果是第一帧
            if result is None:
                result = img1
                continue
            
            # 拼接
            result = stitch_with_model(model, result, img2, device)
            
            print(f"\r拼接进度: {i+1}/{total_frames-1}", end="", flush=True)
            
            # 定期保存检查点
            if (i + 1) % 5 == 0:
                cv2.imwrite(f'temp_frames/checkpoint_{i+1}.jpg', result)
        
        if result is not None:
            print("\n拼接完成")
            cv2.imwrite('output/panorama_dl.jpg', result)
        
    except Exception as e:
        print(f"发生错误: {str(e)}")
    finally:
        # cleanup()
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()

Using device: cuda
初始化数据集...
数据集初始化完成，共有 24 对图像
开始训练模型...


Training Progress: 100%|██████████| 50/50 [01:43<00:00,  2.07s/it, avg_loss=0.0026, time=2.0s]


开始拼接全景图...
Image shapes before transform - img1: (1080, 1920, 3), img2: (1080, 1920, 3)
Image shapes before transform - img1: (1080, 1920, 3), img2: (1080, 1920, 3)
发生错误: CUDA out of memory. Tried to allocate 15971.77 GiB. GPU 0 has a total capacity of 23.99 GiB of which 18.47 GiB is free. Of the allocated memory 1.79 GiB is allocated by PyTorch, and 2.11 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
