# code to render stream

In [None]:
!pip install aiortc aiohttp --quiet

In [None]:
import aiortc, aiohttp
import asyncio, cv2, numpy as np, aiohttp, logging, base64
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.mediastreams import MediaStreamError
from IPython.display import display, HTML


WHEP_URL = "your stream url/whep"
DISPLAY_FPS = 15


logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR)
# logging.basicConfig(level=logging.WARNING)
# logging.getLogger("aiortc").setLevel(logging.WARNING)

class WebRTCStreamViewer:
    def __init__(self, whep_url):
        self.whep_url = whep_url
        self.pc = RTCPeerConnection()
        self.done = asyncio.Event()
        self.lock = asyncio.Lock()
        self.latest_frame = None


        @self.pc.on("track")
        async def on_track(track):
            if track.kind == "video":
                asyncio.create_task(self._frame_receiver_task(track))

        @self.pc.on("connectionstatechange")
        async def on_connectionstatechange():
            print(f"Connection state is {self.pc.connectionState}")
            if self.pc.connectionState in ["failed", "closed", "disconnected"]:
                self.done.set()


    async def _frame_receiver_task(self, track):
        while not self.done.is_set():
            try:
                frame = await track.recv()
                img = frame.to_ndarray(format="bgr24")
                async with self.lock:
                    self.latest_frame = img
            except MediaStreamError:
                return


    async def _display_loop_task(self):
        display_handle = display(HTML('<img>'), display_id=True)

        while not self.done.is_set():
            frame_to_display = None
            async with self.lock:
                if self.latest_frame is not None:

                    frame_to_display = self.latest_frame.copy()

            if frame_to_display is not None:
                # --- DEPTH ANALYSIS CODE GOES HERE ---
                processed_frame = frame_to_display
                # ---


                _, buffer = cv2.imencode('.jpg', processed_frame)
                b64_str = base64.b64encode(buffer).decode('utf-8')
                data_url = f"data:image/jpeg;base64,{b64_str}"
                display_handle.update(HTML(f'<img src="{data_url}" style="width: 80%;" />'))


            await asyncio.sleep(1 / DISPLAY_FPS)

    async def run(self):
        self.pc.addTransceiver("video", direction="recvonly")
        offer = await self.pc.createOffer()
        await self.pc.setLocalDescription(offer)

        try:
            display_task = asyncio.create_task(self._display_loop_task())
            async with aiohttp.ClientSession() as session:
                print(f"Connecting to {self.whep_url}...")
                async with session.post(self.whep_url, data=self.pc.localDescription.sdp,
                                        headers={"Content-Type": "application/sdp"}, timeout=15) as resp:
                    if resp.status == 201:
                        # print("WHEP connection accepted.")
                        answer_sdp = await resp.text()
                        await self.pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type="answer"))
                    else:
                        # print(f"Server error: {resp.status} {await resp.text()}")
                        self.done.set()


            await self.done.wait()

        finally:
            # print("Shutting down")
            if 'display_task' in locals() and not display_task.done():
                display_task.cancel()
            if self.pc.connectionState != "closed":
                await self.pc.close()

async def main():
    viewer = WebRTCStreamViewer(WHEP_URL)
    await viewer.run()

try:
    await main()
    # asyncio.run(main())
except KeyboardInterrupt:
    print("\nStream stopped.")
except asyncio.CancelledError:
    pass

# Applying Depth-Estimation on stream

In [None]:
%pip install transformers torch torchvision accelerate aiortc aiohttp accelerate Pillow --quiet

import asyncio, cv2, numpy as np, aiohttp, logging, base64, torch
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.mediastreams import MediaStreamError
from IPython.display import display, HTML
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForDepthEstimation


In [None]:

try:
    checkpoint = "apple/DepthPro-hf"    # It is slower on gpu but it gives great outputs
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device == "cuda" else torch.float32
    print(f"Loading model on device: '{device}' with dtype: {dtype}")

    image_processor = AutoImageProcessor.from_pretrained(checkpoint)
    depth_model = AutoModelForDepthEstimation.from_pretrained(
        checkpoint,
        torch_dtype=dtype,
        device_map="auto"
    )
    print("Depth Pro model and processor loaded successfully.")
except Exception as e:
    print(f"Failed to load model: {e}")
    raise

In [None]:
WHEP_URL = "Your url route/whep"
DISPLAY_FPS = 15    # fixed stream fps for colab, otherwise colab starts to lag

logging.getLogger("libav").setLevel(logging.CRITICAL)




