In [None]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install opencv-python numpy Pillow

# Import libraries
import torch
import cv2
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor

# Download RVM model (CORRECTED)
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3", pretrained=True)  # or "resnet50" for better quality
model = model.eval().cuda() if torch.cuda.is_available() else model.eval()

# Upload video file
from google.colab import files
uploaded = files.upload()
video_path = list(uploaded.keys())[0]

# Video processing function
def process_video_rvm(input_path, output_path, background_color=(0, 0, 0)):
    # Initialize video capture
    cap = cv2.VideoCapture(input_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Initialize video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    # Initialize recurrent states
    rec = [None] * 4  # RVM uses 4 recurrent states
    downsample_ratio = 0.25  # Adjust based on video resolution

    # Background tensor
    bg = torch.tensor(background_color).view(1, 3, 1, 1).float() / 255
    bg = bg.cuda() if torch.cuda.is_available() else bg

    frame_count = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Convert frame to tensor
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        src = ToTensor()(frame_rgb).unsqueeze(0)
        src = src.cuda() if torch.cuda.is_available() else src

        # Inference
        with torch.no_grad():
            fgr, pha, *rec = model(src, *rec, downsample_ratio)

        # Composite with background
        out_frame = fgr * pha + bg * (1 - pha)

        # Convert to numpy array
        out_frame = out_frame[0].permute(1, 2, 0).cpu().numpy()
        out_frame = (out_frame * 255).astype(np.uint8)
        out_frame = cv2.cvtColor(out_frame, cv2.COLOR_RGB2BGR)

        out.write(out_frame)
        frame_count += 1
        if frame_count % 10 == 0:
            print(f"Processed frame {frame_count}")

    cap.release()
    out.release()
    print(f"Processing complete. Saved to {output_path}")
    return output_path

# Process video
output_path = 'output-final-crowd_rvm.mp4'
process_video_rvm(video_path, output_path)

# Download result
files.download(output_path)


