From 805a9f3a4bff2b246024457237c0c438c0e5c849 Mon Sep 17 00:00:00 2001 From: Artem Sokolov <58517203+githubartema@users.noreply.github.com> Date: Sat, 11 May 2024 14:22:31 +0100 Subject: [PATCH] chore: Update processing_video.py --- .../languagebind/video/processing_video.py | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/videollava/model/multimodal_encoder/languagebind/video/processing_video.py b/videollava/model/multimodal_encoder/languagebind/video/processing_video.py index 2bb2921..c6557b8 100644 --- a/videollava/model/multimodal_encoder/languagebind/video/processing_video.py +++ b/videollava/model/multimodal_encoder/languagebind/video/processing_video.py @@ -1,6 +1,7 @@ import torch import cv2 +import os import decord import numpy as np from PIL import Image @@ -12,12 +13,27 @@ from torchvision.transforms import Compose, Lambda, ToTensor from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample +from torch.utils.data import DataLoader, Dataset +from .database import * decord.bridge.set_bridge('torch') OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) +class DBFramesDataset(Dataset): + def __init__(self, db_path): + self.db = PILImageDatabase(db_path) + self.keys = self.db.keys + + def __len__(self): + return len(self.keys) - 1 + + def __getitem__(self, index): + key = self.keys[index] + sample = self.db[key] + return key, sample + def make_list_of_images(x): if not isinstance(x, list): return [x] @@ -64,8 +80,20 @@ def get_video_transform(config): RandomHorizontalFlipVideo(p=0.5), ] ) + + elif config.video_decode_backend == 'lmdb': + transform = Compose( + [ + # UniformTemporalSubsample(num_frames), + Lambda(lambda x: x / 255.0), + NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), + ShortSideScale(size=224), + CenterCropVideo(224), + RandomHorizontalFlipVideo(p=0.5), + ] + ) else: - raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv)') + raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv, lmdb)') return transform @@ -86,6 +114,22 @@ def load_and_transform_video( video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec) video_outputs = transform(video_data) + if video_decode_backend == 'lmdb': + # here scene_path == video_path + db_path = os.path.join(video_path, "frames") + dataset = DBFramesDataset(db_path) + duration = len(dataset) + frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) + + video_data = [] + for indx in frame_id_list: + _, frame = dataset[indx] + frame = np.array(frame) + video_data.append(torch.from_numpy(frame).permute(2, 0, 1)) + + video_data = torch.stack(video_data, dim=1) + video_outputs = transform(video_data) + elif video_decode_backend == 'decord': decord.bridge.set_bridge('torch') decord_vr = VideoReader(video_path, ctx=cpu(0))