In [19]:
# !pip install ffmpeg-python, !pip uninstall gradio

In [20]:
import os
import cv2
import torch
import argparse
from torch.nn import functional as F
import warnings
import numpy as np
from PIL import Image
from tqdm import tqdm 
warnings.filterwarnings("ignore")

In [None]:
# change the model type to context based models
from choose_model import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
# NOTE: Change the path accordingly
model.flownet.load_state_dict(torch.load(pretrained_model_path, map_location=device))
model.eval()


In [22]:
font = cv2.FONT_HERSHEY_SIMPLEX
fontScale = 0.8
thickness = 2

def save_video(frames, video_path):
    print(frames[0].shape)
    FPS, W, H = 30, frames[0].shape[1], frames[0].shape[0] 
    out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*"xvid"), FPS, (W, H))

    for frame in frames:
        # print(frame.shape)
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame)

    out.release()
    cv2.destroyAllWindows()

def write_text(image, is_gen=True):
    org = (15, image.shape[0]-15)

    color = (255, 41, 41) if is_gen else (21, 245, 186)
    text = "GENERATED" if is_gen else "ORIGINAL"
    image = cv2.putText(image.copy(), text, org, font, 
                    fontScale, color, thickness, cv2.LINE_AA)
    return image
    

def generate_intermidiate_frames(img0_0, img0_1, img1_0, img1_1, exp_num_frames):
    img0_0 = (torch.tensor(img0_0.transpose(2, 0, 1)) / 255.).unsqueeze(0) # 1, 3, 224, 224
    img0_1 = (torch.tensor(img0_1.transpose(2, 0, 1)) / 255.).unsqueeze(0) # 1, 3, 224, 224
    img1_0 = (torch.tensor(img1_0.transpose(2, 0, 1)) / 255.).unsqueeze(0)
    img1_1 = (torch.tensor(img1_1.transpose(2, 0, 1)) / 255.).unsqueeze(0)
    
    # padding
    n, c, h, w = img0_0.shape
    ph = ((h - 1) // 32 + 1) * 32
    pw = ((w - 1) // 32 + 1) * 32
    padding = (0, pw - w, 0, ph - h)
    img0_0 = F.pad(img0_0, padding)
    img0_1 = F.pad(img0_1, padding)
    img1_0 = F.pad(img1_0, padding)
    img1_1 = F.pad(img1_1, padding)
    
    # print(img0.shape, img1.shape)
    
    
    # generating 
    img_list = [img0_0, img0_1, img1_0, img1_1]
    for i in range(exp_num_frames):
        tmp = []
        for j in range(0, len(img_list)-3, 1):
            mid = model.inference(torch.concat((img_list[j], img_list[j+1]),1), torch.concat((img_list[j+2], img_list[j+3]),1))
            # tmp.append(img_list[j])
            # tmp.append(img_list[j+1])
            tmp.append(mid)
        
        c = 0
        for idx in range(2, len(img_list)-1, 1):
            img_list.insert(idx+c, tmp[c])
            c+=1
            
        # tmp.append(img1_0)
        # tmp.append(img1_1)
        # img_list = tmp
        
    # print(len(img_list))
    c_img = -1
    num_gen_frame = (2**exp_num_frames)-1
    for idx, img in enumerate(img_list):
        img = np.uint8(img.squeeze().permute(1, 2, 0).detach()* 255)
        if c_img == 0 or c_img == -1: # 1st frame
           img = write_text(img, is_gen=False)
           c_img += 1
        elif c_img == num_gen_frame+1: #2nd last frame
           img = write_text(img, is_gen=False)
           c_img +=1
        elif c_img == num_gen_frame+2:
            img = write_text(img, is_gen=False)
            c_img = -1
        else:
           img = write_text(img, is_gen=True)
           c_img += 1
        img_list[idx] = img
             
    

    # convert to pil
    final_images = []
    for img in img_list:
        final_images.append(Image.fromarray(np.uint8(img)).convert('RGB'))
    
    return final_images

def compute_avg_l1_diff(images):
    frames = [np.array(img) for img in images]
    
    total_l1_loss = 0.0
    num_pairs = len(frames) - 1  # Number of adjacent pairs
    
    for i in range(num_pairs):
        l1_loss = np.mean(np.abs(frames[i] - frames[i + 1]))
        total_l1_loss += l1_loss
    
    average_l1_loss = total_l1_loss / num_pairs if num_pairs > 0 else 0.0
    
    return average_l1_loss


def inference(type, data, exp_num_frames):
    if type == 0:
        frames = extract_frames(data)
        print(len(frames))
        frame_list = []
        print('Starting')
        for i in tqdm(range(0, len(frames)-3, 2), total=(len(frames)-3)//2):
            tmp_list = generate_intermidiate_frames(frames[i], frames[i+1], frames[i+2], frames[i+3], exp_num_frames)
            frame_list += tmp_list[:-2]
        # print(tmp_list[-1].shape)
        frame_list += [tmp_list[-2]]
        frame_list += [tmp_list[-1]]
        print(len(frame_list))

        video_path = os.path.splitext(data)[0] + "_processed.avi"  
        save_video([np.array(img) for img in frame_list], video_path)  
        print(video_path)
        
        average_l1_loss = compute_avg_l1_diff(frame_list)
        print("Avg l1 difference =", average_l1_loss)
        
        return convert_to_mp4(video_path), frame_list, average_l1_loss
        # return video_path ,frame_list
    if type == 1: # img
        print("IMAGE")
        img0_0, img0_1, img1_0, img1_1  = data
        img0_0, img0_1, img1_0, img1_1 = np.array(img0_0), np.array(img0_1), np.array(img1_0), np.array(img1_1)
        
        frame_list  = generate_intermidiate_frames(img0_0, img0_1, img1_0, img1_1, exp_num_frames)
        
        average_l1_loss = compute_avg_l1_diff(frame_list)
        print("Avg l1 difference =", average_l1_loss)
        
        
        return frame_list, average_l1_loss
        
        

### Gradio App

In [23]:
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import os
import ffmpeg

In [None]:
    
def convert_to_mp4(video_path):
    output_path = os.path.splitext(video_path)[0] + "_converted.mp4"
    try:
        ffmpeg.input(video_path).output(output_path, vcodec='libx264', acodec='aac').run(overwrite_output=True)
        return output_path
    except Exception as e:
        print(f"Error converting video: {e}")
        return video_path 
    
def process_video(video, exp_num_frames):
    video_path, frame_list, average_l1_loss = inference(0, video, exp_num_frames)
    return video_path, frame_list, average_l1_loss # actually we need to return video path after processing

# Extract frames from video for gallery display
def extract_frames(video_path):
    frames = []
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame_rgb)
    cap.release()
    return frames

def process_images(img1, img1_1, img2, img2_1, exp_num_frames):
    # print("IMAGE")average_l1_loss
    return inference(1, [img1, img1_1, img2, img2_1], exp_num_frames)
    # return 

# immediate conversion
def handle_video_upload(video):
    converted_video_path = convert_to_mp4(video)
    return converted_video_path

def interpolate(choice, video=None, img1=None, img1_1=None, img2=None, img2_1=None, exp_num_frames=None):
    if choice == "Upload a video":
        processed_video, frames, average_l1_loss = process_video(video, exp_num_frames)
        return processed_video, None, frames, average_l1_loss  # Video output, no images, and frames gallery
    elif choice == "Upload two images":
        if img1 is not None and img2 is not None:
            generated_frames, average_l1_loss = process_images(img1, img1_1, img2, img2_1, exp_num_frames)  # Placeholder for interpolated frames
            return None, generated_frames, None, average_l1_loss  # No video, only images, no frames gallery
        else:
            return None, None, None, None

with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>Missing Frames Interpolation</h1>")
    choice = gr.Radio(["Upload a video", "Upload two images"], label="Choose Input Type")


    num_frames = gr.Slider(1, 10, value=2, step=1, label="Number of Interpolated Frames (2^n-1)")

    # Video input and processing
    with gr.Row() as video_section:
        video_input = gr.Video(label="Upload Video", interactive=True)
        video_output = gr.Video(label="Processed Video", interactive=True)
        frame_gallery = gr.Gallery(label="Processed Video Frames")

        
    # Image input and processing
    with gr.Row() as image_section:
        img1_input = gr.Image(label="Upload 1st Image", type="pil")
        img1_1_input = gr.Image(label="Upload 2nd Image", type="pil")
        img2_input = gr.Image(label="Upload 3rd Image", type="pil")
        img2_1_input = gr.Image(label="Upload 4th Image", type="pil")
        # num_frames = gr.Slider(1, 5, value=2, step=1, label="Number of Interpolated Frames")
        image_output = gr.Gallery(label="Generated Frames")


    l1_loss_output = gr.Textbox(label="Average L1 Difference between adjacent frames", interactive=False)

    
    # Initially hide both sections
    video_section.visible = False
    image_section.visible = False
    l1_loss_output.visible = False
    
     # Toggle input visibility based on choice
    def toggle_inputs(choice):
        if choice == "Upload a video":
            return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
        elif choice == "Upload two images":
            return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
    
    choice.change(toggle_inputs, inputs=[choice], outputs=[video_section, image_section, video_output, frame_gallery, l1_loss_output])

    # Handle video upload to convert and display
    video_input.upload(convert_to_mp4, inputs=video_input, outputs=video_input)

    # Processing button
    process_button = gr.Button("Process")

    # Ensure `interpolate` always returns outputs for video, gallery, and frames
    process_button.click(
        interpolate,
        inputs=[choice, video_input, img1_input, img1_1_input, img2_input, img2_1_input, num_frames],
        outputs=[video_output, image_output, frame_gallery, l1_loss_output]
    )

    gr.Markdown("### Instructions")
    gr.Markdown("""
        1. **Select** "Upload a video" to process an entire video offline or "Upload two images" to interpolate images.
        2. **Adjust Settings** (e.g., number of frames) if available.
        3. **Click Process** to view the interpolated video or frames. 
        """)
    
demo.launch(debug=True)

In [None]:
demo.close()

Closing server running on port: 7860
