Trying out different models and their performance for the live depth mapping.

In [None]:
%pip install transformers torch torchvision accelerate aiortc aiohttp Pillow --quiet


import asyncio, cv2, numpy as np, aiohttp, logging, base64, torch
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.mediastreams import MediaStreamError
from IPython.display import display, HTML
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
# from transformers import GLPNImageProcessor, GLPNForDepthEstimation      # For glpn models

In [None]:
WHEP_URL = "https://(          Your Url          )/whep"
DISPLAY_FPS =  30   #  15

logging.getLogger("libav").setLevel(logging.CRITICAL)
logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR)

try:
    checkpoint = "LiheYoung/depth-anything-small-hf"       
    # "vinvino02/glpn-kitti"   "vinvino02/glpn-nyu"    "apple/DepthPro-hf"   "LiheYoung/depth-anything-large-hf"
    # "LiheYoung/depth-anything-small-hf"    "Intel/dpt-large"

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device == "cuda" else torch.float32
    print(f"Loading model on device: '{device}' with dtype: {dtype}")

    image_processor = AutoImageProcessor.from_pretrained(checkpoint)
    depth_model = AutoModelForDepthEstimation.from_pretrained(
        checkpoint,
        torch_dtype=dtype,
        device_map="auto"
    )
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"Failed to load model: {e}")
    raise

In [None]:
class WebRTCStreamViewer:
    def __init__(self, whep_url, processor, model, device, dtype):
        self.whep_url = whep_url
        self.processor = processor
        self.model = model
        self.device = device
        self.dtype = dtype
        self.pc = RTCPeerConnection()
        self.done = asyncio.Event()
        self.lock = asyncio.Lock()
        self.latest_frame = None

        @self.pc.on("track")
        async def on_track(track):
            if track.kind == "video":
                asyncio.create_task(self._frame_receiver_task(track))

        @self.pc.on("connectionstatechange")
        async def on_connectionstatechange():
            if self.pc.connectionState in ["failed", "closed", "disconnected"]:
                self.done.set()

    async def _frame_receiver_task(self, track):
        while not self.done.is_set():
            try:
                frame = await track.recv()
                img = frame.to_ndarray(format="bgr24")
                async with self.lock:
                    self.latest_frame = img
            except MediaStreamError:
                return

    async def _display_loop_task(self):
        display_handle = display(HTML('<img>'), display_id=True)

        while not self.done.is_set():
            frame_to_display = None
            async with self.lock:
                if self.latest_frame is not None:
                    frame_to_display = self.latest_frame.copy()

            if frame_to_display is not None:
                original_h, original_w, _ = frame_to_display.shape
                
                rgb_image = Image.fromarray(cv2.cvtColor(frame_to_display, cv2.COLOR_BGR2RGB))

                inputs = self.processor(images=rgb_image, return_tensors="pt")
                pixel_values = inputs.pixel_values.to(self.device, dtype=self.dtype)

                with torch.no_grad():
                    outputs = self.model(pixel_values)
                    predicted_depth = outputs.predicted_depth

                prediction = torch.nn.functional.interpolate(
                    predicted_depth.unsqueeze(1),
                    size=(original_h, original_w),
                    mode="bicubic",
                    align_corners=False,
                ).squeeze()

                output = prediction.cpu().numpy()
                formatted = (output * 255 / np.max(output)).astype("uint8")
                depth_colormap = cv2.applyColorMap(formatted, cv2.COLORMAP_J)

                combined_frame = np.hstack((frame_to_display, depth_colormap))

                _, buffer = cv2.imencode('.jpg', combined_frame)
                b64_str = base64.b64encode(buffer).decode('utf-8')
                data_url = f"data:image/jpeg;base64,{b64_str}"
                display_handle.update(HTML(f'<img src="{data_url}" style="width: 80%;" />'))

            await asyncio.sleep(1 / DISPLAY_FPS)

    async def run(self):
        self.pc.addTransceiver("video", direction="recvonly")
        offer = await self.pc.createOffer()
        await self.pc.setLocalDescription(offer)

        try:
            display_task = asyncio.create_task(self._display_loop_task())
            async with aiohttp.ClientSession() as session:
                async with session.post(self.whep_url, data=self.pc.localDescription.sdp,
                                        headers={"Content-Type": "application/sdp"}, timeout=15) as resp:
                    if resp.status == 201:
                        answer_sdp = await resp.text()
                        await self.pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type="answer"))
                    else:
                        self.done.set()
            await self.done.wait()
        finally:
            if 'display_task' in locals() and not display_task.done():
                display_task.cancel()
            if self.pc.connectionState != "closed":
                await self.pc.close()

async def main():
    viewer = WebRTCStreamViewer(WHEP_URL, image_processor, depth_model, device, dtype)
    await viewer.run()

try:
    await main()
except (KeyboardInterrupt, asyncio.CancelledError):
    pass


More Optimized Code:

In [None]:
%pip install transformers torch torchvision accelerate aiortc aiohttp Pillow --quiet

import asyncio, cv2, numpy as np, aiohttp, logging, base64, torch, time
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.mediastreams import MediaStreamError
from IPython.display import display, HTML
from transformers import AutoImageProcessor, AutoModelForDepthEstimation

In [None]:

WHEP_URL = "https://(     your Url       )/video/whep"
DISPLAY_FPS = 30  

logging.getLogger("libav").setLevel(logging.CRITICAL)
logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR)


