In [1]:
import os
import json
import pandas as pd
from moviepy.editor import *
from decord import VideoReader
import numpy as np
import math

In [2]:
downsampling_rate= 1
def verify_frame_len(video_path, frame_idx):
    if video_path.endswith(".mp4"):
        video = VideoFileClip(video_path)
    else:
        video = VideoFileClip(video_path + ".mp4")
    n_frames = video.reader.nframes
    if n_frames // (downsampling_rate * video.fps) != max(frame_idx) + 1:
        print("ERROR", video.fps, n_frames, max(frame_idx))

In [3]:
def find_consecutive_timestamps(timestamps):
    """
    Function to find consecutive timestamps in a list and record the start and end time.
    """
    if not timestamps:
        return []

    # Initialize the first start time and the result list
    start = timestamps[0]
    result = []
    
    for i in range(1, len(timestamps)):
        # Check if the current timestamp is not consecutive
        if timestamps[i] != timestamps[i-1] + 1:
            # Record the previous consecutive sequence
            result.append([start, timestamps[i-1]])
            # Update the start for the new sequence
            start = timestamps[i]

    # Add the last sequence
    result.append([start, timestamps[-1]])

    return result

# Example usage
# timestamps = [0, 1, 2, 3, 4, 8, 10, 11, 12]
# find_consecutive_timestamps(timestamps)

In [4]:
def calculate_iou(ground_truth, predictions):
    """
    Calculate the Intersection over Union (IoU) for video moment retrieval.
    
    :param ground_truth: A tuple representing the ground truth interval (start, end).
    :param predictions: A list of tuples representing predicted intervals [(start1, end1), (start2, end2), ...].
    :return: IoU score.
    """
    GT_start, GT_end = ground_truth
    total_intersection = 0
    total_union = 0

    for (P_start, P_end) in predictions:
        # Calculate intersection
        intersection = max(0, min(GT_end, P_end) - max(GT_start, P_start))
        total_intersection += intersection

        # Calculate union for this predicted interval
        union = (P_end - P_start)  - intersection
        total_union += union
    total_union += (GT_end - GT_start)
    # Avoid division by zero
    if total_union == 0:
        return 0

    # Calculate IoU
    iou = total_intersection / total_union
    return iou

In [5]:
def save_json(content, save_path):
    with open(save_path, 'w') as f:
        f.write(json.dumps(content))
def load_jsonl(filename):
    with open(filename, "r") as f:
        return [json.loads(l.strip("\n")) for l in f.readlines()]
        # return json.loads(f)

In [6]:
#set folder path
root_path = "/home/hlpark/REDUCE/REDUCE_benchmarks/SeViLA/sevila_data/tvqa"
video_root = "/home/hlpark/shared/TVQA/video/video_files"
eval_path = "/home/hlpark/REDUCE/REDUCE_benchmarks/SeViLA/sevila_data/tvqa_evaluation_json"

In [7]:
test_med_queries_path = "/home/hlpark/shared/TVQA/output_gpt_tvqa_test_queries.txt"
val_med_queries_path = "/home/hlpark/shared/TVQA/output_gpt_tvqa_val_queries.txt"
test_med_query_f_metamap_list, val_med_query_f_metamap_list = [], []
with open(test_med_queries_path, "r") as f:
    lines = f.read()
    for line in lines.split("\n"):
        if line == "\n" or line == "":
            continue
        if line.startswith(" ") or line.startswith("-"):
            line = line.lstrip("-").lstrip(" ")
        test_med_query_f_metamap_list.append(line.lower().replace(" ", "").replace("?", "").replace("\n", "").replace("-", "").replace("'", "").replace("\"", "").replace(",", "").replace(".", "").replace("/", "").replace(">", ""))