class WebRTCStreamViewer:
    def __init__(self, whep_url, processor, model, device, dtype):
        self.whep_url = whep_url
        self.processor = processor
        self.model = model
        self.device = device
        self.dtype = dtype
        self.pc = RTCPeerConnection()
        self.done = asyncio.Event()
        self.lock = asyncio.Lock()
        self.latest_frame = None
        self.video_ssrc = None

        @self.pc.on("track")
        async def on_track(track):
            if track.kind == "video":
                receiver = next((r for r in self.pc.getReceivers() if r.track == track), None)
                if receiver and receiver.getSynchronizationSources():
                    self.video_ssrc = receiver.getSynchronizationSources()[0].ssrc
                    asyncio.create_task(self._keyframe_requester_task())
                asyncio.create_task(self._frame_receiver_task(track))

        @self.pc.on("connectionstatechange")
        async def on_connectionstatechange():
            if self.pc.connectionState in ["failed", "closed", "disconnected"]:
                self.done.set()

    async def _keyframe_requester_task(self):
        while not self.done.is_set():
            async with self.lock:
                if self.latest_frame is not None:
                    break
            if self.video_ssrc:
                try:
                    await self.pc.sendFeedback(ssrc=self.video_ssrc, fmt=1)
                except Exception:
                    pass
            await asyncio.sleep(1)

    async def _frame_receiver_task(self, track):
        while not self.done.is_set():
            try:
                frame = await track.recv()
                img = frame.to_ndarray(format="bgr24")
                async with self.lock:
                    self.latest_frame = img
            except MediaStreamError:
                await asyncio.sleep(0.01)
                continue
            except Exception:
                self.done.set()
                break

    async def _display_loop_task(self):
        display_handle = display(HTML('<img>'), display_id=True)
        while not self.done.is_set():
            frame_to_display = None
            async with self.lock:
                if self.latest_frame is not None:
                    frame_to_display = self.latest_frame.copy()

            if frame_to_display is not None:
                original_h, original_w, _ = frame_to_display.shape
                rgb_image = Image.fromarray(cv2.cvtColor(frame_to_display, cv2.COLOR_BGR2RGB))
                inputs = self.processor(images=rgb_image, return_tensors="pt")
                pixel_values = inputs.pixel_values.to(self.device, dtype=self.dtype)

                with torch.no_grad():
                    outputs = self.model(pixel_values)
                    predicted_depth = outputs.predicted_depth

                prediction = torch.nn.functional.interpolate(
                    predicted_depth.unsqueeze(1),
                    size=(original_h, original_w),
                    mode="bicubic",
                    align_corners=False,
                ).squeeze()

                output = prediction.cpu().numpy()
                formatted = (output * 255 / np.max(output)).astype("uint8")
                depth_colormap = cv2.applyColorMap(formatted, cv2.COLORMAP_JET)
                combined_frame = np.hstack((frame_to_display, depth_colormap))

                _, buffer = cv2.imencode('.jpg', combined_frame)
                b64_str = base64.b64encode(buffer).decode('utf-8')
                data_url = f"data:image/jpeg;base64,{b64_str}"
                display_handle.update(HTML(f'<img src="{data_url}" style="width: 80%;" />'))

            await asyncio.sleep(1 / DISPLAY_FPS)

    async def run(self):
        self.pc.addTransceiver("video", direction="recvonly")
        offer = await self.pc.createOffer()
        await self.pc.setLocalDescription(offer)

        try:
            display_task = asyncio.create_task(self._display_loop_task())
            async with aiohttp.ClientSession() as session:
                async with session.post(self.whep_url, data=self.pc.localDescription.sdp,
                                        headers={"Content-Type": "application/sdp"}, timeout=15) as resp:
                    if resp.status == 201:
                        answer_sdp = await resp.text()
                        await self.pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type="answer"))
                    else:
                        self.done.set()
            await self.done.wait()
        finally:
            if 'display_task' in locals() and not display_task.done():
                display_task.cancel()
            if self.pc.connectionState != "closed":
                await self.pc.close()

async def main():
    viewer = WebRTCStreamViewer(WHEP_URL, image_processor, depth_model, device, dtype)
    await viewer.run()

try:
    await main()
except (KeyboardInterrupt, asyncio.CancelledError):
    pass

# Using GLPN Models

In [None]:
%pip install transformers torch torchvision accelerate aiortc aiohttp Pillow --quiet

import asyncio, cv2, numpy as np, aiohttp, logging, base64, torch
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.mediastreams import MediaStreamError
from IPython.display import display, HTML
from PIL import Image
from transformers import GLPNImageProcessor, GLPNForDepthEstimation

In [None]:
WHEP_URL = "Your stream url/whep"
DISPLAY_FPS = 30      # increased stream fps 
CHECKPOINT = "vinvino02/glpn-kitti"       # Runs really fast on gpu but less output quality
# vinvino02/glpn-nyu   is also really fast and gives better output than glpn-kitti
# Intel/intel-dpt-large  is also great but runs slower

logging.getLogger("libav").setLevel(logging.CRITICAL)


device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

