In [None]:
# credits: https://github.com/theAIGuysCode/colab-webcam/blob/main/yolov4_webcam.ipynb

In [None]:
from IPython.display import display, Javascript
from google.colab.output import eval_js
from base64 import b64decode, b64encode
import io
import PIL
from PIL import Image
import numpy as np
import cv2

In [None]:
# JavaScript to properly create our live video stream using our webcam as input
def video_stream():
  js = Javascript('''
    var video;
    var div = null;
    var stream;
    var captureCanvas;
    var imgElement;
    var labelElement;

    var pendingResolve = null;
    var shutdown = false;
    var clickedPositionX = -1;
    var clickedPositionY = -1;

    function removeDom() {
       stream.getVideoTracks()[0].stop();
       video.remove();
       div.remove();
       video = null;
       div = null;
       stream = null;
       imgElement = null;
       captureCanvas = null;
       labelElement = null;
    }

    function onAnimationFrame() {
      if (!shutdown) {
        window.requestAnimationFrame(onAnimationFrame);
      }
      if (pendingResolve) {
        var result = "";
        if (!shutdown) {
          captureCanvas.getContext('2d').drawImage(video, 0, 0, 640, 480, 0, 0, 640, 480);
          result = captureCanvas.toDataURL('image/jpeg', 0.8)
        }
        var lp = pendingResolve;
        pendingResolve = null;
        lp(result);
      }
    }

    async function createDom() {
      if (div !== null) {
        return stream;
      }

      div = document.createElement('div');
      div.style.border = '2px solid black';
      div.style.padding = '3px';
      div.style.width = '100%';
      div.style.maxWidth = '600px';
      document.body.appendChild(div);

      const modelOut = document.createElement('div');
      modelOut.innerHTML = "Status:";
      labelElement = document.createElement('span');
      labelElement.innerText = 'No data';
      labelElement.style.fontWeight = 'bold';
      modelOut.appendChild(labelElement);
      div.appendChild(modelOut);

      video = document.createElement('video');
      video.style.display = 'block';
      video.width = div.clientWidth - 6;
      video.setAttribute('playsinline', '');
      video.onclick = (event) => {
      };
      stream = await navigator.mediaDevices.getUserMedia(
          {video: { facingMode: "environment"}});
      div.appendChild(video);

      imgElement = document.createElement('img');
      imgElement.style.position = 'absolute';
      imgElement.style.zIndex = 1;
      imgElement.onclick = (event) => {
        const x = event.offsetX;  // x relative to the element
        const y = event.offsetY;  // y relative to the element
        clickedPositionX = x;
        clickedPositionY = y;
      };
      div.appendChild(imgElement);

      const instruction = document.createElement('div');
      instruction.innerHTML =
          '' +
          'When finished, click here to stop this demo';
      div.appendChild(instruction);
      instruction.onclick = () => { shutdown = true; };

      video.srcObject = stream;
      await video.play();

      captureCanvas = document.createElement('canvas');
      captureCanvas.width = 2*640; //video.videoWidth;
      captureCanvas.height = 480; //video.videoHeight;
      window.requestAnimationFrame(onAnimationFrame);

      return stream;
    }
    async function stream_frame(label, imgData) {
      if (shutdown) {
        removeDom();
        shutdown = false;
        return '';
      }

      var preCreate = Date.now();
      stream = await createDom();

      var preShow = Date.now();
      if (label != "") {
        labelElement.innerHTML = label;
      }

      if (imgData != "") {
        var videoRect = video.getClientRects()[0];
        imgElement.style.top = videoRect.top + "px";
        imgElement.style.left = videoRect.left + "px";
        imgElement.style.width = videoRect.width*2 + "px";
        imgElement.style.height = videoRect.height + "px";
        imgElement.src = imgData;
      }

      var preCapture = Date.now();
      var result = await new Promise(function(resolve, reject) {
        pendingResolve = resolve;
      });
      shutdown = false;
      returnX = Math.round(clickedPositionX*1.07);
      returnY = Math.round(clickedPositionY*1.07);
      clickedPositionX = -1;
      clickedPositionY = -1;

      return {'create': preShow - preCreate,
              'show': preCapture - preShow,
              'capture': Date.now() - preCapture,
              'x': returnX,
              'y': returnY,
              'shutdown': shutdown,
              'img': result};
    }
    ''')

  display(js)

def video_frame(label, bbox):
  data = eval_js('stream_frame("{}", "{}")'.format(label, bbox))
  return data


In [None]:
# function to convert the JavaScript object into an OpenCV image
def js_to_image(js_reply):
  """
  Params:
          js_reply: JavaScript object containing image from webcam
  Returns:
          img: OpenCV BGR image
  """
  if not js_reply:
      return None

  # decode base64 image
  image_bytes = b64decode(js_reply.split(',')[1])
  # convert bytes to numpy array
  jpg_as_np = np.frombuffer(image_bytes, dtype=np.uint8)
  # decode numpy array into OpenCV BGR image
  img = cv2.imdecode(jpg_as_np, flags=1)

  return img

