# 自监督学习

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from pathlib import Path
import random

class DepthNet(nn.Module):
    """优化后的深度估计网络"""
    def __init__(self, training=False):
        super(DepthNet, self).__init__()
        self.training = training
        
        # 基础编码器-解码器(推理和训练都会使用)
        self.encoder = nn.ModuleList([
            self._make_encoder_layer(3, 32),      # 1/2
            self._make_encoder_layer(32, 64),     # 1/4
            self._make_encoder_layer(64, 128),    # 1/8
        ])
        
        self.decoder = nn.ModuleList([
            self._make_decoder_layer(128, 64),    # 1/4
            self._make_decoder_layer(64, 32),     # 1/2
            self._make_decoder_layer(32, 1),      # 1
        ])
        
        # 训练时的额外组件
        if training:
            self.auxiliary_encoder = nn.ModuleList([
                self._make_encoder_layer(128, 256),   # 1/16
                self._make_encoder_layer(256, 512),   # 1/32
            ])
            self.auxiliary_decoder = nn.ModuleList([
                self._make_decoder_layer(512, 256),   # 1/16
                self._make_decoder_layer(256, 128),   # 1/8
            ])
            
    def _make_encoder_layer(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def _make_decoder_layer(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, 
                             padding=1, output_padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # 基础特征提取(推理和训练共用)
        features = []
        for encoder in self.encoder:
            x = encoder(x)
            features.append(x)
            
        if not self.training:
            # 推理时只使用基础解码器
            for i, decoder in enumerate(self.decoder):
                if i < len(self.decoder) - 1:
                    x = decoder(x + features[-(i+2)])
                else:
                    x = decoder(x)
            return torch.sigmoid(x)
            
        # 训练时使用额外的处理
        aux_features = []
        for aux_encoder in self.auxiliary_encoder:
            x = aux_encoder(x)
            aux_features.append(x)
            
        # 解码时合并所有特征
        for i, aux_decoder in enumerate(self.auxiliary_decoder):
            x = aux_decoder(x + aux_features[-(i+1)])
            
        for i, decoder in enumerate(self.decoder):
            if i < len(self.decoder) - 1:
                x = decoder(x + features[-(i+2)])
            else:
                x = decoder(x)
                
        return torch.sigmoid(x)

class PoseNet(nn.Module):
    """优化后的位姿估计网络"""
    def __init__(self, training=False):
        super(PoseNet, self).__init__()
        self.training = training
        
        # 基础位姿估计(推理和训练共用)
        self.base_encoder = nn.Sequential(
            self._make_layer(6, 32),    # 输入是concat的两帧
            self._make_layer(32, 64),
            self._make_layer(64, 128),
        )
        
        self.base_pose_pred = nn.Conv2d(128, 6, 1)
        
        # 训练时的额外组件
        if training:
            self.aux_encoder = nn.Sequential(
                self._make_layer(128, 256),
                self._make_layer(256, 512),
            )
            self.aux_pose_pred = nn.Conv2d(512, 6, 1)
            
    def _make_layer(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, target_frame, source_frame):
        x = torch.cat([target_frame, source_frame], dim=1)
        
        # 基础位姿估计
        feat = self.base_encoder(x)
        base_feat = F.adaptive_avg_pool2d(feat, 1)
        base_pose = self.base_pose_pred(base_feat)
        
        if not self.training:
            return base_pose.view(-1, 6)
            
        # 训练时使用额外特征
        aux_feat = self.aux_encoder(feat)
        aux_feat = F.adaptive_avg_pool2d(aux_feat, 1)
        aux_pose = self.aux_pose_pred(aux_feat)
        
        # 融合基础和辅助预测
        final_pose = (base_pose + aux_pose) / 2
        return final_pose.view(-1, 6)

class SelfSupervisedTrainer:
    def __init__(self, height=256, width=384):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.height = height
        self.width = width
        
        # 初始化网络(训练模式)
        self.depth_net = DepthNet(training=True).to(self.device)
        self.pose_net = PoseNet(training=True).to(self.device)
        
        # 使用混合精度训练
        self.scaler = torch.amp.GradScaler('cuda')
        
        # 优化器
        self.optimizer = torch.optim.Adam([
            {'params': self.depth_net.parameters(), 'lr': 1e-4},
            {'params': self.pose_net.parameters(), 'lr': 1e-4}
        ])
        
        # 学习率调度器
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5
        )
        
    def train(self, video_path, num_epochs=50, batch_size=4):
        # 增强的数据预处理
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        ])
        
        dataset = VideoDataset(video_path, transform=transform,
                             max_frames=2000,
                             target_size=(self.width, self.height))
        
        dataloader = DataLoader(
            dataset, 
            batch_size=batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )
        
        for epoch in range(num_epochs):
            total_loss = 0
            self.depth_net.train()
            self.pose_net.train()
            
            for batch_idx, (frame1, frame2) in enumerate(dataloader):
                frame1 = frame1.to(self.device, non_blocking=True)
                frame2 = frame2.to(self.device, non_blocking=True)
                
                with torch.amp.autocast('cuda'):
                    # 前向传播
                    depth1 = self.depth_net(frame1)
                    depth2 = self.depth_net(frame2)
                    pose = self.pose_net(frame1, frame2)
                    
                    # 训练时使用更复杂的损失函数
                    loss = (
                        self.photometric_loss(frame1, frame2, depth1, depth2, pose) +
                        0.1 * self.smoothness_loss(depth1, frame1) +
                        0.05 * self.geometric_consistency_loss(depth1, depth2, pose)
                    )
                
                self.optimizer.zero_grad(set_to_none=True)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
                
                total_loss += loss.item()
                
                if batch_idx % 10 == 0:
                    avg_loss = total_loss / (batch_idx + 1)
                    print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, '
                          f'Loss: {avg_loss:.4f}')
            
            epoch_loss = total_loss / len(dataloader)
            self.scheduler.step(epoch_loss)
            
            if (epoch + 1) % 5 == 0:
                self.save_models(f'checkpoint_epoch_{epoch+1}')

    def geometric_consistency_loss(self, depth1, depth2, pose):
        """几何一致性损失"""
        # 计算重投影后的深度一致性
        proj_depth2 = self.warp_depth(depth2, depth1, pose)
        consistency_mask = self.compute_consistency_mask(depth1, proj_depth2)
        
        loss = torch.abs(depth1 - proj_depth2) * consistency_mask
        return torch.mean(loss)
        
    def warp_depth(self, depth2, depth1, pose):
        """将depth2变换到depth1的视角"""
        batch_size = depth1.size(0)
        height = depth1.size(2)
        width = depth1.size(3)
        
        # 生成像素网格
        pixels_x, pixels_y = torch.meshgrid(
            torch.arange(width, device=self.device),
            torch.arange(height, device=self.device)
        )
        pixels = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_x)]).float()
        
        # 投影变换
        K = self.get_camera_matrix(height, width).to(self.device)
        pixels_3d = torch.matmul(torch.inverse(K), pixels.reshape(3, -1))
        pixels_3d = pixels_3d.unsqueeze(0).expand(batch_size, -1, -1)
        
        # 应用深度
        points_3d = pixels_3d * depth2.reshape(batch_size, 1, -1)
        
        # 应用位姿变换
        R = pose[:, :3].reshape(-1, 3, 3)
        t = pose[:, 3:].reshape(-1, 3, 1)
        transformed_points = torch.matmul(R, points_3d) + t
        
        # 投影回图像平面
        proj_points = torch.matmul(K, transformed_points)
        proj_pixels = proj_points[:, :2] / (proj_points[:, 2:] + 1e-7)
        
        # 网格采样
        proj_pixels = proj_pixels.reshape(batch_size, 2, height, width)
        proj_pixels = proj_pixels.permute(0, 2, 3, 1)
        warped_depth = F.grid_sample(depth1, proj_pixels, align_corners=True)
        
        return warped_depth
        
    def compute_consistency_mask(self, depth1, proj_depth2, threshold=0.1):
        """计算深度一致性掩码"""
        diff = torch.abs(depth1 - proj_depth2)
        mask = (diff < threshold).float()
        return mask

    def save_inference_model(self, output_dir):
        """保存推理模型"""
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        # 创建推理模型实例
        inference_depth_net = DepthNet(training=False)
        inference_pose_net = PoseNet(training=False)
        
        # 复制基础组件的权重
        depth_state_dict = self.depth_net.state_dict()
        pose_state_dict = self.pose_net.state_dict()
        
        # 只保留基础组件的权重
        inference_depth_dict = {k: v for k, v in depth_state_dict.items() 
                              if not k.startswith('auxiliary')}
        inference_pose_dict = {k: v for k, v in pose_state_dict.items() 
                             if not k.startswith('aux')}
        
        # 加载权重到推理模型
        inference_depth_net.load_state_dict(inference_depth_dict)
        inference_pose_net.load_state_dict(inference_pose_dict)
        
        # 将模型移到CPU并设置为评估模式
        inference_depth_net = inference_depth_net.cpu().eval()
        inference_pose_net = inference_pose_net.cpu().eval()
        
        # 保存完整的推理模型
        torch.save(inference_depth_net, output_path / 'inference_depth_net.pth')
        torch.save(inference_pose_net, output_path / 'inference_pose_net.pth')

    @staticmethod
    def load_inference_models(model_dir):
        """加载推理模型"""
        model_path = Path(model_dir)
        depth_net = torch.load(model_path / 'inference_depth_net.pth')
        pose_net = torch.load(model_path / 'inference_pose_net.pth')
        return depth_net, pose_net

    def generate_depth_map(self, frame, inference_depth_net=None):
        """推理时生成深度图"""
        with torch.no_grad():
            # 如果没有提供推理模型，使用训练模型
            if inference_depth_net is None:
                self.depth_net.eval()
                self.depth_net.training = False
                model = self.depth_net
            else:
                model = inference_depth_net
                
            # 预处理帧
            frame_tensor = transforms.ToTensor()(frame).unsqueeze(0)
            frame_tensor = transforms.Resize((self.height, self.width))(frame_tensor)
            frame_tensor = transforms.Normalize(mean=[0.5, 0.5, 0.5], 
                                             std=[0.5, 0.5, 0.5])(frame_tensor)
                                             
            # 如果使用训练模型，需要移到正确的设备上
            if inference_depth_net is None:
                frame_tensor = frame_tensor.to(self.device)
            
            depth = model(frame_tensor)
            return depth.cpu().numpy()[0, 0]
            
