In [None]:
!git clone https://github.com/xinyu1205/recognize-anything.git
!pip install timm transformers fairscale pycocoevalcap
%cd recognize-anything

In [None]:
import os
import glob
import json
import torch
import numpy as np
from PIL import Image
from ram.models import ram
from ram import get_transform
from tqdm import tqdm

# Parse data path

In [3]:
keyframes_dir = './Keyframes'
all_keyframe_paths = dict()
for part in sorted(os.listdir(keyframes_dir)):
    data_part = part.split('_')[-1] # L01, L02 for ex
    all_keyframe_paths[data_part] =  dict()

for data_part in sorted(all_keyframe_paths.keys()):
    data_part_path = f'{keyframes_dir}/{data_part}'
    video_dirs = sorted(os.listdir(data_part_path))
    video_ids = [video_dir.split('_')[-1] for video_dir in video_dirs]
    for video_id, video_dir in zip(video_ids, video_dirs):
        keyframe_paths = sorted(glob.glob(f'{data_part_path}/{video_dir}/*.jpg'))
        all_keyframe_paths[data_part][video_id] = keyframe_paths

# Download Checkpoint


In [None]:
def download_checkpoints(model):
    print('You selected', model)
    if not os.path.exists('pretrained'):
        os.makedirs('pretrained')

    if model == "RAM":
        ram_weights_path = 'pretrained/ram_swin_large_14m.pth'
        if not os.path.exists(ram_weights_path):
            !wget https://huggingface.co/spaces/xinyu1205/Recognize_Anything-Tag2Text/resolve/main/ram_swin_large_14m.pth -O pretrained/ram_swin_large_14m.pth
        else:
            print("RAM weights already downloaded!")

model = "RAM"
download_checkpoints(model)
print(model, 'weights are downloaded!')

# Helper Function

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
@torch.no_grad()
def forward_ram(model, imgs):
    image_embeds = model.image_proj(model.visual_encoder(imgs.to(device)))
    image_atts = torch.ones(
        image_embeds.size()[:-1], dtype=torch.long).to(device)
    label_embed = torch.nn.functional.relu(model.wordvec_proj(model.label_embed)).unsqueeze(0)\
        .repeat(imgs.shape[0], 1, 1)
    tagging_embed, _ = model.tagging_head(
        encoder_embeds=label_embed,
        encoder_hidden_states=image_embeds,
        encoder_attention_mask=image_atts,
        return_dict=False,
        mode='tagging',
    )
    bs = imgs.shape[0]
    logits = torch.sigmoid(model.fc(tagging_embed).squeeze(-1))
    targets = torch.where(
        logits > model.class_threshold.to(device),
        torch.tensor(1.0).to(device),
        torch.zeros(model.num_class).to(device))

    tag = targets.cpu().numpy()
    tag_outputs = []
    tag_logits = []
    for b in range(bs):
        index = np.argwhere(tag[b] == 1)
        tokens = model.tag_list[index].squeeze(axis=1)
        scores = logits[b][index[:, 0]]
        tag_outputs.append([token.replace(" ", "_") for token in tokens])
        tag_logits.append(scores.cpu().numpy())

    return tag_outputs, tag_logits

# Run inference

In [None]:
transform = get_transform(image_size=384)
model = ram(pretrained='pretrained/ram_swin_large_14m.pth',
            image_size=384,
            vit='swin_l')
model.eval()
model = model.to(device)
tag_list = model.tag_list

In [None]:
bs = 4
save_dir_all = 'context_encoded'
if not os.path.exists(save_dir_all):
    os.mkdir(save_dir_all)

save_dir = f'{save_dir_all}/tags_encoded'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

for key, video_keyframe_paths in all_keyframe_paths.items():
    video_ids = sorted(video_keyframe_paths.keys())
    
    if not os.path.exists(os.path.join(save_dir, key)):
        os.mkdir(os.path.join(save_dir, key))
    
    for video_id in tqdm(video_ids):
        tag_contexts = []
        video_keyframe_path = video_keyframe_paths[video_id]
        for i in tqdm(range(0, len(video_keyframe_path), bs)):
            # Support batchsize inferencing
            images = []
            image_paths = video_keyframe_path[i:i+bs]
            for image_path in image_paths:
                image = transform(Image.open(image_path)).unsqueeze(0)
                images.append(image)
            images = torch.cat(images).to(device)

            # Forward ram model
            tag_outputs, tag_logits = forward_ram(model, images)

            # Encode result
            for b in range(len(tag_outputs)):
                tag_context = []
                tag_output, tag_logit = tag_outputs[b], tag_logits[b]
                tag_frequency = np.round(tag_logit*10).astype(int)
                for tag, freq in zip(tag_output, tag_frequency):
                    tag_context.extend([tag]*freq)
                tag_context = ' '.join(map(str, tag_context))
                tag_contexts.append(tag_context)
        
        if len(tag_contexts) != len(video_keyframe_path):
            print("Something wrong!!!!!")
            break

        # Saving the video tag context txt
        with open(f"{save_dir}/{key}/{video_id}.txt", "w") as f:
            for item in tag_contexts:
                f.write("%s\n" % item)    

In [9]:
!rm -r /kaggle/working/recognize-anything