In [2]:
import torch
import clip
from PIL import Image
import os
import random
from tqdm import tqdm
import numpy as np
import pandas as pd
import json
import math

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def get_frame_number(filename):
    # This extracts numbers from a filename like 'frame_00001.jpg'
    return int(filename.replace('frame_', '').replace('.jpg', ''))

In [4]:
def save_json(content, save_path):
    with open(save_path, 'w') as f:
        f.write(json.dumps(content))
def load_jsonl(filename):
    with open(filename, "r") as f:
        return [json.loads(l.strip("\n")) for l in f.readlines()]

In [5]:
extract_from_full = False

In [6]:
random.seed(10)
if extract_from_full:
    file_dict = {'image_path':[], 'video':[] }
    
    for root, dirs, files in os.walk("/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/splits/tvqa/raw_frames"):
        for dir in dirs:
            for r, d, file in os.walk(os.path.join(root, dir)):
                file_dict['image_path'].append(os.path.join(r, file[random.randrange(len(file))])) 
                file_dict['video'].append(dir)
else:
    file_dict = {'image_path':[], 'video':[], 'query':[] }
    train_path = '/home/hlpark/shared/TVQA/tvqa_qa_release/tvqa_train.jsonl'
    vid_duration_json = "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/splits/tvqa/video_duration.json"
    train = load_jsonl(train_path)
    vid_duration = load_jsonl(vid_duration_json)
    root = "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/splits/tvqa/raw_frames"
    for idx, v in enumerate(train):
        # for vid_idx, vid in enumerate(file_dict['video']):
        #     if file_dict['video'][vid_idx] == v['vid_name'] and file_dict['query'][vid_idx] == v['q']:
        #         continue
        filenames = os.listdir(os.path.join(root, v['vid_name'] +  ".mp4"))
        sorted_filenames = sorted(filenames, key=get_frame_number)
        #print(sorted_filenames)
        if math.isnan(float(v['ts'].split('-')[0])) or math.isnan(float(v['ts'].split('-')[1])):
            k = random.randint(0, len(filenames) - 1)
            # k = len(filenames) // 2
        elif int(float(v['ts'].split('-')[1])) > vid_duration[0][v['vid_name'] + ".mp4"] or int(float(v['ts'].split('-')[0])) > vid_duration[0][v['vid_name'] + ".mp4"]:
            print(f"TIMESTAMP ERROR in original json {v['vid_name']} with video length {vid_duration[0][v['vid_name'] + '.mp4']}: {v['ts']}")
            k = random.randint(0, len(filenames) - 1)
        else:
            k = random.randint(int(float(v['ts'].split('-')[0])), int(float(v['ts'].split('-')[1])))
            print(k)
            #k = (int(float(v['ts'].split('-')[0])) + int(float(v['ts'].split('-')[1]))) // 2
        print(vid_duration[0][v['vid_name'] + ".mp4"],v['vid_name'], len(sorted_filenames), k, v['ts'])
        file_dict['image_path'].append(os.path.join(root, v['vid_name'] +  '.mp4', sorted_filenames[k - 1]))
        file_dict['video'].append(v['vid_name'])
        file_dict['query'].append(v['q'])

76
91.67 grey_s03e20_seg02_clip_14 92 76 76.01-84.2
58
62.0 met_s06e05_seg02_clip_09 62 58 45.05-61.29
22
59.0 friends_s03e11_seg02_clip_06 59 22 15.09-24.37
43
55.0 s04e10_seg01_clip_03 55 43 43.42-46.3
120
152.67000000000002 house_s01e03_seg02_clip_25 153 120 114.78-135.3
14
62.0 s09e08_seg02_clip_08 62 14 11.29-15.87
4
60.0 met_s03e18_seg02_clip_09 60 4 1.18-4.73
5
78.67 met_s06e22_seg01_clip_02 79 5 1.21-8.49
61
96.67 castle_s07e13_seg02_clip_07 97 61 56.18-72.02
0
62.0 friends_s05e21_seg02_clip_03 62 0 0-5.51
19
58.0 friends_s09e03_seg02_clip_08 58 19 11.98-21.67
23
47.0 friends_s01e14_seg02_clip_22 47 23 8.28-29.69
41
94.67 grey_s02e25_seg02_clip_19 95 41 39.96-42.78
42
91.67 grey_s01e06_seg02_clip_12 92 42 41.33-49.13
6
89.67 house_s04e05_seg02_clip_10 90 6 5.8-9.82
51
56.0 friends_s06e12_seg02_clip_16 56 51 46.54-55.73
5
60.0 met_s04e03_seg02_clip_10 60 5 5.95-11.01
27
60.0 friends_s08e23-24_seg02_clip_14 60 27 21.89-31.36
23
59.0 friends_s09e19_seg02_clip_19 59 23 22.34-25.53


