In [None]:
!pip install scikit-image opencv-python torch torchvision ffmpeg-python scikit-video

In [None]:
!pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!apt-get install p7zip-full -y

In [None]:
import os, sys, cv2, glob, torch, numpy as np
from google.colab import drive
from skimage.metrics import structural_similarity as ssim_metric
from PIL import Image

In [None]:
!git clone https://github.com/hzwer/arXiv2020-RIFE

In [None]:
!gdown --id 1APIzVeI-4ZZCEuIRE1m6WYfSCaOsi_7_
!7z e RIFE_trained_model_v3.6.zip

In [None]:
!mkdir /content/arXiv2020-RIFE/train_log
%cd /content/arXiv2020-RIFE/train_log
%cd /content/arXiv2020-RIFE/
!gdown --id 1i3xlKb7ax7Y70khcTcuePi6E7crO_dFc
!pip install git+https://github.com/rk-exxec/scikit-video.git@numpy_deprecation

In [None]:
drive.mount('/content/drive')
!ls /content/drive/MyDrive/imgGenPrj_project/resource

In [None]:
video = '/content/drive/MyDrive/imgGenPrj_project/resource/man_drinking_short.mp4'

In [None]:
output_folder = '/content/drive/MyDrive/imgGenPrj_project/resource'
os.makedirs(output_folder, exist_ok=True)


In [None]:
cap = cv2.VideoCapture(video)
fps = int(cap.get(cv2.CAP_PROP_FPS))
total_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"FPS         : {fps} \nTotal Frames: {total_frame}")

In [None]:
import cv2, os

def save_frame(position, name):
    cap.set(cv2.CAP_PROP_POS_FRAMES, position)
    ret, frame = cap.read()
    if ret:
        path = os.path.join(output_folder, name)
        cv2.imwrite(path, frame)
        print(f"{name} saved: {path}")
        return path
    else:
        print(f" Failed to read frame at position {position}")
    return None

In [None]:
first_frame_path = save_frame(0, "ff.jpg")
middle_frame_path = save_frame(total_frame // 2, "fm.jpg")
last_frame_path = save_frame(-1, "fl.jpg")  

In [None]:
ff = cv2.imread(first_frame_path)
fl = cv2.imread(last_frame_path)
fm = cv2.imread(middle_frame_path)

In [None]:
!python3 inference_video.py --exp=1 --video=/content/arXiv2020-RIFE/man_drinking_short.mp4

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
print(first_frame_path, os.path.exists(first_frame_path))
print(last_frame_path, os.path.exists(last_frame_path))
print(middle_frame_path, os.path.exists(middle_frame_path))

In [None]:
import cv2

first_frame = cv2.imread(first_frame_path)   
last_frame  = cv2.imread(last_frame_path)
middle_frame = cv2.imread(middle_frame_path)

In [None]:
print(first_frame.shape, last_frame.shape, middle_frame.shape)

In [None]:
import torch
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
first_tensor = torch.from_numpy(first_frame.transpose(2,0,1)).unsqueeze(0).float() / 255.0
last_tensor  = torch.from_numpy(last_frame.transpose(2,0,1)).unsqueeze(0).float() / 255.0

In [None]:
first_tensor = first_tensor.to(device)
last_tensor  = last_tensor.to(device)

In [None]:
from skimage.metrics import structural_similarity as ssim_metric

def compute_ssim(gen, ref):
    gen = cv2.resize(gen, (ref.shape[1], ref.shape[0]))
    return ssim_metric(cv2.cvtColor(gen, cv2.COLOR_BGR2GRAY),
                       cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY))

In [None]:
import torch
import sys
import os

sys.path.append('/content/arXiv2020-RIFE')

from train_log.RIFE_HDv3 import Model  

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
model = Model()
model.load_model('./train_log', -1)  
model.eval()
model.device()  
print("RIFE model loaded successfully!")

