In [34]:
import sys
sys.path.append('..')
import json
from pathlib import Path
from collections import defaultdict

import torch
import numpy as np
import decord
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

from ltvu.structures import ResponseTrack
from ltvu.metrics import compute_tious, compute_stious

In [29]:
p_pred = Path("43634_results.json.gz")
p_ann = Path("../data/vq_v2_val_anno.json")
all_anns = json.load(p_ann.open())
all_preds_stratified = json.load(p_pred.open())

quid2pred, quid2pos, quid2ann = defaultdict(list), {}, {}
for ann in all_anns:
    qset_uuid = f'{ann["annotation_uid"]}_{ann["query_set"]}'
    ann['bboxes'] = ResponseTrack.from_json(ann)
    quid2pos[qset_uuid] = ann['bboxes']
    quid2ann[qset_uuid] = ann

for pred_video in all_preds_stratified['results']['videos']:
    video_uid = pred_video['video_uid']
    for pred_clip in pred_video['clips']:
        clip_uid = pred_clip['clip_uid']
        for pred_qsets in pred_clip['predictions']:
            annotation_uid = pred_qsets['annotation_uid']
            for qset_id, qset in pred_qsets['query_sets'].items():
                qset_uuid = f'{annotation_uid}_{qset_id}'
                if qset_uuid in quid2pos:
                    quid2pred[qset_uuid].append(ResponseTrack.from_json(qset))

len(all_anns), len(quid2pred), len(quid2pos)

(4504, 4504, 4504)

In [18]:
tious: dict[str, float] = compute_tious(quid2pred, quid2pos)
stious: dict[str, float] = compute_stious(quid2pred, quid2pos)

list(tious.items())[:10], list(stious.items())[:10]

([('0040187f-4627-4bfa-b141-89c1322fcd22_1', 0.5),
  ('0040187f-4627-4bfa-b141-89c1322fcd22_2', 0.8),
  ('0040187f-4627-4bfa-b141-89c1322fcd22_3', 0.4444444444444444),
  ('00c056f0-09d1-42df-9d5a-9752cb899567_1', 0.0),
  ('00c056f0-09d1-42df-9d5a-9752cb899567_2', 0.16666666666666666),
  ('00c056f0-09d1-42df-9d5a-9752cb899567_3', 0.875),
  ('00c70204-6e59-4a45-a5b3-71446068fa07_1', 0.0),
  ('00c70204-6e59-4a45-a5b3-71446068fa07_2', 0.0),
  ('00c70204-6e59-4a45-a5b3-71446068fa07_3', 0.0),
  ('00cc8ca1-4076-4b51-b313-fc47e8734bf9_1', 0.9230769230769231)],
 [('0040187f-4627-4bfa-b141-89c1322fcd22_1', 0.4711079852709515),
  ('0040187f-4627-4bfa-b141-89c1322fcd22_2', 0.5504340235095683),
  ('0040187f-4627-4bfa-b141-89c1322fcd22_3', 0.323227716257078),
  ('00c056f0-09d1-42df-9d5a-9752cb899567_1', 0.0),
  ('00c056f0-09d1-42df-9d5a-9752cb899567_2', 0.025636377817312563),
  ('00c056f0-09d1-42df-9d5a-9752cb899567_3', 0.5503706742033688),
  ('00c70204-6e59-4a45-a5b3-71446068fa07_1', 0.0),
  ('00c7

In [47]:
iousums = {}
quids = list(tious.keys())
for quid in quids:
    iousums[quid] = tious[quid] + 4*stious[quid]
iousums = list(iousums.items())
iousums = [iousums[idx] for idx in np.random.permutation(len(iousums))]  # to shuffle 0-valued ious at the end
iousums = sorted(iousums, key=lambda x: x[1], reverse=True)
iousums[:10], iousums[-10:]

([('e23ed37c-b043-4932-a363-d72458699748_3', 4.725392073387891),
  ('f106a17a-f2dd-4d41-8d00-1bc6a248a90e_2', 4.651167975974567),
  ('e6107bbf-5d5d-46c5-97b2-495e7a660bb8_1', 4.64893896997782),
  ('1a6212d0-b206-421d-8cdf-1715d97d4da0_2', 4.631869892675029),
  ('5db070ab-7e8c-4843-acd4-56807e1465d9_1', 4.6228409751474535),
  ('02c623c1-a84f-4d2b-8b0f-f65cd5e5f52b_1', 4.589912356819039),
  ('e543d7f0-e18e-4ce1-9762-a641442fbb16_2', 4.576529621195824),
  ('8337f606-ea6c-47b3-9c22-b072b621f4cb_1', 4.575890423945444),
  ('38262098-51dd-4217-9e5c-b9ee478c1859_1', 4.571400565236097),
  ('665afaf7-a98c-4ab2-9d25-20107f398d44_3', 4.570170107067291)],
 [('c3dc2a53-91ea-4be9-861e-8c1013e153ef_3', 0.0),
  ('43f78585-bdbf-48e4-8cee-f747db34c0f4_3', 0.0),
  ('bd6f8d17-aa68-4838-a8da-0b7557be11d7_1', 0.0),
  ('fcd8b48f-5f65-4c54-ad3c-dad368790dc4_2', 0.0),
  ('2b21caa8-42e3-46ed-a7b6-7e47c0b03c98_1', 0.0),
  ('468f8fa3-21ee-4831-8cd3-4ee37ca51487_2', 0.0),
  ('ec969346-07e8-4bb8-b7a5-5abb7807989e_3'

In [55]:
# p_clips_dir = Path("/data/datasets/ego4d_data/v2/clips/")
p_clips_dir = Path("/data/datasets/ego4d_data/v2/vq2d_clips/")
num = 100

frames_easy = []
for i in range(num):
    ann = quid2ann[iousums[i][0]]
    clip_uid = ann['clip_uid']
    rt_ext = ann['response_track_valid_range']
    rt_mid = (rt_ext[0] + rt_ext[1]) // 2
    vr = decord.VideoReader(str(p_clips_dir / f'{clip_uid}.mp4'))
    frame = vr[rt_mid].asnumpy()
    frames_easy.append(frame)
# frames_easy = np.concat(frames_easy, axis=1)

frames_hard = []
for i in range(-1, -num, -1):
    ann = quid2ann[iousums[i][0]]
    clip_uid = ann['clip_uid']
    rt_ext = ann['response_track_valid_range']
    rt_mid = (rt_ext[0] + rt_ext[1]) // 2
    vr = decord.VideoReader(str(p_clips_dir / f'{clip_uid}.mp4'))
    frame = vr[rt_mid].asnumpy()
    frames_hard.append(frame)
# frames_hard = np.concat(frames_hard, axis=1)

In [None]:
stride = 10
for i in range(num//stride):
    display(Image.fromarray(np.concat(frames_easy[stride*i:stride*(i+1)], axis=1)))


In [None]:

for i in range(num//stride):
    display(Image.fromarray(np.concat(frames_hard[stride*i:stride*(i+1)], axis=1)))