class VideoDataset(Dataset):
    """优化后的视频数据集类"""
    def __init__(self, video_path, transform=None, max_frames=None, target_size=(256, 384),
                 sequence_length=3):  # 新增sequence_length参数
        self.frames = []
        self.transform = transform
        self.target_size = target_size
        self.sequence_length = sequence_length
        
        # 读取视频帧
        cap = cv2.VideoCapture(video_path)
        frame_count = 0
        
        # 使用双端队列存储帧,便于滑动窗口访问
        from collections import deque
        frame_buffer = deque(maxlen=sequence_length)
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret or (max_frames and frame_count >= max_frames):
                break
                
            # 预处理
            frame = cv2.resize(frame, target_size)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            frame_buffer.append(frame)
            
            # 当缓冲区填满时,保存帧序列
            if len(frame_buffer) == sequence_length:
                self.frames.append(list(frame_buffer))
            
            frame_count += 1
            
            # 每隔几帧采样一次
            if frame_count % 2 == 0:
                cap.grab()
                
        cap.release()
        print(f"Loaded {len(self.frames)} frame sequences from video")
        
    def __len__(self):
        return len(self.frames)
        
    def __getitem__(self, idx):
        frame_sequence = self.frames[idx]
        
        if self.transform:
            # 对整个序列应用相同的随机变换
            if random.random() > 0.5:
                frame_sequence = [np.fliplr(frame) for frame in frame_sequence]
            
            # 应用颜色增强,但保持序列一致性
            color_jitter = transforms.ColorJitter(
                brightness=0.2,
                contrast=0.2,
                saturation=0.2,
                hue=0.1
            )
            
            frame_sequence = [self.transform(frame) for frame in frame_sequence]
            
            if random.random() > 0.5:
                frame_sequence = [color_jitter(frame) for frame in frame_sequence]
        
        return tuple(frame_sequence)