In [7]:
print(len(file_dict['video']))

122039


In [8]:
data = pd.DataFrame.from_dict(file_dict)

In [20]:
if not os.path.exists("/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/evaluation/analysis/TVQA"):
    os.makedirs("/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/evaluation/analysis/TVQA")

In [21]:
## Caption embeddings
# texts = ['hospital', 'office', 'home', 'school', 'outside', 'not hospital']        # FIXME fill in classes here (e.g., hospital, not hospital)
# name = "six_labels"
texts = ['hospital', 'office', 'home', 'school', 'outside']        # FIXME fill in classes here (e.g., hospital, not hospital)
name = "five_labels_train"
# texts = ['hospital', 'not hospital']        # FIXME fill in classes here (e.g., hospital, not hospital)
# name = "two_labels"

if not os.path.exists(f"./TVQA/txt_probs_{name}.pt"):
    model, preprocess = clip.load("ViT-B/32")
    model.cuda().eval()
    input_resolution = model.visual.input_resolution
    context_length = model.context_length
    vocab_size = model.vocab_size
    prob_dict = {}
    
    captions = ["This is a " + desc for desc in texts]
    print(captions)
    text_tokens = clip.tokenize(captions).cuda()
    with torch.no_grad():
        text_features = model.encode_text(text_tokens).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)

    batch_size = 1000
    images, img_feats, txt_probs = [], [], []
    for i, x in tqdm(data.iterrows(), total=data.shape[0]):     # FIXME make this iterate over your data

        path = x['image_path']        # FIXME replace with the frame from the appropriate video
        image = Image.open(path).convert("RGB")
        images.append(preprocess(image))
        
        if ((i + 1) % batch_size == 0) or (i + 1 == data.shape[0]):
            ## Image embeddings
            image_input = torch.tensor(np.stack(images)).cuda()
            with torch.no_grad():
                image_features = model.encode_image(image_input).float()
                image_features /= image_features.norm(dim=-1, keepdim=True)
                img_feats.append(image_features.cpu())

            ## Caption probabilities
            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            txt_probs.append(text_probs.cpu())
            images = []

    img_feats = torch.cat(img_feats, dim=0)
    txt_probs = torch.cat(txt_probs, dim=0)
    torch.save(img_feats, f'./TVQA/img_feats_{name}.pt')
    torch.save(txt_probs, f'./TVQA/txt_probs_{name}.pt')

['This is a hospital', 'This is a office', 'This is a home', 'This is a school', 'This is a outside']


100%|██████████| 122039/122039 [15:33<00:00, 130.70it/s]


In [9]:
# from IPython.display import Image, display
from PIL import Image
label  = "five_labels_train"
probs = torch.load(f"./TVQA/txt_probs_{label}.pt")
cnt = 0
med_vid_list, non_med_vid_list = [], []
label_dict = {}
print(probs)
for p in probs[0]:
    if np.argmax(p) == 0 and p > 0.8:
        cnt += 1

print(cnt)

root = "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/evaluation/analysis/TVQA"
if not os.path.exists(os.path.join(root, f"{label}_classification_images/")):
    os.makedirs(os.path.join(root, f"{label}_classification_images"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_mm"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_nmm"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_mnm"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_nmnm"))
max_idx = np.argmax(probs, 1)
max_idx = max_idx.tolist()
print(len(max_idx))
print(max_idx.count(0))
medical_tv_class_med, medical_tv_class_nonmed = 0, 0
nonmedical_tv_class_med, nonmedical_tv_class_nonmed = 0, 0
for i,x in enumerate(max_idx):
    img = Image.open(data.iloc[i]['image_path'])
    #print(data.iloc[i])
    if data.iloc[i]['video'] not in label_dict:
        label_dict[data.iloc[i]['video']] = []
    if x == 0:
        label_dict[data.iloc[i]['video']].append({data.iloc[i]['query'] : "med"})
        if data.iloc[i]['video'] not in med_vid_list:
            med_vid_list.append(data.iloc[i]['video'])
        if "house" in data.iloc[i]['video'] or "grey" in data.iloc[i]['video']:
            medical_tv_class_med += 1
            #img.save(f"{root}/{label}_classification_images/{label}_mm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')
        else:
            nonmedical_tv_class_med += 1
            #img.save(f"{root}/{label}_classification_images/{label}_nmm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')

    else:
        label_dict[data.iloc[i]['video']].append({data.iloc[i]['query'] : "nonmed"})
        if data.iloc[i]['video'] not in non_med_vid_list:
            non_med_vid_list.append(data.iloc[i]['video'])
        if "house" in data.iloc[i]['video'] or "grey" in data.iloc[i]['video']:
            medical_tv_class_nonmed += 1
            #img.save(f"{root}/{label}_classification_images/{label}_mnm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')

        else:
            nonmedical_tv_class_nonmed += 1
            #img.save(f"{root}/{label}_classification_images/{label}_nmnm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')


