In [191]:
import os
import glob
import torch
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
from cv2 import waitKey, destroyAllWindows
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights, raft_small, Raft_Small_Weights
from torchvision.utils import flow_to_image
device = "cuda" if torch.cuda.is_available() else "cpu"

# 配置
image_folder = 'overDataSet1'  # 图片文件夹路径
output_folder = 'lightFlowOutput'  # 保存输出的光流文件夹路径
output_PRE_folder = 'lightFlowOutputPre' # 保存输出的预测图片文件夹路径
os.makedirs(output_folder, exist_ok=True)
os.makedirs(output_PRE_folder, exist_ok=True)

# 加载模型
# weights1= Raft_Large_Weights.DEFAULT
weights1= Raft_Large_Weights.C_T_SKHT_K_V2
weights2= Raft_Small_Weights.DEFAULT
model = raft_large(weights=weights1, progress=False).to(device)
model = model.eval()

# 图片预处理函数，像素点需要被8整除
# def preprocess(img1, img2):
#     img1 = F.resize(img1, size=[520, 960], antialias=False)
#     img2 = F.resize(img2, size=[520, 960], antialias=False)
#     return transforms(img1, img2)

# 获取所有图片文件
image_files = sorted(glob.glob(os.path.join(image_folder, '*.png')))
# 针对数字进行排序
image_files = sorted(image_files, key=lambda x: int((os.path.basename(x).split('.')[0]).split('_')[-1]))
print(image_files )
num_frames = len(image_files)