class ViewConfidenceEstimator:
    """优化后的视角置信度估计器"""
    def __init__(self, threshold_angle=45.0):
        self.threshold_angle = threshold_angle
        
    @torch.no_grad()  # 推理时优化
    def compute_view_angles(self, depth_map, K, pose):
        """计算每个像素的观察角度"""
        device = depth_map.device
        batch_size, _, height, width = depth_map.shape
        
        # 使用缓存的网格坐标
        if not hasattr(self, 'pixel_grid'):
            y_coords, x_coords = torch.meshgrid(
                torch.arange(height, device=device),
                torch.arange(width, device=device)
            )
            pixels = torch.stack([x_coords, y_coords, torch.ones_like(x_coords)]).float()
            self.pixel_grid = pixels.reshape(3, -1).to(device)
        
        # 计算3D点
        depths = depth_map.reshape(batch_size, -1, 1)
        points_3d = (torch.inverse(K.to(device)) @ self.pixel_grid).unsqueeze(0) * depths
        
        # 计算观察角度
        view_directions = F.normalize(points_3d, p=2, dim=1)
        camera_normal = pose[:, :3, 2]  # 相机朝向
        
        # 批量计算点积
        angles = torch.arccos(torch.clamp(
            torch.sum(view_directions * camera_normal.unsqueeze(1), dim=1),
            -1.0, 1.0
        ))
        
        return angles.reshape(batch_size, 1, height, width)
    
    @torch.no_grad()  # 推理时优化
    def get_confidence_mask(self, view_angles):
        """根据视角生成置信度掩码"""
        angles_deg = view_angles * 180.0 / np.pi
        confidence = torch.exp(-(angles_deg / self.threshold_angle) ** 2)
        return confidence

