# LOAD IMAGE

In [2]:
from utils.utils import load_img_cache
path = "./save/keyframes/frame_0013.webp"
image = load_img_cache(path)

In [5]:
image.shape

(480, 854, 3)

# BEIT

In [61]:
from transformers import XLMRobertaTokenizer

beit_tokenizer_path = "F:\\UNIVERSITY\\Contest\\AIC\\workspace\\myworkspace\\model\\beit3_semantic.spm"
beit_model_semantic_path = "F:\\UNIVERSITY\\Contest\\AIC\\workspace\\myworkspace\\model\\beit3_large_patch16_224_nlvr2.pth"
beit_model_retrieval_path = "F:\\UNIVERSITY\\Contest\\AIC\\workspace\\myworkspace\\model\\beit3_large_patch16_384_f30k_retrieval.pth"

tokenizer = XLMRobertaTokenizer(beit_tokenizer_path)

In [64]:
import torch
from utils.beit.unilm.beit3.modeling_finetune import beit3_large_patch16_384_retrieval, beit3_large_patch16_224_nlvr2

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_semantic = torch.load(beit_model_semantic_path, map_location=device)
checkpoint_retrieval = torch.load(beit_model_retrieval_path, map_location=device)

In [None]:
model_semantic = beit3_large_patch16_224_nlvr2(pretrained=True)
model_semantic.load_state_dict(checkpoint_semantic['model'])

In [65]:
model_retrieval = beit3_large_patch16_384_retrieval(pretrained=True)
model_retrieval.load_state_dict(checkpoint_retrieval['model'])

<All keys matched successfully>

# SPLIT FFRAME

In [66]:
import cv2
import argparse
import numpy as np
import os

from utils.registry import registry

class FrameSplitter:
    def __init__(self, interval: int):
        self.writer = registry.get_writer("common")
        self.interval = interval

    def split_frames(
            self,
            source  : str,
            save_dir: str = None, 
            is_saved: bool = False
        ):
        """
            Spliting video into frame

            Parameters:
            -----------
            - source: mp4 video path
            - save_dir: directory where frames is saved
        """
        cap = cv2.VideoCapture(source)
        if is_saved:
            if save_dir==None:
                print("Please provide valid save dir")
                raise ValueError
                if not os.path.exists(save_dir):
                    print("Create save directory")
                    os.mkdir(save_dir)

        #-- Setup config
        interval_sec = 2
        fps = cap.get(cv2.CAP_PROP_FPS) # frame per second
        total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        num_skip_frames = interval_sec * fps

        #-- Split frame
        frames = []
        frame_id = 0
        saved_count = 0
        if cap.isOpened() == False:
            print('Cap is not open')

        # print(f"Start splitting - Total frames: {total_frames}")
        while(cap.isOpened()):
            ret, frame = self.cap_frame(cap, frame_id=frame_id)
            # print(f"Frame {frame_id} - Type: {type(frame)}")
            #~ Save frame
            if is_saved and frame is not None:
                save_path = os.path.join(save_dir, f'frame_{saved_count:04d}.webp')
                self.save_frame(save_path, frame)

            #~ Yield each frame
            if frame is not None: frames.append(frame)
            if not ret:
                #~~ Save last frame
                if frame_id - num_skip_frames < total_frames:
                    ret, frame = self.cap_frame(cap, frame_id=total_frames - 1)
                    if is_saved and frame is not None:
                        save_path = os.path.join(save_dir, f'frame_{saved_count:04d}.webp')
                        self.save_frame(save_path, frame)
                    saved_count += 1
                                
                break
            frame_id += num_skip_frames
            saved_count += 1

        print(f"Splitting {saved_count} frames from video")
        cap.release()
        return frames

    def cap_frame(self, cap, frame_id):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
        ret, frame = cap.read()
        return ret, frame

    def save_frame(self, save_path, frame):
        cv2.imwrite(save_path, frame, [int(cv2.IMWRITE_WEBP_QUALITY), 80])

In [53]:
splitter = FrameSplitter(interval=2)

In [52]:
source = "data/video.mp4"
frames = splitter.split_frames(source=source)

Splitting 36 frames from video


# EXTRACT FEATURES

In [7]:
import numpy as np
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import torch

In [26]:
def process_image_from_path(image_path: str, image_size=224):
    """Transform a single image."""
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
    ])
    try:
        with load_img_cache(image_path).convert('RGB') as img:
            return transform(img).unsqueeze(0).to(self.device)
    except Exception as e:
        print(f"Failed to process image {image_path}: {e}")
        return None

