In [None]:
#This program will use the weights from the trainer to 'live' process the videos

In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import os
import cv2
from PIL import Image
from tqdm import tqdm # Keep tqdm for progress
import time
from collections import deque

# --- Configuration (Keep as is, or adjust as needed) ---
VIDEO_OUTPUT_CONFIG = {
    "INPUT_VIDEO_PATH": "2023-08-01_152123_VID003.mp4",
    "OUTPUT_VIDEO_PATH": "output_demo_video3.mp4", 
    "MODEL_PATH": "specular_removal_unet_perceptual_emphasis.pth",
    "TARGET_IMG_SIZE": (512, 512), 
    "OUTPUT_RESOLUTION_SCALE": 0.75, 
    "FPS_AVERAGE_WINDOW": 30,
    "USE_NORMALIZATION": False,
    "NORM_MEAN": [0.485, 0.456, 0.406],
    "NORM_STD": [0.229, 0.224, 0.225],
    "CONCAT_AXIS": 1, 
    "TEXT_COLOR": (0, 255, 0), 
    "FONT_SCALE": 0.7,
    "FONT_THICKNESS": 2,
}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- U-Net Model Definition (Paste your UNet, DoubleConv, Down, Up, OutConv classes here) ---
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels: mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
    def forward(self, x): return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY, diffX = x2.size()[2] - x1.size()[2], x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x): return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels_in, n_channels_out, bilinear=True, base_features=64):
        super(UNet, self).__init__()
        self.inc = DoubleConv(n_channels_in, base_features)
        self.down1 = Down(base_features, base_features * 2)
        self.down2 = Down(base_features * 2, base_features * 4)
        self.down3 = Down(base_features * 4, base_features * 8)
        factor = 2 if bilinear else 1
        self.down4 = Down(base_features * 8, base_features * 16 // factor)
        self.up1 = Up(base_features * 16, base_features * 8 // factor, bilinear)
        self.up2 = Up(base_features * 8, base_features * 4 // factor, bilinear)
        self.up3 = Up(base_features * 4, base_features * 2 // factor, bilinear)
        self.up4 = Up(base_features * 2, base_features, bilinear)
        self.outc = OutConv(base_features, n_channels_out)
    def forward(self, x):
        x1 = self.inc(x); x2 = self.down1(x1); x3 = self.down2(x2); x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4); x = self.up2(x, x3); x = self.up3(x, x2); x = self.up4(x, x1)
        return torch.sigmoid(self.outc(x))

# --- Helper Functions ---
def load_model(model_path, n_channels_in=3, n_channels_out=3, base_features=64):
    model = UNet(n_channels_in=n_channels_in, n_channels_out=n_channels_out, base_features=base_features)
    try:
        model.load_state_dict(torch.load(model_path, map_location=DEVICE))
        model.to(DEVICE)
        model.eval()
        print(f"Model loaded from {model_path} and set to evaluation mode on {DEVICE}.")
        return model
    except FileNotFoundError:
        print(f"Error: Model file not found at {model_path}")
        return None
    except Exception as e:
        print(f"Error loading model: {e}")
        return None


def preprocess_frame(frame_bgr, target_size, use_normalization=False, norm_mean=None, norm_std=None):
    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    pil_img = Image.fromarray(frame_rgb)
    transform_list = [transforms.Resize(target_size), transforms.ToTensor()]
    if use_normalization:
        if norm_mean is None or norm_std is None:
            raise ValueError("Normalization mean and std must be provided if use_normalization is True.")
        transform_list.append(transforms.Normalize(mean=norm_mean, std=norm_std))
    preprocess = transforms.Compose(transform_list)
    return preprocess(pil_img).unsqueeze(0)

def postprocess_output(tensor_output, use_normalization=False, norm_mean=None, norm_std=None):
    img_tensor = tensor_output.squeeze(0).cpu().permute(1, 2, 0)
    if use_normalization:
        if norm_mean is None or norm_std is None:
            raise ValueError("Normalization mean and std must be provided for postprocessing.")
        mean = torch.tensor(norm_mean).view(1, 1, -1); std = torch.tensor(norm_std).view(1, 1, -1)
        img_tensor = img_tensor * std + mean
    img_np = (img_tensor.numpy() * 255).clip(0, 255).astype(np.uint8)
    return cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

# --- Video Processing Function (Revised Loop and Cleanup) ---
def process_video_to_file(config):
    model = load_model(config["MODEL_PATH"])
    if model is None: return # Exit if model loading failed

    cap = cv2.VideoCapture(config["INPUT_VIDEO_PATH"])
    if not cap.isOpened():
        print(f"Error: Could not open video {config['INPUT_VIDEO_PATH']}")
        return

    out_video = None # Initialize to None

    try:
        source_fps = cap.get(cv2.CAP_PROP_FPS)
        source_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        source_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames_metadata = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Metadata count

        display_w = int(source_w * config["OUTPUT_RESOLUTION_SCALE"])
        display_h = int(source_h * config["OUTPUT_RESOLUTION_SCALE"])

        if config["CONCAT_AXIS"] == 1: out_w, out_h = display_w * 2, display_h
        else: out_w, out_h = display_w, display_h * 2
        
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out_video = cv2.VideoWriter(config["OUTPUT_VIDEO_PATH"], fourcc, source_fps, (out_w, out_h))
        if not out_video.isOpened():
            print(f"Error: Could not open VideoWriter for {config['OUTPUT_VIDEO_PATH']}")
            return # No cap.release() here as it's in finally

        print(f"Processing video: {config['INPUT_VIDEO_PATH']}")
        print(f"Source: {source_w}x{source_h} @ {source_fps:.2f} FPS (Metadata total frames: {total_frames_metadata})")
        print(f"Outputting combined video to: {config['OUTPUT_VIDEO_PATH']} ({out_w}x{out_h})")

        fps_deque = deque(maxlen=config["FPS_AVERAGE_WINDOW"])
        
        # Use tqdm without specifying total initially, update manually
        pbar = tqdm(desc="Processing Video")
        frame_count_processed = 0

        while True: # More robust loop
            ret, original_frame_full_res = cap.read()
            if not ret:
                pbar.set_description("End of video or error reading frame.")
                break # Exit loop if no frame or error

            process_start_time = time.time()
            input_tensor = preprocess_frame(
                original_frame_full_res, config["TARGET_IMG_SIZE"],
                config["USE_NORMALIZATION"], config["NORM_MEAN"], config["NORM_STD"]
            ).to(DEVICE)

            with torch.no_grad():
                output_tensor = model(input_tensor)
            
            processed_frame_model_size = postprocess_output(
                output_tensor, config["USE_NORMALIZATION"],
                config["NORM_MEAN"], config["NORM_STD"]
            )
            process_time = time.time() - process_start_time
            
            frame_display_orig = cv2.resize(original_frame_full_res, (display_w, display_h))
            frame_display_processed = cv2.resize(processed_frame_model_size, (display_w, display_h))

            if config["CONCAT_AXIS"] == 1:
                combined_frame = np.concatenate((frame_display_orig, frame_display_processed), axis=1)
            else:
                combined_frame = np.concatenate((frame_display_orig, frame_display_processed), axis=0)

            current_fps = 1.0 / process_time if process_time > 0 else float('inf')
            fps_deque.append(current_fps)
            avg_fps = np.mean(fps_deque) if fps_deque else 0

            fps_text = f"FPS: {current_fps:.1f} (Avg: {avg_fps:.1f})"
            infer_text = f"Infer: {process_time*1000:.1f}ms"
            cv2.putText(combined_frame, fps_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                        config["FONT_SCALE"], config["TEXT_COLOR"], config["FONT_THICKNESS"], cv2.LINE_AA)
            cv2.putText(combined_frame, infer_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 
                        config["FONT_SCALE"], config["TEXT_COLOR"], config["FONT_THICKNESS"], cv2.LINE_AA)
            
            out_video.write(combined_frame)
            
            frame_count_processed += 1
            pbar.update(1) # Manually update tqdm progress
            pbar.set_postfix_str(f"Frames: {frame_count_processed}, CurFPS: {current_fps:.1f}")


    except Exception as e:
        print(f"An error occurred during processing: {e}")
    finally:
        pbar.close() # Ensure tqdm progress bar is closed
        if cap.isOpened():
            cap.release()
            print("Input video capture released.")
        if out_video is not None and out_video.isOpened(): # Check if out_video was successfully opened
            out_video.release()
            print(f"Output video writer released. Video saved to: {config['OUTPUT_VIDEO_PATH']}")
        cv2.destroyAllWindows() # Good practice
        print("Video processing finished.")


if __name__ == "__main__":
    if not os.path.exists(VIDEO_OUTPUT_CONFIG["MODEL_PATH"]):
        print(f"Error: Model file not found at {VIDEO_OUTPUT_CONFIG['MODEL_PATH']}")
    elif not os.path.exists(VIDEO_OUTPUT_CONFIG["INPUT_VIDEO_PATH"]):
        print(f"Error: Input video file not found at {VIDEO_OUTPUT_CONFIG['INPUT_VIDEO_PATH']}")
    else:
        process_video_to_file(VIDEO_OUTPUT_CONFIG)

Using device: cuda
Model loaded from specular_removal_unet_perceptual_emphasis.pth and set to evaluation mode on cuda.
Processing video: 2023-08-01_152123_VID003.mp4
Source: 512x512 @ 29.97 FPS (Metadata total frames: 19454)
Outputting combined video to: output_demo_video3.mp4 (768x384)


End of video or error reading frame.: : 19454it [12:43, 25.49it/s, Frames: 19454, CurFPS: 32.5]

Input video capture released.
Output video writer released. Video saved to: output_demo_video3.mp4
Video processing finished.



