# Demo for video enhancement

In [None]:
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from moviepy.editor import VideoFileClip
from basicsr.archs.rrdbnet_arch import RRDBNet
import warnings
warnings.filterwarnings("ignore")

In [17]:
class VideoEnhancer:
    def __init__(self, batch_size=1):
        # Set CUDA memory management
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            # Set memory allocator settings to reduce fragmentation
            os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.batch_size = batch_size
        print(f"Using device: {self.device}")
        self.setup_models()

    def setup_models(self):
        try:
            print("Initializing Real-ESRGAN model...")
            # Initialize model architecture
            self.sr_model = RRDBNet(
                num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32)

            print("Downloading model weights...")
            model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
            model_weights = torch.hub.load_state_dict_from_url(
                model_url,
                progress=True,
                map_location=self.device
            )

            if 'params_ema' in model_weights:
                model_weights = model_weights['params_ema']
            elif 'params' in model_weights:
                model_weights = model_weights['params']

            self.sr_model.load_state_dict(model_weights)
            self.sr_model.to(self.device)
            self.sr_model.eval()
            print("Model loaded successfully!")

        except Exception as e:
            print(f"Error during model setup: {str(e)}")
            raise

    def process_frame(self, frame):
        try:
            # Reduce input size to save memory
            h, w = frame.shape[:2]
            scale_factor = 0.5  # Reduce input size by half
            frame = cv2.resize(
                frame, (int(w * scale_factor), int(h * scale_factor)))

            # Convert to tensor
            frame_tensor = torch.from_numpy(frame).float().div(255.)
            if frame_tensor.shape[2] == 3:
                frame_tensor = frame_tensor.permute(2, 0, 1)
            frame_tensor = frame_tensor.unsqueeze(0).to(self.device)

            # Process
            with torch.no_grad():
                output_tensor = self.sr_model(frame_tensor)

            # Convert back to numpy
            output_frame = output_tensor.squeeze(
                0).permute(1, 2, 0).mul(255.).clamp(0, 255)
            output_frame = output_frame.cpu().numpy().astype(np.uint8)

            # Clear CUDA cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            return output_frame

        except Exception as e:
            print(f"Error processing frame: {str(e)}")
            return frame

    def enhance_video(self, input_path, output_path, target_fps=15):
        try:
            print(f"Processing video: {input_path}")
            clip = VideoFileClip(input_path)
            original_fps = clip.fps

            interpolation_factor = int(np.ceil(target_fps / original_fps))
            total_frames = int(clip.duration * clip.fps)

            # 720p resolution (1280x720)
            output_width = 1280
            output_height = 720

            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_path,
                                  fourcc,
                                  target_fps,
                                  (output_width, output_height))

            print(f"Original FPS: {original_fps}, Target FPS: {target_fps}")
            print(f"Total frames to process: {total_frames}")

            try:
                for frame_idx in tqdm(range(0, total_frames, self.batch_size)):
                    # Process batch of frames
                    current_time = frame_idx / clip.fps
                    frame = clip.get_frame(current_time)

                    # Process frame
                    enhanced_frame = self.process_frame(frame)

                    # Resize to 720p
                    enhanced_frame = cv2.resize(
                        enhanced_frame, (output_width, output_height))

                    # Write frames
                    for _ in range(interpolation_factor):
                        out.write(enhanced_frame)

                    # Garbage collection
                    if frame_idx % 100 == 0:
                        gc.collect()
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()

            finally:
                clip.close()
                out.release()
                print(f"Finished processing: {input_path}")

        except Exception as e:
            print(f"Error enhancing video: {str(e)}")
            raise

In [18]:
def process_directory(input_dir, output_dir):
    """Process all videos in a directory"""
    try:
        os.makedirs(output_dir, exist_ok=True)
        print(f"Processing directory: {input_dir}")
        print(f"Output directory: {output_dir}")

        enhancer = VideoEnhancer(batch_size=1)

        for filename in os.listdir(input_dir):
            if filename.lower().endswith(('.mp4', '.avi', '.mov')):
                input_path = os.path.join(input_dir, filename)
                output_path = os.path.join(output_dir, f'{filename}')

                print(f"\nProcessing {filename}...")
                try:
                    enhancer.enhance_video(input_path, output_path)
                    print(f"Successfully processed {filename}")
                except Exception as e:
                    print(f"Error processing {filename}: {str(e)}")
                    continue

                # Clear memory after each video
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error in process_directory: {str(e)}")
        raise

In [None]:
input_dir = os.path.join('C:\\Users', 'mc1159', 'OneDrive - University of Exeter',
                         'Documents', 'VISIONARY', 'Durham Experiment', 'test_data')
output_dir = os.path.join('C:\\Users', 'mc1159', 'OneDrive - University of Exeter',
                          'Documents', 'VISIONARY', 'Durham Experiment', 'enh_data')

process_directory(input_dir, output_dir)