print(medical_tv_class_med, medical_tv_class_nonmed, nonmedical_tv_class_med, nonmedical_tv_class_nonmed)
print("medical video numbers: ", len(med_vid_list), "non-medical video numbers: ", len(non_med_vid_list))
#save_json(label_dict, "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/splits/medvidqa/five_labeled_pred_med_from_gt_vid_dict.json")

tensor([[9.8451e-01, 5.6557e-03, 7.0822e-03, 1.8338e-03, 9.1915e-04],
        [9.3945e-02, 6.8580e-02, 2.1284e-01, 1.1338e-02, 6.1330e-01],
        [1.0884e-02, 3.6090e-01, 4.8646e-01, 6.8264e-02, 7.3499e-02],
        ...,
        [2.9553e-02, 1.4153e-01, 1.9083e-02, 7.5248e-01, 5.7357e-02],
        [1.5525e-01, 7.3239e-01, 9.3422e-02, 6.3217e-03, 1.2626e-02],
        [1.2508e-02, 1.9955e-01, 6.2415e-01, 9.4049e-03, 1.5439e-01]])
1
122039
23097
18798 15206 4299 83736
medical video numbers:  6758 non-medical video numbers:  16602


In [10]:
print(len(label_dict))

17435


In [23]:
save_json(label_dict, "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/splits/tvqa/five_labeled_pred_med_train_from_gt_vid_dict.json")

In [31]:
# from IPython.display import Image, display
from PIL import Image
probs = torch.load("./txt_probs_two_labels.pt")
# cnt = 0
# for i, x in data.iterrows():
#     if probs[i][0] > 0.80:
#         #print(x['image_path'], probs[i])
#         cnt += 1
#         img = Image.open(x['image_path'])
#         img.save("image_08_two_labels/" + x['video'] + ".jpg", 'JPEG')
#         # if "house" not in x['image_path'] and "grey" not in x['image_path']:
#         #     display(Image(filename=x['image_path']))
#         # display(Image(filename=x['image_path']))
    

