<a href="https://colab.research.google.com/github/SohyunKang/asd_video/blob/main/gazelle_demo_edited_250915.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import requests
from io import BytesIO
import numpy as np

In [2]:
# load a simple face detector
!pip install retina-face
from retinaface import RetinaFace



In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# load Gaze-LLE model
model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout')
model.eval()
model.to(device)

Using cache found in /root/.cache/torch/hub/fkryan_gazelle_main
Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


GazeLLE(
  (backbone): DinoV2Backbone(
    (model): DinoVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
        (norm): Identity()
      )
      (blocks): ModuleList(
        (0-23): 24 x NestedTensorBlock(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): MemEffAttention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): LayerScale()
          (drop_path1): Identity()
          (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            (drop): Dropout(p=0.0, inpla

In [4]:
import cv2
from google.colab.patches import cv2_imshow
from IPython.display import HTML
from base64 import b64encode

input_video_path = "/content/IF2001_1_1_1023041312_1.mp4"
cap = cv2.VideoCapture(input_video_path)

frames = []
while True:
    ret, frame = cap.read()
    if not ret:
        break
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    width, height, _ = frame.shape
    frames.append(frame)

cap.release()

# Colab에서 바로 재생할 수 있도록 HTML video 태그로 출력
mp4 = open(input_video_path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML(f"""
<video width=640 controls>
      <source src="{data_url}" type="video/mp4">
</video>
""")


In [5]:
# detect faces
bboxes_all = []
valid_frame_num = []
for n, frame in enumerate(frames):
  image = frame
  resp = RetinaFace.detect_faces(np.array(image))
  bboxes = [resp[key]['facial_area'] for key in resp.keys()]
  if bboxes:
    bboxes_all.append(bboxes)
    valid_frame_num.append(n)
    # print(n, bboxes)
print(bboxes_all, valid_frame_num)

[[[np.int64(405), np.int64(603), np.int64(502), np.int64(661)]], [[np.int64(408), np.int64(605), np.int64(497), np.int64(663)]], [[np.int64(407), np.int64(604), np.int64(497), np.int64(664)]], [[np.int64(408), np.int64(604), np.int64(498), np.int64(665)]], [[np.int64(407), np.int64(603), np.int64(497), np.int64(664)]], [[np.int64(407), np.int64(603), np.int64(499), np.int64(665)]], [[np.int64(407), np.int64(603), np.int64(498), np.int64(666)]], [[np.int64(406), np.int64(602), np.int64(501), np.int64(666)]], [[np.int64(407), np.int64(602), np.int64(500), np.int64(667)]], [[np.int64(406), np.int64(602), np.int64(502), np.int64(667)]], [[np.int64(407), np.int64(602), np.int64(500), np.int64(667)]], [[np.int64(403), np.int64(602), np.int64(502), np.int64(669)]], [[np.int64(401), np.int64(599), np.int64(499), np.int64(670)]], [[np.int64(399), np.int64(599), np.int64(501), np.int64(671)]], [[np.int64(400), np.int64(596), np.int64(502), np.int64(671)]], [[np.int64(401), np.int64(592), np.int6

In [6]:
# prepare gazelle input
outputs = []
for bboxes, n in zip(bboxes_all, valid_frame_num):
  image = frames[n]
  img_tensor = transform(image).unsqueeze(0).to(device)
  norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]]

  input = {
      "images": img_tensor, # [num_images, 3, 448, 448]
      "bboxes": norm_bboxes # [[img1_bbox1, img1_bbox2...], [img2_bbox1, img2_bbox2]...]
  }

  with torch.no_grad():
      output = model(input)

  outputs.append(output)

  img1_person1_heatmap = output['heatmap'][0][0] # [64, 64] heatmap
  # print(img1_person1_heatmap.shape)
  if model.inout:
    img1_person1_inout = output['inout'][0][0] # gaze in frame score (if model supports inout prediction)
    # print(img1_person1_inout.item())


In [7]:
# visualize predicted gaze heatmap for each person and gaze in/out of frame score

def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None):
    if isinstance(heatmap, torch.Tensor):
        heatmap = heatmap.detach().cpu().numpy()
    heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR)
    heatmap = plt.cm.jet(np.array(heatmap) / 255.)
    heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
    heatmap = Image.fromarray(heatmap).convert("RGBA")
    heatmap.putalpha(90)
    overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap)

    if bbox is not None:
        width, height = pil_image.size
        xmin, ymin, xmax, ymax = bbox
        draw = ImageDraw.Draw(overlay_image)
        draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01))

        if inout_score is not None:
          text = f"in-frame: {inout_score:.2f}"
          text_width = draw.textlength(text)
          text_height = int(height * 0.01)
          text_x = xmin * width
          text_y = ymax * height + text_height
          draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
    return overlay_image

