In [27]:
import os
import cv2
import numpy as np
from moviepy.editor import VideoFileClip
import torch
from torch.nn import functional as F
import glob
from tqdm import tqdm
import requests
from pathlib import Path
import gdown

In [30]:
class VideoEnhancer:
    def __init__(self, input_dir, output_dir, target_resolution=(1080, 1920), target_fps=15):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.target_resolution = target_resolution
        self.target_fps = target_fps
        self.device = 'mps' if torch.backends.mps.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Initialize frame interpolation
        try:
            # Ensure torch-rife is installed
            import torch_rife
            self.rife_model = torch_rife.RIFE(model_version='4.6')  # Explicitly specify version
            print(f"CUDA available: {torch.cuda.is_available()}")
            print(f"MPS available: {torch.backends.mps.is_available()}")
            print(f"Using device: {self.device}")
            
            if self.device == 'mps':
                # For M1/M2 Macs
                self.rife_model.to('cpu')  # Use CPU for more stable operation on M1/M2
                print("Using CPU for frame interpolation on Apple Silicon")
            else:
                self.rife_model.to(self.device)
                print(f"RIFE model moved to {self.device}")
            
            # Test the model with a small dummy input
            with torch.no_grad():
                dummy1 = torch.zeros(1, 3, 64, 64).to(self.device if self.device != 'mps' else 'cpu')
                dummy2 = torch.zeros(1, 3, 64, 64).to(self.device if self.device != 'mps' else 'cpu')
                _ = self.rife_model.inference(dummy1, dummy2)
                print("Successfully tested RIFE model")
                
        except ImportError:
            print("torch-rife not found. Installing...")
            import subprocess
            try:
                subprocess.check_call(['pip', 'install', 'torch-rife'])
                import torch_rife
                self.rife_model = torch_rife.RIFE(model_version='4.6')
                print("Successfully installed and loaded torch-rife")
            except Exception as e:
                print(f"Error installing torch-rife: {e}")
                self.rife_model = None
                
        except Exception as e:
            print(f"Error initializing frame interpolation: {e}")
            print("Detailed error information:")
            import traceback
            traceback.print_exc()
            self.rife_model = None

    def enhance_frame(self, frame):
        """Upscale a single frame using Lanczos interpolation"""
        # Convert frame to BGR for OpenCV
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        
        # Upscale using Lanczos interpolation
        upscaled = cv2.resize(frame, 
                            (self.target_resolution[1], self.target_resolution[0]),
                            interpolation=cv2.INTER_LANCZOS4)
        
        # Convert back to RGB
        return cv2.cvtColor(upscaled, cv2.COLOR_BGR2RGB)

    def interpolate_frames(self, frame1, frame2):
        """Generate intermediate frame using RIFE or frame blending"""
        if self.rife_model is None:
            # Simple frame blending as fallback
            return cv2.addWeighted(frame1, 0.5, frame2, 0.5, 0)
            
        try:
            # Convert frames to tensors
            frame1 = torch.from_numpy(frame1).permute(2, 0, 1).float() / 255.0
            frame2 = torch.from_numpy(frame2).permute(2, 0, 1).float() / 255.0
            
            # Add batch dimension and move to device
            frame1 = frame1.unsqueeze(0)
            frame2 = frame2.unsqueeze(0)
            
            if self.device != 'cpu':
                frame1 = frame1.to(self.device)
                frame2 = frame2.to(self.device)
            
            # Generate intermediate frame
            with torch.no_grad():
                middle = self.rife_model.inference(frame1, frame2)
            
            # Convert back to numpy array
            if self.device != 'cpu':
                middle = middle.cpu()
            middle = (middle[0].numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
            return middle
        except Exception as e:
            print(f"Frame interpolation failed, falling back to frame blending: {e}")
            return cv2.addWeighted(frame1, 0.5, frame2, 0.5, 0)

    def process_video(self, input_path):
        """Process a single video file"""
        filename = os.path.basename(input_path)
        output_path = os.path.join(self.output_dir, f"{filename}")
        
        print(f"\nProcessing {filename}")
        
        # Load video
        clip = VideoFileClip(input_path)
        
        # Calculate frames needed for target fps
        original_duration = clip.duration
        original_frames = list(clip.iter_frames())
        
        enhanced_frames = []
        total_frames = len(original_frames)
        
        # Process frames with progress bar
        with tqdm(total=total_frames, desc="Processing frames") as pbar:
            # Enhance resolution of original frames
            for i, frame in enumerate(original_frames):
                enhanced = self.enhance_frame(frame)
                enhanced_frames.append(enhanced)
                
                # Interpolate frames if needed and RIFE is available
                if i < total_frames - 1 and self.rife_model is not None:
                    try:
                        next_frame = self.enhance_frame(original_frames[i + 1])
                        middle_frame = self.interpolate_frames(enhanced, next_frame)
                        if middle_frame is not None:
                            enhanced_frames.append(middle_frame)
                    except Exception as e:
                        print(f"\nFrame interpolation error at frame {i}: {e}")
                        print("Continuing without interpolation for this frame pair")
                
                pbar.update(1)
                
        print(f"\nTotal frames after processing: {len(enhanced_frames)}")
        print(f"Final FPS will be: {len(enhanced_frames)/original_duration:.2f}")
        
        # Create new video clip
        from moviepy.editor import ImageSequenceClip
        enhanced_clip = ImageSequenceClip(enhanced_frames, fps=self.target_fps)
        
        print(f"Writing enhanced video to {output_path}")
        
        # Write output video
        enhanced_clip.write_videofile(output_path, 
                                    codec='libx264', 
                                    audio=False,  # Remove if you want to keep audio
                                    threads=8,    # Utilize multiple CPU cores
                                    preset='medium',
                                    verbose=False,
                                    logger=None)
        
        # Clean up
        clip.close()
        enhanced_clip.close()

    def process_directory(self):
        """Process all videos in the input directory"""
        video_files = glob.glob(os.path.join(self.input_dir, "*.[mM][pP]4"))
        video_files.extend(glob.glob(os.path.join(self.input_dir, "*.[mM][oO][vV]")))
        
        print(f"Found {len(video_files)} videos to process")
        for video_file in video_files:
            try:
                self.process_video(video_file)
                print(f"Successfully processed {video_file}")
            except Exception as e:
                print(f"Error processing {video_file}: {e}")

In [31]:
input_dir = os.path.join(os.path.expanduser("~"), "Library", "CloudStorage", 
                              "OneDrive-UniversityofExeter", "Documents", "VISIONARY", 
                              "Durham Experiment", "test_data")
output_dir = os.path.join(os.path.expanduser("~"), "Library", "CloudStorage", 
                              "OneDrive-UniversityofExeter", "Documents", "VISIONARY", 
                              "Durham Experiment", "enh_data")

target_resolution = (1080, 1920)  # Height, Width
target_fps = 15

# Initialize and run enhancer
enhancer = VideoEnhancer(input_dir, output_dir, target_resolution, target_fps)
enhancer.process_directory()

Using device: mps
Error initializing frame interpolation: No module named 'torch_rife'
Frame interpolation will be disabled - using simple frame duplication
To enable frame interpolation, install torch-rife: pip install torch-rife
Found 2 videos to process

Processing Camera_2_20241101.mp4


Processing frames:  61%|██████    | 2187/3602 [00:57<00:38, 37.08it/s]

: 