In [232]:
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
from cv2 import waitKey, destroyAllWindows
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights, raft_small, Raft_Small_Weights
from torchvision.utils import flow_to_image
device = "cuda" if torch.cuda.is_available() else "cpu"

# 配置
image_folder = 'overDataSet2'  # 图片文件夹路径
output_folder = 'lightFlowOutput'  # 保存输出的光流文件夹路径
output_PRE_folder = 'lightFlowOutputPre' # 保存输出的预测图片文件夹路径
os.makedirs(output_folder, exist_ok=True)
os.makedirs(output_PRE_folder, exist_ok=True)

# 加载模型
weights1= Raft_Large_Weights.C_T_SKHT_K_V2
weights2= Raft_Small_Weights.DEFAULT
model = raft_large(weights=weights1, ).to(device)
print(model)

# 获取所有图片文件
image_files = sorted(glob.glob(os.path.join(image_folder, '*.png')))
# 针对数字进行排序
image_files = sorted(image_files, key=lambda x: int((os.path.basename(x).split('.')[0]).split('_')[-1]))
# print(image_files )
num_frames = len(image_files)


RAFT(
  (feature_encoder): FeatureEncoder(
    (convnormrelu): Conv2dNormActivation(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU(inplace=True)
    )
    (layer1): Sequential(
      (0): ResidualBlock(
        (convnormrelu1): Conv2dNormActivation(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): ReLU(inplace=True)
        )
        (convnormrelu2): Conv2dNormActivation(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (2): ReLU(inplace=True)
        )
        (downsample): Identity()
        (relu): ReLU(inplace=True)
      )
      (1): ResidualBlock(
        (

特征编码器的主要作用是从图像中提取有效的特征。由于雷达回波图像与自然图像不同，你可以重点再训练该部分，以便模型更好地捕捉雷达图像的独特模式。具体建议：

第一个卷积层：这一层接收输入图像，因此可以重新训练 convnormrelu 层，使其适应雷达图像中的颜色或纹理信息。
后续的残差块（ResidualBlock）：这些层通过不同的层次学习到越来越抽象的特征。再训练这些层，可以帮助模型更好地提取雷达回波的高层次特征。

In [233]:
# 设置要训练的层
for name, param in model.named_parameters():
    if 'convnormrelu' in name:  # 要训练的层
        param.requires_grad = True  # 只训练这一层
    else:
        param.requires_grad = False  # 冻结其他层
        
LR = 0.0003
entropy_loss = nn.L1Loss()
entropy_loss.to(device)
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
optimizer = optim.Adam(model.parameters(), lr=LR)
# TODO:写训练函数



In [226]:
# 帧预测函数
def apply_flow(image, flow):
    flow = flow.permute(0, 2, 3, 1)  # (N, H, W, 2)
    h, w = flow.shape[1:3]
    y_coords, x_coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
    y_coords, x_coords = y_coords.float(), x_coords.float()
    y_coords, x_coords = y_coords.to(flow.device), x_coords.to(flow.device)
    # 计算新的坐标
    new_x_coords = x_coords - flow[:, :, :, 0]
    new_y_coords = y_coords - flow[:, :, :, 1]
    # 归一化坐标
    new_x_coords = (new_x_coords / (w - 1)) * 2 - 1
    new_y_coords = (new_y_coords / (h - 1)) * 2 - 1
    
    # 使用bicubic插值采样
    grid = torch.stack([new_x_coords, new_y_coords], dim=-1)
    warped_image = torch.nn.functional.grid_sample(image, grid, mode='bicubic', padding_mode='reflection', align_corners=True)
    
    mean_flow_x = flow[:, :, :, 0].mean()
    mean_flow_y = flow[:, :, :, 1].mean()
    print(f'mean_flow_x={mean_flow_x}, mean_flow_y={mean_flow_y}')
    return warped_image



max_flow_step = 5.0
step = 5

# 累加光流函数
def accumulate_flow(accumulated_flow, new_flow):
    return accumulated_flow + new_flow

# 限制光流步长函数
def limit_flow_step(flow, max_step):
    return torch.clamp(flow, min=-max_step, max=max_step)  # 限制光流的最大步长



241
img1= 0
img2= 1
img1= 1
img2= 2
img1= 2
img2= 3
img1= 3
img2= 4
img1= 4
img2= 5
shape= torch.Size([1, 2, 200, 200])
Processed frame 0 to 5
mean_flow_x=3.3959906101226807, mean_flow_y=1.3247820138931274
---
img1= 1
img2= 2
img1= 2
img2= 3
img1= 3
img2= 4
img1= 4
img2= 5
img1= 5
img2= 6
shape= torch.Size([1, 2, 200, 200])
Processed frame 1 to 6
mean_flow_x=2.5793817043304443, mean_flow_y=1.3253446817398071
---
img1= 2
img2= 3
img1= 3
img2= 4
img1= 4
img2= 5
img1= 5
img2= 6
img1= 6
img2= 7
shape= torch.Size([1, 2, 200, 200])
Processed frame 2 to 7
mean_flow_x=2.054450035095215, mean_flow_y=1.1756590604782104
---
img1= 3
img2= 4
img1= 4
img2= 5
img1= 5
img2= 6
img1= 6
img2= 7
img1= 7
img2= 8
shape= torch.Size([1, 2, 200, 200])
Processed frame 3 to 8
mean_flow_x=0.7431689500808716, mean_flow_y=2.158536195755005
---
img1= 4
img2= 5
img1= 5
img2= 6
img1= 6
img2= 7
img1= 7
img2= 8
img1= 8
img2= 9
shape= torch.Size([1, 2, 200, 200])
Processed frame 4 to 9
mean_flow_x=-0.2740216851234436, me

In [None]:
# 开始训练
for i in range(num_frames - step):
    accumulated_flow = None  # 初始化累积光流为None
    optimizer.zero_grad()  # 清除上一次迭代的梯度

    for j in range(step):
        # 加载图片
        img1 = Image.open(image_files[i + j]).convert('RGB')
        img2 = Image.open(image_files[i + j + 1]).convert('RGB')

        img1_tensor = F.to_tensor(img1).unsqueeze(0).to(device)
        img2_tensor = F.to_tensor(img2).unsqueeze(0).to(device)

        # 进行光流预测
        list_of_flows = model(img1_tensor, img2_tensor)
        predicted_flows = list_of_flows[-1]  # 获取光流

        if accumulated_flow is None:
            accumulated_flow = predicted_flows * (j / 10.0)  # 初始化累积光流
        else:
            accumulated_flow = accumulate_flow(accumulated_flow, predicted_flows * (j / 10.0))  # 叠加光流
        if j == step - 1:
            img3 = Image.open(image_files[i + j + 2]).convert('RGB')
            img3_tensor = F.to_tensor(img3).unsqueeze(0).to(device)

    # 使用光流进行帧预测
    predicted_next_frame = apply_flow(img2_tensor, accumulated_flow)

    # 计算损失（光流叠加后的图像与实际图像的差异）
    loss = entropy_loss(predicted_next_frame, img2_tensor)

    # 反向传播和优化
    loss.backward()  # 计算梯度
    optimizer.step()  # 更新权重

    print(f"Processed frame {i} to {i+step}, loss: {loss.item()}")


In [None]:
model = model.eval()

In [None]:
# 进行光流预测
print(num_frames)
for i in range(num_frames - step):
#for i in range(20):
    accumulated_flow = None  # 每次更新i初始化累积光流为None
    for j in range(step):
        print('img1=',i + j)
        print('img2=',i + j + 1)
        img1 = Image.open(image_files[i + j]).convert('RGB')
        img2 = Image.open(image_files[i + j + 1]).convert('RGB')
        # if i == 10:
        #     img1.save(f'1img{i}{j}.jpg')
        # 
        # 
        #     img2.save(f'2img{i}{j+1}.jpg')


        
        img1_tensor = F.to_tensor(img1).unsqueeze(0)
        img2_tensor = F.to_tensor(img2).unsqueeze(0)

        # img1_tensor, img2_tensor = preprocess(img1_tensor, img2_tensor)

        with torch.no_grad():
            list_of_flows = model(img1_tensor.to(device), img2_tensor.to(device))
            predicted_flows = list_of_flows[-1].cpu()# 获取光流（一般来说最后一个最准确）
            
        # 限制光流移动步长
        # predicted_flows = limit_flow_step(predicted_flows, max_flow_step)
        
        if accumulated_flow is None:
            accumulated_flow = predicted_flows*(j/10.0)  # 如果是第一帧光流，初始化累积光流
        else:
            accumulated_flow = accumulate_flow(accumulated_flow, predicted_flows*(j/10.0))  # 叠加光流，以1/6，2/6……的方式叠加
    
    
    print('shape=',predicted_flows.shape)

    flow_imgs = flow_to_image(predicted_flows)
    
    # 保存光流结果
    output_file = os.path.join(output_folder, f'accumulated_flow_{i:03d}_to_{i+step:03d}.png')
    plt.imsave(output_file, flow_imgs.squeeze().permute(1, 2, 0).numpy())

    print(f"Processed frame {i} to {i+step}")

    
    # 进行帧预测
    #expand_pixels=20
    #expanded_image = torch.nn.functional.pad(img2_tensor, (expand_pixels, expand_pixels, expand_pixels, expand_pixels), mode='constant', value=255)
    # 预测
    predicted_next_frame = apply_flow(img2_tensor, accumulated_flow)
    # 裁剪掉扩展的部分，恢复原始大小
    # cropped_warped_image = predicted_next_frame[:, :, expand_pixels:-expand_pixels, expand_pixels:-expand_pixels]
    # 保存预测结果
    output_file = os.path.join(output_PRE_folder, f'predicted{i+step+1:03d}using{i:03d}to{i+step:03d}.png')
    predicted_next_frame = predicted_next_frame.squeeze().permute(1, 2, 0).clamp(0, 1)
    plt.imsave(output_file, predicted_next_frame.numpy())
    print("---")
    
print("All frames processed.")