In [23]:
import os
import glob
import torch
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights, raft_small, Raft_Small_Weights
from torchvision.utils import flow_to_image
device = "cuda" if torch.cuda.is_available() else "cpu"

# 配置
image_folder = 'overDataSet'  # 替换为你的图片文件夹路径
output_folder = 'lightFlowOutput'  # 替换为你希望保存输出的文件夹路径
output_PRE_folder = 'lightFlowOutputPre'
os.makedirs(output_folder, exist_ok=True)
os.makedirs(output_PRE_folder, exist_ok=True)

# 加载模型
# weights1= Raft_Large_Weights.DEFAULT
weights1= Raft_Large_Weights.C_T_SKHT_K_V2
weights2= Raft_Small_Weights.DEFAULT
model = raft_large(weights=weights1, progress=False).to(device)
model = model.eval()

# 图片预处理函数，像素点需要被8整除
# def preprocess(img1, img2):
#     img1 = F.resize(img1, size=[520, 960], antialias=False)
#     img2 = F.resize(img2, size=[520, 960], antialias=False)
#     return transforms(img1, img2)

# 获取所有图片文件
image_files = sorted(glob.glob(os.path.join(image_folder, '*.png')))
num_frames = len(image_files)



Downloading: "https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth" to C:\Users\BLKDASH/.cache\torch\hub\checkpoints\raft_large_C_T_SKHT_K_V2-b5c70766.pth


In [24]:
# 帧预测函数
def apply_flow(image, flow):
        flow = flow.permute(0, 2, 3, 1)  # (N, H, W, 2)
        h, w = flow.shape[1:3]
        y_coords, x_coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
        y_coords, x_coords = y_coords.float(), x_coords.float()
        y_coords, x_coords = y_coords.to(flow.device), x_coords.to(flow.device)
        
        # 计算新的坐标
        new_x_coords = x_coords + flow[:, :, :, 0]
        new_y_coords = y_coords + flow[:, :, :, 1]
        
        # 归一化坐标
        new_x_coords = (new_x_coords / (w - 1)) * 2 - 1
        new_y_coords = (new_y_coords / (h - 1)) * 2 - 1
        
        # 使用双线性插值采样
        grid = torch.stack([new_x_coords, new_y_coords], dim=-1)
        warped_image = torch.nn.functional.grid_sample(image, grid, mode='bicubic', padding_mode='border', align_corners=True)
        return warped_image


# 进行光流预测
for i in range(num_frames - 1):
    img1 = Image.open(image_files[i]).convert('RGB')
    img2 = Image.open(image_files[i + 1]).convert('RGB')

    img1_tensor = F.to_tensor(img1).unsqueeze(0)
    img2_tensor = F.to_tensor(img2).unsqueeze(0)

    # img1_tensor, img2_tensor = preprocess(img1_tensor, img2_tensor)

    with torch.no_grad():
        list_of_flows = model(img1_tensor.to(device), img2_tensor.to(device))
        predicted_flows = list_of_flows[-1].cpu()

    flow_imgs = flow_to_image(predicted_flows)
    
    # 保存光流结果
    output_file = os.path.join(output_folder, f'flow_{i:03d}.png')
    plt.imsave(output_file, flow_imgs.squeeze().permute(1, 2, 0).numpy())

    print(f"Processed frame {i} to {i+1}")

    
    # 进行帧预测
    predicted_next_frame = apply_flow(img2_tensor, predicted_flows)
    # 保存预测结果
    output_file = os.path.join(output_PRE_folder, f'predicted_frame_{i+2:03d}.png')
    predicted_next_frame = predicted_next_frame.squeeze().permute(1, 2, 0).clamp(0, 1)
    plt.imsave(output_file, predicted_next_frame.numpy())

print("All frames processed.")

Processed frame 0 to 1
Processed frame 1 to 2
Processed frame 2 to 3
Processed frame 3 to 4
Processed frame 4 to 5
Processed frame 5 to 6
Processed frame 6 to 7
Processed frame 7 to 8
Processed frame 8 to 9
Processed frame 9 to 10
Processed frame 10 to 11
Processed frame 11 to 12
Processed frame 12 to 13
Processed frame 13 to 14
Processed frame 14 to 15
Processed frame 15 to 16
Processed frame 16 to 17
Processed frame 17 to 18
Processed frame 18 to 19
Processed frame 19 to 20
Processed frame 20 to 21
Processed frame 21 to 22
Processed frame 22 to 23
Processed frame 23 to 24
Processed frame 24 to 25
Processed frame 25 to 26
Processed frame 26 to 27
Processed frame 27 to 28
Processed frame 28 to 29
Processed frame 29 to 30
Processed frame 30 to 31
Processed frame 31 to 32
Processed frame 32 to 33
Processed frame 33 to 34
Processed frame 34 to 35
Processed frame 35 to 36
Processed frame 36 to 37
Processed frame 37 to 38
Processed frame 38 to 39
Processed frame 39 to 40
Processed frame 40 