# print(cnt)
max_idx = np.argmax(probs, 1)
max_idx = max_idx.tolist()
print(len(max_idx))
print(max_idx.count(0))
medical_tv_class_med, medical_tv_class_nonmed = 0, 0
nonmedical_tv_class_med, nonmedical_tv_class_nonmed = 0, 0
for i,x in enumerate(max_idx):
    img = Image.open(data.iloc[i]['image_path'])
    if x == 0:
        if "house" in data.iloc[i]['video'] or "grey" in data.iloc[i]['video']:
            medical_tv_class_med += 1
            img.save("two_label_classification_images/two_label_mm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')
        else:
            nonmedical_tv_class_med += 1
            img.save("two_label_classification_images/two_label_nmm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')

    else:
        if "house" in data.iloc[i]['video'] or "grey" in data.iloc[i]['video']:
            medical_tv_class_nonmed += 1
            img.save("two_label_classification_images/two_label_mnm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')
        else:
            nonmedical_tv_class_nonmed += 1
            img.save("two_label_classification_images/two_label_nmnm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')

print(medical_tv_class_med, medical_tv_class_nonmed, nonmedical_tv_class_med, nonmedical_tv_class_nonmed)


3268
1564
565 287 999 1417


In [52]:
# from IPython.display import Image, display
from PIL import Image
probs = torch.load("./txt_probs_five_labels.pt")
five_label_dict = {}
cnt = 0
for i, x in data.iterrows():
    # if np.argmax(probs[i]) == 0 and "house" in x['image_path']:
    #     print(x['image_path'], probs[i])
    if probs[i][0] > 0.50:
        #print(x['image_path'], probs[i])
        cnt += 1
        img = Image.open(x['image_path'])

        img.save("image_05_five_labels/" + x['video'] + ".jpg", 'JPEG')
        # if "house" not in x['image_path'] and "grey" not in x['image_path']:
        #     display(Image(filename=x['image_path']))
        # display(Image(filename=x['image_path']))
print(cnt)
max_idx = np.argmax(probs, 1)
max_idx = max_idx.tolist()
print(len(max_idx))
print(max_idx.count(0))
medical_tv_class_med, medical_tv_class_nonmed = 0, 0
nonmedical_tv_class_med, nonmedical_tv_class_nonmed = 0, 0
for i,x in enumerate(max_idx):
    img = Image.open(data.iloc[i]['image_path'])
    if x == 0:
        five_label_dict[data.iloc[i]['video']] = "med"
        if "house" in data.iloc[i]['video'] or "grey" in data.iloc[i]['video']:
            medical_tv_class_med += 1
            img.save("five_label_classification_images/five_label_mm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')
        else:
            nonmedical_tv_class_med += 1
            img.save("five_label_classification_images/five_label_nmm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')

    else:
        five_label_dict[data.iloc[i]['video']] = "nonmed"
        if "house" in data.iloc[i]['video'] or "grey" in data.iloc[i]['video']:
            medical_tv_class_nonmed += 1
            img.save("five_label_classification_images/five_label_mnm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')

        else:
            nonmedical_tv_class_nonmed += 1
            img.save("five_label_classification_images/five_label_nmnm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')


print(medical_tv_class_med, medical_tv_class_nonmed, nonmedical_tv_class_med, nonmedical_tv_class_nonmed)
save_json(five_label_dict, "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/splits/tvqa/five_labeled_pred_med_vid_dict.json")

459
3268
590
473 379 117 2299


In [11]:
# from IPython.display import Image, display
from PIL import Image
label  = "no_batch_five_labels"
#probs = torch.load(f"./txt_probs_{label}.pt")
cnt = 0
label_dict = {}
for i, x in data.iterrows():
    if "grey_s01e01_seg02_clip_09" in x['image_path'] or "house_s05e16_seg02_clip_02" in x['image_path']:
        print(txt_probs[i])

#     # if np.argmax(probs[i]) == 0 and "house" in x['image_path']:
#     #     print(x['image_path'], probs[i])
#     if probs[i][0] > 0.80:
#         #print(x['image_path'], probs[i])
#         cnt += 1
#         img = Image.open(x['image_path'])

#         img.save("image_08_five_labels/" + x['video'] + ".jpg", 'JPEG')
#         # if "house" not in x['image_path'] and "grey" not in x['image_path']:
#         #     display(Image(filename=x['image_path']))
#         # display(Image(filename=x['image_path']))
# print(cnt)
root = "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/evaluation/analysis"
if not os.path.exists(os.path.join(root, f"{label}_classification_images/")):
    os.makedirs(os.path.join(root, f"{label}_classification_images"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_mm"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_nmm"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_mnm"))
    os.makedirs(os.path.join(root, f"{label}_classification_images/{label}_nmnm"))
max_idx = np.argmax(probs, 1)
max_idx = max_idx.tolist()
print(len(max_idx))
print(max_idx.count(0))
medical_tv_class_med, medical_tv_class_nonmed = 0, 0
nonmedical_tv_class_med, nonmedical_tv_class_nonmed = 0, 0
for i,x in enumerate(max_idx):
    img = Image.open(data.iloc[i]['image_path'])
    if data.iloc[i]['video'] not in label_dict:
        label_dict[data.iloc[i]['video']] = []
    if x == 0:
        label_dict[data.iloc[i]['video']].append({data.iloc[i]['query'] : "med"})
        if "house" in data.iloc[i]['video'] or "grey" in data.iloc[i]['video']:
            medical_tv_class_med += 1
            img.save(f"{label}_classification_images/{label}_mm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')
        else:
            nonmedical_tv_class_med += 1
            img.save(f"{label}_classification_images/{label}_nmm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')

    else:
        label_dict[data.iloc[i]['video']].append({data.iloc[i]['query'] : "nonmed"})
        if "house" in data.iloc[i]['video'] or "grey" in data.iloc[i]['video']:
            medical_tv_class_nonmed += 1
            img.save(f"{label}_classification_images/{label}_mnm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')

        else:
            nonmedical_tv_class_nonmed += 1
            img.save(f"{label}_classification_images/{label}_nmnm/" + data.iloc[i]['video'] + ".jpg", 'JPEG')


print(medical_tv_class_med, medical_tv_class_nonmed, nonmedical_tv_class_med, nonmedical_tv_class_nonmed)
save_json(label_dict, "/home/hlpark/REDUCE/REDUCE_benchmarks/HiREST/data/splits/tvqa/five_labeled_pred_med_from_gt_vid_dict.json")

tensor([0.9595, 0.0175, 0.0130, 0.0028, 0.0071], grad_fn=<SelectBackward0>)
tensor([0.4154, 0.5116, 0.0274, 0.0141, 0.0314], grad_fn=<SelectBackward0>)