f.close()
with open(val_med_queries_path, "r") as f:
    lines = f.read()
    for line in lines.split("\n"):
        if line == "\n" or line == "":
            continue
        # if line.startswith(" ") or line.startswith("-"):
        #     line = line.lstrip("-").lstrip(" ")
        val_med_query_f_metamap_list.append(line.lower().replace(" ", "").replace("?", "").replace("\n", "").replace("-", "").replace("'", "").replace("\"", "").replace(",", "").replace(".", "").replace("/", "").replace(">", ""))
        #val_med_query_f_metamap_list.append(line.strip(" ").strip("\n").replace("  ", " "))
f.close()
print(len(test_med_query_f_metamap_list), len(val_med_query_f_metamap_list))

246 569


In [16]:
# TVQA
tvqa_val_json = []
tvqa_list = []
video_list = []
#tvqa_result_root = "/home/hlpark/REDUCE/REDUCE_benchmarks/SeViLA/sevila_result_32_tvqa_val"
tvqa_result_root = "/home/hlpark/REDUCE/REDUCE_benchmarks/SeViLA/sevila_result_full_tvqa"
for _, dir, _ in os.walk(tvqa_result_root):
    for f in dir:
        tvqa_val_json.append(os.path.join(root_path, f + "_val_gt.json"))
        video_name = f
        video_list.append(os.path.join(video_root, video_name))
        if os.path.exists(os.path.join(tvqa_result_root, video_name, "result", "test_epochbest.json")):
            tvqa_list.append(os.path.join(tvqa_result_root, video_name, "result", "test_epochbest.json"))
        elif os.path.exists(os.path.join(tvqa_result_root, video_name, "result", "val_epochbest.json")):
            tvqa_list.append(os.path.join(tvqa_result_root, video_name, "result", "val_epochbest.json"))
        elif os.path.exists(os.path.join(tvqa_result_root, video_name, "result", "train_epochbest.json")):
            tvqa_list.append(os.path.join(tvqa_result_root, video_name, "result", "train_epochbest.json"))
            print("train file")
        else: 
            tvqa_list.append("")
assert len(tvqa_list) == len(tvqa_val_json)
al_json_val = {}
fileerr = 0
for idx, val_json in enumerate(tvqa_val_json):
    if tvqa_list[idx] == '':
        continue
    try:
        tvqa = load_jsonl(tvqa_list[idx])
        val = load_jsonl(val_json)
        tvqa_video = video_list[idx]
        
        
        for i, qa in enumerate(val[0]):
            dic = {}
            dic['ground_truth'] = [float(qa['start']), float(qa['end'])]
            dic['time_span_len'] = float(qa['end']) - float(qa['start'])
            dic['vid_name'] = qa['video']
            dic['q'] = qa['question']
            if  math.isnan(dic['time_span_len']):
                print(float(qa['end']) , float(qa['start']), qa)
            
            al_json_val[qa['qid']] = dic

        max_frame_num = 0

        for i, qa in enumerate(tvqa[0]):
            max_frame_num = max(qa['frame_idx'])
            if qa['qid'] not in al_json_val:
                print("QID doesnt exist", qa['qid'])
            else:
                dic = al_json_val[qa['qid']] 
                if dic['time_span_len'] == float(np.nan):
                    print("nan")
                try:
                    pred = qa['frame_idx'][:int(np.ceil(dic['time_span_len']))]
                    pred = sorted(pred)
                    # print(pred)
                    pred_time_span = find_consecutive_timestamps(pred)
                    #print(dic['ground_truth'],pred_time_span)
                    dic['pred'] = pred_time_span
                    iou = calculate_iou(dic['ground_truth'], pred_time_span)
                    dic['iou'] = iou
                except ValueError:
                    print(dic['time_span_len'])
                    print("ValueError ", qa['qid'], qa['frame_idx'])
                    dic['pred'] = [0, 0]
                    dic['iou'] = 0
                #uncomment this only when you want to check frame length for verification purpose
                #verify_frame_len(tvqa_video,  qa['frame_idx'])
                
                
                
    except FileNotFoundError as e: 
        print("file not found ", e)
        fileerr +=1