def process_image(img: np.ndarray, image_size=224):
    """Transform a single image."""
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
    ])
    if img.shape[0] > 3:
        img = img.transpose(2, 0, 1)
    try:
        return transform(torch.tensor(img)).to("cpu")
    except Exception as e:
        print(f"Failed to process image: {e}")
        return None
    # return transform(torch.tensor(img)).unsqueeze(0).to("cpu")
    

In [54]:
img = frames[10]
process_images = process_image(img)

## ENCODE IMAGE USING BEIT

In [39]:
import torch
from torch import nn
from utils.registry import registry
import torch.nn.functional as F
import numpy as np
from transformers import XLMRobertaTokenizer
from torchvision import transforms
from utils.beit.unilm.beit3.modeling_finetune import beit3_base_patch16_224_retrieval, beit3_large_patch16_224_nlvr2
from torchvision.transforms.functional import InterpolationMode
from utils.utils import load_img_cache
from tqdm import tqdm

class BEiTImangeEncoder:
    def __init__(self, feat_type):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.config = registry.get_config("beit")
        self.writer = registry.get_writer("common")
        self.feat_type = feat_type

        # self.build_task()

    #-- BUILD
    def build_task(self):
        beit_type_config = self.config.get(self.feat_type, None)
        if beit_type_config==None:
            print(f"Feature type {self.feat_type} unavailable")
            assert ValueError

        beit_model_path = beit_type_config["model_path"]
        beit_tokenizer_path = beit_type_config["tokenizer_path"]
        if self.feat_type=="retrieval":
            self.model = beit3_large_patch16_224_nlvr2(pretrained=True)
        elif self.feat_type=="semantic":
            self.model = beit3_large_patch16_224_nlvr2(pretrained=True)
        
        checkpoint = torch.load(beit_model_path, map_location=self.device)
        self.tokenizer = XLMRobertaTokenizer(beit_tokenizer_path)
        self.model.load_state_dict(checkpoint['model'])
        self.model.to(self.device)
        self.model.eval()

    #-- FUNCTION
    def process_image_from_path(self, image_path: str):
        """Transform a single image."""
        transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
        ])
        try:
            with load_img_cache(image_path).convert('RGB') as img:
                return transform(img).unsqueeze(0).to(self.device)
        except Exception as e:
            print(f"Failed to process image {image_path}: {e}")
            return None

    def process_image(self, img: np.ndarray, image_size=224):
        """Transform a single image."""
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
        ])
        if img.shape[0] > 3:
            img = img.transpose(2, 0, 1)
        try:
            return transform(torch.tensor(img)).to(self.device)
        except Exception as e:
            print(f"Failed to process image: {e}")
            return None
        
    #-- Encode frame
    def encode_frames(self, model, frames, batch_size=4):
        """
            Function:
            ---------
                Encode all frames in one single video shot

            Params:
            ------
                frames: List[np.ndarray] - W, H, C
                    - Frame from frame splitting modules
        """
        encoding_list = []
        with torch.no_grad():
            for start_idx in tqdm(range(0, len(frames), batch_size), desc="Processing and encoding images"):
                batch_frames = frames[start_idx:start_idx + batch_size]
                batch_tensors = []

                # Preprocess images in the batch
                for frame_id, frame in enumerate(frames):
                    try:
                        image_tensor = self.process_image(frame)
                        if image_tensor is not None:
                            batch_tensors.append(image_tensor)
                    except Exception as e:
                        print(f"Failed to process frame {start_idx + frame_id}: {e}")

                if not batch_tensors:
                    print(f"No valid images in batch {start_idx}-{start_idx + batch_size}. Skipping.")
                    continue

                # Stack tensors and move to device
                batch_images = torch.stack(batch_tensors).to(self.device)

                # Encode images
                try:
                    image_features, _ = model(image=batch_images, only_infer=True)
                    image_features /= image_features.norm(dim=-1, keepdim=True)
                    encoding_list.extend(image_features.cpu().numpy().astype(np.float32))
                except Exception as e:
                    print(f"Error during encoding batch {start_idx}-{start_idx + batch_size}: {e}")

        return encoding_list


In [None]:
beit_encoder = BEiTImangeEncoder("retrieval")
embed = beit_encoder.encode_frames(model=model_retrieval, frames=frames)

Processing and encoding images:   0%|          | 0/9 [00:28<?, ?it/s]