try:
    checkpoint = "LiheYoung/depth-anything-small-hf"

    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Using float16 on CUDA.
    dtype = torch.float16 if device == "cuda" else torch.float32
    print(f"Loading model on device: '{device}' with dtype: {dtype}")

    image_processor = AutoImageProcessor.from_pretrained(checkpoint)
    depth_model = AutoModelForDepthEstimation.from_pretrained(
        checkpoint,
        torch_dtype=dtype,
        # device_map="auto" 
        device_map="auto"
    )

    # if hasattr(torch, 'compile'):
    #     print("Compiling model with torch.compile()... This may take a moment.")
    #     # This will JIT-compile the model on the first run for your specific hardware.
    #     depth_model = torch.compile(depth_model, mode="max-autotune")

    print("Depth Pro model and processor loaded successfully.")
except Exception as e:
    print(f"Failed to load model: {e}")
    raise



class WebRTCStreamViewer:
    def __init__(self, whep_url, processor, model, device, dtype):
        self.whep_url = whep_url
        self.processor = processor
        self.model = model
        self.device = device
        self.dtype = dtype
        self.pc = RTCPeerConnection()
        self.done = asyncio.Event()
        self.lock = asyncio.Lock()
        self.latest_frame = None
        self.last_frame_timestamp = 0

        @self.pc.on("track")
        async def on_track(track):
            print(f"Track {track.kind} received")
            if track.kind == "video":
                asyncio.create_task(self._frame_receiver_task(track))

        @self.pc.on("connectionstatechange")
        async def on_connectionstatechange():
            print(f"Connection state is {self.pc.connectionState}")
            if self.pc.connectionState in ["failed", "closed", "disconnected"]:
                self.done.set()

    async def _frame_receiver_task(self, track):
        """Receives frames and puts the latest one into a shared variable."""
        while not self.done.is_set():
            try:
                frame = await track.recv()
                img = frame.to_ndarray(format="bgr24")
                async with self.lock:
                    self.latest_frame = img
                    self.last_frame_timestamp = frame.time
            except MediaStreamError:
                print("Stream ended")
                return
            except Exception as e:
                print(f"Error in receiver task: {e}")
                return

    async def _display_loop_task(self):
        """Processes and displays frames at a target FPS."""
        display_handle = display(HTML('<img>'), display_id=True)
        target_delay = 1 / DISPLAY_FPS
        last_processed_timestamp = 0

        while not self.done.is_set():
            start_time = time.perf_counter() 

            frame_to_process = None
            current_frame_timestamp = 0

            async with self.lock:
                if self.latest_frame is not None and self.last_frame_timestamp > last_processed_timestamp:
                    frame_to_process = self.latest_frame.copy()
                    current_frame_timestamp = self.last_frame_timestamp
                    last_processed_timestamp = current_frame_timestamp

            if frame_to_process is not None:
                original_h, original_w, _ = frame_to_process.shape

                rgb_frame = cv2.cvtColor(frame_to_process, cv2.COLOR_BGR2RGB)
                
                inputs = self.processor(images=rgb_frame, return_tensors="pt")
                pixel_values = inputs.pixel_values.to(self.device, self.dtype)

                with torch.no_grad():       # with torch.inference_mode():    
                    outputs = self.model(pixel_values)
                    predicted_depth = outputs.predicted_depth

                    #'bilinear' for faster interpolation.
                    prediction = torch.nn.functional.interpolate(
                        predicted_depth.unsqueeze(1),
                        size=(original_h, original_w),
                        mode="bilinear", #  "bicubic"
                        align_corners=False,
                    )

                    
                    p_min = torch.min(prediction)
                    p_max = torch.max(prediction)
                    if p_max > p_min:
                        normalized_prediction = (prediction - p_min) / (p_max - p_min)
                    else:
                        normalized_prediction = torch.zeros_like(prediction)
                    
                    
                    output_normalized = (normalized_prediction.squeeze() * 255.0).cpu().to(torch.uint8).numpy()

                depth_colormap = cv2.applyColorMap(output_normalized, cv2.COLORMAP_JET)

                
                combined_frame = np.hstack((frame_to_process, depth_colormap))

                _, buffer = cv2.imencode('.jpg', combined_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
                b64_str = base64.b64encode(buffer).decode('utf-8')
                data_url = f"data:image/jpeg;base64,{b64_str}"
                
                
                display_handle.update(HTML(f'<img src="{data_url}" style="width: 80%;" />'))

            # Adaptive sleep for stable FPS
            # change
            processing_time = time.perf_counter() - start_time
            sleep_duration = max(0, target_delay - processing_time)
            await asyncio.sleep(sleep_duration)


    async def run(self):
        """Main method to start the WebRTC connection and processing loops."""
        self.pc.addTransceiver("video", direction="recvonly")
        offer = await self.pc.createOffer()
        await self.pc.setLocalDescription(offer)

        display_task = None
        try:
            display_task = asyncio.create_task(self._display_loop_task())
            async with aiohttp.ClientSession() as session:
                print("Attempting to connect to WHEP endpoint...")
                async with session.post(self.whep_url, data=self.pc.localDescription.sdp,
                                        headers={"Content-Type": "application/sdp"}, timeout=15) as resp:
                    if resp.status == 201:
                        print("WHEP connection successful. Receiving answer...")
                        answer_sdp = await resp.text()
                        await self.pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type="answer"))
                    else:
                        print(f"WHEP connection failed with status {resp.status}: {await resp.text()}")
                        self.done.set()
            
            await self.done.wait()
        except Exception as e:
            print(f"An error occurred during run: {e}")
        finally:
            print("Cleaning up...")
            if display_task and not display_task.done():
                display_task.cancel()
            if self.pc.connectionState != "closed":
                await self.pc.close()
            print("Cleanup complete.")

async def main():
    viewer = WebRTCStreamViewer(WHEP_URL, image_processor, depth_model, device, dtype)
    await viewer.run()


try:
    await main()
except (KeyboardInterrupt, asyncio.CancelledError):
    print("Stream stopped by user.")
finally:
    print("Main task finished.")