In [1]:
import os
import glob
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import argparse
from RAFT.core.raft import RAFT
from RAFT.core.utils import flow_viz
from RAFT.core.utils.utils import InputPadder

DEVICE = 'cuda'

In [37]:
args = argparse.Namespace()
args.small = False
args.mixed_precision = True
args.alternate_corr = False

model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load("raft-things.pth"))
model = model.eval().cuda()

print("RAFT model loaded successfully!")

def load_image(imfile):
    img = Image.open(imfile).convert("RGB")
    img = np.array(img).astype(np.uint8)
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img[None].to(DEVICE)

def flow_to_mask(flow, threshold):
    if flow.dim() == 4:
        flow = flow[0]
    flow_np = flow.cpu().numpy()
    u, v = flow_np[0], flow_np[1]
    magnitude = np.sqrt(u**2 + v**2)
    magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
    mask = (magnitude > threshold).astype(np.uint8) * 255
    return mask

before_path = "Test_Dataset/Unstable/Before/*.png"
after_path = "Test_Dataset/Unstable/After/*.png"
motion_output_dir = "Test_Dataset/UnstableMotion/"

os.makedirs(motion_output_dir, exist_ok=True)

before_images = sorted(glob.glob(before_path))
after_images = sorted(glob.glob(after_path))

for before_img, after_img in zip(before_images, after_images):
    print(f"Processing: {before_img} and {after_img}")
    image1 = load_image(before_img)
    image2 = load_image(after_img)
    
    padder = InputPadder(image1.shape)
    image1, image2 = padder.pad(image1, image2)
    
    with torch.no_grad():
        _, flow_up = model(image1, image2, iters=20, test_mode=True)
    
    flow_up = padder.unpad(flow_up)
    mask = flow_to_mask(flow_up, threshold=0.2)
    
    mask_img = Image.fromarray(mask)
    output_filename = os.path.join(motion_output_dir, os.path.basename(after_img))
    mask_img.save(output_filename)
    
print("All optical flow masks have been processed and saved.")


  model.load_state_dict(torch.load("raft-things.pth"))


RAFT model loaded successfully!
Processing: Test_Dataset/Unstable/Before\render_00_before.png and Test_Dataset/Unstable/After\render_00_after.png
Processing: Test_Dataset/Unstable/Before\render_03_before.png and Test_Dataset/Unstable/After\render_03_after.png
Processing: Test_Dataset/Unstable/Before\render_04_before.png and Test_Dataset/Unstable/After\render_04_after.png
Processing: Test_Dataset/Unstable/Before\render_06_before.png and Test_Dataset/Unstable/After\render_06_after.png
Processing: Test_Dataset/Unstable/Before\render_07_before.png and Test_Dataset/Unstable/After\render_07_after.png
Processing: Test_Dataset/Unstable/Before\render_08_before.png and Test_Dataset/Unstable/After\render_08_after.png
Processing: Test_Dataset/Unstable/Before\render_101_before.png and Test_Dataset/Unstable/After\render_101_after.png
Processing: Test_Dataset/Unstable/Before\render_102_before.png and Test_Dataset/Unstable/After\render_102_after.png
Processing: Test_Dataset/Unstable/Before\render_105_