In [1]:
import os
import json
import pandas as pd
from moviepy.editor import *
from decord import VideoReader
import numpy as np
import math

In [20]:
downsampling_rate= 1
def verify_frame_len(video_path, frame_idx):
    if video_path.endswith(".mp4"):
        video = VideoFileClip(video_path)
    else:
        video = VideoFileClip(video_path + ".mp4")
    n_frames = video.reader.nframes
    if n_frames // (downsampling_rate * video.fps) != max(frame_idx) + 1:
        print("ERROR", video.fps, n_frames, max(frame_idx))

In [21]:
def find_consecutive_timestamps(timestamps):
    """
    Function to find consecutive timestamps in a list and record the start and end time.
    """
    if not timestamps:
        return []

    # Initialize the first start time and the result list
    start = timestamps[0]
    result = []
    
    for i in range(1, len(timestamps)):
        # Check if the current timestamp is not consecutive
        if timestamps[i] != timestamps[i-1] + 1:
            # Record the previous consecutive sequence
            result.append([start, timestamps[i-1]])
            # Update the start for the new sequence
            start = timestamps[i]

    # Add the last sequence
    result.append([start, timestamps[-1]])

    return result

# Example usage
# timestamps = [0, 1, 2, 3, 4, 8, 10, 11, 12]
# find_consecutive_timestamps(timestamps)

In [22]:
def calculate_iou(ground_truth, predictions):
    """
    Calculate the Intersection over Union (IoU) for video moment retrieval.
    
    :param ground_truth: A tuple representing the ground truth interval (start, end).
    :param predictions: A list of tuples representing predicted intervals [(start1, end1), (start2, end2), ...].
    :return: IoU score.
    """
    GT_start, GT_end = ground_truth
    total_intersection = 0
    total_union = 0

    for (P_start, P_end) in predictions:
        # Calculate intersection
        intersection = max(0, min(GT_end, P_end) - max(GT_start, P_start))
        total_intersection += intersection

        # Calculate union for this predicted interval
        union = (P_end - P_start)  - intersection
        total_union += union
    total_union += (GT_end - GT_start)
    # Avoid division by zero
    if total_union == 0:
        return 0

    # Calculate IoU
    iou = total_intersection / total_union
    return iou

In [23]:
def save_json(content, save_path):
    with open(save_path, 'w') as f:
        f.write(json.dumps(content))
def load_jsonl(filename):
    with open(filename, "r") as f:
        return [json.loads(l.strip("\n")) for l in f.readlines()]
        # return json.loads(f)

In [24]:
#set folder path
root_path = "/home/hlpark/shared/REDUCE_benchmarks/SeViLA/sevila_data/tvqa"
video_root = "/home/hlpark/shared/TVQA/video/video_files"
tvqa_result_root = "/home/hlpark/shared/REDUCE_benchmarks/SeViLA/sevila_result"
eval_path = "/home/hlpark/shared/REDUCE_benchmarks/SeViLA/sevila_data/tvqa_evaluation_json"

In [25]:
# TVQA
tvqa_val_json = []
tvqa_list = []
video_list = []

for _, dir, _ in os.walk(tvqa_result_root):
    for f in dir:
        tvqa_val_json.append(os.path.join(root_path, f + "_val.json"))
        video_name = f
        video_list.append(os.path.join(video_root, video_name))
        if os.path.exists(os.path.join(tvqa_result_root, video_name, "result", "test_epochbest.json")):
            tvqa_list.append(os.path.join(tvqa_result_root, video_name, "result", "test_epochbest.json"))
        elif os.path.exists(os.path.join(tvqa_result_root, video_name, "result", "val_epochbest.json")):
            tvqa_list.append(os.path.join(tvqa_result_root, video_name, "result", "val_epochbest.json"))
        elif os.path.exists(os.path.join(tvqa_result_root, video_name, "result", "train_epochbest.json")):
            tvqa_list.append(os.path.join(tvqa_result_root, video_name, "result", "train_epochbest.json"))
            print("train file")
        else: 
            tvqa_list.append("")
assert len(tvqa_list) == len(tvqa_val_json)
al_json = {}
fileerr = 0
for idx, val_json in enumerate(tvqa_val_json):
    if tvqa_list[idx] == '':
        continue
    try:
        tvqa = load_jsonl(tvqa_list[idx])
        val = load_jsonl(val_json)
        tvqa_video = video_list[idx]
        
        
        for i, qa in enumerate(val[0]):
            dic = {}
            dic['ground_truth'] = [float(qa['start']), float(qa['end'])]
            dic['time_span_len'] = float(qa['end']) - float(qa['start'])
            al_json[qa['qid']] = dic

        max_frame_num = 0

        for i, qa in enumerate(tvqa[0]):
            max_frame_num = max(qa['frame_idx'])
            if qa['qid'] not in al_json:
                print("QID doesnt exist", qa['qid'])
            else:
                dic = al_json[qa['qid']] 
                pred = qa['frame_idx'][:int(np.ceil(dic['time_span_len']))]
                pred = sorted(pred)
                # print(pred)
                pred_time_span = find_consecutive_timestamps(pred)
                #print(dic['ground_truth'],pred_time_span)
                dic['pred'] = pred_time_span
                iou = calculate_iou(dic['ground_truth'], pred_time_span)
                dic['iou'] = iou
                verify_frame_len(tvqa_video,  qa['frame_idx'])
                #print(iou)
                
                
    except FileNotFoundError as e: 
        print("file not found ", e)
        fileerr +=1


In [26]:
train_path = f'{eval_path}/train.json'
val_path = f'{eval_path}/val.json'
test_path = f'{eval_path}/test.json'

In [27]:
train = load_jsonl(train_path)
val = load_jsonl(val_path)
test = load_jsonl(test_path)

In [28]:
val_total_cnt = 0
val_dict = {}
for i, qa in enumerate(val[0]):
    val_total_cnt +=1
    val_dict[str(qa['qid'])] = 1
print(idx)
train_dict = {}
for i, qa in enumerate(train[0]):
    train_dict[str(qa['qid'])] = 1
test_dict = {}
for i, qa in enumerate(test[0]):
    test_dict[str(qa['qid'])] = 1

4357


In [29]:
val_cnt, train_cnt, test_cnt = 0, 0, 0
cnt_3_thresh, cnt_5_thresh, cnt_7_thresh = 0, 0, 0
for key, val in al_json.items():
    if 'iou' not in val:
        continue
    #print(key)
    if key in val_dict:
        val_cnt += 1
        if val['iou'] > 0.3:
            cnt_3_thresh += 1
        if val['iou'] > 0.5:
            cnt_5_thresh += 1
        if val['iou'] > 0.7:
            cnt_7_thresh += 1
        
if val_cnt > 0:
    # assert len(val_dict) == val_cnt
    print(f"IoU=0.3: {cnt_3_thresh/(val_cnt) * 100}\nIoU=0.5: {cnt_5_thresh/(val_cnt) * 100}\nIoU=0.7: {cnt_7_thresh/(val_cnt) * 100}\ntotal queries:{(val_cnt)}/{val_total_cnt}")


IoU=0.3: 12.613156306578151
IoU=0.5: 6.430630993093274
IoU=0.7: 3.051029303292429
total queries:14913/15253