output_video_path = "output_video.mp4"
height, width = frames[0].shape[:2]  # (H, W)

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, 20.0, (width, height))  # (W, H)

# valid_frame_num을 빠르게 lookup 할 수 있도록 dict 생성
frame_to_output = {n: outputs[idx] for idx, n in enumerate(valid_frame_num)}
idx = 0
for f_idx in range(len(frames)):
    pil_img = Image.fromarray(frames[f_idx]) # 원본 프레임

    if f_idx in frame_to_output:  # 결과가 있는 프레임이면 heatmap 시각화
        output = frame_to_output[f_idx]
        bboxes = bboxes_all[idx]
        idx += 1
        heatmaps = output['heatmap'][0]
        inouts = output['inout'][0] if model.inout else [None]*len(bboxes)

        overlay = pil_img
        for i in range(len(bboxes)):
            overlay = visualize_heatmap(
                overlay,
                heatmaps[i],
                np.array(bboxes[i]) / np.array([width, height, width, height]),  # 정규화
                inout_score=inouts[i] if inouts is not None else None
            )
    else:  # 결과 없는 프레임은 원본만 사용
        overlay = pil_img

    overlay_rgb = np.array(overlay.convert("RGB"), dtype=np.uint8)   # (H, W, 3) uint8 보장
    frame_bgr = cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR)
    out.write(frame_bgr)


out.release()

# Colab에서 재생
from IPython.display import HTML
from base64 import b64encode

mp4 = open(output_video_path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML(f"""
<video width=640 controls>
      <source src="{data_url}" type="video/mp4">
</video>
""")

# 오디오 포함 최종 영상
final_video = "final_with_audio.mp4"

# ffmpeg로 원본 오디오 트랙을 복사해서 합치기
!ffmpeg -y -i {output_video_path} -i {input_video_path} -c copy -map 0:v:0 -map 1:a:0 {final_video}


ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab

In [41]:
# combined visualization with maximal gaze points for each person

def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
    colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']
    overlay_image = pil_image.convert("RGBA")
    draw = ImageDraw.Draw(overlay_image)
    width, height = pil_image.size

    for i in range(len(bboxes)):
        bbox = bboxes[i]
        xmin, ymin, xmax, ymax = bbox
        color = colors[i % len(colors)]
        draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01))

        if inout_scores is not None:
            inout_score = inout_scores[i]
            text = f"in-frame: {inout_score:.2f}"
            text_width = draw.textlength(text)
            text_height = int(height * 0.01)
            text_x = xmin * width
            text_y = ymax * height + text_height
            draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05)))

        if inout_scores is not None and inout_score > inout_thresh:
            heatmap = heatmaps[i]
            heatmap_np = heatmap.detach().cpu().numpy()
            max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)
            gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
            gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
            bbox_center_x = ((xmin + xmax) / 2) * width
            bbox_center_y = ((ymin + ymax) / 2) * height

            draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height)))
            draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height)))

    return overlay_image

plt.figure(figsize=(10,10))
plt.imshow(visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5))
plt.axis('off')
plt.show()

AttributeError: 'numpy.ndarray' object has no attribute 'convert'

<Figure size 1000x1000 with 0 Axes>