class DSMGenerator:
    """优化后的数字表面模型生成器"""
    def __init__(self, cell_size=0.1, grid_size=(1000, 1000)):
        self.cell_size = cell_size
        self.grid_size = grid_size
        self.dsm = np.zeros(grid_size)
        self.weight_sum = np.zeros(grid_size)
        
    @torch.no_grad()  # 推理时优化
    def update_dsm(self, depth_map, pose, confidence_map=None, K=None):
        """更新DSM"""
        # 转换深度图到点云
        points_3d = self.depth_to_points(depth_map, K)
        
        # 应用位姿变换
        points_world = self.transform_points(points_3d, pose)
        
        # 投影到网格
        grid_coords = self.project_to_grid(points_world)
        
        # 更新DSM
        valid_mask = (
            (grid_coords[:, 0] >= 0) & 
            (grid_coords[:, 0] < self.grid_size[0]) &
            (grid_coords[:, 1] >= 0) & 
            (grid_coords[:, 1] < self.grid_size[1])
        )
        
        grid_coords = grid_coords[valid_mask]
        heights = points_world[valid_mask, 2]
        weights = confidence_map.reshape(-1)[valid_mask] if confidence_map is not None \
                 else np.ones_like(heights)
        
        # 使用numpy的高效索引操作
        np.add.at(self.dsm, (grid_coords[:, 0], grid_coords[:, 1]), 
                 heights * weights)
        np.add.at(self.weight_sum, (grid_coords[:, 0], grid_coords[:, 1]), 
                 weights)
    
    def finalize_dsm(self):
        """完成DSM生成"""
        # 避免除零
        mask = self.weight_sum > 0
        self.dsm[mask] /= self.weight_sum[mask]
        
        # 填充未观测区域
        self._fill_unobserved_regions()
        
        return self.dsm
        
    @staticmethod
    def depth_to_points(depth_map, K):
        """深度图转换为点云"""
        height, width = depth_map.shape
        y_coords, x_coords = np.meshgrid(np.arange(height), np.arange(width), 
                                       indexing='ij')
        pixels = np.stack([x_coords, y_coords, np.ones_like(x_coords)]).reshape(3, -1)
        
        # 反投影
        points_3d = np.linalg.inv(K) @ pixels
        points_3d *= depth_map.reshape(1, -1)
        
        return points_3d.T
    
    @staticmethod
    def transform_points(points_3d, pose):
        """应用位姿变换"""
        R = pose[:3, :3]
        t = pose[:3, 3]
        return (R @ points_3d.T).T + t
    
    def project_to_grid(self, points_world):
        """投影到网格坐标"""
        grid_coords = np.floor(points_world[:, :2] / self.cell_size).astype(int)
        return grid_coords
    
    def _fill_unobserved_regions(self):
        """填充未观测区域(使用最近邻插值)"""
        from scipy.ndimage import distance_transform_edt
        
        unobserved = self.weight_sum == 0
        if not np.any(unobserved):
            return
            
        # 计算到最近有效点的距离
        dist = distance_transform_edt(unobserved)
        
        # 创建掩码
        mask = dist > 0
        
        # 使用最近邻填充
        y_indices, x_indices = np.nonzero(unobserved)
        valid_y, valid_x = np.nonzero(~unobserved)
        
        if len(valid_y) == 0:
            return
            
        from scipy.spatial import cKDTree
        tree = cKDTree(np.column_stack([valid_y, valid_x]))
        
        # 找到最近的有效点
        _, indices = tree.query(np.column_stack([y_indices, x_indices]))
        
        # 填充值
        self.dsm[y_indices, x_indices] = self.dsm[valid_y[indices], valid_x[indices]]
        self.weight_sum[unobserved] = 1  # 标记为已填充

