In [None]:
import json
import sys
import os
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from transformers import XLMRobertaTokenizer
from torchvision.transforms.functional import InterpolationMode
from extra.unilm.beit3.utils import load_model_and_may_interpolate
from extra.unilm.beit3.modeling_finetune import beit3_large_patch16_384_retrieval

In [None]:
%cd D:/AIC2024/dataset

#### **Parsing Data Path**

In [None]:
def parse_data_path(feature_dir='./keyframe'):
    all_feature_paths = dict()
    for feature_part in sorted(os.listdir(feature_dir)):
        all_feature_paths[feature_part] = dict()
    for feature_part in sorted(all_feature_paths.keys()):
        feature_part_path = f'{feature_dir}/{feature_part}'
        feature_paths = sorted(os.listdir(feature_part_path))
        feature_ids = [feature_path.split('.')[0] for feature_path in feature_paths]
        for feature_id, feature_path in zip(feature_ids, feature_paths):
            feature_path_full = f'{feature_part_path}/{feature_path}'
            all_feature_paths[feature_part][feature_id] = feature_path_full
    return all_feature_paths

In [None]:
all_video_paths = parse_data_path(feature_dir='./distilled_keyframe')

#### **BEiT3**

In [None]:
%cd D:/AIC2024

In [None]:
model_weight_path = './dict/beit/weights/beit3_large_itc_patch16_224.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
beit3_model = beit3_large_patch16_384_retrieval(pretrained=False)
load_model_and_may_interpolate(model_weight_path, beit3_model, model_key='model', model_prefix='')
beit3_model.to(device)
beit3_model.eval()

In [None]:
def encode_images(image_paths, batch_size, image_size=384):
    id2image_fps = {}
    video_features, images = [], []
    for id, image_path in enumerate(image_paths):
        id2image_fps[id] = image_path
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size),
                            interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
        ])
        raw_image = Image.open(image_path).convert('RGB')
        image = transform(raw_image).unsqueeze(0).to(device)
        images.append(image)

    images = torch.cat(images, dim=0).to(device)
    with torch.no_grad():
        for start_index in range(0, images.shape[0], batch_size):
            image_features, _ = beit3_model(image=images[start_index:start_index+batch_size], only_infer=True)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            for index in range(image_features.shape[0]):
                video_features.append(image_features[index].cpu().numpy().astype(np.float32).flatten())
    return id2image_fps, video_features

#### **Inference**

In [None]:
def sorted_by_id(keyframe_paths):
    id_path_keyframes = []
    for keyframe_path in keyframe_paths:
        keyframe_filename = keyframe_path.split('/')[-1]
        keyframe_id = int(keyframe_filename.split('.')[0])
        id_path_keyframes.append((keyframe_id, keyframe_path))
    sorted_id_path_keyframes = sorted(id_path_keyframes, key=lambda id_path: id_path[0])
    return [id_path[1] for id_path in sorted_id_path_keyframes]

In [None]:
id2image_save_dir='./beit/large/id2image'
feature_save_dir="./beit/large/features"
if not os.path.exists(id2image_save_dir):
    os.makedirs(id2image_save_dir)
if not os.path.exists(feature_save_dir):
    os.makedirs(feature_save_dir)

In [None]:
batch_size = 32
for video_part, video_path_dict in all_video_paths.items():
    video_ids = video_path_dict.keys()
    for video_id in tqdm(video_ids, desc=f'Encoding Part {video_part}'):
        video_id_path = video_path_dict[video_id]
        keyframe_image_paths = [video_id_path + '/' + keyframe_image_path for keyframe_image_path in os.listdir(video_id_path)]
        sorted_keyframe_image_paths = sorted_by_id(keyframe_image_paths)
        id2image_fps, video_features = encode_images(sorted_keyframe_image_paths, batch_size)

        os.makedirs(f'{feature_save_dir}/{video_part}', exist_ok=True)
        np.save(f'{feature_save_dir}/{video_part}/{video_id}.npy', video_features)

        os.makedirs(f'{id2image_save_dir}/{video_part}', exist_ok=True)
        with open(f'{id2image_save_dir}/{video_part}/{video_id}.json', 'w') as f:
            json.dump(id2image_fps, f, ensure_ascii=False, indent=4)