nan nan {'video': 'castle_s07e23_seg02_clip_14', 'num_option': 5, 'qid': 'TVQA_999', 'a0': 'Beckett says Castle is her husband.', 'a1': 'Beckett says Castle is her brother.', 'a2': 'Beckett says Castle is her best friend.', 'a3': 'Beckett says Castle is her nephew.', 'a4': 'Beckett says Castle is her step-brother.', 'answer': 0, 'question': 'What relation does Beckett say she has with Castle when she discusses her partner with Collins?', 'start': 'NaN', 'end': 'NaN'}
nan
ValueError  TVQA_999 [60, 11, 12, 10, 69, 42, 43, 13, 61, 59, 44, 27, 57, 45, 86, 25, 70, 41, 5, 56, 72, 26, 55, 14, 74, 84, 68, 67, 4, 32, 49, 40, 0, 48, 46, 71, 83, 73, 64, 2, 8, 3, 9, 47, 66, 58, 78, 29, 19, 28, 82, 65, 50, 1, 85, 24, 31, 7, 39, 16, 91, 15, 77, 76, 30, 75, 88, 81, 21, 90, 89, 20, 6, 87, 51, 33, 17, 36, 38, 53, 80, 52, 79, 34, 35, 37, 63, 62, 23, 54, 18, 22]
nan
ValueError  TVQA_999 [60, 11, 12, 10, 69, 42, 43, 13, 61, 59, 44, 27, 57, 45, 86, 25, 70, 41, 5, 56, 72, 26, 55, 14, 74, 84, 68, 67, 4, 32, 

In [17]:
tvqa_val_json = []
tvqa_list = []
video_list = []
#tvqa_result_root = "/home/hlpark/REDUCE/REDUCE_benchmarks/SeViLA/sevila_result_32_tvqa_test"
tvqa_result_root = "/home/hlpark/REDUCE/REDUCE_benchmarks/SeViLA/sevila_result_full_tvqa_test"
for _, dir, _ in os.walk(tvqa_result_root):
    for f in dir:
        tvqa_val_json.append(os.path.join(root_path, f + "_test_gt.json"))
        video_name = f
        video_list.append(os.path.join(video_root, video_name))
        if os.path.exists(os.path.join(tvqa_result_root, video_name, "result", "test_epochbest.json")):
            tvqa_list.append(os.path.join(tvqa_result_root, video_name, "result", "test_epochbest.json"))
        elif os.path.exists(os.path.join(tvqa_result_root, video_name, "result", "val_epochbest.json")):
            tvqa_list.append(os.path.join(tvqa_result_root, video_name, "result", "val_epochbest.json"))
        elif os.path.exists(os.path.join(tvqa_result_root, video_name, "result", "train_epochbest.json")):
            tvqa_list.append(os.path.join(tvqa_result_root, video_name, "result", "train_epochbest.json"))
            print("train file")
        else: 
            tvqa_list.append("")
assert len(tvqa_list) == len(tvqa_val_json)
al_json_test = {}
fileerr = 0
for idx, val_json in enumerate(tvqa_val_json):
    if tvqa_list[idx] == '':
        continue
    try:
        tvqa = load_jsonl(tvqa_list[idx])
        val = load_jsonl(val_json)
        tvqa_video = video_list[idx]
        
        
        for i, qa in enumerate(val[0]):
            dic = {}
            dic['ground_truth'] = [float(qa['start']), float(qa['end'])]
            dic['time_span_len'] = float(qa['end']) - float(qa['start'])
            dic['vid_name'] = qa['video']
            dic['q'] = qa['question']
            if  math.isnan(dic['time_span_len']):
                print(float(qa['end']) , float(qa['start']), qa)
            
            al_json_test[qa['qid']] = dic

        max_frame_num = 0

        for i, qa in enumerate(tvqa[0]):
            max_frame_num = max(qa['frame_idx'])
            if qa['qid'] not in al_json_test:
                print("QID doesnt exist", qa['qid'])
            else:
                dic = al_json_test[qa['qid']] 
                if dic['time_span_len'] == float(np.nan):
                    print("nan")
                try:
                    pred = qa['frame_idx'][:int(np.ceil(dic['time_span_len']))]
                    pred = sorted(pred)
                    # print(pred)
                    pred_time_span = find_consecutive_timestamps(pred)
                    #print(dic['ground_truth'],pred_time_span)
                    dic['pred'] = pred_time_span
                    iou = calculate_iou(dic['ground_truth'], pred_time_span)
                    dic['iou'] = iou
                except ValueError:
                    print(dic['time_span_len'])
                    print("ValueError ", qa['qid'], qa['frame_idx'])
                    dic['pred'] = [0, 0]
                    dic['iou'] = 0
                #uncomment this only when you want to check frame length for verification purpose
                #verify_frame_len(tvqa_video,  qa['frame_idx'])
                
                
                
    except FileNotFoundError as e: 
        print("file not found ", e)
        fileerr +=1

nan nan {'video': 'castle_s05e08_seg02_clip_03', 'num_option': 5, 'qid': 'TVQA_178', 'a0': 'Mr. Dolan worked at the movie theatre', 'a1': "Mr. Dolan worked at McDonald's.", 'a2': 'Mr Dolan worked at Office Max. ', 'a3': 'Mr Dolan worked for The Gap.', 'a4': "The O'Reilly mobster family as the enforcer. ", 'answer': 1, 'question': 'Who did Micheal Dolan work for according to Esposito when talking with Sister Mary?', 'start': 'NaN', 'end': 'NaN'}
nan
ValueError  TVQA_178 [53, 66, 73, 43, 74, 62, 60, 75, 65, 49, 46, 48, 72, 47, 87, 64, 61, 59, 85, 86, 33, 45, 63, 79, 32, 57, 67, 77, 68, 37, 76, 42, 25, 31, 78, 70, 52, 80, 44, 56, 38, 50, 6, 58, 8, 81, 29, 9, 69, 39, 7, 22, 71, 12, 0, 23, 30, 51, 21, 5, 54, 55, 10, 24, 83, 4, 11, 2, 36, 16, 1, 34, 84, 40, 28, 19, 26, 17, 35, 82, 20, 27, 41, 18, 14, 13, 3, 15]
nan
ValueError  TVQA_178 [53, 66, 73, 43, 74, 62, 60, 75, 65, 49, 46, 48, 72, 47, 87, 64, 61, 59, 85, 86, 33, 45, 63, 79, 32, 57, 67, 77, 68, 37, 76, 42, 25, 31, 78, 70, 52, 80, 44, 5

In [18]:
train_path = f'{eval_path}/train.json'
val_path = f'{eval_path}/val.json'
test_path = f'{eval_path}/test.json'

In [19]:
train = load_jsonl(train_path)
val = load_jsonl(val_path)
test = load_jsonl(test_path)

In [20]:
val_total_cnt = 0
test_total_cnt = 0
val_dict = {}
for i, qa in enumerate(val[0]):
    val_total_cnt +=1
    val_dict[str(qa['qid'])] = 1
print(idx)
train_dict = {}
for i, qa in enumerate(train[0]):
    train_dict[str(qa['qid'])] = 1
test_dict = {}
for i, qa in enumerate(test[0]):
    test_total_cnt += 1
    test_dict[str(qa['qid'])] = 1

2177


In [21]:

vid_json_folder = "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/splits/tvqa"
clip_pred_med = load_jsonl(f'{vid_json_folder}/five_labeled_pred_med_from_gt_vid_dict.json')

In [22]:
select_option = "med_queries"
# select_option = "med_samples_clip_from_gt"
# select_option = "med_samples_clip_full_rand"
# select_option = "tv shows"

In [23]:
val_cnt, train_cnt, test_cnt = 0, 0, 0
cnt_3_thresh, cnt_5_thresh, cnt_7_thresh = 0, 0, 0
med_cnt_3_thresh, med_cnt_5_thresh, med_cnt_7_thresh = 0, 0, 0
nonmed_cnt_3_thresh, nonmed_cnt_5_thresh, nonmed_cnt_7_thresh = 0, 0, 0
med_cnt, nonmed_cnt = 0, 0
for key, val in al_json_val.items():
    if 'iou' not in val:
        continue
    #print(key)
    if key in val_dict:
        val_cnt += 1
        ismed = False
        if select_option == "tv shows":
            if "house" in val['vid_name'] or "grey" in val['vid_name']:
                ismed = True
        elif select_option == "med_samples_clip_full_rand":
            if clip_pred_med[0][val['vid_name']] == "med":
               ismed = True
        elif select_option == "med_samples_clip_from_gt":
            for qc in clip_pred_med[0][val['vid_name']]:
                #print(val)
                if val['q'] in qc and qc[val['q']] == "med":
                    ismed = True
        elif select_option == "med_queries":
            newkey = val['q'].lower().replace(" ", "").replace("?", "").replace("\n", "").replace("-", "").replace("'", "").replace("\"", "").replace(">", "").replace(",", "").replace(".", "").replace("/", "")
            if newkey  in val_med_query_f_metamap_list:
                ismed = True
        
        if ismed:
            med_cnt += 1
            if val['iou'] > 0.3:
                med_cnt_3_thresh += 1
            if val['iou'] > 0.5:
                med_cnt_5_thresh += 1
            if val['iou'] > 0.7:
                med_cnt_7_thresh += 1
        else:
            nonmed_cnt += 1
            if val['iou'] > 0.3:
                nonmed_cnt_3_thresh += 1
            if val['iou'] > 0.5:
                nonmed_cnt_5_thresh += 1
            if val['iou'] > 0.7:
                nonmed_cnt_7_thresh += 1
        if val['iou'] > 0.3:
            cnt_3_thresh += 1
        if val['iou'] > 0.5:
            cnt_5_thresh += 1
        if val['iou'] > 0.7:
            cnt_7_thresh += 1
        
if val_cnt > 0:
    # assert len(val_dict) == val_cnt
    print(f"VAL IoU=0.3: {cnt_3_thresh/(val_cnt) * 100}\nIoU=0.5: {cnt_5_thresh/(val_cnt) * 100}\nIoU=0.7: {cnt_7_thresh/(val_cnt) * 100}\ntotal queries:{(val_cnt)}/{val_total_cnt}")
    print(f"\nVAL Medical IoU=0.3: {med_cnt_3_thresh/(med_cnt) * 100}\nIoU=0.5: {med_cnt_5_thresh/(med_cnt) * 100}\nIoU=0.7: {med_cnt_7_thresh/(med_cnt) * 100}\ntotal queries:{(med_cnt)}/{val_total_cnt}")
    print(f"\nVAL Non-medical IoU=0.3: {nonmed_cnt_3_thresh/(nonmed_cnt) * 100}\nIoU=0.5: {nonmed_cnt_5_thresh/(nonmed_cnt) * 100}\nIoU=0.7: {nonmed_cnt_7_thresh/(nonmed_cnt) * 100}\ntotal queries:{(nonmed_cnt)}/{val_total_cnt}")

cnt_3_thresh, cnt_5_thresh, cnt_7_thresh = 0, 0, 0
med_cnt_3_thresh, med_cnt_5_thresh, med_cnt_7_thresh = 0, 0, 0
nonmed_cnt_3_thresh, nonmed_cnt_5_thresh, nonmed_cnt_7_thresh = 0, 0, 0
med_cnt, nonmed_cnt = 0, 0
for key, val in al_json_test.items():
    if 'iou' not in val:
        continue
    #print(key)
    if key in test_dict:
        test_cnt += 1
        ismed = False
        # if "house" in val['vid_name'] or "grey" in val['vid_name']:
        # for qc in clip_pred_med[0][val['vid_name']]:
        #         #print(val)
        #         if val['q'] in qc and qc[val['q']] == "med":
        #             ismed = True
        if select_option == "tv shows":
            if "house" in val['vid_name'] or "grey" in val['vid_name']:
                ismed = True
        elif select_option == "med_samples_clip_full_rand":
            if clip_pred_med[0][val['vid_name']] == "med":
               ismed = True
        elif select_option == "med_samples_clip_from_gt":
            for qc in clip_pred_med[0][val['vid_name']]:
                #print(val)
                if val['q'] in qc and qc[val['q']] == "med":
                    ismed = True
        elif select_option == "med_queries":
            #print(val['q'])
            newkey = val['q'].lower().replace(" ", "").replace("?", "").replace("\n", "").replace("-", "").replace("'", "").replace("\"", "").replace(">", "").replace(",", "").replace(".", "").replace("/", "")
            if newkey  in test_med_query_f_metamap_list:
                ismed = True
        if ismed:
            med_cnt += 1
            if val['iou'] > 0.3:
                med_cnt_3_thresh += 1
            if val['iou'] > 0.5:
                med_cnt_5_thresh += 1
            if val['iou'] > 0.7:
                med_cnt_7_thresh += 1
        else:
            nonmed_cnt += 1
            if val['iou'] > 0.3:
                nonmed_cnt_3_thresh += 1
            if val['iou'] > 0.5:
                nonmed_cnt_5_thresh += 1
            if val['iou'] > 0.7:
                nonmed_cnt_7_thresh += 1
        if val['iou'] > 0.3:
            cnt_3_thresh += 1
        if val['iou'] > 0.5:
            cnt_5_thresh += 1
        if val['iou'] > 0.7:
            cnt_7_thresh += 1
        
if test_cnt > 0:
    # assert len(val_dict) == val_cnt
    print(f"TEST IoU=0.3: {cnt_3_thresh/(test_cnt) * 100}\nIoU=0.5: {cnt_5_thresh/(test_cnt) * 100}\nIoU=0.7: {cnt_7_thresh/(test_cnt) * 100}\ntotal queries:{(test_cnt)}/{test_total_cnt}")
    print(f"\nTEST Medical IoU=0.3: {med_cnt_3_thresh/(med_cnt) * 100}\nIoU=0.5: {med_cnt_5_thresh/(med_cnt) * 100}\nIoU=0.7: {med_cnt_7_thresh/(med_cnt) * 100}\ntotal queries:{(med_cnt)}/{test_total_cnt}")
    print(f"\nTEST Non-medical IoU=0.3: {nonmed_cnt_3_thresh/(nonmed_cnt) * 100}\nIoU=0.5: {nonmed_cnt_5_thresh/(nonmed_cnt) * 100}\nIoU=0.7: {nonmed_cnt_7_thresh/(nonmed_cnt) * 100}\ntotal queries:{(nonmed_cnt)}/{test_total_cnt}")


VAL IoU=0.3: 18.1109243697479
IoU=0.5: 7.8924369747899155
IoU=0.7: 3.1731092436974793
total queries:14875/15253

VAL Medical IoU=0.3: 13.405797101449277
IoU=0.5: 4.891304347826087
IoU=0.7: 1.8115942028985508
total queries:552/15253

VAL Non-medical IoU=0.3: 18.29225720868533
IoU=0.5: 8.008098861970257
IoU=0.7: 3.2255812329819173
total queries:14323/15253
TEST IoU=0.3: 18.00483351235231
IoU=0.5: 8.270676691729323
IoU=0.7: 3.7593984962406015
total queries:7448/7623

TEST Medical IoU=0.3: 13.135593220338984
IoU=0.5: 4.661016949152542
IoU=0.7: 2.11864406779661
total queries:236/7623

TEST Non-medical IoU=0.3: 18.164170826400444
IoU=0.5: 8.388796450360509
IoU=0.7: 3.813089295618414
total queries:7212/7623