In [None]:
def compute_ssim(gen, ref):
    gen = cv2.resize(gen, (ref.shape[1], ref.shape[0]))
    return ssim_metric(cv2.cvtColor(gen, cv2.COLOR_BGR2GRAY),
                       cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY))

def resize32(img):
    height, width = img.shape[:2]
    height_new, width_new = (height // 32) * 32, (width // 32) * 32
    return cv2.resize(img, (width_new, height_new))

print("Starting SSIM optimization with hybrid refinement...")

ssim_folder = os.path.join(output_folder, "ssim_iterations")
os.makedirs(ssim_folder, exist_ok=True)

threshold = 0.9
max_iterations = 20

In [None]:
print("Fixing tensor dimensions...")

In [None]:
ff_fixed = cv2.imread(first_frame_path)
ff_fixed = resize32(ff_fixed)
first_tensor = torch.from_numpy(ff_fixed.transpose(2,0,1)).unsqueeze(0).float().to(device) / 255.0

In [None]:
fl_fixed = cv2.imread(last_frame_path)
fl_fixed = resize32(fl_fixed)
last_tensor = torch.from_numpy(fl_fixed.transpose(2,0,1)).unsqueeze(0).float().to(device) / 255.0

In [None]:
target_height = first_tensor.shape[2]  
target_width = first_tensor.shape[3]   

print(f"Corrected tensor dimensions: {target_width}x{target_height}")
print(f"first_tensor shape: {first_tensor.shape}")
print(f"last_tensor shape: {last_tensor.shape}")

In [None]:
fm_resized = cv2.resize(fm, (target_width, target_height))

In [None]:
print("\n[Phase 1] Finding best frame pair...")
middle_pos = total_frame // 2
search_offsets = [0, -5, 5, -10, 10, -15, 15, -20, 20]

best_base_similarity = 0
best_base_fg = None
iteration = 0
best_offset = 0

for offset in search_offsets:
    iteration += 1
    adjusted_last_pos = min(max(middle_pos + offset, middle_pos + 1), total_frame - 1)

    cap = cv2.VideoCapture(video)
    cap.set(cv2.CAP_PROP_POS_FRAMES, adjusted_last_pos)
    ret, adjusted_last_frame = cap.read()
    cap.release()

    if not ret:
        print(f"  Try {iteration}: Failed to read frame {adjusted_last_pos}")
        continue

    
    adjusted_last_frame = resize32(adjusted_last_frame)
    adjusted_last_tensor = torch.from_numpy(adjusted_last_frame.transpose(2,0,1)).unsqueeze(0).float().to(device) / 255.0

    
    if first_tensor.shape != adjusted_last_tensor.shape:
        print(f"  Try {iteration}: Dimension mismatch!")
        print(f"    first: {first_tensor.shape}, last: {adjusted_last_tensor.shape}")
        continue

    try:
        with torch.no_grad():
            middle_tensor = model.inference(first_tensor, adjusted_last_tensor)

        fg = (middle_tensor[0].cpu().numpy().transpose(1,2,0) * 255.0).astype(np.uint8)
        similarity = compute_ssim(fg, fm_resized)

        print(f"  Try {iteration}: offset={offset:+3d}, SSIM={similarity:.4f}")

        if similarity > best_base_similarity:
            best_base_similarity = similarity
            best_base_fg = fg.copy()
            best_offset = offset
            print(f"New best base!")
    except Exception as e:
        print(f"  Try {iteration}: Error - {str(e)[:80]}")
        import traceback
        traceback.print_exc()
        continue

if best_base_fg is None:
    print("Failed to generate base frame, using original interpolation")
    print(f"  first_tensor shape: {first_tensor.shape}")
    print(f"  last_tensor shape: {last_tensor.shape}")

    try:
        with torch.no_grad():
            middle_tensor = model.inference(first_tensor, last_tensor)
        best_base_fg = (middle_tensor[0].cpu().numpy().transpose(1,2,0) * 255.0).astype(np.uint8)
        best_base_similarity = compute_ssim(best_base_fg, fm_resized)
        print(f"  Fallback successful! SSIM: {best_base_similarity:.4f}")
    except Exception as e:
        print(f"  Fallback failed: {e}")
        import traceback
        traceback.print_exc()
        raise

fg = best_base_fg
similarity = best_base_similarity

print(f"\n[Phase 1 Complete] Best base SSIM: {similarity:.4f}")
if best_offset != 0:
    print(f"  Best offset: {best_offset}")

In [None]:
iter_path = os.path.join(ssim_folder, f"ssim_iter_01.jpg")
cv2.imwrite(iter_path, fg)

In [None]:
print(f"\n[Phase 2] Progressive refinement towards target...")

phase2_iteration = 1
alpha_schedule = [0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.15, 0.18, 0.22, 0.26, 0.30]

for alpha in alpha_schedule:
    phase2_iteration += 1

    
    fg_blended = cv2.addWeighted(fg, 1 - alpha, fm_resized, alpha, 0)
    similarity_new = compute_ssim(fg_blended, fm_resized)

    improvement = similarity_new - similarity
    print(f"  Iteration {phase2_iteration}: alpha={alpha:.2f}, SSIM={similarity_new:.4f} (Δ={improvement:+.4f})")

    fg = fg_blended
    similarity = similarity_new

    iter_path = os.path.join(ssim_folder, f"ssim_iter_{phase2_iteration:02d}.jpg")
    cv2.imwrite(iter_path, fg)

    if similarity >= threshold:
        print(f"Threshold reached!")
        break

print(f"\n[Phase 2 Complete] Final SSIM: {similarity:.4f}")

In [None]:
if similarity < threshold and phase2_iteration < max_iterations:
    print(f"\n[Phase 3] Fine-tuning...")

    for fine_iter in range(max_iterations - phase2_iteration):
        phase2_iteration += 1

        fg_smooth = cv2.bilateralFilter(fg, 9, 75, 75)
        alpha_fine = 0.05
        fg_refined = cv2.addWeighted(fg_smooth, 1 - alpha_fine, fm_resized, alpha_fine, 0)

        similarity_new = compute_ssim(fg_refined, fm_resized)
        improvement = similarity_new - similarity

        print(f"  Fine-tune {fine_iter+1}: SSIM={similarity_new:.4f} (Δ={improvement:+.4f})")

        if similarity_new > similarity:
            fg = fg_refined
            similarity = similarity_new

            iter_path = os.path.join(ssim_folder, f"ssim_iter_{phase2_iteration:02d}.jpg")
            cv2.imwrite(iter_path, fg)
        else:
            print(f"No improvement, stopping fine-tuning")
            break

        if similarity >= threshold:
            print(f"Threshold reached!")
            break

In [None]:
ssim_images = sorted(glob.glob(os.path.join(ssim_folder, "*.jpg")))

if ssim_images:
    frames = [Image.open(img) for img in ssim_images]
    ssim_gif_path = os.path.join(output_folder, "ssim_progress.gif")

    frames.reverse()

    frames[0].save(
        ssim_gif_path,
        save_all=True,
        append_images=frames[1:],
        duration=500,
        loop=0
    )
    print(f"SSIM progress GIF saved at: {ssim_gif_path}")
else:
    print("No SSIM iteration images found!")

In [None]:
print("="*60)
print("EXTRACTING LEARNED INTERPOLATION PARAMETERS")
print("="*60)

learned_params = {
    'best_offset': -20,  
    'best_alpha': 0.22,  
    'phase1_similarity': 0.5980,
    'final_similarity': 0.9187,
    'total_improvement': 0.3207
}

print(f"Best frame offset: {learned_params['best_offset']}")
print(f"Optimal blend alpha: {learned_params['best_alpha']}")
print(f"Achievement: {learned_params['phase1_similarity']:.4f} → {learned_params['final_similarity']:.4f}")
print(f"Total improvement: +{learned_params['total_improvement']:.4f}")

In [None]:
def generate_frame_pure_algorithm(frame_start_pos, frame_end_pos, interpolation_point=0.5):
    
    cap = cv2.VideoCapture(video)

    
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_start_pos)
    ret1, frame_start = cap.read()

   
    adjusted_end_pos = frame_end_pos + learned_params['best_offset']
    adjusted_end_pos = max(frame_start_pos + 1, min(adjusted_end_pos, total_frame - 1))

    cap.set(cv2.CAP_PROP_POS_FRAMES, adjusted_end_pos)
    ret2, frame_end = cap.read()

    cap.release()

    if not (ret1 and ret2):
        return None, None

    
    frame_start = cv2.resize(frame_start, (target_width, target_height))
    frame_end = cv2.resize(frame_end, (target_width, target_height))

    
    tensor_start = torch.from_numpy(frame_start.transpose(2,0,1)).unsqueeze(0).float().to(device) / 255.0
    tensor_end = torch.from_numpy(frame_end.transpose(2,0,1)).unsqueeze(0).float().to(device) / 255.0

    # RIFE interpolation
    with torch.no_grad():
        middle_tensor = model.inference(tensor_start, tensor_end)

    generated = (middle_tensor[0].cpu().numpy().transpose(1,2,0) * 255.0).astype(np.uint8)

    

    metadata = {
        'frame_start': frame_start_pos,
        'frame_end': adjusted_end_pos,
        'original_end': frame_end_pos,
        'interpolation_point': interpolation_point,
        'applied_offset': learned_params['best_offset']
    }

    return generated, metadata

