#### **Import Libraries**

In [None]:
import os
import json
import torch
import numpy as np
from PIL import Image
from ram.models import ram
from ram import get_transform
from tqdm import tqdm

#### **Parsing Data Path**

In [None]:
def parse_data_path(feature_dir='./distilled_keyframe'):
    all_feature_paths = dict()
    for feature_part in sorted(os.listdir(feature_dir)):
        all_feature_paths[feature_part] = dict()
    for feature_part in sorted(all_feature_paths.keys()):
        feature_part_path = f'{feature_dir}/{feature_part}'
        feature_paths = sorted(os.listdir(feature_part_path))
        feature_ids = [feature_path.split('.')[0] for feature_path in feature_paths]
        for feature_id, feature_path in zip(feature_ids, feature_paths):
            feature_path_full = f'{feature_part_path}/{feature_path}'
            all_feature_paths[feature_part][feature_id] = feature_path_full
    return all_feature_paths

In [None]:
%cd D:/AIC2024/extra/recognize-anything

#### **Downdload Model**

In [None]:
def download_checkpoints(model):
    print('You selected', model)
    if not os.path.exists('pretrained'):
        os.makedirs('pretrained')

    if model == "RAM":
        ram_weights_path = 'pretrained/ram_swin_large_14m.pth'
        if not os.path.exists(ram_weights_path):
            !wget https://huggingface.co/spaces/xinyu1205/Recognize_Anything-Tag2Text/resolve/main/ram_swin_large_14m.pth -O pretrained/ram_swin_large_14m.pth
        else:
            print("RAM weights already downloaded!")

model = "RAM"
download_checkpoints(model)
print(model, 'weights are downloaded!')

#### **Function Definition**

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
@torch.no_grad()
def forward_ram(model, imgs):
    image_embeds = model.image_proj(model.visual_encoder(imgs.to(device)))
    image_atts = torch.ones(
        image_embeds.size()[:-1], dtype=torch.long).to(device)
    label_embed = torch.nn.functional.relu(model.wordvec_proj(model.label_embed)).unsqueeze(0)\
        .repeat(imgs.shape[0], 1, 1)
    tagging_embed, _ = model.tagging_head(
        encoder_embeds=label_embed,
        encoder_hidden_states=image_embeds,
        encoder_attention_mask=image_atts,
        return_dict=False,
        mode='tagging',
    )
    bs = imgs.shape[0]
    logits = torch.sigmoid(model.fc(tagging_embed).squeeze(-1))
    targets = torch.where(
        logits > model.class_threshold.to(device),
        torch.tensor(1.0).to(device),
        torch.zeros(model.num_class).to(device))

    tag = targets.cpu().numpy()
    tag_outputs = []
    tag_logits = []
    for b in range(bs):
        index = np.argwhere(tag[b] == 1)
        tokens = model.tag_list[index].squeeze(axis=1)
        scores = logits[b][index[:, 0]]
        tag_outputs.append([token.replace(" ", "_") for token in tokens])
        tag_logits.append(scores.cpu().numpy())

    return tag_outputs, tag_logits

In [None]:
transform = get_transform(image_size=384)
model = ram(pretrained='pretrained/ram_swin_large_14m.pth',
            image_size=384,
            vit='swin_l')
model.eval()
model = model.to(device)
tag_list = model.tag_list

#### **Run Inference**

In [None]:
def load_images(image_paths, transform):
    images = [transform(Image.open(image_path)).unsqueeze(0) for image_path in image_paths]
    return images

def encode_tags(model, images):
    tag_outputs, tag_logits = forward_ram(model, images)
    tag_contexts = []

    for index in range(len(tag_outputs)):
        tag_context = []
        tag_output, tag_logit = tag_outputs[index], tag_logits[index]
        tag_frequency = np.round(tag_logit * 10).astype(int)

        for tag, freq in zip(tag_output, tag_frequency):
            tag_context.extend([tag] * freq)

        tag_context = ' '.join(map(str, tag_context))
        tag_contexts.append(tag_context)

    return tag_contexts

def write_json_file(json_data, file_path):
    with open(file_path, 'w') as f:
        json.dump(json_data, f, ensure_ascii=False, indent=4)

In [None]:
%cd D:/AIC2024/dataset

In [None]:
keyframe_dir='./distilled_keyframe'
all_keyframe_paths = parse_data_path(feature_dir=keyframe_dir)

In [None]:
def sorted_by_id(keyframe_paths):
    id_path_keyframes = []
    for keyframe_path in keyframe_paths:
        keyframe_filename = keyframe_path.split('/')[-1]
        keyframe_id = int(keyframe_filename.split('.')[0])
        id_path_keyframes.append((keyframe_id, keyframe_path))
    sorted_id_path_keyframes = sorted(id_path_keyframes, key=lambda id_path: id_path[0])
    return [id_path[1] for id_path in sorted_id_path_keyframes]

In [None]:
batch_size = 4
save_dir = './filter/tag/features'
os.makedirs(save_dir, exist_ok=True)

for video_part, video_path_dict in all_keyframe_paths.items():
    video_ids = video_path_dict.keys()
    full_save_dir = os.path.join(save_dir, video_part)
    os.makedirs(full_save_dir, exist_ok=True)

    for video_id in tqdm(video_ids, desc=f'Encoding Part {video_part}'):
        video_id_metadata_records = {}
        video_id_path = video_path_dict[video_id]
        keyframe_image_paths = [os.path.join(video_id_path, keyframe_image_path) for keyframe_image_path in os.listdir(video_id_path)]
        keyframe_image_paths = sorted_by_id(keyframe_image_paths)
        for i in range(0, len(keyframe_image_paths), batch_size):
            image_paths = keyframe_image_paths[i:i+batch_size]
            images = load_images(image_paths, transform)
            images = torch.cat(images).to(device)
            tag_contexts = encode_tags(model, images)
            for image_path, tag_context in zip(image_paths, tag_contexts):
                video_id_metadata_records[image_path] = {
                    'tag': tag_context
                }
        write_json_file(video_id_metadata_records, os.path.join(full_save_dir, f'{video_id}.json'))