In [1]:
import os
import argparse
import struct
import cv2
import numpy as np
import torch
import torch.multiprocessing as mp
from typing import List, Tuple
from tqdm import tqdm
from gen_back_flow import *
# RAFT + helpers --------------------------------------------------------------
from raft import RAFT                      # <‑ ensure RAFT repo in PYTHONPATH
from utils.utils import InputPadder 
from pathlib import Path
import glob


image_paths = sorted(glob.glob('../UniControl_Video_Interpolation/data/jockey/images/*.png'))

gop_size = 8
pairs = []
out_paths = []

for i in range(0, len(image_paths) - gop_size + 1, gop_size):
    gop = image_paths[i:i + gop_size]
    ref_frame = gop[0]
    for j in range(1, len(gop)):
        pairs.append((ref_frame, gop[j]))
        out_paths.append('../UniControl_Video_Interpolation/data/jockey/optical_flow/'+ str(Path(gop[j]).stem + '.flo'))

# ✅ Print result
for pair in pairs:
    print(pair)

print(f"\nTotal flow pairs: {len(pairs)}")
print(out_paths)

('../UniControl_Video_Interpolation/data/jockey/images/im00001.png', '../UniControl_Video_Interpolation/data/jockey/images/im00002.png')
('../UniControl_Video_Interpolation/data/jockey/images/im00001.png', '../UniControl_Video_Interpolation/data/jockey/images/im00003.png')
('../UniControl_Video_Interpolation/data/jockey/images/im00001.png', '../UniControl_Video_Interpolation/data/jockey/images/im00004.png')
('../UniControl_Video_Interpolation/data/jockey/images/im00001.png', '../UniControl_Video_Interpolation/data/jockey/images/im00005.png')
('../UniControl_Video_Interpolation/data/jockey/images/im00001.png', '../UniControl_Video_Interpolation/data/jockey/images/im00006.png')
('../UniControl_Video_Interpolation/data/jockey/images/im00001.png', '../UniControl_Video_Interpolation/data/jockey/images/im00007.png')
('../UniControl_Video_Interpolation/data/jockey/images/im00001.png', '../UniControl_Video_Interpolation/data/jockey/images/im00008.png')
('../UniControl_Video_Interpolation/data/

In [2]:
def load_image(path: str, device: torch.device) -> torch.Tensor:
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (512,512))
    tensor = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0) / 255.0
    return tensor.to(device)

# -----------------------------------------------------------------------------
# Core batch routine
# -----------------------------------------------------------------------------

def process_batch(model: torch.nn.Module,
                  pairs: List[Tuple[str, str]],
                  out_paths: List[str],
                  device: torch.device) -> None:
    imgs1, imgs2, unpads = [], [], []
    for ref_path, tgt_path in pairs:
        ref = load_image(ref_path, device)
        tgt = load_image(tgt_path, device)
        padder = InputPadder(ref.shape)
        ref, tgt = padder.pad(ref, tgt)
        imgs1.append(ref)
        imgs2.append(tgt)
        unpads.append(padder.unpad)

    imgs1 = torch.cat(imgs1, dim=0)
    imgs2 = torch.cat(imgs2, dim=0)

    with torch.no_grad():
        _, flows = model(imgs1, imgs2, iters=20, test_mode=True)

    # Write each flow
    for flow, unpad, out_path in zip(flows, unpads, out_paths):
        flow = unpad(flow[None])[0].permute(1, 2, 0).cpu().numpy()
        write_flo_file(flow, out_path)

In [3]:
device = torch.device("cuda:1")

# Load model **after** setting device
raft_args = argparse.Namespace(small=False, mixed_precision=False, alternate_corr=False, model='models/raft-sintel.pth')
model = RAFT(raft_args)
state_dict = torch.load('models/raft-sintel.pth', map_location=device)
# Remove 'module.' prefix if it exists
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith('module.'):
        new_k = k[len('module.'):]
    else:
        new_k = k
    new_state_dict[new_k] = v

model.load_state_dict(new_state_dict, strict=False) 
# model.load_state_dict(torch.load(ckpt, map_location=device))
model.to(device)
model.eval()
process_batch(model, pairs[0:14], out_paths[0:14], device)

  state_dict = torch.load('models/raft-sintel.pth', map_location=device)
  with autocast(enabled=self.args.mixed_precision):
  with autocast(enabled=self.args.mixed_precision):
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  with autocast(enabled=self.args.mixed_precision):
