In [None]:
import os
import numpy as np
import torch
import hydra
from open_clip.tokenizer import HFTokenizer
import av

from videoechoclip.model import create_model_and_transforms
from videoechoclip.utils import random_seed, pt_load

In [None]:
with hydra.initialize(config_path="config", version_base=None):
    args = hydra.compose(config_name="config")

In [None]:
args.device = "cuda:0"
device = torch.device(args.device)

random_seed(args.seed, rank=0)

In [None]:
# Create model & preprocess functions
model, preprocess_train, preprocess_val = create_model_and_transforms(args)

In [None]:
# Load a checkpoint
checkpoint_path = "./weights/checkpoint.pt"
checkpoint = pt_load(checkpoint_path, map_location="cpu")

sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith("module"):
    sd = {k[len("module.") :]: v for k, v in sd.items()}

model.load_state_dict(sd)

model.eval()

In [None]:
# Create tokenizer
tokenizer = HFTokenizer(args.model.text.hf_tokenizer_name, context_length=args.model.text.context_length)

In [None]:
def video_preprocess(video_path, transforms, num_frames):
    num_framesx2 = num_frames * 2  # NOTE stride 2

    container = av.open(os.path.abspath(os.path.expanduser(video_path)), mode="r")

    frames = []
    for frame in container.decode(video=0):
        rgb_frame = frame.to_rgb()
        arr = rgb_frame.to_ndarray()  # (H, W, 3)
        # arr = cv2.resize(arr, (224, 224))
        frames.append(arr)

        if len(frames) == num_framesx2:
            break

    container.close()

    if num_frames is not None and len(frames) < num_framesx2:
        # if video is too short, pad last frame
        rgb_frames = np.stack(frames + [frames[-1] for _ in range(num_framesx2 - len(frames))], axis=0)[::2]  # (N, H, W, 3) # NOTE stride 2
    else:
        rgb_frames =  np.stack(frames, axis=0)[::2]  # (N, H, W, 3) # NOTE stride 2

    video_tensor = transforms(list(rgb_frames), return_tensors="pt")["pixel_values"][0]  # (N, 3, H', W')

    return video_tensor.unsqueeze(0).to(device)  # (1, 3, N, H', W')

In [None]:
def text_preprocess(text_path, tokenizer):
    with open(os.path.abspath(os.path.expanduser(text_path)), "r") as f:
        report_text = f.read().strip()

    text_tensor = tokenizer([report_text])  # (1, L)

    return text_tensor.to(device)

In [None]:
with torch.inference_mode():
    image1 = video_preprocess("./example1.mp4", preprocess_val, num_frames=args.model.vision.num_frames)  # (1, 3, N, H', W')
    text1 = text_preprocess("./example1.txt", tokenizer)  # (1, L)

    image2 = video_preprocess("./example2.mp4", preprocess_val, num_frames=args.model.vision.num_frames)  # (1, 3, N, H', W')
    text2 = text_preprocess("./example2.txt", tokenizer)  # (1,

    image_feature1 = model.encode_image(image1, normalize=True)  # (1, 512)
    text_feature1 = model.encode_text(text1, normalize=True)  # (1, 512)

    image_feature2 = model.encode_image(image2, normalize=True)  # (1, 512)
    text_feature2 = model.encode_text(text2, normalize=True)  # (1, 512)

image_features = torch.cat([image_feature1, image_feature2])  # (2, 512)
text_features = torch.cat([text_feature1, text_feature2])  # (2, 512)

similarity = image_features @ text_features.T  # (2, 2)
print("similarity score:\n", np.round(similarity.detach().cpu().numpy().squeeze(), decimals=2))