# Movie Recap Pipeline - Real-ESRGAN 4K Upscaling

GPU-accelerated 4K upscaling with chunked processing and Drive integration.

In [None]:
# Install dependencies
!pip install -q torch torchvision opencv-python-headless Pillow requests tqdm
!git clone https://github.com/xinntao/Real-ESRGAN.git
%cd Real-ESRGAN
!pip install -q basicsr facexlib gfpgan .
!wget -q https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P weights/

In [None]:
# Mount Drive and setup paths
from google.colab import drive
drive.mount('/content/drive')

import os, json, time, shutil, tempfile, subprocess, requests
from pathlib import Path
import torch, cv2, numpy as np
from tqdm import tqdm
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet

DRIVE_ROOT = '/content/drive/MyDrive/movie-recap-pipeline'
UPSCALE_INPUTS = f'{DRIVE_ROOT}/upscale_inputs'
UPSCALE_OUTPUTS = f'{DRIVE_ROOT}/upscale_outputs'
CHECKPOINTS_DIR = f'{DRIVE_ROOT}/checkpoints'

for path in [UPSCALE_INPUTS, UPSCALE_OUTPUTS, CHECKPOINTS_DIR]:
    os.makedirs(path, exist_ok=True)

print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

In [None]:
class VideoUpscaler:
    def __init__(self, model_path='weights/RealESRGAN_x4plus.pth', scale=4):
        self.scale = scale
        self.model_path = model_path
        self.upsampler = None
        
    def initialize_model(self):
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.scale)
        self.upsampler = RealESRGANer(
            scale=self.scale, model_path=self.model_path, model=model,
            tile=400, tile_pad=10, pre_pad=0, half=True,
            gpu_id=0 if torch.cuda.is_available() else None
        )
        print("Model loaded")
        
    def extract_frames(self, video_path, start_frame=0, chunk_size=50):
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
        
        frames = []
        for _ in range(chunk_size):
            ret, frame = cap.read()
            if not ret: break
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        cap.release()
        return frames, total_frames, fps
        
    def upscale_frames(self, frames):
        upscaled = []
        for i, frame in enumerate(tqdm(frames, desc="Upscaling")):
            try:
                result, _ = self.upsampler.enhance(frame, outscale=self.scale)
                upscaled.append(result)
                if i % 10 == 0: torch.cuda.empty_cache()
            except Exception as e:
                print(f"Frame {i} error: {e}")
                upscaled.append(frame)
        return upscaled
        
    def process_video_chunked(self, input_path, output_path, job_id, webhook_url=None):
        if not self.upsampler: self.initialize_model()
        
        checkpoint_path = f"{CHECKPOINTS_DIR}/checkpoint_{job_id}.json"
        start_frame = 0
        processed_chunks = []
        
        if os.path.exists(checkpoint_path):
            with open(checkpoint_path, 'r') as f:
                data = json.load(f)
                start_frame = data.get('last_frame', 0)
                processed_chunks = data.get('processed_chunks', [])
        
        temp_dir = tempfile.mkdtemp()
        
        try:
            cap = cv2.VideoCapture(input_path)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            fps = cap.get(cv2.CAP_PROP_FPS)
            cap.release()
            
            chunk_num = len(processed_chunks)
            current_frame = start_frame
            
            while current_frame < total_frames:
                frames, _, _ = self.extract_frames(input_path, current_frame, 50)
                if not frames: break
                
                upscaled = self.upscale_frames(frames)
                
                # Save chunk
                chunk_path = f"{temp_dir}/chunk_{chunk_num:04d}.mp4"
                height, width = upscaled[0].shape[:2]
                out = cv2.VideoWriter(chunk_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
                for frame in upscaled:
                    out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
                out.release()
                
                processed_chunks.append(chunk_path)
                current_frame += len(frames)
                
                # Save checkpoint
                with open(checkpoint_path, 'w') as f:
                    json.dump({
                        'job_id': job_id, 'last_frame': current_frame,
                        'processed_chunks': processed_chunks, 'total_frames': total_frames,
                        'fps': fps, 'timestamp': time.time()
                    }, f)
                
                progress = (current_frame / total_frames) * 100
                print(f"Progress: {progress:.1f}%")
                
                if webhook_url:
                    try:
                        requests.post(webhook_url, json={'job_id': job_id, 'progress': progress}, timeout=10)
                    except: pass
                
                chunk_num += 1
                torch.cuda.empty_cache()
            
            # Concatenate chunks
            filelist = f"{temp_dir}/filelist.txt"
            with open(filelist, 'w') as f:
                for chunk in processed_chunks:
                    f.write(f"file '{chunk}'\n")
            
            cmd = ['ffmpeg', '-f', 'concat', '-safe', '0', '-i', filelist, '-c', 'copy', '-y', output_path]
            subprocess.run(cmd, capture_output=True)
            
            if os.path.exists(checkpoint_path): os.remove(checkpoint_path)
            
            return {'status': 'success', 'output_path': output_path, 'total_frames': total_frames}
            
        except Exception as e:
            return {'status': 'error', 'error': str(e)}
        finally:
            shutil.rmtree(temp_dir, ignore_errors=True)

upscaler = VideoUpscaler()
print("VideoUpscaler ready")

In [None]:
def process_pending_jobs():
    job_files = [f for f in os.listdir(UPSCALE_INPUTS) if f.endswith('.json')]
    if not job_files:
        print("No jobs found")
        return
    
    for job_file in job_files:
        job_path = f"{UPSCALE_INPUTS}/{job_file}"
        try:
            with open(job_path, 'r') as f:
                job = json.load(f)
            
            job_id = job['job_id']
            video_path = job['video_path']
            webhook_url = job.get('webhook_url')
            
            print(f"Processing {job_id}")
            
            output_dir = f"{UPSCALE_OUTPUTS}/{job_id}"
            os.makedirs(output_dir, exist_ok=True)
            output_path = f"{output_dir}/upscaled_video.mp4"
            
            result = upscaler.process_video_chunked(video_path, output_path, job_id, webhook_url)
            
            with open(f"{output_dir}/result.json", 'w') as f:
                json.dump(result, f, indent=2)
            
            if webhook_url and result['status'] == 'success':
                completion_data = {
                    'job_id': job_id, 'status': 'completed',
                    'upscaled_video_path': output_path, 'timestamp': time.time()
                }
                try:
                    requests.post(webhook_url, json=completion_data, timeout=30)
                    print(f"Notified completion for {job_id}")
                except Exception as e:
                    print(f"Notification failed: {e}")
            
            # Move to completed
            completed_dir = f"{DRIVE_ROOT}/completed_jobs"
            os.makedirs(completed_dir, exist_ok=True)
            shutil.move(job_path, f"{completed_dir}/{job_file}")
            print(f"Job {job_id} completed")
            
        except Exception as e:
            print(f"Error processing {job_file}: {e}")
            failed_dir = f"{DRIVE_ROOT}/failed_jobs"
            os.makedirs(failed_dir, exist_ok=True)
            shutil.move(job_path, f"{failed_dir}/{job_file}")

def monitor_jobs(check_interval=60, max_runtime=3600):
    print(f"Monitoring started (check every {check_interval}s, max {max_runtime}s)")
    start_time = time.time()
    
    while time.time() - start_time < max_runtime:
        try:
            process_pending_jobs()
            time.sleep(check_interval)
        except KeyboardInterrupt:
            break
        except Exception as e:
            print(f"Monitor error: {e}")
            time.sleep(check_interval)

print("Ready to monitor jobs")

In [None]:
# Start monitoring
monitor_jobs(check_interval=60, max_runtime=3600)