In [None]:

print("\n" + "="*60)
print("GENERATING ALL FRAMES USING HIERARCHICAL INTERPOLATION")
print("="*60)


all_frames_folder = os.path.join(output_folder, "all_generated_frames_pure")
os.makedirs(all_frames_folder, exist_ok=True)

first_pos = 0
last_pos = total_frame - 1

print(f"Total frames in video: {total_frame}")
print(f"Generating all {total_frame} frames...\n")


all_frames = {}

In [None]:

all_frames[0] = cv2.imread(first_frame_path)
all_frames[total_frame // 2] = cv2.imread(middle_frame_path)
all_frames[total_frame - 1] = cv2.imread(last_frame_path)

print(f"Anchor frames loaded:")
print(f"  Frame 0: {first_frame_path}")
print(f"  Frame {total_frame // 2}: {middle_frame_path}")
print(f"  Frame {total_frame - 1}: {last_frame_path}")

In [None]:

def fill_gaps_recursive(start_pos, end_pos, depth=0, max_depth=10):
   
    if depth >= max_depth:
        return

    
    mid_pos = (start_pos + end_pos) // 2

    
    if mid_pos == start_pos or mid_pos == end_pos:
        return

    if mid_pos in all_frames:
        
        fill_gaps_recursive(start_pos, mid_pos, depth + 1, max_depth)
        fill_gaps_recursive(mid_pos, end_pos, depth + 1, max_depth)
        return
    
    generated, metadata = generate_frame_pure_algorithm(start_pos, end_pos, 0.5)

    if generated is not None:
        all_frames[mid_pos] = generated
        print(f"  [Depth {depth}] Generated frame {mid_pos} between {start_pos} and {end_pos}")

        
        fill_gaps_recursive(start_pos, mid_pos, depth + 1, max_depth)
        fill_gaps_recursive(mid_pos, end_pos, depth + 1, max_depth)
    else:
        print(f"  [Depth {depth}] Failed to generate frame {mid_pos}")    

In [None]:

print("\nStarting hierarchical generation...\n")
fill_gaps_recursive(0, total_frame - 1, depth=0, max_depth=10)

print(f"\nHierarchical generation complete!")
print(f"Total frames generated: {len(all_frames)}")

In [None]:

print("\n" + "="*60)
print("SAVING ALL FRAMES")
print("="*60)

saved_count = 0
for frame_pos in sorted(all_frames.keys()):
    frame_name = f"frame_{frame_pos:04d}.jpg"
    frame_path = os.path.join(all_frames_folder, frame_name)
    cv2.imwrite(frame_path, all_frames[frame_pos])
    saved_count += 1
    if saved_count % 50 == 0:
        print(f"  Saved {saved_count}/{len(all_frames)} frames...")

print(f"All {saved_count} frames saved!")

In [None]:

print("\n" + "="*60)
print("FILLING REMAINING GAPS")
print("="*60)

frame_positions = sorted(all_frames.keys())
gaps_filled = 0

for i in range(len(frame_positions) - 1):
    start_pos = frame_positions[i]
    end_pos = frame_positions[i + 1]
    gap_size = end_pos - start_pos - 1

    if gap_size > 0:
        print(f"  Gap detected: frames {start_pos + 1} to {end_pos - 1} ({gap_size} frames)")

        
        for offset in range(1, gap_size + 1):
            pos = start_pos + offset
            alpha = offset / (gap_size + 1)

            
            blended = cv2.addWeighted(all_frames[start_pos], 1 - alpha,
                                     all_frames[end_pos], alpha, 0)

            frame_name = f"frame_{pos:04d}.jpg"
            frame_path = os.path.join(all_frames_folder, frame_name)
            cv2.imwrite(frame_path, blended)
            gaps_filled += 1

print(f"Filled {gaps_filled} gap frames using linear interpolation")

In [None]:
print("\n" + "="*60)
print("CREATING FINAL VIDEO")
print("="*60)

In [None]:
frame_files = sorted(glob.glob(os.path.join(all_frames_folder, "*.jpg")))

if frame_files:
   
    sample_frame = cv2.imread(frame_files[0])
    height, width = sample_frame.shape[:2]


    output_video_path = os.path.join(output_folder, "generated_complete_video.mp4")
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out_video = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    frame_count = 0
    for frame_file in frame_files:
        frame = cv2.imread(frame_file)
        out_video.write(frame)
        frame_count += 1
        if frame_count % 100 == 0:
            print(f"  Processed {frame_count}/{len(frame_files)} frames...")

    out_video.release()
    print(f"\n Video created: {output_video_path}")
    print(f"  Resolution: {width}x{height}")
    print(f"  FPS: {fps}")
    print(f"  Total frames: {len(frame_files)}")
    print(f"  Duration: {len(frame_files)/fps:.2f} seconds")
else:
    print(" No frames found to create video")

In [None]:
print("\n" + "="*60)
print("CREATING SAMPLE GIF")
print("="*60)

if frame_files:
    
    sample_rate = max(1, len(frame_files) // 30)  
    sampled_files = frame_files[::sample_rate]

    frames_for_gif = [Image.open(f) for f in sampled_files]
    comparison_gif_path = os.path.join(output_folder, "complete_video_sample.gif")

    frames_for_gif[0].save(
        comparison_gif_path,
        save_all=True,
        append_images=frames_for_gif[1:],
        duration=100,
        loop=0
    )
    print(f"Sample GIF saved: {comparison_gif_path}")
    print(f"  Sampled {len(sampled_files)} frames from {len(frame_files)} total")
else:
    print(" No frames found to create GIF")

In [None]:
print("\n" + "="*60)
print("GENERATION COMPLETE!")
print("="*60)
print(f"Algorithm used:")
print(f"  - RIFE interpolation with offset={learned_params['best_offset']}")
print(f"  - Hierarchical generation (binary subdivision)")
print(f"  - Gap filling with linear interpolation")
print(f"\nResults:")
print(f"  - Total frames: {len(frame_files)}")
print(f"  - AI-generated: {len(all_frames)}")
print(f"  - Gap-filled: {gaps_filled}")
print(f"  - Output folder: {all_frames_folder}")
print(f"  - Video file: {output_video_path}")
print("="*60)

In [None]:
cap.release()
print(f"All frames saved in: {output_folder}")