In [31]:
import json
from pathlib import Path
import os

def read_image_captions(folder_path):
    image_captions = {}

    for file in Path(folder_path).glob("*.jsonl"):
        file_image_captions = read_jsonl_to_list(file)
        file_base, extension = os.path.splitext(os.path.basename(file))
        image_captions[file_base] = file_image_captions
    return image_captions

def read_jsonl_to_list(file_path):
    result = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            try:
                json_obj = json.loads(line)
                for key, value in json_obj.items():
                    result.append([key.split('.')[0], value])
            except json.JSONDecodeError:
                print(f"Skipping invalid JSON line: {line}")
    return result


image_captions = read_image_captions("data/image-caption")

In [None]:
from transformers import BertTokenizer, BertModel
import torch
import numpy as np

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")
device = torch.device("cuda")
model.to('cuda')

In [None]:
text_input = ""
inputs_sample = tokenizer(text_input, return_tensors="pt", padding=True, truncation=True)
inputs_sample = {k: v.to(device) for k, v in inputs_sample.items()}  # Move input tensors to GPU
outputs_sample = model(**inputs_sample)
embeddings_sample = outputs_sample.last_hidden_state.mean(dim=1).detach().cpu().numpy()

max_caption = None
max_score = 0

# Assuming image_captions is defined elsewhere in your code
for video in image_captions:
    for caption in image_captions[video]:
        inputs_candidate = tokenizer(caption, return_tensors="pt", padding=True, truncation=True)
        inputs_candidate = {k: v.to(device) for k, v in inputs_candidate.items()}  # Move input tensors to GPU
        outputs_candidate = model(**inputs_candidate)
        embeddings_candidate = outputs_candidate.last_hidden_state.mean(dim=1).detach().cpu().numpy()
        
        similarity = np.dot(embeddings_sample, embeddings_candidate.T) / (np.linalg.norm(embeddings_sample) * np.linalg.norm(embeddings_candidate))
        
        if similarity[0][0] > max_score:
            max_score = similarity[0][0]
            max_caption = caption
print(max_caption)
print(max_score)