['overDataSet1\\0.PNG', 'overDataSet1\\1.PNG', 'overDataSet1\\2.PNG', 'overDataSet1\\3.PNG', 'overDataSet1\\4.PNG', 'overDataSet1\\5.PNG', 'overDataSet1\\6.PNG', 'overDataSet1\\7.PNG', 'overDataSet1\\8.PNG', 'overDataSet1\\9.PNG', 'overDataSet1\\10.PNG', 'overDataSet1\\11.PNG', 'overDataSet1\\12.PNG', 'overDataSet1\\13.PNG', 'overDataSet1\\14.PNG', 'overDataSet1\\15.PNG', 'overDataSet1\\16.PNG', 'overDataSet1\\17.PNG', 'overDataSet1\\18.PNG', 'overDataSet1\\19.PNG', 'overDataSet1\\20.PNG', 'overDataSet1\\21.PNG', 'overDataSet1\\22.PNG', 'overDataSet1\\23.PNG', 'overDataSet1\\24.PNG', 'overDataSet1\\25.PNG', 'overDataSet1\\26.PNG', 'overDataSet1\\27.PNG', 'overDataSet1\\28.PNG', 'overDataSet1\\29.PNG', 'overDataSet1\\30.PNG', 'overDataSet1\\31.PNG', 'overDataSet1\\32.PNG', 'overDataSet1\\33.PNG', 'overDataSet1\\34.PNG', 'overDataSet1\\35.PNG', 'overDataSet1\\36.PNG', 'overDataSet1\\37.PNG', 'overDataSet1\\38.PNG', 'overDataSet1\\39.PNG', 'overDataSet1\\40.PNG', 'overDataSet1\\41.PNG', '

In [192]:
# 帧预测函数
def apply_flow(image, flow):
    flow = flow.permute(0, 2, 3, 1)  # (N, H, W, 2)
    h, w = flow.shape[1:3]
    y_coords, x_coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
    y_coords, x_coords = y_coords.float(), x_coords.float()
    y_coords, x_coords = y_coords.to(flow.device), x_coords.to(flow.device)
    
    # 计算新的坐标
    new_x_coords = x_coords + flow[:, :, :, 0]
    new_y_coords = y_coords + flow[:, :, :, 1]
    
    # 将坐标限制在图像边界内进行裁剪
    new_x_coords = torch.clamp(new_x_coords, 0, w - 1)
    new_y_coords = torch.clamp(new_y_coords, 0, h - 1)
    # 归一化坐标
    new_x_coords = (new_x_coords / (w - 1)) * 2 - 1
    new_y_coords = (new_y_coords / (h - 1)) * 2 - 1
    
    
    # 使用bicubic插值采样
    grid = torch.stack([new_x_coords, new_y_coords], dim=-1)
    warped_image = torch.nn.functional.grid_sample(image, grid, mode='bicubic', padding_mode='border', align_corners=True)
    
    mean_flow_x = flow[:, :, :, 0].mean()
    mean_flow_y = flow[:, :, :, 1].mean()
    print(f'mean_flow_x={mean_flow_x}, mean_flow_y={mean_flow_y}')
    # if mean_flow_x > 0:  # 如果光流向右
    #     cropped_warped_image = warped_image[:, :, :, :-int(mean_flow_x)]  # 裁剪右边
    # else:  # 如果光流向左
    #     cropped_warped_image = warped_image[:, :, :, int(mean_flow_x):]  # 裁剪左边
    #     
    # 
    # # 对结果进行裁剪（例如裁剪掉边缘10个像素）
    # # cropped_warped_image = warped_image[:, :, :, :-10]
    return warped_image



max_flow_step = 5.0
step = 5

# 累加光流函数
def accumulate_flow(accumulated_flow, new_flow):
    return accumulated_flow + new_flow

# 限制光流步长函数
def limit_flow_step(flow, max_step):
    return torch.clamp(flow, min=-max_step, max=max_step)  # 限制光流的最大步长

# 进行光流预测
print(num_frames)
for i in range(num_frames - step):
#for i in range(20):
    accumulated_flow = None  # 每次更新i初始化累积光流为None
    for j in range(step):
        print('img1=',i + j)
        print('img2=',i + j + 1)
        img1 = Image.open(image_files[i + j]).convert('RGB')
        img2 = Image.open(image_files[i + j + 1]).convert('RGB')
        # if i == 10:
        #     img1.save(f'1img{i}{j}.jpg')
        # 
        # 
        #     img2.save(f'2img{i}{j+1}.jpg')


        
        img1_tensor = F.to_tensor(img1).unsqueeze(0)
        img2_tensor = F.to_tensor(img2).unsqueeze(0)

        # img1_tensor, img2_tensor = preprocess(img1_tensor, img2_tensor)

        with torch.no_grad():
            list_of_flows = model(img1_tensor.to(device), img2_tensor.to(device))
            predicted_flows = list_of_flows[-1].cpu()# 获取光流（一般来说最后一个最准确）
            
        # 限制光流移动步长
        # predicted_flows = limit_flow_step(predicted_flows, max_flow_step)
        
        if accumulated_flow is None:
            accumulated_flow = predicted_flows*((j+1)/6)  # 如果是第一帧光流，初始化累积光流
        else:
            accumulated_flow = accumulate_flow(accumulated_flow, predicted_flows*((j+1)/6))  # 叠加光流，以1/6，2/6……的方式叠加
    
    
    print('shape=',predicted_flows.shape)

    flow_imgs = flow_to_image(predicted_flows)
    
    # 保存光流结果
    output_file = os.path.join(output_folder, f'accumulated_flow_{i:03d}_to_{i+step:03d}.png')
    plt.imsave(output_file, flow_imgs.squeeze().permute(1, 2, 0).numpy())

    print(f"Processed frame {i} to {i+step}")

    
    # 进行帧预测
    #expand_pixels=20
    #expanded_image = torch.nn.functional.pad(img2_tensor, (expand_pixels, expand_pixels, expand_pixels, expand_pixels), mode='constant', value=255)
    # 预测
    predicted_next_frame = apply_flow(img2_tensor, accumulated_flow)
    # 裁剪掉扩展的部分，恢复原始大小
    # cropped_warped_image = predicted_next_frame[:, :, expand_pixels:-expand_pixels, expand_pixels:-expand_pixels]
    # 保存预测结果
    output_file = os.path.join(output_PRE_folder, f'predicted{i+step+1:03d}using{i:03d}to{i+step:03d}.png')
    predicted_next_frame = predicted_next_frame.squeeze().permute(1, 2, 0).clamp(0, 1)
    plt.imsave(output_file, predicted_next_frame.numpy())
    print("---")
    
print("All frames processed.")

241
img1= 0
img2= 1
img1= 1
img2= 2
img1= 2
img2= 3
img1= 3
img2= 4
img1= 4
img2= 5
shape= torch.Size([1, 2, 200, 200])
Processed frame 0 to 5
mean_flow_x=16.45806312561035, mean_flow_y=14.311131477355957
---
img1= 1
img2= 2
img1= 2
img2= 3
img1= 3
img2= 4
img1= 4
img2= 5
img1= 5
img2= 6
shape= torch.Size([1, 2, 200, 200])
Processed frame 1 to 6
mean_flow_x=14.06870174407959, mean_flow_y=12.397560119628906
---
img1= 2
img2= 3
img1= 3
img2= 4
img1= 4
img2= 5
img1= 5
img2= 6
img1= 6
img2= 7
shape= torch.Size([1, 2, 200, 200])
Processed frame 2 to 7
mean_flow_x=10.814901351928711, mean_flow_y=8.46880054473877
---
img1= 3
img2= 4
img1= 4
img2= 5
img1= 5
img2= 6
img1= 6
img2= 7
img1= 7
img2= 8
shape= torch.Size([1, 2, 200, 200])
Processed frame 3 to 8
mean_flow_x=10.693026542663574, mean_flow_y=3.6985394954681396
---
img1= 4
img2= 5
img1= 5
img2= 6
img1= 6
img2= 7
img1= 7
img2= 8
img1= 8
img2= 9
shape= torch.Size([1, 2, 200, 200])
Processed frame 4 to 9
mean_flow_x=9.416265487670898, mean_f