In [None]:
import cv2
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from typing import Tuple, Optional

class VideoProcessor:
    def __init__(self, model, device, target_size: Tuple[int, int] = (256, 256)):
        self.model = model
        self.device = device
        self.target_size = target_size
        self.model.eval()

    def preprocess_frame(self, frame: np.ndarray, transform=None) -> torch.Tensor:

        frame_resized = cv2.resize(frame, self.target_size)
        
        frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
        
        if transform:
            try:
                augmented = transform(image=frame_rgb)
                frame_tensor = augmented['image'].unsqueeze(0).to(self.device)
            except Exception as e:
                print(f"Transform error: {e}")
                # Fallback preprocessing
                frame_tensor = self._basic_preprocessing(frame_rgb)
        else:
            frame_tensor = self._basic_preprocessing(frame_rgb)
            
        return frame_tensor

    def _basic_preprocessing(self, frame_rgb: np.ndarray) -> torch.Tensor:

        frame_tensor = torch.from_numpy(frame_rgb.transpose(2, 0, 1)).float()
        frame_tensor = frame_tensor.unsqueeze(0).to(self.device)
        frame_tensor = frame_tensor / 255.0
        return frame_tensor

    def process_frame(self, frame: np.ndarray, transform=None) -> Tuple[np.ndarray, np.ndarray]:
        original_size = frame.shape[:2][::-1]  # (width, height)
        
        frame_tensor = self.preprocess_frame(frame, transform)
                with torch.no_grad():
            try:
                output = self.model(frame_tensor)
                pred_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
                
                # Resize mask back to original size
                pred_mask_resized = cv2.resize(
                    pred_mask.astype(np.float32), 
                    original_size,
                    interpolation=cv2.INTER_NEAREST
                ).astype(np.int32)
                
                return pred_mask, pred_mask_resized
            except Exception as e:
                print(f"Prediction error: {e}")
                return None, None

    def create_colored_mask(self, pred_mask: np.ndarray) -> np.ndarray:

        height, width = pred_mask.shape
        colored_mask = np.zeros((height, width, 3), dtype=np.uint8)
        
        for class_id, (name, color) in IDD_CLASSES.items():
            colored_mask[pred_mask == class_id] = color
        
        return colored_mask

    def add_labels_to_mask(self, colored_mask: np.ndarray, pred_mask: np.ndarray) -> np.ndarray:
        mask_with_labels = colored_mask.copy()
        height, width = pred_mask.shape
        
        font_scale = min(width, height) / 1000.0
        min_font_scale = 0.3
        font_scale = max(font_scale, min_font_scale)
        
        unique_classes = np.unique(pred_mask)
        
        for class_id in unique_classes:
            if class_id == 13:  
                continue
                
            class_name = IDD_CLASSES[class_id][0]
            
            y_coords, x_coords = np.where(pred_mask == class_id)
            if len(y_coords) > 0:
                mask = (pred_mask == class_id).astype(np.uint8)
                num_labels, labels = cv2.connectedComponents(mask)
                
                for label in range(1, num_labels):
                    component_mask = (labels == label)
                    if np.sum(component_mask) > 100: \
                        y_comp, x_comp = np.where(component_mask)
                        center_y = int(np.mean(y_comp))
                        center_x = int(np.mean(x_comp))
                        
                        cv2.putText(
                            mask_with_labels, 
                            class_name,
                            (center_x, center_y),
                            cv2.FONT_HERSHEY_SIMPLEX,
                            font_scale,
                            (255, 255, 255),
                            max(1, int(font_scale * 2)),
                            cv2.LINE_AA
                        )
        
        return mask_with_labels

    def process_video(self, video_path: str, output_path: str, transform=None, overlay: bool = False):
        try:
            cap = cv2.VideoCapture(video_path)
            if not cap.isOpened():
                raise ValueError(f"Could not open video file: {video_path}")
            
            frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            fps = int(cap.get(cv2.CAP_PROP_FPS))
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
            
            with tqdm(total=total_frames, desc="Processing video") as pbar:
                frame_count = 0
                while cap.isOpened():
                    ret, frame = cap.read()
                    if not ret:
                        break
                        
                    # Process frame
                    _, pred_mask_resized = self.process_frame(frame, transform)
                    if pred_mask_resized is not None:
                        colored_mask = self.create_colored_mask(pred_mask_resized)
                        labeled_mask = self.add_labels_to_mask(colored_mask, pred_mask_resized)
                        
                        if overlay:
                            # Overlay mask on original frame
                            alpha = 0.5
                            output_frame = cv2.addWeighted(frame, 1-alpha, labeled_mask, alpha, 0)
                        else:
                            output_frame = labeled_mask
                            
                        out.write(output_frame)
                    
                    frame_count += 1
                    pbar.update(1)
            
            print(f"Processed {frame_count} frames successfully")
            
        except Exception as e:
            print(f"Error processing video: {e}")
        
        finally:
            if 'cap' in locals():
                cap.release()
            if 'out' in locals():
                out.release()


In [None]:

model = UNet(n_channels=3, n_classes=14).to(device)
model.load_state_dict(torch.load('best_model.pth'))

processor = VideoProcessor(
    model=model,
    device=device,
    target_size=(256, 256)  # Model's target size
)

processor.process_video(
    video_path="any_size_video.mp4",
    output_path="processed_output.mp4",
    transform=get_transforms("val"),
    overlay=True  # Set to False for mask-only output
)