def main(video_path, num_epochs):
    """主函数"""
    output_dir = "output"
    
    # 创建训练器
    trainer = SelfSupervisedTrainer()
    
    # 训练模型
    trainer.train(video_path, num_epochs, batch_size=4)
    
    # 保存训练模型和推理模型
    trainer.save_inference_model(output_dir)
    
    # 加载推理模型
    inference_depth_net, inference_pose_net = trainer.load_inference_models(output_dir)
    
    # 使用推理模型生成DSM
    dsm_generator = DSMGenerator()
    cap = cv2.VideoCapture(video_path)
    
    prev_frame = None

    pose = np.eye(4)  # 初始位姿
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
            
        # 使用推理模型生成深度图
        depth = trainer.generate_depth_map(frame, inference_depth_net)
        
        if prev_frame is not None:
            # 使用推理模型估计位姿
            with torch.no_grad():
                pose_update = inference_pose_net(
                    trainer.preprocess_frame(frame),
                    trainer.preprocess_frame(prev_frame)
                ).cpu().numpy()
                
                # 更新累积位姿
                pose = update_pose(pose, pose_update)
            
            # 更新DSM
            dsm_generator.update_dsm(depth, pose)
        
        prev_frame = frame.copy()
    
    cap.release()
    
    # 完成DSM生成
    final_dsm = dsm_generator.finalize_dsm()
    
    # 保存结果
    np.save(Path(output_dir) / 'dsm.npy', final_dsm)
    
    # 可视化
    visualize_dsm(final_dsm, Path(output_dir) / 'dsm_visualization.png')

def update_pose(prev_pose, pose_update):
    """更新累积位姿"""
    R_update = cv2.Rodrigues(pose_update[:3])[0]
    t_update = pose_update[3:]
    
    # 合并旋转和平移
    R = prev_pose[:3, :3] @ R_update
    t = prev_pose[:3, :3] @ t_update + prev_pose[:3, 3]
    
    new_pose = np.eye(4)
    new_pose[:3, :3] = R
    new_pose[:3, 3] = t
    
    return new_pose

def visualize_dsm(dsm, output_path):
    """可视化DSM"""
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(12, 8))
    
    # 主视图
    plt.subplot(121)
    plt.imshow(dsm, cmap='terrain')
    plt.colorbar(label='Height (m)')
    plt.title('Digital Surface Model')
    
    # 等高线图
    plt.subplot(122)
    plt.contour(dsm, levels=20, cmap='viridis')
    plt.colorbar(label='Height (m)')
    plt.title('Contour Map')
    
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

if __name__ == "__main__":
    video_path = "video/7.mp4"
    num_epochs=50
    main(video_path, num_epochs)

Loaded 982 frame sequences from video
