In [None]:
import cv2
import numpy as np
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
from tqdm import tqdm

def inpaint_video_borders(video_path, mask_path, output_path, prompt, batch_size=64):
    # loading video and mask with opencv
    video_capture = cv2.VideoCapture(video_path)
    mask_capture = cv2.VideoCapture(mask_path)

    fps = video_capture.get(cv2.CAP_PROP_FPS)
    width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))

    frames = []
    while True:
        ret, frame = video_capture.read()
        if not ret:
            break
        frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    mask_frames = []
    while True:
        ret, frame = mask_capture.read()
        if not ret:
            break
        mask_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY))

    video_capture.release()
    mask_capture.release()

    if len(frames) != len(mask_frames):
        raise ValueError("The number of frames in the video and mask must be the same.")

    # load  stable diff inpainting pipeline
    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
        dtype=torch.float16,
    )
    pipe = pipe.to("cuda")

    # ssplit frames and mask_frames into batches
    frame_batches = [frames[i:i + batch_size] for i in range(0, len(frames), batch_size)]
    mask_batches = [mask_frames[i:i + batch_size] for i in range(0, len(mask_frames), batch_size)]

    output_frames = []

    # run each of those batches through stable diff
    for i in tqdm(range(len(frame_batches)), desc="Processing batches"):
        batch_frames = frame_batches[i]
        batch_masks = mask_batches[i]

        for frame, mask in zip(batch_frames, batch_masks):
            frame_pil = Image.fromarray(frame)
            mask_pil = Image.fromarray(mask)

            result = pipe(prompt=prompt, image=frame_pil, mask_image=mask_pil).images[0]
            output_frames.append(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR))


        # images = [Image.fromarray(frame) for frame in batch_frames]
        # masks = [Image.fromarray(mask) for mask in batch_masks]

        # inpainted_images = pipe(prompt=prompt, image=images, mask_image=masks).images
        # output_frames.extend([cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) for img in inpainted_images])

    # combine outputted frames back into video
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    for frame in output_frames:
        frame = frame.astype(np.uint8)
        if len(frame.shape) == 2:  # convert grayscale to BGR
            frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
        out.write(frame)

    out.release()
    print(f"Inpainted video saved to {output_path}")

if __name__ == '__main__':
    dummy_video_path = "oneVideo/video/#122_Cleaning_Up_The_Beach_In_Chiba__Japan_pick_f_nm_np1_le_bad_2.avi"
    dummy_mask_path = "oneVideo/masks/#122_Cleaning_Up_The_Beach_In_Chiba__Japan_pick_f_nm_np1_le_bad_2_mask.avi"
    output_video_path = "output_video.avi"

    # Create a dummy video
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(dummy_video_path, fourcc, 30.0, (320, 240))
    for _ in range(90): # 3 second video at 30 fps
        frame = np.random.randint(0, 255, (240, 320, 3), dtype=np.uint8)
        out.write(frame)
    out.release()

    # create dummy mask video with white border
    mask_frame = np.zeros((240, 320), dtype=np.uint8)
    cv2.rectangle(mask_frame, (0, 0), (320, 240), 255, 20)
    out_mask = cv2.VideoWriter(dummy_mask_path, cv2.VideoWriter_fourcc(*'XVID'), 30.0, (320, 240), isColor=False)
    for _ in range(90):
        out_mask.write(mask_frame)
    out_mask.release()


    inpaint_video_borders(
        video_path=dummy_video_path,
        mask_path=dummy_mask_path,
        output_path=output_video_path,
        prompt="inpaint the borders",
        batch_size=8 
    )