try:
    print(f"Loading model '{CHECKPOINT}' on device: '{device}' with dtype: {dtype}")
    image_processor = GLPNImageProcessor.from_pretrained(CHECKPOINT)
    depth_model = GLPNForDepthEstimation.from_pretrained(CHECKPOINT).to(device, dtype=dtype)
    print("Model and processor loaded successfully.")
except Exception as e:
    print(f"Failed to load model: {e}")
    raise

In [None]:
class WebRTCStreamViewer:
    def __init__(self, whep_url, processor, model, device, dtype):
        self.whep_url = whep_url
        self.processor = processor
        self.model = model
        self.device = device
        self.dtype = dtype
        self.pc = RTCPeerConnection()
        self.done = asyncio.Event()
        self.lock = asyncio.Lock()
        self.latest_frame = None
        self.video_track = None
        asyncio.create_task(self._keyframe_watchdog_task())

        @self.pc.on("track")
        async def on_track(track):
            if track.kind == "video":
                self.video_track = track
                asyncio.create_task(self._frame_receiver_task(track))

        @self.pc.on("connectionstatechange")
        async def on_connectionstatechange():
            if self.pc.connectionState in ["failed", "closed", "disconnected"]:
                self.done.set()

    async def _keyframe_watchdog_task(self):
        await asyncio.sleep(0.1)
        while self.pc.connectionState != "connected":
            if self.done.is_set(): return
            await asyncio.sleep(0.1)
        
        video_ssrc = None
        while video_ssrc is None:
            if self.done.is_set() or self.video_track is None: return
            receiver = next((r for r in self.pc.getReceivers() if r.track == self.video_track), None)
            if receiver and receiver.getSynchronizationSources():
                video_ssrc = receiver.getSynchronizationSources()[0].ssrc
                break
            await asyncio.sleep(0.5)

        while not self.done.is_set():
            async with self.lock:
                if self.latest_frame is not None: return
            try:
                await self.pc.sendFeedback(ssrc=video_ssrc, fmt=1)
            except Exception: pass
            await asyncio.sleep(1)

    async def _frame_receiver_task(self, track):
        while not self.done.is_set():
            try:
                frame = await track.recv()
                img = frame.to_ndarray(format="bgr24")
                async with self.lock:
                    self.latest_frame = img
            except MediaStreamError:
                await asyncio.sleep(0.01)
                continue
            except Exception:
                self.done.set()
                break

    async def _display_loop_task(self):
        display_handle = display(HTML('<img>'), display_id=True)
        while not self.done.is_set():
            frame_to_display = None
            async with self.lock:
                if self.latest_frame is not None:
                    frame_to_display = self.latest_frame.copy()

            if frame_to_display is not None:
                original_h, original_w, _ = frame_to_display.shape
                rgb_image = Image.fromarray(cv2.cvtColor(frame_to_display, cv2.COLOR_BGR2RGB))
                
                inputs = self.processor(images=rgb_image, return_tensors="pt").to(self.device)
                
                with torch.no_grad():
                    outputs = self.model(**inputs)
                    predicted_depth = outputs.predicted_depth

                prediction = torch.nn.functional.interpolate(
                    predicted_depth.unsqueeze(1),
                    size=(original_h, original_w),
                    mode="bicubic",
                    align_corners=False,
                ).squeeze()

                output = prediction.cpu().numpy()
                formatted = ((output - output.min()) * 255 / (output.max() - output.min())).astype("uint8")
                
                depth_colormap = cv2.applyColorMap(formatted, cv2.COLORMAP_JET)
                
                combined_frame = np.hstack((frame_to_display, depth_colormap))

                _, buffer = cv2.imencode('.jpg', combined_frame)
                b64_str = base64.b64encode(buffer).decode('utf-8')
                data_url = f"data:image/jpeg;base64,{b64_str}"
                display_handle.update(HTML(f'<img src="{data_url}" style="width: 80%;" />'))

            await asyncio.sleep(1 / DISPLAY_FPS)

    async def run(self):
        self.pc.addTransceiver("video", direction="recvonly")
        offer = await self.pc.createOffer()
        await self.pc.setLocalDescription(offer)
        try:
            display_task = asyncio.create_task(self._display_loop_task())
            async with aiohttp.ClientSession() as session:
                async with session.post(self.whep_url, data=self.pc.localDescription.sdp,
                                        headers={"Content-Type": "application/sdp"}, timeout=15) as resp:
                    if resp.status == 201:
                        answer_sdp = await resp.text()
                        await self.pc.setRemoteDescription(RTCSessionDescription(sdp=answer_sdp, type="answer"))
                    else:
                        self.done.set()
            await self.done.wait()
        finally:
            if 'display_task' in locals() and not display_task.done():
                display_task.cancel()
            if self.pc.connectionState != "closed":
                await self.pc.close()

async def main():
    viewer = WebRTCStreamViewer(WHEP_URL, image_processor, depth_model, device, dtype)
    await viewer.run()

try:
    await main()
except (KeyboardInterrupt, asyncio.CancelledError):
    pass