In [13]:
import json
from pathlib import Path

p_cap_json = Path('/data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu/captions/test/step1/0b20e242-a496-4662-a3e7-645bcecdbe55/cbdc37c7-820a-5bb3-a597-53ca31a13a6f.json')
with p_cap_json.open() as f:
    cap_json = json.load(f)
query = cap_json['query']
prompts = cap_json['prompts']
for caps in cap_json['captions']:
    s, e = caps['interval']['start_sec'], caps['interval']['end_sec']
    print(f'{s}s ~ {e}s')
    for prompt, cap in zip(prompts, caps['captions']):
        prompt = prompt.format(query=query)
        print(f'{prompt:40s}  {cap}')
    print()

FileNotFoundError: [Errno 2] No such file or directory: '/data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu/captions/test/step1/0b20e242-a496-4662-a3e7-645bcecdbe55/cbdc37c7-820a-5bb3-a597-53ca31a13a6f.json'

In [11]:
import re
import json
import math
from pathlib import Path
import numpy as np
%cd /data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu
from intervaltree import IntervalTree, Interval
from IPython.display import HTML


PATTERNS_NO = [
    r'^no(?=[,\.!\s])',
    # r' cannot ',
    # r' sorry',
    # r' unable ',
    # r' not[^\w]',
]

def can_answer(caption: str, no_is_positive: bool = False):
    if not no_is_positive:
        if re.search(r'yes[^\w]', caption, flags=re.IGNORECASE):
            return True
        return False

    else:
        for pattern in PATTERNS_NO:
            if re.search(pattern, caption, flags=re.IGNORECASE):
                return True
        return False

def merge_if_close(t: IntervalTree, stride_sec=3., threshold=1, eps=1e-12):
    precision = -int(math.log10(eps))
    threshold_sec = stride_sec * threshold
    t = IntervalTree([Interval(iv.begin, iv.end+threshold_sec+eps) for iv in sorted(t)])
    t.merge_overlaps()
    t = IntervalTree([Interval(iv.begin, round(iv.end-threshold_sec, precision-1)) for iv in sorted(t)])  # suppress eps
    t_sliced = t.copy()
    for sec in np.arange(t.begin(), t.end(), stride_sec):
        t_sliced.slice(sec)
    return t_sliced, t

def get_length(t: None|IntervalTree):
    if t is None:
        return 0
    return sum([iv.end - iv.begin for iv in t])

def stylize(s: str, bold=False, italic=False, underline=False, color=None):
    style = ''
    if bold:
        style += 'font-weight: bold;'
    if italic:
        style += 'font-style: italic;'
    if underline:
        style += 'text-decoration: underline;'
    if color:
        style += f'color: {color};'
    return f'<span style="{style}">{s}</span>'

def get_caption_scores(p_cap_json, stride_sec=3.):
    with p_cap_json.open() as f:
        cap_json = json.load(f)

    t = IntervalTree()
    for caps in cap_json['captions']:
        s, e = caps['interval']['start_sec'], caps['interval']['end_sec']
        last_caption = caps['captions'][-1]
        if can_answer(last_caption, no_is_positive=p_cap_json.parent.parent.parent.stem in ['20240308v4']):
            t |= IntervalTree.from_tuples([(s, e+1e-12)])
    duration_sec = e

    t_pred, t_pred_merged = merge_if_close(t, stride_sec=stride_sec, threshold=0)

    s_sec, e_sec = cap_json['gt']['start_sec'], cap_json['gt']['end_sec']
    s_sec = s_sec // stride_sec * stride_sec
    e_sec = e_sec // stride_sec * stride_sec + stride_sec
    t_gt = IntervalTree.from_tuples([(s_sec, e_sec)])
    tp = t_pred.overlap(s_sec, e_sec)  # slots having overlap with gt
    recall = round(get_length(tp) / (e_sec - s_sec + 1e-7), 6)
    precision = round(get_length(tp) / (get_length(t_pred) + 1e-7), 6)
    TP = get_length(tp)
    PRED = get_length(t_pred)

    query = cap_json['query']
    html_text = f'<a href="vscode-remote://{p_cap_json}">{p_cap_json.stem.split("-")[0]}</a>'
    html_text += ' | ' + ' / '.join([
        stylize(f'{get_length(t_gt):3.0f}', color='#00FF00'),
        stylize(f'{TP:3.0f}', color='cyan'),
        stylize(f'{PRED:4.0f}', color='orange'),
        f'{duration_sec:4.0f}',])
    msg = f'{p_cap_json.parent.stem.split("-")[0]} | {html_text} | {recall:6.1%} {precision:6.1%} | {PRED/duration_sec:6.1%} | {query}'
    return recall, precision, msg

