In [39]:
import torch
import clip
from PIL import Image
import os
import random
from tqdm import tqdm
import numpy as np
import pandas as pd
import json
import math

In [40]:
def get_frame_number(filename):
    # This extracts numbers from a filename like 'frame_00001.jpg'
    return int(filename.replace('frame_', '').replace('.jpg', ''))

In [41]:
def save_json(content, save_path):
    with open(save_path, 'w') as f:
        f.write(json.dumps(content))
def load_jsonl(filename):
    with open(filename, "r") as f:
        return [json.loads(l.strip("\n")) for l in f.readlines()]

In [25]:
extract_from_full = True

In [34]:
random.seed(10)
if extract_from_full:
    file_dict = {'image_path':[], 'video':[] }
    
    for root, dirs, files in os.walk("/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/medvidqa/raw_frames"):
        for dir in dirs:
            for r, d, file in os.walk(os.path.join(root, dir)):
                file_dict['image_path'].append(os.path.join(r, file[random.randrange(len(file))])) 
                file_dict['video'].append(dir)
else:
    unmatching_frames_cnt, unavailable_cnt = 0, 0
    list_of_bad_videos = []
    vid_list = []
    file_dict = {'image_path':[], 'video':[], 'query':[] }
    train_path = '/home/hlpark/shared/MedVidQA/train.json'
    val_path = '/home/hlpark/shared/MedVidQA/val.json'
    test_path = '/home/hlpark/shared/MedVidQA/test.json'
    val = json.load(open(val_path))
    test = json.load(open(test_path))
    train = json.load(open(train_path))
    root = "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/medvidqa/raw_frames"
    for idx, v in enumerate(train):
        #some videos were deleted from youtube, missing features
        if not os.path.exists(os.path.join(root, v['video_id'] +  ".mp4")):
            unavailable_cnt += 1
            continue
        filenames = os.listdir(os.path.join(root, v['video_id'] +  ".mp4"))
        sorted_filenames = sorted(filenames, key=get_frame_number)
        #print(sorted_filenames)
        if math.isnan(v['answer_start_second']) or math.isnan(v['answer_end_second']):
            print("nan detected")
            k = random.randint(0, len(filenames) - 1)
        elif int(v['answer_end_second']) > v['video_length'] or int(v['answer_start_second']) > v['video_length']:
            print(f"TIMESTAMP ERROR in original json {v['video_id']} with video length {v['video_length']}: {v['answer_start_second'], v['answer_end_second']}")
            k = random.randint(0, len(filenames) - 1)
        else:
            k = random.randint(int(v['answer_start_second']), int(v['answer_end_second']))
        if abs(v['video_length'] - len(filenames)) > 2:
            unmatching_frames_cnt += 1
            if v['video_id'] not in list_of_bad_videos:
                list_of_bad_videos.append(v['video_id'])
            print(v['video_id'], k, v['video_length'], len(filenames))
        else:
            if v['video_id'] not in vid_list:
                vid_list.append(v['video_id'])
            file_dict['image_path'].append(os.path.join(root, v['video_id'] +  '.mp4', sorted_filenames[k]))
            file_dict['video'].append(v['video_id'])
            file_dict['query'].append(v['question'])
    print(unmatching_frames_cnt, unavailable_cnt, len(train), len(test), len(val), len(file_dict['image_path']))
    print(list_of_bad_videos)
    print(f"total number of videos used : {len(vid_list)}")

SztsZNp-jDM 24 97 322325
OKXoHwkx55c 78 98 90
V9j5JkWGwI8 66 306 303
V9j5JkWGwI8 205 306 303
ehl2MPczYoQ 58 92 285458
ehl2MPczYoQ 76 92 285458
mMNloo140pU 267 477 474
QwhD5UTUW60 125 364 430963
QwhD5UTUW60 193 364 430963
QwhD5UTUW60 298 364 430963
Ehan_VI7p4c 258 723 696
Ehan_VI7p4c 346 723 696
Ehan_VI7p4c 354 723 696
Ehan_VI7p4c 411 723 696
Ehan_VI7p4c 465 723 696
Ehan_VI7p4c 511 723 696
Ehan_VI7p4c 587 723 696
Ehan_VI7p4c 595 723 696
Ehan_VI7p4c 658 723 696
19 66 2710 155 145 2625
['SztsZNp-jDM', 'OKXoHwkx55c', 'V9j5JkWGwI8', 'ehl2MPczYoQ', 'mMNloo140pU', 'QwhD5UTUW60', 'Ehan_VI7p4c']
total number of videos used : 762


In [35]:
data = pd.DataFrame.from_dict(file_dict)

In [36]:
if not os.path.exists("/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/evaluation/analysis/MedVidQA"):
    os.makedirs("/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/evaluation/analysis/MedVidQA")

In [37]:
## Caption embeddings
# texts = ['hospital', 'office', 'home', 'school', 'outside', 'not hospital']        # FIXME fill in classes here (e.g., hospital, not hospital)
# name = "six_labels"
texts = ['hospital', 'studio', 'home', 'school', 'outside']        # FIXME fill in classes here (e.g., hospital, not hospital)
name = "five_labels_train"
# texts = ['hospital', 'not hospital']        # FIXME fill in classes here (e.g., hospital, not hospital)
# name = "two_labels"

