In [20]:
# !pip install ffmpeg-python, !pip uninstall gradio

In [21]:
from model.RIFE import Model
import os
import cv2
import torch
import argparse
from torch.nn import functional as F
import warnings
import numpy as np
from PIL import Image
from tqdm import tqdm 
warnings.filterwarnings("ignore")

In [None]:
model = Model()
# NOTE: Change the path accordingly
model.flownet.load_state_dict(torch.load(os.path.join('./new_train_log', 'flownet.pkl'), map_location=torch.device('cpu')))
model.eval()

def save_video(frames, video_path):
    print(frames[0].shape)
    FPS, W, H = 30, frames[0].shape[1], frames[0].shape[0] 
    out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*"xvid"), FPS, (W, H))

    for frame in frames:
        # print(frame.shape)
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame)

    out.release()
    cv2.destroyAllWindows()

def generate_intermidiate_frames(img0, img1, exp_num_frames):
    img0 = (torch.tensor(img0.transpose(2, 0, 1)) / 255.).unsqueeze(0)
    img1 = (torch.tensor(img1.transpose(2, 0, 1)) / 255.).unsqueeze(0)
    
    # padding
    n, c, h, w = img0.shape
    ph = ((h - 1) // 32 + 1) * 32
    pw = ((w - 1) // 32 + 1) * 32
    padding = (0, pw - w, 0, ph - h)
    img0 = F.pad(img0, padding)
    img1 = F.pad(img1, padding)
    
    # print(img0.shape, img1.shape)
    
    
    # generating 
    img_list = [img0, img1]
    for i in range(exp_num_frames):
        tmp = []
        for j in range(len(img_list) - 1):
            mid = model.inference(img_list[j], img_list[j + 1])
            tmp.append(img_list[j])
            tmp.append(mid)
        tmp.append(img1)
        img_list = tmp
        
    # print(len(img_list))

    # convert to pil
    final_images = []
    for img in img_list:
        final_images.append(Image.fromarray(np.uint8(img.squeeze().permute(1, 2, 0).detach()* 255)).convert('RGB'))
    
    return final_images
            

def inference(type, data, exp_num_frames):
    if type == 0:
        frames = extract_frames(data)
        print(len(frames))
        frame_list = []
        for i in tqdm(range(len(frames)-1)):
            tmp_list = generate_intermidiate_frames(frames[i], frames[i+1], exp_num_frames)
            frame_list += tmp_list[:-1]
        frame_list += [Image.fromarray(frames[-1])]
        print(len(frame_list))

        video_path = os.path.splitext(data)[0] + "_processed.avi"  
        save_video([np.array(img) for img in frame_list], video_path)  
        print(video_path)
        
        return convert_to_mp4(video_path), frame_list
        # return video_path ,frame_list
    if type == 1: # img
        print("IMAGE")
        img0, img1  = data
        img0 = np.array(img0)
        img1 = np.array(img1)
        
        return generate_intermidiate_frames(img0, img1, exp_num_frames)
        
        

### Gradio App

In [46]:
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import os
import ffmpeg

In [None]:
def convert_to_mp4(video_path):
    output_path = os.path.splitext(video_path)[0] + "_converted.mp4"
    try:
        ffmpeg.input(video_path).output(output_path, vcodec='libx264', acodec='aac').run(overwrite_output=True)
        return output_path
    except Exception as e:
        print(f"Error converting video: {e}")
        return video_path 
    
def process_video(video, exp_num_frames):
    video_path, frame_list = inference(0, video, exp_num_frames)
    return video_path, frame_list # actually we need to return video path after processing

# Extract frames from video for gallery display
def extract_frames(video_path):
    frames = []
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame_rgb)
    cap.release()
    return frames

def process_images(img1, img2, exp_num_frames):
    # print("IMAGE")
    return inference(1, [img1, img2], exp_num_frames)
    # return 

# immediate conversion
def handle_video_upload(video):
    converted_video_path = convert_to_mp4(video)
    return converted_video_path

def interpolate(choice, video=None, img1=None, img2=None, exp_num_frames=None):
    if choice == "Upload a video":
        processed_video, frames = process_video(video, exp_num_frames)
        return processed_video, None, frames  # Video output, no images, and frames gallery
    elif choice == "Upload two images":
        if img1 is not None and img2 is not None:
            generated_frames = process_images(img1, img2, exp_num_frames)  # Placeholder for interpolated frames
            return None, generated_frames, None  # No video, only images, no frames gallery
        else:
            return None, None, None

with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>Video and Image Interpolation</h1>")
    choice = gr.Radio(["Upload a video", "Upload two images"], label="Choose Input Type")


    num_frames = gr.Slider(1, 10, value=2, step=1, label="Number of Interpolated Frames (2^n-1)")

    # Video input and processing
    with gr.Row() as video_section:
        video_input = gr.Video(label="Upload Video", interactive=True)
        video_output = gr.Video(label="Processed Video", interactive=True)
        frame_gallery = gr.Gallery(label="Processed Video Frames")

        
    # Image input and processing
    with gr.Row() as image_section:
        img1_input = gr.Image(label="Upload First Image", type="pil")
        img2_input = gr.Image(label="Upload Second Image", type="pil")
        # num_frames = gr.Slider(1, 5, value=2, step=1, label="Number of Interpolated Frames")
        image_output = gr.Gallery(label="Generated Frames")



    
    # Initially hide both sections
    video_section.visible = False
    image_section.visible = False
    
     # Toggle input visibility based on choice
    def toggle_inputs(choice):
        if choice == "Upload a video":
            return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
        elif choice == "Upload two images":
            return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
    
    choice.change(toggle_inputs, inputs=[choice], outputs=[video_section, image_section, video_output, frame_gallery])

    # Handle video upload to convert and display
    video_input.upload(convert_to_mp4, inputs=video_input, outputs=video_input)

    # Processing button
    process_button = gr.Button("Process")

    # Ensure `interpolate` always returns outputs for video, gallery, and frames
    process_button.click(
        interpolate,
        inputs=[choice, video_input, img1_input, img2_input, num_frames],
        outputs=[video_output, image_output, frame_gallery]
    )

    gr.Markdown("### Instructions")
    gr.Markdown("""
        1. **Select** "Upload a video" to process an entire video offline or "Upload two images" to interpolate images.
        2. **Adjust Settings** (e.g., number of frames) if available.
        3. **Click Process** to view the interpolated video or frames. 
        """)
    
demo.launch(debug=True)

In [None]:
demo.close()

Closing server running on port: 7860