m = None
p_caps_dirs = sorted(Path('/data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu/captions/test/step1').glob('*'))
p_caps_dirs = [sorted(list(p.glob('*')), key=lambda p: (p.parent.stem, float(p.stem.replace('s', '')))) for p in p_caps_dirs]
p_caps_dirs = sum(p_caps_dirs, [])
for i, p_caps_dir in enumerate(p_caps_dirs):
    stride_sec = float(p_caps_dir.stem.replace('s', ''))
    recalls, precisions, msgs = [], [], []
    for p_cap_json in p_caps_dir.glob('**/*.json'):
        recall, precision, msg = get_caption_scores(p_cap_json, stride_sec=stride_sec)
        recalls.append(recall)
        precisions.append(precision)
        msgs.append(msg)
    msgs = [f'clip uid |   q_uid  |  GT /  TP / Pred /  Dur | Recall   Prec |Pred/Dur| Query'] + msgs
    print(f'{p_caps_dir.parent.stem} {float(p_caps_dir.stem):4.1f}s, Recall: {np.mean(recalls):.3%}, Precision: {np.mean(precisions):.3%}')

display(HTML('<pre style="font-family: Consolas;">' + '<br/>'.join(msgs) + '</pre>'))

/data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu
20240304v1  3.0s, Recall: 58.940%, Precision: 2.526%
20240304v1  6.0s, Recall: 65.317%, Precision: 3.075%
20240304v1 12.0s, Recall: 57.897%, Precision: 3.712%
20240307v1  3.0s, Recall: 75.368%, Precision: 2.024%
20240307v1  6.0s, Recall: 74.321%, Precision: 2.470%
20240307v1 12.0s, Recall: 72.484%, Precision: 3.514%
20240308v1  3.0s, Recall: 90.895%, Precision: 2.022%
20240308v1  6.0s, Recall: 91.514%, Precision: 2.595%
20240308v2  3.0s, Recall: 68.057%, Precision: 1.991%
20240308v4  3.0s, Recall: 27.610%, Precision: 1.138%


In [35]:
import re
import json
from pathlib import Path
import numpy as np
%cd /data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu
from intervaltree import IntervalTree
from IPython.display import HTML

p_caps_dir = Path('/data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu/captions/test/step1/20240304v1/3.0s')
# p_caps_dir = Path('/data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu/captions/test/step1/20240304v1/6.0s')
# p_caps_dir = Path('/data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu/captions/test/step1/20240304v1/12.0s')
# p_caps_dir = Path('/data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu/captions/test/step1/20240307v1/3.0s')

# PATTERNS_YES = [
#     r'^yes(?=,)?(?=\.)?(?=!)?(?=\s)?',
# ]

PATTERNS_NO = [
    r'^no(?=[,\.!\s])',
    r' cannot ',
    r' sorry',
    r' unable ',
    # r' not[^\w]',
]

def get_caption_scores(p_cap_json):
    with p_cap_json.open() as f:
        cap_json = json.load(f)

    t = IntervalTree()
    for caps in cap_json['captions']:
        s, e = caps['interval']['start_sec'], caps['interval']['end_sec']
        last_caption = caps['captions'][-1]
        no = False
        for pattern in PATTERNS_NO:
            if re.search(pattern, last_caption, flags=re.IGNORECASE):
                no = True
                break
        yes = not no
        t[s:e] = yes
        # has_yes = bool(re.findall(r'^yes(?=,)', last_caption, flags=re.IGNORECASE))
        # has_no = bool(re.findall(r'^no(?=,)', last_caption, flags=re.IGNORECASE))
        # assert has_yes ^ has_no, f'{s:5.1f}s ~ {e:5.1f}s {last_caption}'
    s, e = cap_json['gt']['start_sec'], cap_json['gt']['end_sec']
    t_gt = t[s:e]
    TP = sum([iv.data for iv in t_gt])
    FP = sum([iv.data for iv in t - t_gt])
    FN = sum([not iv.data for iv in t_gt])
    recall = round(TP / (TP + FN + 1e-7), 6)
    precision = round(TP / (TP + FP + 1e-7), 6)
    query = cap_json['query']
    html_text = f'<a href="vscode-remote://{p_cap_json}">{p_cap_json.stem.split("-")[0]}</a>'
    msg = f'{p_cap_json.parent.stem.split("-")[0]} | {html_text} | {TP:3d}/{FP:3d}/{FN:3d} | {recall:6.1%} {precision:6.1%} | {query}'
    return recall, precision, msg

recalls, precisions, msgs = [], [], []
for p_cap_json in p_caps_dir.glob('**/*.json'):
    recall, precision, msg = get_caption_scores(p_cap_json)
    recalls.append(recall)
    precisions.append(precision)
    msgs.append(msg)
print(f'Recall: {np.mean(recalls):.3%}, Precision: {np.mean(precisions):.3%}')
display(HTML('<pre style="font-family: Consolas;">' + '<br/>'.join(msgs) + '</pre>'))

/data/gunsbrother/prjs/ltvu/llms/Video-LLaVA/ltvu
Recall: 74.727%, Precision: 2.161%