# function to convert OpenCV Rectangle bounding box image into base64 byte string to be overlayed on video stream
def bbox_to_bytes(bbox_array):
  """
  Params:
          bbox_array: Numpy array (pixels) containing rectangle to overlay on video stream.
  Returns:
        bytes: Base64 image byte string
  """
  # convert array into PIL image
  bbox_PIL = PIL.Image.fromarray(bbox_array)
  iobuf = io.BytesIO()
  # format bbox into png for return
  bbox_PIL.save(iobuf, format='png')
  # format return string
  bbox_bytes = 'data:image/png;base64,{}'.format((str(b64encode(iobuf.getvalue()), 'utf-8')))

  return bbox_bytes

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
# Feature space of camera images, dimension: 384 features
import torch
import torchvision.transforms as T
import cv2
from PIL import Image

# 1. Load DINO (v1) model — ViT-S/16 (small version, 384-dim), Source: https://github.com/facebookresearch/dino
model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16').to(device)
model.eval()  # set to inference mode

# 2. Define preprocessing (standard ImageNet normalization)
transform = T.Compose([
    T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

# 3. Function to encode OpenCV image to 384-dim DINO v1 feature
def encode_image(cv2_img):
    # Convert OpenCV BGR → RGB
    img_rgb = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
    pil_img = Image.fromarray(img_rgb)

    # Apply transforms and add batch dimension
    img_tensor = transform(pil_img).unsqueeze(0)

    with torch.no_grad():
        # DINO v1 returns a dict of features — use [CLS] token
        features = model(img_tensor.to(device))
        # features shape: [1, 384]

    return features.squeeze().cpu().numpy()

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import cv2

In [None]:
# download variational autoencoder
!wget http://agentspace.org/download/mnist_cvae_decoder.pth
# load variational autoencoder
decoder = torch.load('mnist_cvae_decoder.pth', weights_only=False, map_location=device)
decoder.eval()

In [None]:
def render(x,y):
    img = decoder(torch.tensor([[x,y]],dtype=torch.float32).to(device)).squeeze(0).squeeze(0).detach().cpu().numpy()
    return img

def scale(x,y):
    return (x-320)*1.645*240/320/320, (y-240)*1.645/245

def unscale(x,y):
    return int(x*320*320/240/1.645+320), int(y*245/1.645+240)

# Create a grid image: 10 rows (digits 0–9), each row with 10 copies of the same digit
rows = []
for y in range(24):
    row = np.hstack([cv2.resize(render(*scale(x*20,y*20)),(20,20)) for x in range(32)])
    rows.append(row)

grid_img = np.vstack(rows)  # stack all rows vertically

plt.figure(figsize=(10, 10))
plt.imshow(grid_img, cmap='gray')
plt.axis('off')
plt.show()

In [None]:
scale(*unscale(100,50))

In [None]:
grid_img = (grid_img*255).astype(np.uint8)

In [None]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

In [None]:
def attn(q,K,V):
    d = K.shape[1]
    return softmax(q @ K.T / np.sqrt(d)) @ V

In [None]:
# start streaming video from webcam
video_stream()
# label for video
label_html = 'Capturing...'
# initialze bounding box to empty
bbox = ''
keys = []
values = []
while True:
    js_reply = video_frame(label_html, bbox)
    if not js_reply:
        break
    if js_reply["shutdown"]:
        break

    # convert JS response to OpenCV Image
    frame = js_to_image(js_reply["img"])
    if frame is None:
        break

    x = js_reply["x"]
    y = js_reply["y"]

    # create transparent overlay for bounding box
    bbox_array = np.zeros([480,640,4], dtype=np.uint8)
    right_array = cv2.cvtColor(grid_img, cv2.COLOR_GRAY2RGBA)

    # call model on frame
    query = encode_image(frame)

    # click by mouse
    if x != -1 and y != -1:
        if x >= 640:
            key = query
            value = np.array(scale(x-640,y))
            keys.append(key)
            values.append(value)

    # attention
    if len(keys) > 0:
        out = attn(query, np.array(keys),np.array(values))
        xx, yy = unscale(*out)
        right_array = cv2.circle(right_array, (xx, yy), 5, (0, 255, 0, 255), -1)

    # combine
    combined = cv2.hconcat([bbox_array, right_array])

    # Encode to base64 for overlay
    bbox_bytes = bbox_to_bytes(combined)

    # update bbox so next frame gets new overlay
    bbox = bbox_bytes


In [None]:
js_reply

In [None]:
not js_reply["img"]