if not os.path.exists(f"./MedVidQA/txt_probs_{name}.pt"):

    model, preprocess = clip.load("ViT-B/32")
    model.cuda().eval()
    input_resolution = model.visual.input_resolution
    context_length = model.context_length
    vocab_size = model.vocab_size
    prob_dict = {}

    captions = ["This is a " + desc for desc in texts]
    print(captions)
    text_tokens = clip.tokenize(captions).cuda()
    with torch.no_grad():
        text_features = model.encode_text(text_tokens).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)

    batch_size = 500
    images, img_feats, txt_probs = [], [], []
    for i, x in tqdm(data.iterrows(), total=data.shape[0]):     # FIXME make this iterate over your data

        path = x['image_path']        # FIXME replace with the frame from the appropriate video
        image = Image.open(path).convert("RGB")
        images.append(preprocess(image))
        
        if ((i + 1) % batch_size == 0) or (i + 1 == data.shape[0]):
            ## Image embeddings
            image_input = torch.tensor(np.stack(images)).cuda()
            with torch.no_grad():
                image_features = model.encode_image(image_input).float()
                image_features /= image_features.norm(dim=-1, keepdim=True)
                img_feats.append(image_features.cpu())

            ## Caption probabilities
            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            txt_probs.append(text_probs.cpu())
            images = []

    img_feats = torch.cat(img_feats, dim=0)
    txt_probs = torch.cat(txt_probs, dim=0)
    torch.save(img_feats, f'MedVidQA/img_feats_{name}.pt')
    torch.save(txt_probs, f'MedVidQA/txt_probs_{name}.pt')

['This is a hospital', 'This is a studio', 'This is a home', 'This is a school', 'This is a outside']


100%|██████████| 2625/2625 [00:34<00:00, 75.29it/s] 


In [38]:
# from IPython.display import Image, display
from PIL import Image
label  = "five_labels_train"
probs = torch.load(f"./MedVidQA/txt_probs_{label}.pt")
cnt = 0
med_vid_list, non_med_vid_list = [], []
label_dict = {}
print(probs)
for p in probs[0]:
    if np.argmax(p) == 0 and p > 0.8:
        cnt += 1

print(cnt)

root = "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/evaluation/analysis/MedVidQA"
if not os.path.exists(os.path.join(root, f"{label}_classification_images/")):
    os.makedirs(os.path.join(root, f"{label}_classification_images"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_mm"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_nmm"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_mnm"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_nmnm"))
max_idx = np.argmax(probs, 1)
max_idx = max_idx.tolist()
print(len(max_idx))
print(max_idx.count(0))
medical_tv_class_med, medical_tv_class_nonmed = 0, 0
nonmedical_tv_class_med, nonmedical_tv_class_nonmed = 0, 0
for i,x in enumerate(max_idx):
    img = Image.open(data.iloc[i]['image_path'])
    if data.iloc[i]['video'] not in label_dict:
        label_dict[data.iloc[i]['video']] = []
    if x == 0:
        label_dict[data.iloc[i]['video']].append({data.iloc[i]['query'] : "med"})
        if data.iloc[i]['video'] not in med_vid_list:
            med_vid_list.append(data.iloc[i]['video'])
        if "house" in data.iloc[i]['video'] or "grey" in data.iloc[i]['video']:
            medical_tv_class_med += 1
            img.save(f"{root}/{label}_classification_images/{label}_mm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')
        else:
            nonmedical_tv_class_med += 1
            img.save(f"{root}/{label}_classification_images/{label}_nmm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')

    else:
        label_dict[data.iloc[i]['video']].append({data.iloc[i]['query'] : "nonmed"})
        if data.iloc[i]['video'] not in non_med_vid_list:
            non_med_vid_list.append(data.iloc[i]['video'])
        if "house" in data.iloc[i]['video'] or "grey" in data.iloc[i]['video']:
            medical_tv_class_nonmed += 1
            img.save(f"{root}/{label}_classification_images/{label}_mnm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')

        else:
            nonmedical_tv_class_nonmed += 1
            img.save(f"{root}/{label}_classification_images/{label}_nmnm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')


print(medical_tv_class_med, medical_tv_class_nonmed, nonmedical_tv_class_med, nonmedical_tv_class_nonmed)
print("medical video numbers: ", len(med_vid_list), "non-medical video numbers: ", len(non_med_vid_list))
save_json(label_dict, "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/splits/medvidqa/five_labeled_pred_med_from_gt_vid_dict.json")

tensor([[0.8294, 0.0583, 0.0635, 0.0153, 0.0336],
        [0.2339, 0.4441, 0.0530, 0.0247, 0.2442],
        [0.2683, 0.5058, 0.0616, 0.0579, 0.1064],
        ...,
        [0.3533, 0.3429, 0.1499, 0.0216, 0.1322],
        [0.7830, 0.1075, 0.0420, 0.0262, 0.0413],
        [0.6810, 0.1176, 0.0598, 0.0695, 0.0721]])
1
2625
1198
0 0 1198 1427
medical video numbers:  526 non-medical video numbers:  449
