From a6fe7cd3c93c7892e8203702508cd33121e583c2 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Wed, 31 Jan 2024 14:58:17 -0800 Subject: [PATCH 01/21] patched --- modal_inference.py | 181 ++++++++++++++++++ stub.py | 96 ++++++++++ .../mpt/hf_prefixlm_converter.py | 76 ++++---- .../languagebind/audio/modeling_audio.py | 7 +- .../languagebind/depth/modeling_depth.py | 7 +- .../languagebind/image/modeling_image.py | 7 +- .../languagebind/thermal/modeling_thermal.py | 7 +- .../languagebind/video/modeling_video.py | 7 +- 8 files changed, 336 insertions(+), 52 deletions(-) create mode 100644 modal_inference.py create mode 100644 stub.py diff --git a/modal_inference.py b/modal_inference.py new file mode 100644 index 0000000..e7222e7 --- /dev/null +++ b/modal_inference.py @@ -0,0 +1,181 @@ +from modal import ( + Cls, + method, + enter, +) +from pathlib import Path +from .stub import stub, cls_dec, MODELS_DIR, volume, HF_DATASETS_CACHE, REPO_HOME + + +LOCAL_VIDEOS_PATH = Path(REPO_HOME) / "downloaded_videos" +DEFAULT_PROMPT = "describe what is going on in this video" + + +def load_pretrained_from_cache(load_4bit=True, load_8bit=False): + from videollava.utils import disable_torch_init + from transformers import AutoTokenizer, BitsAndBytesConfig + from videollava.model import LlavaLlamaForCausalLM + import torch + disable_torch_init() + + kwargs = { + "device_map": "auto", + "cache_dir": HF_DATASETS_CACHE, + } + video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' + + vlp_exists = video_llava_path.exists() + if not vlp_exists: + video_llava_path.mkdir(exist_ok=True, parents=True) + + save = False + if not video_llava_path.exists() or len(list(video_llava_path.iterdir())) == 0: + save = True + print("Downloading model") + video_llava_path = 'LanguageBind/Video-LLaVA-7B' + + tokenizer_path = Path(MODELS_DIR) / 'tokenizer' + if not tokenizer_path.exists() or len(list(tokenizer_path.iterdir())) == 0: + print("Downloading tokenizer") + tokenizer_path = 'LanguageBind/Video-LLaVA-7B' + + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.float16 + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False, cache_dir=kwargs["cache_dir"]) + model = LlavaLlamaForCausalLM.from_pretrained(video_llava_path, low_cpu_mem_usage=True, **kwargs) + model.generation_config.do_sample = True + + if save: + # save to on-disk paths + video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' + tokenizer_path = Path(MODELS_DIR) / 'tokenizer' + tokenizer.save_pretrained(str(tokenizer_path)) + model.save_pretrained(str(video_llava_path)) + return model, tokenizer + +def prepare_processor(model, device="cuda"): + import torch + processor = {'image': None, 'video': None} + if model.config.mm_image_tower is not None: + image_tower = model.get_image_tower() + if not image_tower.is_loaded: + image_tower.load_model() + image_tower.to(device=device, dtype=torch.float16) + image_processor = image_tower.image_processor + processor['image'] = image_processor + if model.config.mm_video_tower is not None: + video_tower = model.get_video_tower() + if not video_tower.is_loaded: + video_tower.load_model() + video_tower.to(device=device, dtype=torch.float16) + video_processor = video_tower.video_processor + processor['video'] = video_processor + return processor + +def prepare_special_tokens(model, tokenizer): + from videollava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \ + DEFAULT_VIDEO_PATCH_TOKEN, DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN + mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) + mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) + if mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + tokenizer.add_tokens([DEFAULT_VIDEO_PATCH_TOKEN], special_tokens=True) + if mm_use_im_start_end: + tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens([DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN], special_tokens=True) + model.resize_token_embeddings(len(tokenizer)) + +@cls_dec(container_idle_timeout=30, gpu='L4') +class VideoLlava: + # TODO when they fix + #@build() + @enter() + def load_model(self): + self.model, self.tokenizer = load_pretrained_from_cache() + self.processor = prepare_processor(self.model) + self.video_processor = self.processor['video'] + + def prepare_conv(self): + from videollava.conversation import conv_templates + self.conv = conv_templates["llava_v1"].copy() + self.roles = self.conv.roles + + @method() + def inference(self, video_path: str, inp: str = DEFAULT_PROMPT): + import torch + from videollava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN + from videollava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria + from videollava.conversation import SeparatorStyle + import requests + from io import BytesIO + + self.prepare_conv() + if video_path.startswith("http"): + print("Downloading video") + video_bytes = requests.get(video_path).content + local_video_path = LOCAL_VIDEOS_PATH / video_path.split("/")[-1] + if not LOCAL_VIDEOS_PATH.exists(): + LOCAL_VIDEOS_PATH.mkdir(exist_ok=True, parents=True) + with open(local_video_path, "wb") as f: + f.write(video_bytes) + video_path = BytesIO(video_bytes) + print(f"Downloaded video and saved to {local_video_path}") + elif not Path(video_path).exists(): + volume.reload() + if not Path(video_path).exists(): + raise FileNotFoundError(f"Video {video_path} not found") + video_tensor = self.video_processor(video_path, return_tensors='pt')['pixel_values'] + if type(video_tensor) is list: + tensor = [video.to(self.model.device, dtype=torch.float16) for video in video_tensor] + else: + tensor = video_tensor.to(self.model.device, dtype=torch.float16) + + inp = ' '.join([DEFAULT_IMAGE_TOKEN] * model.get_video_tower().config.num_frames) + '\n' + inp + self.conv.append_message(self.conv.roles[0], inp) + self.conv.append_message(self.conv.roles[1], None) + prompt = self.conv.get_prompt() + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX['VIDEO'], return_tensors='pt').unsqueeze(0).cuda() + stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) + + import time + begin = time.time() + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, + images=tensor, + do_sample=True, + temperature=0.9, + max_new_tokens=1024, + use_cache=True, + stopping_criteria=[stopping_criteria]) + end = time.time() + print(f"Generate time taken: {end-begin}") + + output = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() + print(output) + return output + +@stub.local_entrypoint() +def main(): + prompt = "describe what is going on in this video" + video_path = '/volume/pika_water_city.mp4' + output = VideoLlava().inference.remote(video_path, prompt) + print(output) + +if __name__ == "__main__": + video_llava = Cls.lookup("video-llava", 'VideoLlava')() + video_path = '/volume/pika_water_city.mp4' + print(video_llava.inference.remote(video_path=video_path, inp="describe what is going on in this video")) diff --git a/stub.py b/stub.py new file mode 100644 index 0000000..a1bf8be --- /dev/null +++ b/stub.py @@ -0,0 +1,96 @@ +from modal import Volume, Image, Stub, Mount, Secret +from pathlib import Path +REPO_HOME = "/app" +VOLUME_DIR = "/volume" +MODELS_DIR = "/root" +HF_DATASETS_CACHE = str(Path(VOLUME_DIR) / "hf_datasets_cache") +mounts = [Mount.from_local_dir("./ai_video_editor/updated_video_llava", remote_path=REPO_HOME)] +volume = Volume.persisted("video-llava-vol") +volumes = {VOLUME_DIR: volume} +stub = Stub("updated-video-llava", mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv()]) + +image = ( + Image.from_registry( + "nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04", add_python="3.10" + ) + .apt_install( + "git", + "curl", + "libgl1-mesa-glx", + "libglib2.0-0", + "libsm6", + "libxrender1", + "libxext6", + "ffmpeg", + "clang", + "libopenmpi-dev", + gpu="any", + ) + + .pip_install( + "torch==2.1.2", + "transformers==4.37.2", + "bitsandbytes==0.42.0", + gpu="any", + ) + .run_commands( + "python -m bitsandbytes", + gpu="A10G" + ) + .pip_install( + "torchvision>=0.15.2", + #"tokenizers>=0.12.1,<0.14", + "sentencepiece==0.1.99", + "shortuuid", + "accelerate==0.21.0", + "peft==0.4.0", + "pydantic<2,>=1", + "markdown2[all]", + "numpy", + "scikit-learn==1.2.2", + "gradio==3.37.0", + "gradio_client==0.7.0", + "requests", + "httpx==0.24.0", + "uvicorn", + "fastapi", + #"einops==0.6.1", + "einops-exts==0.0.4", + "timm==0.6.13", + "deepspeed==0.9.5", + "ninja", + "wandb", + "tensorboardX==2.6.2.2", + + "tenacity", + "torch==2.1.2", + "wheel", + gpu="any", + ) + .run_commands("pip install flash-attn --no-build-isolation", gpu="any") + .env({"PYTHONPATH": REPO_HOME, "HF_DATASETS_CACHE": HF_DATASETS_CACHE}) + .pip_install( + "decord", + "opencv-python", + # TODO try removing or upgrading this to a version + "git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d", + gpu="any", + ) +) + +def function_dec(**extras): + return stub.function( + image=image, + timeout=80000, + checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. + _allow_background_volume_commits=True, + **extras, + ) + +def cls_dec(**extras): + return stub.cls( + image=image, + timeout=80000, + checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. + **extras, + ) diff --git a/videollava/model/language_model/mpt/hf_prefixlm_converter.py b/videollava/model/language_model/mpt/hf_prefixlm_converter.py index 8c1a648..0441b07 100644 --- a/videollava/model/language_model/mpt/hf_prefixlm_converter.py +++ b/videollava/model/language_model/mpt/hf_prefixlm_converter.py @@ -12,16 +12,16 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss -from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom -from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom +#from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom +#from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom from transformers.models.bloom.modeling_bloom import logging from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM from transformers.models.gptj.modeling_gptj import GPTJForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM -from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt -from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt +#from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt +#from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt logger = logging.get_logger(__name__) _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM) CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM] @@ -123,19 +123,19 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa assert isinstance(model, BloomForCausalLM) assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models' - def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor: - combined_attention_mask = None - device = attention_mask.device - (_, src_length) = input_shape - if src_length > 1: - combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length) - if bidirectional_mask is not None: - assert attention_mask.shape == bidirectional_mask.shape - expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length) - combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask) - expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length) - combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - return combined_attention_mask + # def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor: + # combined_attention_mask = None + # device = attention_mask.device + # (_, src_length) = input_shape + # if src_length > 1: + # combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length) + # if bidirectional_mask is not None: + # assert attention_mask.shape == bidirectional_mask.shape + # expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length) + # combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask) + # expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length) + # combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + # return combined_attention_mask def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: num_heads = self.config.n_head @@ -194,7 +194,9 @@ def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_k else: attention_mask = attention_mask.to(hidden_states.device) alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device) - causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length) + # causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length) + causal_mask = _prepare_4d_causal_attention_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), inpust_embeds=inputs_embeds, past_key_values_length=past_key_values_length) + causal_mask = causal_mask.bool() for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: hst = (hidden_states,) @@ -225,7 +227,7 @@ def custom_forward(*inputs): if not return_dict: return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)) return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions) - setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer)) + # setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer)) setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer)) setattr(model.transformer, 'forward', MethodType(forward, model.transformer)) KeyValueT = Tuple[torch.Tensor, torch.Tensor] @@ -282,23 +284,23 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM setattr(model, '_original_generate', getattr(model, 'generate')) model.model.decoder.bidirectional_mask = None - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - combined_attention_mask = None - if input_shape[-1] > 1: - if self.bidirectional_mask == 'g': - (bsz, src_length) = input_shape - combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device) - else: - combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device) - if self.bidirectional_mask is not None: - assert attention_mask.shape == self.bidirectional_mask.shape - expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask) - if attention_mask is not None: - expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - return combined_attention_mask - setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder)) + # def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # combined_attention_mask = None + # if input_shape[-1] > 1: + # if self.bidirectional_mask == 'g': + # (bsz, src_length) = input_shape + # combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device) + # else: + # combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device) + # if self.bidirectional_mask is not None: + # assert attention_mask.shape == self.bidirectional_mask.shape + # expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) + # combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask) + # if attention_mask is not None: + # expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) + # combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + # return combined_attention_mask + setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_4d_causal_attention_mask, model.model.decoder)) def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None): @@ -412,4 +414,4 @@ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]): elif 'labels' in batch and 'attention_mask' in batch: batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask']) else: - raise KeyError('No bidirectional_mask in batch and not sure how to construct one.') \ No newline at end of file + raise KeyError('No bidirectional_mask in batch and not sure how to construct one.') diff --git a/videollava/model/multimodal_encoder/languagebind/audio/modeling_audio.py b/videollava/model/multimodal_encoder/languagebind/audio/modeling_audio.py index 908ab43..c45eac0 100644 --- a/videollava/model/multimodal_encoder/languagebind/audio/modeling_audio.py +++ b/videollava/model/multimodal_encoder/languagebind/audio/modeling_audio.py @@ -9,7 +9,8 @@ from transformers import PreTrainedModel, add_start_docstrings from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ - CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput, clip_loss +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_audio import LanguageBindAudioConfig, CLIPVisionConfig, CLIPTextConfig @@ -499,7 +500,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1027,4 +1028,4 @@ def forward( image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, - ) \ No newline at end of file + ) diff --git a/videollava/model/multimodal_encoder/languagebind/depth/modeling_depth.py b/videollava/model/multimodal_encoder/languagebind/depth/modeling_depth.py index 849eade..ac304af 100644 --- a/videollava/model/multimodal_encoder/languagebind/depth/modeling_depth.py +++ b/videollava/model/multimodal_encoder/languagebind/depth/modeling_depth.py @@ -9,7 +9,8 @@ from transformers import PreTrainedModel, add_start_docstrings from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ - CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput, clip_loss +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_depth import LanguageBindDepthConfig, CLIPVisionConfig, CLIPTextConfig @@ -499,7 +500,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1027,4 +1028,4 @@ def forward( image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, - ) \ No newline at end of file + ) diff --git a/videollava/model/multimodal_encoder/languagebind/image/modeling_image.py b/videollava/model/multimodal_encoder/languagebind/image/modeling_image.py index e95ac47..2d4ef0e 100644 --- a/videollava/model/multimodal_encoder/languagebind/image/modeling_image.py +++ b/videollava/model/multimodal_encoder/languagebind/image/modeling_image.py @@ -9,7 +9,8 @@ from transformers import PreTrainedModel, add_start_docstrings from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ - CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput, clip_loss +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_image import LanguageBindImageConfig, CLIPVisionConfig, CLIPTextConfig @@ -499,7 +500,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1027,4 +1028,4 @@ def forward( image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, - ) \ No newline at end of file + ) diff --git a/videollava/model/multimodal_encoder/languagebind/thermal/modeling_thermal.py b/videollava/model/multimodal_encoder/languagebind/thermal/modeling_thermal.py index f0323b3..60272e6 100644 --- a/videollava/model/multimodal_encoder/languagebind/thermal/modeling_thermal.py +++ b/videollava/model/multimodal_encoder/languagebind/thermal/modeling_thermal.py @@ -9,7 +9,8 @@ from transformers import PreTrainedModel, add_start_docstrings from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ - CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput, clip_loss +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_thermal import LanguageBindThermalConfig, CLIPVisionConfig, CLIPTextConfig @@ -499,7 +500,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1027,4 +1028,4 @@ def forward( image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, - ) \ No newline at end of file + ) diff --git a/videollava/model/multimodal_encoder/languagebind/video/modeling_video.py b/videollava/model/multimodal_encoder/languagebind/video/modeling_video.py index cb5c621..d7d13d5 100644 --- a/videollava/model/multimodal_encoder/languagebind/video/modeling_video.py +++ b/videollava/model/multimodal_encoder/languagebind/video/modeling_video.py @@ -9,7 +9,8 @@ from transformers import PreTrainedModel, add_start_docstrings from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ - CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput, clip_loss +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_video import LanguageBindVideoConfig, CLIPVisionConfig, CLIPTextConfig @@ -499,7 +500,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1030,4 +1031,4 @@ def forward( image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, - ) \ No newline at end of file + ) From 18a01998c537dee4a338b8445c26600a30f047cb Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Wed, 31 Jan 2024 15:21:51 -0800 Subject: [PATCH 02/21] fix autoconfig gen and typo --- .gitignore | 1 + modal_inference.py | 4 ++-- videollava/model/language_model/llava_llama.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bee8a64 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/modal_inference.py b/modal_inference.py index e7222e7..ef964e0 100644 --- a/modal_inference.py +++ b/modal_inference.py @@ -141,7 +141,7 @@ def inference(self, video_path: str, inp: str = DEFAULT_PROMPT): else: tensor = video_tensor.to(self.model.device, dtype=torch.float16) - inp = ' '.join([DEFAULT_IMAGE_TOKEN] * model.get_video_tower().config.num_frames) + '\n' + inp + inp = ' '.join([DEFAULT_IMAGE_TOKEN] * self.model.get_video_tower().config.num_frames) + '\n' + inp self.conv.append_message(self.conv.roles[0], inp) self.conv.append_message(self.conv.roles[1], None) prompt = self.conv.get_prompt() @@ -176,6 +176,6 @@ def main(): print(output) if __name__ == "__main__": - video_llava = Cls.lookup("video-llava", 'VideoLlava')() + video_llava = Cls.lookup("updated-video-llava", 'VideoLlava')() video_path = '/volume/pika_water_city.mp4' print(video_llava.inference.remote(video_path=video_path, inp="describe what is going on in this video")) diff --git a/videollava/model/language_model/llava_llama.py b/videollava/model/language_model/llava_llama.py index 58ccb30..dd78d5d 100644 --- a/videollava/model/language_model/llava_llama.py +++ b/videollava/model/language_model/llava_llama.py @@ -107,5 +107,5 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ _inputs['images'] = images return _inputs -AutoConfig.register("llava", LlavaConfig) +#AutoConfig.register("llava", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) From 3d8d0e4dd12311c966a48e72d57b533bc2012878 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Thu, 1 Feb 2024 09:22:29 -0800 Subject: [PATCH 03/21] fix typos --- modal_inference.py | 2 +- videollava/model/language_model/mpt/hf_prefixlm_converter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modal_inference.py b/modal_inference.py index ef964e0..e237f25 100644 --- a/modal_inference.py +++ b/modal_inference.py @@ -145,7 +145,7 @@ def inference(self, video_path: str, inp: str = DEFAULT_PROMPT): self.conv.append_message(self.conv.roles[0], inp) self.conv.append_message(self.conv.roles[1], None) prompt = self.conv.get_prompt() - input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX['VIDEO'], return_tensors='pt').unsqueeze(0).cuda() + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) diff --git a/videollava/model/language_model/mpt/hf_prefixlm_converter.py b/videollava/model/language_model/mpt/hf_prefixlm_converter.py index 0441b07..31b339a 100644 --- a/videollava/model/language_model/mpt/hf_prefixlm_converter.py +++ b/videollava/model/language_model/mpt/hf_prefixlm_converter.py @@ -195,7 +195,7 @@ def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_k attention_mask = attention_mask.to(hidden_states.device) alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device) # causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length) - causal_mask = _prepare_4d_causal_attention_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), inpust_embeds=inputs_embeds, past_key_values_length=past_key_values_length) + causal_mask = _prepare_4d_causal_attention_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length) causal_mask = causal_mask.bool() for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: From 0d9a7c9c3db091297e4d84f2c2ea284b04166d4b Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Wed, 7 Feb 2024 17:41:36 -0800 Subject: [PATCH 04/21] modal-gradio-app --- gradio_web_server.py | 303 ++++++++++ modal_inference.py | 146 +++-- stub.py | 148 +++-- .../model/language_model/llava_llama.py | 2 +- .../mpt/hf_prefixlm_converter.py | 76 ++- .../languagebind/audio/modeling_audio.py | 7 +- .../languagebind/depth/modeling_depth.py | 7 +- .../languagebind/image/modeling_image.py | 7 +- .../languagebind/thermal/modeling_thermal.py | 7 +- .../languagebind/video/modeling_video.py | 7 +- videollava/serve/gradio_web_server.py | 530 ++++++++++-------- 11 files changed, 836 insertions(+), 404 deletions(-) create mode 100644 gradio_web_server.py diff --git a/gradio_web_server.py b/gradio_web_server.py new file mode 100644 index 0000000..9103bf5 --- /dev/null +++ b/gradio_web_server.py @@ -0,0 +1,303 @@ +import shutil +import os +import tempfile + +from modal import asgi_app, method, enter +from .stub import stub +from .stub import VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, REPO_HOME, EXAMPLES_PATH +from pathlib import Path + + +def save_image_to_local(image): + from PIL import Image + filename = os.path.join(VOLUME_DIR, next(tempfile._get_candidate_names()) + '.jpg') + image = Image.open(image) + image.save(filename) + # print(filename) + return filename + + +def save_video_to_local(video_path): + filename = os.path.join(VOLUME_DIR, next(tempfile._get_candidate_names()) + '.mp4') + shutil.copyfile(video_path, filename) + return filename + + +@cls_dec(gpu="any") +class VideoLlavaModel: + @enter() + def load_model(self): + import torch + from videollava.serve.gradio_utils import Chat + self.conv_mode = "llava_v1" + model_path = 'LanguageBind/Video-LLaVA-7B' + device = 'cuda' + load_8bit = True + load_4bit = False + self.dtype = torch.float16 + self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=MODEL_CACHE) + # self.handler.model.to(dtype=self.dtype) + + @method() + def generate(self, image1, video, textbox_in, first_run, state, state_, images_tensor): + from videollava.conversation import conv_templates, Conversation + import gradio as gr + from videollava.constants import DEFAULT_IMAGE_TOKEN + flag = 1 + if not textbox_in: + if len(state_.messages) > 0: + textbox_in = state_.messages[-1][1] + state_.messages.pop(-1) + flag = 0 + else: + return "Please enter instruction" + + print("textbox_in", textbox_in) + print("video path ", video) + print("image path ", image1) + image1 = image1 if image1 else "none" + video = video if video else "none" + print("video path after checking", video) + print("os.path.exists(video)", os.path.exists(video)) + if os.path.exists('/assets'): + print('assets dir', os.listdir('/assets')) + else: + print('no assets dir') + # assert not (os.path.exists(image1) and os.path.exists(video)) + + if type(state) is not Conversation: + state = conv_templates[self.conv_mode].copy() + state_ = conv_templates[self.conv_mode].copy() + images_tensor = [] + + first_run = False if len(state.messages) > 0 else True + + text_en_in = textbox_in.replace("picture", "image") + + # images_tensor = [[], []] + image_processor = self.handler.image_processor + if os.path.exists(image1) and not os.path.exists(video): + tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] + # print(tensor.shape) + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + video_processor = self.handler.video_processor + if not os.path.exists(image1) and os.path.exists(video): + tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] + # print(tensor.shape) + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + if os.path.exists(image1) and os.path.exists(video): + tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] + # print(tensor.shape) + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + + tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] + # print(tensor.shape) + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + + if os.path.exists(image1) and not os.path.exists(video): + text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in + elif not os.path.exists(image1) and os.path.exists(video): + text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + elif os.path.exists(image1) and os.path.exists(video): + text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN + else: + print("WARNING: No image or video supplied") + + print(text_en_in) + text_en_out, state_ = self.handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) + state_.messages[-1] = (state_.roles[1], text_en_out) + + text_en_out = text_en_out.split('#')[0] + textbox_out = text_en_out + + show_images = "" + if os.path.exists(image1): + filename = save_image_to_local(image1) + show_images += f'' + if os.path.exists(video): + filename = save_video_to_local(video) + show_images += f'' + + if flag: + state.append_message(state.roles[0], textbox_in + "\n" + show_images) + state.append_message(state.roles[1], textbox_out) + + return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) + + @method() + def clear_history(self, state, state_): + from videollava.conversation import conv_templates + import gradio as gr + state = conv_templates[self.conv_mode].copy() + state_ = conv_templates[self.conv_mode].copy() + return (gr.update(value=None, interactive=True), + gr.update(value=None, interactive=True), \ + gr.update(value=None, interactive=True), \ + True, state, state_, state.to_gradio_chatbot(), []) + + + + +def regenerate(state, state_): + state.messages.pop(-1) + state_.messages.pop(-1) + if len(state.messages) > 0: + return state, state_, state.to_gradio_chatbot(), False + return (state, state_, state.to_gradio_chatbot(), True) + + + + + +def build_gradio_interface(model): + import gradio as gr + from videollava.serve.gradio_utils import tos_markdown, learn_more_markdown, title_markdown, block_css + + # if not os.path.exists("temp"): + # os.makedirs("temp") + + + textbox = gr.Textbox( + show_label=False, placeholder="Enter text and press ENTER", container=False + ) + with gr.Blocks(title='Video-LLaVA๐Ÿš€', theme=gr.themes.Default(), css=block_css) as interface: + gr.Markdown(title_markdown) + state = gr.State() + state_ = gr.State() + first_run = gr.State() + images_tensor = gr.State() + + with gr.Row(): + with gr.Column(scale=3): + image1 = gr.Image(label="Input Image", type="filepath") + video = gr.Video(label="Input Video") + + #cur_dir = Path(REPO_HOME, 'videollava', 'serve') + cur_dir = EXAMPLES_PATH + gr.Examples( + examples=[ + [ + f"{cur_dir}/extreme_ironing.jpg", + "What is unusual about this image?", + ], + [ + f"{cur_dir}/waterview.jpg", + "What are the things I should be cautious about when I visit here?", + ], + [ + f"{cur_dir}/desert.jpg", + "If there are factual errors in the questions, point it out; if not, proceed answering the question. Whatโ€™s happening in the desert?", + ], + ], + inputs=[image1, textbox], + ) + + with gr.Column(scale=7): + chatbot = gr.Chatbot(label="Video-LLaVA", height=750) + with gr.Row(): + with gr.Column(scale=8): + textbox.render() + with gr.Column(scale=1, min_width=50): + submit_btn = gr.Button( + value="Send", variant="primary", interactive=True + ) + with gr.Row(elem_id="buttons") as button_row: + upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=True) + downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=True) + flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=True) + # stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False) + regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=True) + clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear history", interactive=True) + + with gr.Row(): + gr.Examples( + examples=[ + [ + f"{cur_dir}/sample_img_22.png", + f"{cur_dir}/sample_demo_22.mp4", + "Are the instruments in the pictures used in the video?", + ], + [ + f"{cur_dir}/sample_img_13.png", + f"{cur_dir}/sample_demo_13.mp4", + "Does the flag in the image appear in the video?", + ], + [ + f"{cur_dir}/sample_img_8.png", + f"{cur_dir}/sample_demo_8.mp4", + "Are the image and the video depicting the same place?", + ], + ], + inputs=[image1, video, textbox], + ) + gr.Examples( + examples=[ + [ + f"{cur_dir}/sample_demo_1.mp4", + "Why is this video funny?", + ], + [ + f"{cur_dir}/sample_demo_3.mp4", + "Can you identify any safety hazards in this video?" + ], + [ + f"{cur_dir}/sample_demo_9.mp4", + "Describe the video.", + ], + [ + f"{cur_dir}/sample_demo_22.mp4", + "Describe the activity in the video.", + ], + ], + inputs=[video, textbox], + ) + gr.Markdown(tos_markdown) + gr.Markdown(learn_more_markdown) + + submit_btn.click(model.generate.remote, [image1, video, textbox, first_run, state, state_, images_tensor], + [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) + + regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( + model.generate.remote, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) + + clear_btn.click(model.clear_history.remote, [state, state_], + [image1, video, textbox, first_run, state, state_, chatbot, images_tensor]) + return interface + + +@function_dec(gpu="any") +@asgi_app() +def fastapi_app(): + from gradio.routes import mount_gradio_app + import fastapi.staticfiles + from fastapi import FastAPI + app = FastAPI() + + model = VideoLlavaModel() + # interface = gr.Interface( + # fn=classifier.predict.remote, + # inputs=gr.Image(shape=(224, 224)), + # outputs="label", + # examples=create_demo_examples(), + # css="/assets/index.css", + # ) + # @app.get("/inference") + # async def inference(video_path: str, prompt: str): + # return model.generate.remote(video_path, prompt) + + + app.mount("/assets", fastapi.staticfiles.StaticFiles(directory="/assets")) + return mount_gradio_app( + app=app, + blocks=build_gradio_interface(model), + path="/gradio", + ) +# app = gr.mount_gradio_app(app, demo, path="/") +# demo.launch() + +# uvicorn videollava.serve.gradio_web_server:app +# python -m videollava.serve.gradio_web_server diff --git a/modal_inference.py b/modal_inference.py index e237f25..a3dd4d7 100644 --- a/modal_inference.py +++ b/modal_inference.py @@ -2,67 +2,70 @@ Cls, method, enter, + web_endpoint, ) from pathlib import Path -from .stub import stub, cls_dec, MODELS_DIR, volume, HF_DATASETS_CACHE, REPO_HOME +from .stub import stub, cls_dec, function_dec, MODELS_DIR, volume, HF_DATASETS_CACHE, REPO_HOME, load_pretrained_from_cache LOCAL_VIDEOS_PATH = Path(REPO_HOME) / "downloaded_videos" DEFAULT_PROMPT = "describe what is going on in this video" -def load_pretrained_from_cache(load_4bit=True, load_8bit=False): - from videollava.utils import disable_torch_init - from transformers import AutoTokenizer, BitsAndBytesConfig - from videollava.model import LlavaLlamaForCausalLM - import torch - disable_torch_init() - - kwargs = { - "device_map": "auto", - "cache_dir": HF_DATASETS_CACHE, - } - video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' - - vlp_exists = video_llava_path.exists() - if not vlp_exists: - video_llava_path.mkdir(exist_ok=True, parents=True) - - save = False - if not video_llava_path.exists() or len(list(video_llava_path.iterdir())) == 0: - save = True - print("Downloading model") - video_llava_path = 'LanguageBind/Video-LLaVA-7B' - - tokenizer_path = Path(MODELS_DIR) / 'tokenizer' - if not tokenizer_path.exists() or len(list(tokenizer_path.iterdir())) == 0: - print("Downloading tokenizer") - tokenizer_path = 'LanguageBind/Video-LLaVA-7B' - - if load_8bit: - kwargs['load_in_8bit'] = True - elif load_4bit: - kwargs['load_in_4bit'] = True - kwargs['quantization_config'] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type='nf4' - ) - else: - kwargs['torch_dtype'] = torch.float16 - - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False, cache_dir=kwargs["cache_dir"]) - model = LlavaLlamaForCausalLM.from_pretrained(video_llava_path, low_cpu_mem_usage=True, **kwargs) - model.generation_config.do_sample = True - - if save: - # save to on-disk paths - video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' - tokenizer_path = Path(MODELS_DIR) / 'tokenizer' - tokenizer.save_pretrained(str(tokenizer_path)) - model.save_pretrained(str(video_llava_path)) - return model, tokenizer +# def load_pretrained_from_cache(load_4bit=True, load_8bit=False): + # print("Loading pretrained model") + # from videollava.utils import disable_torch_init + # from transformers import AutoTokenizer, BitsAndBytesConfig + # from videollava.model import LlavaLlamaForCausalLM + # import torch + # disable_torch_init() + # print("imported") + + # kwargs = { + # "device_map": "auto", + # "cache_dir": HF_DATASETS_CACHE, + # } + # video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' + + # vlp_exists = video_llava_path.exists() + # if not vlp_exists: + # video_llava_path.mkdir(exist_ok=True, parents=True) + + # save = False + # if not video_llava_path.exists() or len(list(video_llava_path.iterdir())) == 0: + # save = True + # print("Downloading model") + # video_llava_path = 'LanguageBind/Video-LLaVA-7B' + + # tokenizer_path = Path(MODELS_DIR) / 'tokenizer' + # if not tokenizer_path.exists() or len(list(tokenizer_path.iterdir())) == 0: + # print("Downloading tokenizer") + # tokenizer_path = 'LanguageBind/Video-LLaVA-7B' + + # if load_8bit: + # kwargs['load_in_8bit'] = True + # elif load_4bit: + # kwargs['load_in_4bit'] = True + # kwargs['quantization_config'] = BitsAndBytesConfig( + # load_in_4bit=True, + # bnb_4bit_compute_dtype=torch.float16, + # bnb_4bit_use_double_quant=True, + # bnb_4bit_quant_type='nf4' + # ) + # else: + # kwargs['torch_dtype'] = torch.float16 + + # tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False, cache_dir=kwargs["cache_dir"]) + # model = LlavaLlamaForCausalLM.from_pretrained(video_llava_path, low_cpu_mem_usage=True, **kwargs) + # model.generation_config.do_sample = True + + # if save: + # # save to on-disk paths + # video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' + # tokenizer_path = Path(MODELS_DIR) / 'tokenizer' + # tokenizer.save_pretrained(str(tokenizer_path)) + # model.save_pretrained(str(video_llava_path)) + # return model, tokenizer def prepare_processor(model, device="cuda"): import torch @@ -98,12 +101,16 @@ def prepare_special_tokens(model, tokenizer): @cls_dec(container_idle_timeout=30, gpu='L4') class VideoLlava: + def __init__(self, device='cuda'): + self.device = device + self.model = None # TODO when they fix #@build() @enter() def load_model(self): - self.model, self.tokenizer = load_pretrained_from_cache() - self.processor = prepare_processor(self.model) + self.model, self.tokenizer = load_pretrained_from_cache(load_4bit=False, load_8bit=True) + print("got model") + self.processor = prepare_processor(self.model, device=self.device) self.video_processor = self.processor['video'] def prepare_conv(self): @@ -119,6 +126,7 @@ def inference(self, video_path: str, inp: str = DEFAULT_PROMPT): from videollava.conversation import SeparatorStyle import requests from io import BytesIO + print('preparing conv') self.prepare_conv() if video_path.startswith("http"): @@ -135,6 +143,7 @@ def inference(self, video_path: str, inp: str = DEFAULT_PROMPT): volume.reload() if not Path(video_path).exists(): raise FileNotFoundError(f"Video {video_path} not found") + print('processing video') video_tensor = self.video_processor(video_path, return_tensors='pt')['pixel_values'] if type(video_tensor) is list: tensor = [video.to(self.model.device, dtype=torch.float16) for video in video_tensor] @@ -149,7 +158,9 @@ def inference(self, video_path: str, inp: str = DEFAULT_PROMPT): stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) + print(input_ids.shape, tensor.shape) + import torch import time begin = time.time() with torch.inference_mode(): @@ -168,14 +179,33 @@ def inference(self, video_path: str, inp: str = DEFAULT_PROMPT): print(output) return output +@function_dec(gpu='L4') +@web_endpoint() +def inference(video_path: str, inp: str): + video_llava = VideoLlava() + if not hasattr(video_llava, 'model') or video_llava.model is None: + print('loading model') + video_llava.load_model() + print('model loaded') + output = VideoLlava().inference(video_path, inp) + print(output) + return output + @stub.local_entrypoint() def main(): prompt = "describe what is going on in this video" video_path = '/volume/pika_water_city.mp4' - output = VideoLlava().inference.remote(video_path, prompt) + input_ids, tensor, stopping_criteria = VideoLlava().get_inputs.remote(video_path, prompt) + output = VideoLlava().inference.remote(input_ids, tensor, stopping_criteria) print(output) if __name__ == "__main__": - video_llava = Cls.lookup("updated-video-llava", 'VideoLlava')() + video_llava = Cls.lookup("updated-video-llava-ephemeral", 'VideoLlava')() video_path = '/volume/pika_water_city.mp4' - print(video_llava.inference.remote(video_path=video_path, inp="describe what is going on in this video")) + #print(video_llava.inference.remote(video_path=video_path, inp="describe what is going on in this video")) + + prompt = "describe what is going on in this video" + video_path = '/volume/pika_water_city.mp4' + output = video_llava.inference.remote(video_path, prompt) + print(output) + diff --git a/stub.py b/stub.py index a1bf8be..5ccefd1 100644 --- a/stub.py +++ b/stub.py @@ -1,18 +1,28 @@ -from modal import Volume, Image, Stub, Mount, Secret +from modal import Volume, Image, Stub, Mount, Secret, build from pathlib import Path REPO_HOME = "/app" VOLUME_DIR = "/volume" MODELS_DIR = "/root" HF_DATASETS_CACHE = str(Path(VOLUME_DIR) / "hf_datasets_cache") -mounts = [Mount.from_local_dir("./ai_video_editor/updated_video_llava", remote_path=REPO_HOME)] +MODEL_CACHE = Path(VOLUME_DIR, "models") +assets_path = Path(__file__).parent / "assets" +local_examples_path = Path(__file__).parent / "videollava" / "serve" / "examples" +EXAMPLES_PATH = "/assets/examples" +mounts = [ + Mount.from_local_dir("./ai_video_editor/updated_video_llava", remote_path=REPO_HOME), + Mount.from_local_dir(assets_path, remote_path="/assets"), + Mount.from_local_dir(local_examples_path, remote_path=EXAMPLES_PATH), +] volume = Volume.persisted("video-llava-vol") volumes = {VOLUME_DIR: volume} stub = Stub("updated-video-llava", mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv()]) + image = ( Image.from_registry( - "nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04", add_python="3.10" + "nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04", add_python="3.11" ) + #Image.debian_slim() .apt_install( "git", "curl", @@ -28,54 +38,63 @@ ) .pip_install( - "torch==2.1.2", - "transformers==4.37.2", - "bitsandbytes==0.42.0", + # "torch==2.1.2", + # "transformers==4.37.2", + # "bitsandbytes==0.42.0", + "torch==2.0.1", "torchvision==0.15.2", + "transformers==4.31.0", "tokenizers>=0.12.1,<0.14", "sentencepiece==0.1.99", "shortuuid", + "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0", + "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2", + "requests", "httpx==0.24.0", "uvicorn", "fastapi", + "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", + "tensorboardX==2.6.2.2", "gradio==3.37.0", "gradio_client==0.7.0", + "deepspeed==0.9.5", "ninja", "wandb", + "wheel", gpu="any", ) .run_commands( "python -m bitsandbytes", - gpu="A10G" + gpu="any" ) - .pip_install( - "torchvision>=0.15.2", - #"tokenizers>=0.12.1,<0.14", - "sentencepiece==0.1.99", - "shortuuid", - "accelerate==0.21.0", - "peft==0.4.0", - "pydantic<2,>=1", - "markdown2[all]", - "numpy", - "scikit-learn==1.2.2", - "gradio==3.37.0", - "gradio_client==0.7.0", - "requests", - "httpx==0.24.0", - "uvicorn", - "fastapi", - #"einops==0.6.1", - "einops-exts==0.0.4", - "timm==0.6.13", - "deepspeed==0.9.5", - "ninja", - "wandb", - "tensorboardX==2.6.2.2", + # .pip_install( + # "torchvision>=0.15.2", + # #"tokenizers>=0.12.1,<0.14", + # "sentencepiece==0.1.99", + # "shortuuid", + # "accelerate==0.21.0", + # "peft==0.4.0", + # "pydantic<2,>=1", + # "markdown2[all]", + # "numpy", + # "scikit-learn==1.2.2", + # "gradio==3.37.0", + # "gradio_client==0.7.0", + # "requests", + # "httpx==0.24.0", + # "uvicorn", + # "fastapi", + # #"einops==0.6.1", + # "einops-exts==0.0.4", + # "timm==0.6.13", + # "deepspeed==0.9.5", + # "ninja", + # "wandb", + # "tensorboardX==2.6.2.2", - "tenacity", - "torch==2.1.2", - "wheel", - gpu="any", - ) + # "tenacity", + # "torch==2.1.2", + # "wheel", + # gpu="any", + # ) .run_commands("pip install flash-attn --no-build-isolation", gpu="any") .env({"PYTHONPATH": REPO_HOME, "HF_DATASETS_CACHE": HF_DATASETS_CACHE}) .pip_install( "decord", "opencv-python", - # TODO try removing or upgrading this to a version "git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d", gpu="any", ) + #.run_function(load_pretrained_from_cache, gpu="any") ) def function_dec(**extras): @@ -94,3 +113,58 @@ def cls_dec(**extras): checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. **extras, ) + +def load_pretrained_from_cache(load_4bit=True, load_8bit=False): + print("Loading pretrained model") + from videollava.utils import disable_torch_init + from transformers import AutoTokenizer, BitsAndBytesConfig + from videollava.model import LlavaLlamaForCausalLM + import torch + disable_torch_init() + print("imported") + + kwargs = { + "device_map": "auto", + "cache_dir": HF_DATASETS_CACHE, + } + video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' + + vlp_exists = video_llava_path.exists() + if not vlp_exists: + video_llava_path.mkdir(exist_ok=True, parents=True) + + save = False + if not video_llava_path.exists() or len(list(video_llava_path.iterdir())) == 0: + save = True + print("Downloading model") + video_llava_path = 'LanguageBind/Video-LLaVA-7B' + + tokenizer_path = Path(MODELS_DIR) / 'tokenizer' + if not tokenizer_path.exists() or len(list(tokenizer_path.iterdir())) == 0: + print("Downloading tokenizer") + tokenizer_path = 'LanguageBind/Video-LLaVA-7B' + + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.float16 + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False, cache_dir=kwargs["cache_dir"]) + model = LlavaLlamaForCausalLM.from_pretrained(video_llava_path, low_cpu_mem_usage=True, **kwargs) + model.generation_config.do_sample = True + + if save: + # save to on-disk paths + video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' + tokenizer_path = Path(MODELS_DIR) / 'tokenizer' + tokenizer.save_pretrained(str(tokenizer_path)) + model.save_pretrained(str(video_llava_path)) + return model, tokenizer diff --git a/videollava/model/language_model/llava_llama.py b/videollava/model/language_model/llava_llama.py index dd78d5d..58ccb30 100644 --- a/videollava/model/language_model/llava_llama.py +++ b/videollava/model/language_model/llava_llama.py @@ -107,5 +107,5 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ _inputs['images'] = images return _inputs -#AutoConfig.register("llava", LlavaConfig) +AutoConfig.register("llava", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) diff --git a/videollava/model/language_model/mpt/hf_prefixlm_converter.py b/videollava/model/language_model/mpt/hf_prefixlm_converter.py index 31b339a..8c1a648 100644 --- a/videollava/model/language_model/mpt/hf_prefixlm_converter.py +++ b/videollava/model/language_model/mpt/hf_prefixlm_converter.py @@ -12,16 +12,16 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss -#from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom -#from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom +from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom +from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom from transformers.models.bloom.modeling_bloom import logging from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM from transformers.models.gptj.modeling_gptj import GPTJForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM -#from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt -#from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt +from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt +from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt logger = logging.get_logger(__name__) _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM) CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM] @@ -123,19 +123,19 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa assert isinstance(model, BloomForCausalLM) assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models' - # def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor: - # combined_attention_mask = None - # device = attention_mask.device - # (_, src_length) = input_shape - # if src_length > 1: - # combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length) - # if bidirectional_mask is not None: - # assert attention_mask.shape == bidirectional_mask.shape - # expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length) - # combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask) - # expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length) - # combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - # return combined_attention_mask + def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor: + combined_attention_mask = None + device = attention_mask.device + (_, src_length) = input_shape + if src_length > 1: + combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length) + if bidirectional_mask is not None: + assert attention_mask.shape == bidirectional_mask.shape + expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length) + combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask) + expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length) + combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + return combined_attention_mask def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: num_heads = self.config.n_head @@ -194,9 +194,7 @@ def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_k else: attention_mask = attention_mask.to(hidden_states.device) alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device) - # causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length) - causal_mask = _prepare_4d_causal_attention_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length) - causal_mask = causal_mask.bool() + causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length) for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: hst = (hidden_states,) @@ -227,7 +225,7 @@ def custom_forward(*inputs): if not return_dict: return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)) return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions) - # setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer)) + setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer)) setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer)) setattr(model.transformer, 'forward', MethodType(forward, model.transformer)) KeyValueT = Tuple[torch.Tensor, torch.Tensor] @@ -284,23 +282,23 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM setattr(model, '_original_generate', getattr(model, 'generate')) model.model.decoder.bidirectional_mask = None - # def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # combined_attention_mask = None - # if input_shape[-1] > 1: - # if self.bidirectional_mask == 'g': - # (bsz, src_length) = input_shape - # combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device) - # else: - # combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device) - # if self.bidirectional_mask is not None: - # assert attention_mask.shape == self.bidirectional_mask.shape - # expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) - # combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask) - # if attention_mask is not None: - # expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) - # combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - # return combined_attention_mask - setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_4d_causal_attention_mask, model.model.decoder)) + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + combined_attention_mask = None + if input_shape[-1] > 1: + if self.bidirectional_mask == 'g': + (bsz, src_length) = input_shape + combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device) + else: + combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device) + if self.bidirectional_mask is not None: + assert attention_mask.shape == self.bidirectional_mask.shape + expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask) + if attention_mask is not None: + expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + return combined_attention_mask + setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder)) def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None): @@ -414,4 +412,4 @@ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]): elif 'labels' in batch and 'attention_mask' in batch: batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask']) else: - raise KeyError('No bidirectional_mask in batch and not sure how to construct one.') + raise KeyError('No bidirectional_mask in batch and not sure how to construct one.') \ No newline at end of file diff --git a/videollava/model/multimodal_encoder/languagebind/audio/modeling_audio.py b/videollava/model/multimodal_encoder/languagebind/audio/modeling_audio.py index c45eac0..908ab43 100644 --- a/videollava/model/multimodal_encoder/languagebind/audio/modeling_audio.py +++ b/videollava/model/multimodal_encoder/languagebind/audio/modeling_audio.py @@ -9,8 +9,7 @@ from transformers import PreTrainedModel, add_start_docstrings from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ - CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput, clip_loss -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_audio import LanguageBindAudioConfig, CLIPVisionConfig, CLIPTextConfig @@ -500,7 +499,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1028,4 +1027,4 @@ def forward( image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, - ) + ) \ No newline at end of file diff --git a/videollava/model/multimodal_encoder/languagebind/depth/modeling_depth.py b/videollava/model/multimodal_encoder/languagebind/depth/modeling_depth.py index ac304af..849eade 100644 --- a/videollava/model/multimodal_encoder/languagebind/depth/modeling_depth.py +++ b/videollava/model/multimodal_encoder/languagebind/depth/modeling_depth.py @@ -9,8 +9,7 @@ from transformers import PreTrainedModel, add_start_docstrings from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ - CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput, clip_loss -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_depth import LanguageBindDepthConfig, CLIPVisionConfig, CLIPTextConfig @@ -500,7 +499,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1028,4 +1027,4 @@ def forward( image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, - ) + ) \ No newline at end of file diff --git a/videollava/model/multimodal_encoder/languagebind/image/modeling_image.py b/videollava/model/multimodal_encoder/languagebind/image/modeling_image.py index 2d4ef0e..e95ac47 100644 --- a/videollava/model/multimodal_encoder/languagebind/image/modeling_image.py +++ b/videollava/model/multimodal_encoder/languagebind/image/modeling_image.py @@ -9,8 +9,7 @@ from transformers import PreTrainedModel, add_start_docstrings from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ - CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput, clip_loss -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_image import LanguageBindImageConfig, CLIPVisionConfig, CLIPTextConfig @@ -500,7 +499,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1028,4 +1027,4 @@ def forward( image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, - ) + ) \ No newline at end of file diff --git a/videollava/model/multimodal_encoder/languagebind/thermal/modeling_thermal.py b/videollava/model/multimodal_encoder/languagebind/thermal/modeling_thermal.py index 60272e6..f0323b3 100644 --- a/videollava/model/multimodal_encoder/languagebind/thermal/modeling_thermal.py +++ b/videollava/model/multimodal_encoder/languagebind/thermal/modeling_thermal.py @@ -9,8 +9,7 @@ from transformers import PreTrainedModel, add_start_docstrings from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ - CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput, clip_loss -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_thermal import LanguageBindThermalConfig, CLIPVisionConfig, CLIPTextConfig @@ -500,7 +499,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1028,4 +1027,4 @@ def forward( image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, - ) + ) \ No newline at end of file diff --git a/videollava/model/multimodal_encoder/languagebind/video/modeling_video.py b/videollava/model/multimodal_encoder/languagebind/video/modeling_video.py index d7d13d5..cb5c621 100644 --- a/videollava/model/multimodal_encoder/languagebind/video/modeling_video.py +++ b/videollava/model/multimodal_encoder/languagebind/video/modeling_video.py @@ -9,8 +9,7 @@ from transformers import PreTrainedModel, add_start_docstrings from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ - CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput, clip_loss -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings from .configuration_video import LanguageBindVideoConfig, CLIPVisionConfig, CLIPTextConfig @@ -500,7 +499,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1031,4 +1030,4 @@ def forward( image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, - ) + ) \ No newline at end of file diff --git a/videollava/serve/gradio_web_server.py b/videollava/serve/gradio_web_server.py index dfce85a..fc7d7de 100644 --- a/videollava/serve/gradio_web_server.py +++ b/videollava/serve/gradio_web_server.py @@ -1,249 +1,281 @@ -import shutil -import subprocess - -import torch -import gradio as gr -from fastapi import FastAPI -import os -from PIL import Image -import tempfile -from decord import VideoReader, cpu -from transformers import TextStreamer - -from videollava.constants import DEFAULT_IMAGE_TOKEN -from videollava.conversation import conv_templates, SeparatorStyle, Conversation -from videollava.serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css - - - -def save_image_to_local(image): - filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg') - image = Image.open(image) - image.save(filename) - # print(filename) - return filename - - -def save_video_to_local(video_path): - filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') - shutil.copyfile(video_path, filename) - return filename - - -def generate(image1, video, textbox_in, first_run, state, state_, images_tensor): - flag = 1 - if not textbox_in: - if len(state_.messages) > 0: - textbox_in = state_.messages[-1][1] - state_.messages.pop(-1) - flag = 0 - else: - return "Please enter instruction" - - image1 = image1 if image1 else "none" - video = video if video else "none" - # assert not (os.path.exists(image1) and os.path.exists(video)) - - if type(state) is not Conversation: - state = conv_templates[conv_mode].copy() - state_ = conv_templates[conv_mode].copy() - images_tensor = [] - - first_run = False if len(state.messages) > 0 else True - - text_en_in = textbox_in.replace("picture", "image") - - # images_tensor = [[], []] - image_processor = handler.image_processor - if os.path.exists(image1) and not os.path.exists(video): - tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(handler.model.device, dtype=dtype) - images_tensor.append(tensor) - video_processor = handler.video_processor - if not os.path.exists(image1) and os.path.exists(video): - tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(handler.model.device, dtype=dtype) - images_tensor.append(tensor) - if os.path.exists(image1) and os.path.exists(video): - tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(handler.model.device, dtype=dtype) - images_tensor.append(tensor) - - tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(handler.model.device, dtype=dtype) - images_tensor.append(tensor) - - if os.path.exists(image1) and not os.path.exists(video): - text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in - if not os.path.exists(image1) and os.path.exists(video): - text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in - if os.path.exists(image1) and os.path.exists(video): - text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN - # print(text_en_in) - text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) - state_.messages[-1] = (state_.roles[1], text_en_out) - - text_en_out = text_en_out.split('#')[0] - textbox_out = text_en_out - - show_images = "" - if os.path.exists(image1): - filename = save_image_to_local(image1) - show_images += f'' - if os.path.exists(video): - filename = save_video_to_local(video) - show_images += f'' - - if flag: - state.append_message(state.roles[0], textbox_in + "\n" + show_images) - state.append_message(state.roles[1], textbox_out) - - return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) - - -def regenerate(state, state_): - state.messages.pop(-1) - state_.messages.pop(-1) - if len(state.messages) > 0: - return state, state_, state.to_gradio_chatbot(), False - return (state, state_, state.to_gradio_chatbot(), True) - - -def clear_history(state, state_): - state = conv_templates[conv_mode].copy() - state_ = conv_templates[conv_mode].copy() - return (gr.update(value=None, interactive=True), - gr.update(value=None, interactive=True), \ - gr.update(value=None, interactive=True), \ - True, state, state_, state.to_gradio_chatbot(), []) - - -conv_mode = "llava_v1" -model_path = 'LanguageBind/Video-LLaVA-7B' -cache_dir = './cache_dir' -device = 'cuda' -load_8bit = True -load_4bit = False -dtype = torch.float16 -handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir) -# handler.model.to(dtype=dtype) -if not os.path.exists("temp"): - os.makedirs("temp") - -app = FastAPI() - - -textbox = gr.Textbox( - show_label=False, placeholder="Enter text and press ENTER", container=False -) -with gr.Blocks(title='Video-LLaVA๐Ÿš€', theme=gr.themes.Default(), css=block_css) as demo: - gr.Markdown(title_markdown) - state = gr.State() - state_ = gr.State() - first_run = gr.State() - images_tensor = gr.State() - - with gr.Row(): - with gr.Column(scale=3): - image1 = gr.Image(label="Input Image", type="filepath") - video = gr.Video(label="Input Video") - - cur_dir = os.path.dirname(os.path.abspath(__file__)) - gr.Examples( - examples=[ - [ - f"{cur_dir}/examples/extreme_ironing.jpg", - "What is unusual about this image?", - ], - [ - f"{cur_dir}/examples/waterview.jpg", - "What are the things I should be cautious about when I visit here?", - ], - [ - f"{cur_dir}/examples/desert.jpg", - "If there are factual errors in the questions, point it out; if not, proceed answering the question. Whatโ€™s happening in the desert?", - ], - ], - inputs=[image1, textbox], - ) - - with gr.Column(scale=7): - chatbot = gr.Chatbot(label="Video-LLaVA", bubble_full_width=True).style(height=750) - with gr.Row(): - with gr.Column(scale=8): - textbox.render() - with gr.Column(scale=1, min_width=50): - submit_btn = gr.Button( - value="Send", variant="primary", interactive=True - ) - with gr.Row(elem_id="buttons") as button_row: - upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=True) - downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=True) - flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=True) - # stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False) - regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=True) - clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear history", interactive=True) - - with gr.Row(): - gr.Examples( - examples=[ - [ - f"{cur_dir}/examples/sample_img_22.png", - f"{cur_dir}/examples/sample_demo_22.mp4", - "Are the instruments in the pictures used in the video?", - ], - [ - f"{cur_dir}/examples/sample_img_13.png", - f"{cur_dir}/examples/sample_demo_13.mp4", - "Does the flag in the image appear in the video?", - ], - [ - f"{cur_dir}/examples/sample_img_8.png", - f"{cur_dir}/examples/sample_demo_8.mp4", - "Are the image and the video depicting the same place?", - ], - ], - inputs=[image1, video, textbox], - ) - gr.Examples( - examples=[ - [ - f"{cur_dir}/examples/sample_demo_1.mp4", - "Why is this video funny?", - ], - [ - f"{cur_dir}/examples/sample_demo_3.mp4", - "Can you identify any safety hazards in this video?" - ], - [ - f"{cur_dir}/examples/sample_demo_9.mp4", - "Describe the video.", - ], - [ - f"{cur_dir}/examples/sample_demo_22.mp4", - "Describe the activity in the video.", - ], - ], - inputs=[video, textbox], - ) - gr.Markdown(tos_markdown) - gr.Markdown(learn_more_markdown) - - submit_btn.click(generate, [image1, video, textbox, first_run, state, state_, images_tensor], - [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) - - regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( - generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) - - clear_btn.click(clear_history, [state, state_], - [image1, video, textbox, first_run, state, state_, chatbot, images_tensor]) - -# app = gr.mount_gradio_app(app, demo, path="/") -demo.launch() - -# uvicorn videollava.serve.gradio_web_server:app -# python -m videollava.serve.gradio_web_server +# import shutil +# import os +# import tempfile + +# from modal import asgi_app, method, enter +# from ...stub import VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec + + +# def save_image_to_local(image): + # from PIL import Image + # filename = os.path.join(VOLUME_DIR, next(tempfile._get_candidate_names()) + '.jpg') + # image = Image.open(image) + # image.save(filename) + # # print(filename) + # return filename + + +# def save_video_to_local(video_path): + # filename = os.path.join(VOLUME_DIR, next(tempfile._get_candidate_names()) + '.mp4') + # shutil.copyfile(video_path, filename) + # return filename + + +# @cls_dec(gpu="any") +# class VideoLlavaModel: + # @enter() + # def load_model(self): + # import torch + # from videollava.serve.gradio_utils import Chat + # self.conv_mode = "llava_v1" + # model_path = 'LanguageBind/Video-LLaVA-7B' + # device = 'cuda' + # load_8bit = True + # load_4bit = False + # self.dtype = torch.float16 + # self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=MODEL_CACHE) + # # self.handler.model.to(dtype=self.dtype) + + # @method() + # def generate(self, image1, video, textbox_in, first_run, state, state_, images_tensor): + # from videollava.conversation import conv_templates, Conversation + # import gradio as gr + # from videollava.constants import DEFAULT_IMAGE_TOKEN + # flag = 1 + # if not textbox_in: + # if len(state_.messages) > 0: + # textbox_in = state_.messages[-1][1] + # state_.messages.pop(-1) + # flag = 0 + # else: + # return "Please enter instruction" + + # image1 = image1 if image1 else "none" + # video = video if video else "none" + # # assert not (os.path.exists(image1) and os.path.exists(video)) + + # if type(state) is not Conversation: + # state = conv_templates[self.conv_mode].copy() + # state_ = conv_templates[self.conv_mode].copy() + # images_tensor = [] + + # first_run = False if len(state.messages) > 0 else True + + # text_en_in = textbox_in.replace("picture", "image") + + # # images_tensor = [[], []] + # image_processor = self.handler.image_processor + # if os.path.exists(image1) and not os.path.exists(video): + # tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] + # # print(tensor.shape) + # tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + # images_tensor.append(tensor) + # video_processor = self.handler.video_processor + # if not os.path.exists(image1) and os.path.exists(video): + # tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] + # # print(tensor.shape) + # tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + # images_tensor.append(tensor) + # if os.path.exists(image1) and os.path.exists(video): + # tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] + # # print(tensor.shape) + # tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + # images_tensor.append(tensor) + + # tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] + # # print(tensor.shape) + # tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + # images_tensor.append(tensor) + + # if os.path.exists(image1) and not os.path.exists(video): + # text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in + # if not os.path.exists(image1) and os.path.exists(video): + # text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + # if os.path.exists(image1) and os.path.exists(video): + # text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN + # # print(text_en_in) + # text_en_out, state_ = self.handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) + # state_.messages[-1] = (state_.roles[1], text_en_out) + + # text_en_out = text_en_out.split('#')[0] + # textbox_out = text_en_out + + # show_images = "" + # if os.path.exists(image1): + # filename = save_image_to_local(image1) + # show_images += f'' + # if os.path.exists(video): + # filename = save_video_to_local(video) + # show_images += f'' + + # if flag: + # state.append_message(state.roles[0], textbox_in + "\n" + show_images) + # state.append_message(state.roles[1], textbox_out) + + # return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) + + # @method() + # def clear_history(self, state, state_): + # from videollava.conversation import conv_templates + # import gradio as gr + # state = conv_templates[self.conv_mode].copy() + # state_ = conv_templates[self.conv_mode].copy() + # return (gr.update(value=None, interactive=True), + # gr.update(value=None, interactive=True), \ + # gr.update(value=None, interactive=True), \ + # True, state, state_, state.to_gradio_chatbot(), []) + + + + +# def regenerate(state, state_): + # state.messages.pop(-1) + # state_.messages.pop(-1) + # if len(state.messages) > 0: + # return state, state_, state.to_gradio_chatbot(), False + # return (state, state_, state.to_gradio_chatbot(), True) + + + + + +# def build_gradio_interface(model): + # import gradio as gr + # from videollava.serve.gradio_utils import tos_markdown, learn_more_markdown, title_markdown, block_css + + # # if not os.path.exists("temp"): + # # os.makedirs("temp") + + + # textbox = gr.Textbox( + # show_label=False, placeholder="Enter text and press ENTER", container=False + # ) + # with gr.Blocks(title='Video-LLaVA๐Ÿš€', theme=gr.themes.Default(), css=block_css) as interface: + # gr.Markdown(title_markdown) + # state = gr.State() + # state_ = gr.State() + # first_run = gr.State() + # images_tensor = gr.State() + + # with gr.Row(): + # with gr.Column(scale=3): + # image1 = gr.Image(label="Input Image", type="filepath") + # video = gr.Video(label="Input Video") + + # cur_dir = os.path.dirname(os.path.abspath(__file__)) + # gr.Examples( + # examples=[ + # [ + # f"{cur_dir}/examples/extreme_ironing.jpg", + # "What is unusual about this image?", + # ], + # [ + # f"{cur_dir}/examples/waterview.jpg", + # "What are the things I should be cautious about when I visit here?", + # ], + # [ + # f"{cur_dir}/examples/desert.jpg", + # "If there are factual errors in the questions, point it out; if not, proceed answering the question. Whatโ€™s happening in the desert?", + # ], + # ], + # inputs=[image1, textbox], + # ) + + # with gr.Column(scale=7): + # chatbot = gr.Chatbot(label="Video-LLaVA", bubble_full_width=True).style(height=750) + # with gr.Row(): + # with gr.Column(scale=8): + # textbox.render() + # with gr.Column(scale=1, min_width=50): + # submit_btn = gr.Button( + # value="Send", variant="primary", interactive=True + # ) + # with gr.Row(elem_id="buttons") as button_row: + # upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=True) + # downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=True) + # flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=True) + # # stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False) + # regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=True) + # clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear history", interactive=True) + + # with gr.Row(): + # gr.Examples( + # examples=[ + # [ + # f"{cur_dir}/examples/sample_img_22.png", + # f"{cur_dir}/examples/sample_demo_22.mp4", + # "Are the instruments in the pictures used in the video?", + # ], + # [ + # f"{cur_dir}/examples/sample_img_13.png", + # f"{cur_dir}/examples/sample_demo_13.mp4", + # "Does the flag in the image appear in the video?", + # ], + # [ + # f"{cur_dir}/examples/sample_img_8.png", + # f"{cur_dir}/examples/sample_demo_8.mp4", + # "Are the image and the video depicting the same place?", + # ], + # ], + # inputs=[image1, video, textbox], + # ) + # gr.Examples( + # examples=[ + # [ + # f"{cur_dir}/examples/sample_demo_1.mp4", + # "Why is this video funny?", + # ], + # [ + # f"{cur_dir}/examples/sample_demo_3.mp4", + # "Can you identify any safety hazards in this video?" + # ], + # [ + # f"{cur_dir}/examples/sample_demo_9.mp4", + # "Describe the video.", + # ], + # [ + # f"{cur_dir}/examples/sample_demo_22.mp4", + # "Describe the activity in the video.", + # ], + # ], + # inputs=[video, textbox], + # ) + # gr.Markdown(tos_markdown) + # gr.Markdown(learn_more_markdown) + + # submit_btn.click(model.generate.remote, [image1, video, textbox, first_run, state, state_, images_tensor], + # [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) + + # regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( + # model.generate.remote, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) + + # clear_btn.click(model.clear_history.remote, [state, state_], + # [image1, video, textbox, first_run, state, state_, chatbot, images_tensor]) + # return interface + + +# @function_dec(gpu="any") +# @asgi_app() +# def fastapi_app(): + # from gradio.routes import mount_gradio_app + # from fastapi import FastAPI + # app = FastAPI() + + # model = VideoLlavaModel() + # # interface = gr.Interface( + # # fn=classifier.predict.remote, + # # inputs=gr.Image(shape=(224, 224)), + # # outputs="label", + # # examples=create_demo_examples(), + # # css="/assets/index.css", + # # ) + # return mount_gradio_app( + # app=app, + # blocks=build_gradio_interface(model), + # path="/", + # ) +# # app = gr.mount_gradio_app(app, demo, path="/") +# # demo.launch() + +# # uvicorn videollava.serve.gradio_web_server:app +# # python -m videollava.serve.gradio_web_server From ed7ac2bad080cf2e928e19019b4e6b8fb1cd1665 Mon Sep 17 00:00:00 2001 From: bschreck Date: Fri, 9 Feb 2024 13:42:39 -0800 Subject: [PATCH 05/21] gradio app running locally --- gradio_web_server.py | 63 ++++++++++++++++++++++++++++---------------- stub.py | 2 +- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/gradio_web_server.py b/gradio_web_server.py index 9103bf5..0dec8eb 100644 --- a/gradio_web_server.py +++ b/gradio_web_server.py @@ -3,9 +3,12 @@ import tempfile from modal import asgi_app, method, enter -from .stub import stub -from .stub import VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, REPO_HOME, EXAMPLES_PATH +from stub import stub, VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, REPO_HOME, EXAMPLES_PATH from pathlib import Path +VOLUME_DIR = "volume" +MODEL_CACHE = "models" +Path(VOLUME_DIR).mkdir(exist_ok=True, parents=True) +Path(MODEL_CACHE).mkdir(exist_ok=True, parents=True) def save_image_to_local(image): @@ -23,22 +26,24 @@ def save_video_to_local(video_path): return filename -@cls_dec(gpu="any") +#@cls_dec(gpu="any") class VideoLlavaModel: - @enter() + def __init__(self): + self.load_model() + #@enter() def load_model(self): import torch from videollava.serve.gradio_utils import Chat self.conv_mode = "llava_v1" model_path = 'LanguageBind/Video-LLaVA-7B' device = 'cuda' - load_8bit = True - load_4bit = False + load_8bit = False + load_4bit = True self.dtype = torch.float16 self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=MODEL_CACHE) # self.handler.model.to(dtype=self.dtype) - @method() + #method() def generate(self, image1, video, textbox_in, first_run, state, state_, images_tensor): from videollava.conversation import conv_templates, Conversation import gradio as gr @@ -128,7 +133,7 @@ def generate(self, image1, video, textbox_in, first_run, state, state_, images_t return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) - @method() + #@method() def clear_history(self, state, state_): from videollava.conversation import conv_templates import gradio as gr @@ -176,8 +181,8 @@ def build_gradio_interface(model): image1 = gr.Image(label="Input Image", type="filepath") video = gr.Video(label="Input Video") - #cur_dir = Path(REPO_HOME, 'videollava', 'serve') - cur_dir = EXAMPLES_PATH + cur_dir = Path(__file__).parent / 'videollava' / 'serve' / 'examples' + #cur_dir = EXAMPLES_PATH gr.Examples( examples=[ [ @@ -258,26 +263,23 @@ def build_gradio_interface(model): gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) - submit_btn.click(model.generate.remote, [image1, video, textbox, first_run, state, state_, images_tensor], + submit_btn.click(model.generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( - model.generate.remote, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) + model.generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) - clear_btn.click(model.clear_history.remote, [state, state_], + clear_btn.click(model.clear_history, [state, state_], [image1, video, textbox, first_run, state, state_, chatbot, images_tensor]) return interface -@function_dec(gpu="any") -@asgi_app() def fastapi_app(): from gradio.routes import mount_gradio_app import fastapi.staticfiles from fastapi import FastAPI app = FastAPI() - model = VideoLlavaModel() # interface = gr.Interface( # fn=classifier.predict.remote, # inputs=gr.Image(shape=(224, 224)), @@ -290,14 +292,29 @@ def fastapi_app(): # return model.generate.remote(video_path, prompt) - app.mount("/assets", fastapi.staticfiles.StaticFiles(directory="/assets")) - return mount_gradio_app( - app=app, + app.mount("/assets", fastapi.staticfiles.StaticFiles(directory="assets")) + app.mount("/examples", fastapi.staticfiles.StaticFiles(directory="videollava/serve")) + return app + +#if __name__ == '__main__': +fast_api_app = fastapi_app() +model = VideoLlavaModel() + +@function_dec(gpu="any") +@asgi_app() +def fastapi_app_modal(): + mount_gradio_app( + app=fastapi_app(), blocks=build_gradio_interface(model), path="/gradio", ) -# app = gr.mount_gradio_app(app, demo, path="/") -# demo.launch() -# uvicorn videollava.serve.gradio_web_server:app -# python -m videollava.serve.gradio_web_server + + +if __name__ == '__main__': + demo = build_gradio_interface(model) + demo.launch(share=True) + +# poetry shell +# uvicorn gradio_web_server:app +# python -m gradio_web_server diff --git a/stub.py b/stub.py index 5ccefd1..72be948 100644 --- a/stub.py +++ b/stub.py @@ -7,7 +7,7 @@ MODEL_CACHE = Path(VOLUME_DIR, "models") assets_path = Path(__file__).parent / "assets" local_examples_path = Path(__file__).parent / "videollava" / "serve" / "examples" -EXAMPLES_PATH = "/assets/examples" +EXAMPLES_PATH = "/examples" mounts = [ Mount.from_local_dir("./ai_video_editor/updated_video_llava", remote_path=REPO_HOME), Mount.from_local_dir(assets_path, remote_path="/assets"), From 8137d56615340d2661ddb81916f569d224e2a3a5 Mon Sep 17 00:00:00 2001 From: bschreck Date: Mon, 12 Feb 2024 11:49:05 -0800 Subject: [PATCH 06/21] fastapi app running locally --- gradio_web_server.py | 283 +++++++++---------------------------------- pyproject.toml | 2 +- 2 files changed, 60 insertions(+), 225 deletions(-) diff --git a/gradio_web_server.py b/gradio_web_server.py index 0dec8eb..46a4229 100644 --- a/gradio_web_server.py +++ b/gradio_web_server.py @@ -1,6 +1,8 @@ import shutil import os import tempfile +import urllib +import aiofiles from modal import asgi_app, method, enter from stub import stub, VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, REPO_HOME, EXAMPLES_PATH @@ -9,21 +11,10 @@ MODEL_CACHE = "models" Path(VOLUME_DIR).mkdir(exist_ok=True, parents=True) Path(MODEL_CACHE).mkdir(exist_ok=True, parents=True) - - -def save_image_to_local(image): - from PIL import Image - filename = os.path.join(VOLUME_DIR, next(tempfile._get_candidate_names()) + '.jpg') - image = Image.open(image) - image.save(filename) - # print(filename) - return filename - - -def save_video_to_local(video_path): - filename = os.path.join(VOLUME_DIR, next(tempfile._get_candidate_names()) + '.mp4') - shutil.copyfile(video_path, filename) - return filename +VIDEOS_DIR = Path(VOLUME_DIR) / "videos" +IMAGES_DIR = Path(VOLUME_DIR) / "images" +VIDEOS_DIR.mkdir(exist_ok=True, parents=True) +IMAGES_DIR.mkdir(exist_ok=True, parents=True) #@cls_dec(gpu="any") @@ -43,39 +34,19 @@ def load_model(self): self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=MODEL_CACHE) # self.handler.model.to(dtype=self.dtype) - #method() - def generate(self, image1, video, textbox_in, first_run, state, state_, images_tensor): + def generate(self, image1, video, textbox_in): from videollava.conversation import conv_templates, Conversation import gradio as gr from videollava.constants import DEFAULT_IMAGE_TOKEN - flag = 1 if not textbox_in: - if len(state_.messages) > 0: - textbox_in = state_.messages[-1][1] - state_.messages.pop(-1) - flag = 0 - else: - return "Please enter instruction" + raise ValueError("no prompt provided") - print("textbox_in", textbox_in) - print("video path ", video) - print("image path ", image1) image1 = image1 if image1 else "none" video = video if video else "none" - print("video path after checking", video) - print("os.path.exists(video)", os.path.exists(video)) - if os.path.exists('/assets'): - print('assets dir', os.listdir('/assets')) - else: - print('no assets dir') - # assert not (os.path.exists(image1) and os.path.exists(video)) - if type(state) is not Conversation: - state = conv_templates[self.conv_mode].copy() - state_ = conv_templates[self.conv_mode].copy() - images_tensor = [] - - first_run = False if len(state.messages) > 0 else True + state = conv_templates[self.conv_mode].copy() + state_ = conv_templates[self.conv_mode].copy() + images_tensor = [] text_en_in = textbox_in.replace("picture", "image") @@ -113,208 +84,72 @@ def generate(self, image1, video, textbox_in, first_run, state, state_, images_t print("WARNING: No image or video supplied") print(text_en_in) - text_en_out, state_ = self.handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) - state_.messages[-1] = (state_.roles[1], text_en_out) + text_en_out, _ = self.handler.generate(images_tensor, text_en_in, first_run=True, state=state_) text_en_out = text_en_out.split('#')[0] textbox_out = text_en_out - show_images = "" - if os.path.exists(image1): - filename = save_image_to_local(image1) - show_images += f'' - if os.path.exists(video): - filename = save_video_to_local(video) - show_images += f'' - - if flag: - state.append_message(state.roles[0], textbox_in + "\n" + show_images) - state.append_message(state.roles[1], textbox_out) - - return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) - - #@method() - def clear_history(self, state, state_): - from videollava.conversation import conv_templates - import gradio as gr - state = conv_templates[self.conv_mode].copy() - state_ = conv_templates[self.conv_mode].copy() - return (gr.update(value=None, interactive=True), - gr.update(value=None, interactive=True), \ - gr.update(value=None, interactive=True), \ - True, state, state_, state.to_gradio_chatbot(), []) - - - - -def regenerate(state, state_): - state.messages.pop(-1) - state_.messages.pop(-1) - if len(state.messages) > 0: - return state, state_, state.to_gradio_chatbot(), False - return (state, state_, state.to_gradio_chatbot(), True) - - + return textbox_out -def build_gradio_interface(model): - import gradio as gr - from videollava.serve.gradio_utils import tos_markdown, learn_more_markdown, title_markdown, block_css - - # if not os.path.exists("temp"): - # os.makedirs("temp") - - - textbox = gr.Textbox( - show_label=False, placeholder="Enter text and press ENTER", container=False - ) - with gr.Blocks(title='Video-LLaVA๐Ÿš€', theme=gr.themes.Default(), css=block_css) as interface: - gr.Markdown(title_markdown) - state = gr.State() - state_ = gr.State() - first_run = gr.State() - images_tensor = gr.State() - - with gr.Row(): - with gr.Column(scale=3): - image1 = gr.Image(label="Input Image", type="filepath") - video = gr.Video(label="Input Video") - - cur_dir = Path(__file__).parent / 'videollava' / 'serve' / 'examples' - #cur_dir = EXAMPLES_PATH - gr.Examples( - examples=[ - [ - f"{cur_dir}/extreme_ironing.jpg", - "What is unusual about this image?", - ], - [ - f"{cur_dir}/waterview.jpg", - "What are the things I should be cautious about when I visit here?", - ], - [ - f"{cur_dir}/desert.jpg", - "If there are factual errors in the questions, point it out; if not, proceed answering the question. Whatโ€™s happening in the desert?", - ], - ], - inputs=[image1, textbox], - ) - - with gr.Column(scale=7): - chatbot = gr.Chatbot(label="Video-LLaVA", height=750) - with gr.Row(): - with gr.Column(scale=8): - textbox.render() - with gr.Column(scale=1, min_width=50): - submit_btn = gr.Button( - value="Send", variant="primary", interactive=True - ) - with gr.Row(elem_id="buttons") as button_row: - upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=True) - downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=True) - flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=True) - # stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False) - regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=True) - clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear history", interactive=True) - - with gr.Row(): - gr.Examples( - examples=[ - [ - f"{cur_dir}/sample_img_22.png", - f"{cur_dir}/sample_demo_22.mp4", - "Are the instruments in the pictures used in the video?", - ], - [ - f"{cur_dir}/sample_img_13.png", - f"{cur_dir}/sample_demo_13.mp4", - "Does the flag in the image appear in the video?", - ], - [ - f"{cur_dir}/sample_img_8.png", - f"{cur_dir}/sample_demo_8.mp4", - "Are the image and the video depicting the same place?", - ], - ], - inputs=[image1, video, textbox], - ) - gr.Examples( - examples=[ - [ - f"{cur_dir}/sample_demo_1.mp4", - "Why is this video funny?", - ], - [ - f"{cur_dir}/sample_demo_3.mp4", - "Can you identify any safety hazards in this video?" - ], - [ - f"{cur_dir}/sample_demo_9.mp4", - "Describe the video.", - ], - [ - f"{cur_dir}/sample_demo_22.mp4", - "Describe the activity in the video.", - ], - ], - inputs=[video, textbox], - ) - gr.Markdown(tos_markdown) - gr.Markdown(learn_more_markdown) - - submit_btn.click(model.generate, [image1, video, textbox, first_run, state, state_, images_tensor], - [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) - - regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( - model.generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video]) - - clear_btn.click(model.clear_history, [state, state_], - [image1, video, textbox, first_run, state, state_, chatbot, images_tensor]) - return interface - - -def fastapi_app(): +def fastapi_app(model): from gradio.routes import mount_gradio_app import fastapi.staticfiles - from fastapi import FastAPI + from fastapi import FastAPI, UploadFile, File, Request, Form, Query, HTTPException + from fastapi.responses import StreamingResponse app = FastAPI() - - # interface = gr.Interface( - # fn=classifier.predict.remote, - # inputs=gr.Image(shape=(224, 224)), - # outputs="label", - # examples=create_demo_examples(), - # css="/assets/index.css", - # ) - # @app.get("/inference") - # async def inference(video_path: str, prompt: str): - # return model.generate.remote(video_path, prompt) + model = VideoLlavaModel() + + @app.post("/upload") + async def upload( + file: UploadFile = File(...), + ): + filename_decoded = urllib.parse.unquote(file.filename) + file_path = str(VIDEOS_DIR / filename_decoded) + async with aiofiles.open(file_path, "wb") as buffer: + while content := await file.read(1024): # Read chunks of 1024 bytes + await buffer.write(content) + return {"file_path": file_path} + + @app.post("/inference") + async def inference( + video_file_name: str = '', + video_file_path: str = '', + image_file_name: str = '', + image_file_path: str = '', + prompt: str = '', + ): + video_file_name = urllib.parse.unquote(video_file_name) + video_file_path = urllib.parse.unquote(video_file_path) + if video_file_path is None or video_file_path == '': + if video_file_name is None or video_file_name == '': + raise ValueError("one of video_file_path or video_file_name must be specified") + video_file_path = str(VIDEOS_DIR / video_file_name) + + image_file_name = urllib.parse.unquote(image_file_name) + image_file_path = urllib.parse.unquote(image_file_path) + if image_file_path is None or image_file_path == '': + if image_file_name is not None and image_file_name != '': + image_file_path = str(IMAGES_DIR / image_file_name) + + return model.generate(image_file_path, video_file_path, prompt) app.mount("/assets", fastapi.staticfiles.StaticFiles(directory="assets")) app.mount("/examples", fastapi.staticfiles.StaticFiles(directory="videollava/serve")) return app -#if __name__ == '__main__': -fast_api_app = fastapi_app() -model = VideoLlavaModel() -@function_dec(gpu="any") -@asgi_app() -def fastapi_app_modal(): - mount_gradio_app( - app=fastapi_app(), - blocks=build_gradio_interface(model), - path="/gradio", - ) +# comment this out to deploy +app = fastapi_app() -if __name__ == '__main__': - demo = build_gradio_interface(model) - demo.launch(share=True) +@function_dec(gpu="any") +@asgi_app() +def fastapi_app_modal(): + app = fastapi_app() -# poetry shell -# uvicorn gradio_web_server:app -# python -m gradio_web_server +# conda activate videollava +# uvicorn modal_inference:app diff --git a/pyproject.toml b/pyproject.toml index 7c0b7f1..7f21463 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2", "requests", "httpx==0.24.0", "uvicorn", "fastapi", "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", - "tensorboardX==2.6.2.2", "gradio==3.37.0", "gradio_client==0.7.0" + "tensorboardX==2.6.2.2", "gradio==3.37.0", "gradio_client==0.7.0", "modal" ] [project.optional-dependencies] From 660189cfb59018ad8b41c3cdeb3b9054165e1b28 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Mon, 12 Feb 2024 15:34:08 -0800 Subject: [PATCH 07/21] modal working --- gradio_web_server.py | 155 -------------------- modal_inference.py | 339 +++++++++++++++++-------------------------- stub.py | 104 ++----------- 3 files changed, 150 insertions(+), 448 deletions(-) delete mode 100644 gradio_web_server.py diff --git a/gradio_web_server.py b/gradio_web_server.py deleted file mode 100644 index 46a4229..0000000 --- a/gradio_web_server.py +++ /dev/null @@ -1,155 +0,0 @@ -import shutil -import os -import tempfile -import urllib -import aiofiles - -from modal import asgi_app, method, enter -from stub import stub, VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, REPO_HOME, EXAMPLES_PATH -from pathlib import Path -VOLUME_DIR = "volume" -MODEL_CACHE = "models" -Path(VOLUME_DIR).mkdir(exist_ok=True, parents=True) -Path(MODEL_CACHE).mkdir(exist_ok=True, parents=True) -VIDEOS_DIR = Path(VOLUME_DIR) / "videos" -IMAGES_DIR = Path(VOLUME_DIR) / "images" -VIDEOS_DIR.mkdir(exist_ok=True, parents=True) -IMAGES_DIR.mkdir(exist_ok=True, parents=True) - - -#@cls_dec(gpu="any") -class VideoLlavaModel: - def __init__(self): - self.load_model() - #@enter() - def load_model(self): - import torch - from videollava.serve.gradio_utils import Chat - self.conv_mode = "llava_v1" - model_path = 'LanguageBind/Video-LLaVA-7B' - device = 'cuda' - load_8bit = False - load_4bit = True - self.dtype = torch.float16 - self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=MODEL_CACHE) - # self.handler.model.to(dtype=self.dtype) - - def generate(self, image1, video, textbox_in): - from videollava.conversation import conv_templates, Conversation - import gradio as gr - from videollava.constants import DEFAULT_IMAGE_TOKEN - if not textbox_in: - raise ValueError("no prompt provided") - - image1 = image1 if image1 else "none" - video = video if video else "none" - - state = conv_templates[self.conv_mode].copy() - state_ = conv_templates[self.conv_mode].copy() - images_tensor = [] - - text_en_in = textbox_in.replace("picture", "image") - - # images_tensor = [[], []] - image_processor = self.handler.image_processor - if os.path.exists(image1) and not os.path.exists(video): - tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(self.handler.model.device, dtype=self.dtype) - images_tensor.append(tensor) - video_processor = self.handler.video_processor - if not os.path.exists(image1) and os.path.exists(video): - tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(self.handler.model.device, dtype=self.dtype) - images_tensor.append(tensor) - if os.path.exists(image1) and os.path.exists(video): - tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(self.handler.model.device, dtype=self.dtype) - images_tensor.append(tensor) - - tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] - # print(tensor.shape) - tensor = tensor.to(self.handler.model.device, dtype=self.dtype) - images_tensor.append(tensor) - - if os.path.exists(image1) and not os.path.exists(video): - text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in - elif not os.path.exists(image1) and os.path.exists(video): - text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in - elif os.path.exists(image1) and os.path.exists(video): - text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN - else: - print("WARNING: No image or video supplied") - - print(text_en_in) - text_en_out, _ = self.handler.generate(images_tensor, text_en_in, first_run=True, state=state_) - - text_en_out = text_en_out.split('#')[0] - textbox_out = text_en_out - - return textbox_out - - - -def fastapi_app(model): - from gradio.routes import mount_gradio_app - import fastapi.staticfiles - from fastapi import FastAPI, UploadFile, File, Request, Form, Query, HTTPException - from fastapi.responses import StreamingResponse - app = FastAPI() - model = VideoLlavaModel() - - @app.post("/upload") - async def upload( - file: UploadFile = File(...), - ): - filename_decoded = urllib.parse.unquote(file.filename) - file_path = str(VIDEOS_DIR / filename_decoded) - async with aiofiles.open(file_path, "wb") as buffer: - while content := await file.read(1024): # Read chunks of 1024 bytes - await buffer.write(content) - return {"file_path": file_path} - - @app.post("/inference") - async def inference( - video_file_name: str = '', - video_file_path: str = '', - image_file_name: str = '', - image_file_path: str = '', - prompt: str = '', - ): - video_file_name = urllib.parse.unquote(video_file_name) - video_file_path = urllib.parse.unquote(video_file_path) - if video_file_path is None or video_file_path == '': - if video_file_name is None or video_file_name == '': - raise ValueError("one of video_file_path or video_file_name must be specified") - video_file_path = str(VIDEOS_DIR / video_file_name) - - image_file_name = urllib.parse.unquote(image_file_name) - image_file_path = urllib.parse.unquote(image_file_path) - if image_file_path is None or image_file_path == '': - if image_file_name is not None and image_file_name != '': - image_file_path = str(IMAGES_DIR / image_file_name) - - return model.generate(image_file_path, video_file_path, prompt) - - - app.mount("/assets", fastapi.staticfiles.StaticFiles(directory="assets")) - app.mount("/examples", fastapi.staticfiles.StaticFiles(directory="videollava/serve")) - return app - - - - -# comment this out to deploy -app = fastapi_app() - -@function_dec(gpu="any") -@asgi_app() -def fastapi_app_modal(): - app = fastapi_app() - -# conda activate videollava -# uvicorn modal_inference:app diff --git a/modal_inference.py b/modal_inference.py index a3dd4d7..937531e 100644 --- a/modal_inference.py +++ b/modal_inference.py @@ -1,211 +1,144 @@ -from modal import ( - Cls, - method, - enter, - web_endpoint, -) +import os +import urllib + +from modal import asgi_app, method, enter, build +from .stub import VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, volume from pathlib import Path -from .stub import stub, cls_dec, function_dec, MODELS_DIR, volume, HF_DATASETS_CACHE, REPO_HOME, load_pretrained_from_cache - - -LOCAL_VIDEOS_PATH = Path(REPO_HOME) / "downloaded_videos" -DEFAULT_PROMPT = "describe what is going on in this video" - - -# def load_pretrained_from_cache(load_4bit=True, load_8bit=False): - # print("Loading pretrained model") - # from videollava.utils import disable_torch_init - # from transformers import AutoTokenizer, BitsAndBytesConfig - # from videollava.model import LlavaLlamaForCausalLM - # import torch - # disable_torch_init() - # print("imported") - - # kwargs = { - # "device_map": "auto", - # "cache_dir": HF_DATASETS_CACHE, - # } - # video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' - - # vlp_exists = video_llava_path.exists() - # if not vlp_exists: - # video_llava_path.mkdir(exist_ok=True, parents=True) - - # save = False - # if not video_llava_path.exists() or len(list(video_llava_path.iterdir())) == 0: - # save = True - # print("Downloading model") - # video_llava_path = 'LanguageBind/Video-LLaVA-7B' - - # tokenizer_path = Path(MODELS_DIR) / 'tokenizer' - # if not tokenizer_path.exists() or len(list(tokenizer_path.iterdir())) == 0: - # print("Downloading tokenizer") - # tokenizer_path = 'LanguageBind/Video-LLaVA-7B' - - # if load_8bit: - # kwargs['load_in_8bit'] = True - # elif load_4bit: - # kwargs['load_in_4bit'] = True - # kwargs['quantization_config'] = BitsAndBytesConfig( - # load_in_4bit=True, - # bnb_4bit_compute_dtype=torch.float16, - # bnb_4bit_use_double_quant=True, - # bnb_4bit_quant_type='nf4' - # ) - # else: - # kwargs['torch_dtype'] = torch.float16 - - # tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False, cache_dir=kwargs["cache_dir"]) - # model = LlavaLlamaForCausalLM.from_pretrained(video_llava_path, low_cpu_mem_usage=True, **kwargs) - # model.generation_config.do_sample = True - - # if save: - # # save to on-disk paths - # video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' - # tokenizer_path = Path(MODELS_DIR) / 'tokenizer' - # tokenizer.save_pretrained(str(tokenizer_path)) - # model.save_pretrained(str(video_llava_path)) - # return model, tokenizer - -def prepare_processor(model, device="cuda"): - import torch - processor = {'image': None, 'video': None} - if model.config.mm_image_tower is not None: - image_tower = model.get_image_tower() - if not image_tower.is_loaded: - image_tower.load_model() - image_tower.to(device=device, dtype=torch.float16) - image_processor = image_tower.image_processor - processor['image'] = image_processor - if model.config.mm_video_tower is not None: - video_tower = model.get_video_tower() - if not video_tower.is_loaded: - video_tower.load_model() - video_tower.to(device=device, dtype=torch.float16) - video_processor = video_tower.video_processor - processor['video'] = video_processor - return processor - -def prepare_special_tokens(model, tokenizer): - from videollava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \ - DEFAULT_VIDEO_PATCH_TOKEN, DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN - mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) - mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) - if mm_use_im_patch_token: - tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) - tokenizer.add_tokens([DEFAULT_VIDEO_PATCH_TOKEN], special_tokens=True) - if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) - tokenizer.add_tokens([DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN], special_tokens=True) - model.resize_token_embeddings(len(tokenizer)) - -@cls_dec(container_idle_timeout=30, gpu='L4') -class VideoLlava: - def __init__(self, device='cuda'): - self.device = device - self.model = None - # TODO when they fix - #@build() +# for local testing +#VOLUME_DIR = "volume" +#MODEL_CACHE = "models" +#Path(VOLUME_DIR).mkdir(exist_ok=True, parents=True) +VIDEOS_DIR = Path(VOLUME_DIR) / "videos" +IMAGES_DIR = Path(VOLUME_DIR) / "images" + + +@cls_dec(gpu="any") +class VideoLlavaModel: + @build() @enter() def load_model(self): - self.model, self.tokenizer = load_pretrained_from_cache(load_4bit=False, load_8bit=True) - print("got model") - self.processor = prepare_processor(self.model, device=self.device) - self.video_processor = self.processor['video'] - - def prepare_conv(self): - from videollava.conversation import conv_templates - self.conv = conv_templates["llava_v1"].copy() - self.roles = self.conv.roles + import torch + from videollava.serve.gradio_utils import Chat + self.conv_mode = "llava_v1" + model_path = 'LanguageBind/Video-LLaVA-7B' + device = 'cuda' + load_8bit = False + load_4bit = True + self.dtype = torch.float16 + self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=str(MODEL_CACHE)) + print("model loaded") + # self.handler.model.to(dtype=self.dtype) @method() - def inference(self, video_path: str, inp: str = DEFAULT_PROMPT): - import torch - from videollava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN - from videollava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria - from videollava.conversation import SeparatorStyle - import requests - from io import BytesIO - print('preparing conv') - - self.prepare_conv() - if video_path.startswith("http"): - print("Downloading video") - video_bytes = requests.get(video_path).content - local_video_path = LOCAL_VIDEOS_PATH / video_path.split("/")[-1] - if not LOCAL_VIDEOS_PATH.exists(): - LOCAL_VIDEOS_PATH.mkdir(exist_ok=True, parents=True) - with open(local_video_path, "wb") as f: - f.write(video_bytes) - video_path = BytesIO(video_bytes) - print(f"Downloaded video and saved to {local_video_path}") - elif not Path(video_path).exists(): - volume.reload() - if not Path(video_path).exists(): - raise FileNotFoundError(f"Video {video_path} not found") - print('processing video') - video_tensor = self.video_processor(video_path, return_tensors='pt')['pixel_values'] - if type(video_tensor) is list: - tensor = [video.to(self.model.device, dtype=torch.float16) for video in video_tensor] + def generate(self, image1, video, textbox_in): + from videollava.conversation import conv_templates + from videollava.constants import DEFAULT_IMAGE_TOKEN + if not textbox_in: + raise ValueError("no prompt provided") + + image1 = image1 if image1 else "none" + video = video if video else "none" + + state_ = conv_templates[self.conv_mode].copy() + images_tensor = [] + + text_en_in = textbox_in.replace("picture", "image") + + image_processor = self.handler.image_processor + if os.path.exists(image1) and not os.path.exists(video): + tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + video_processor = self.handler.video_processor + if not os.path.exists(image1) and os.path.exists(video): + tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + if os.path.exists(image1) and os.path.exists(video): + tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + + tensor = image_processor.preprocess(image1, return_tensors='pt')['pixel_values'][0] + tensor = tensor.to(self.handler.model.device, dtype=self.dtype) + images_tensor.append(tensor) + + if os.path.exists(image1) and not os.path.exists(video): + text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in + elif not os.path.exists(image1) and os.path.exists(video): + text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + elif os.path.exists(image1) and os.path.exists(video): + text_en_in = ''.join([DEFAULT_IMAGE_TOKEN] * self.handler.model.get_video_tower().config.num_frames) + '\n' + text_en_in + '\n' + DEFAULT_IMAGE_TOKEN else: - tensor = video_tensor.to(self.model.device, dtype=torch.float16) - - inp = ' '.join([DEFAULT_IMAGE_TOKEN] * self.model.get_video_tower().config.num_frames) + '\n' + inp - self.conv.append_message(self.conv.roles[0], inp) - self.conv.append_message(self.conv.roles[1], None) - prompt = self.conv.get_prompt() - input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() - stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 - keywords = [stop_str] - stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) - print(input_ids.shape, tensor.shape) + print("WARNING: No image or video supplied") - import torch - import time - begin = time.time() - with torch.inference_mode(): - output_ids = self.model.generate( - input_ids, - images=tensor, - do_sample=True, - temperature=0.9, - max_new_tokens=1024, - use_cache=True, - stopping_criteria=[stopping_criteria]) - end = time.time() - print(f"Generate time taken: {end-begin}") - - output = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() - print(output) - return output - -@function_dec(gpu='L4') -@web_endpoint() -def inference(video_path: str, inp: str): - video_llava = VideoLlava() - if not hasattr(video_llava, 'model') or video_llava.model is None: - print('loading model') - video_llava.load_model() - print('model loaded') - output = VideoLlava().inference(video_path, inp) - print(output) - return output - -@stub.local_entrypoint() -def main(): - prompt = "describe what is going on in this video" - video_path = '/volume/pika_water_city.mp4' - input_ids, tensor, stopping_criteria = VideoLlava().get_inputs.remote(video_path, prompt) - output = VideoLlava().inference.remote(input_ids, tensor, stopping_criteria) - print(output) - -if __name__ == "__main__": - video_llava = Cls.lookup("updated-video-llava-ephemeral", 'VideoLlava')() - video_path = '/volume/pika_water_city.mp4' - #print(video_llava.inference.remote(video_path=video_path, inp="describe what is going on in this video")) - - prompt = "describe what is going on in this video" - video_path = '/volume/pika_water_city.mp4' - output = video_llava.inference.remote(video_path, prompt) - print(output) + print(text_en_in) + text_en_out, _ = self.handler.generate(images_tensor, text_en_in, first_run=True, state=state_) + text_en_out = text_en_out.split('#')[0] + textbox_out = text_en_out + + return textbox_out + + + +def fastapi_app(): + from fastapi import FastAPI, UploadFile, File + import aiofiles + + Path(MODEL_CACHE).mkdir(exist_ok=True, parents=True) + VIDEOS_DIR.mkdir(exist_ok=True, parents=True) + IMAGES_DIR.mkdir(exist_ok=True, parents=True) + + app = FastAPI() + model = VideoLlavaModel() + + @app.post("/upload") + async def upload( + file: UploadFile = File(...), + ): + filename_decoded = urllib.parse.unquote(file.filename) + file_path = str(VIDEOS_DIR / filename_decoded) + async with aiofiles.open(file_path, "wb") as buffer: + while content := await file.read(1024): # Read chunks of 1024 bytes + await buffer.write(content) + volume.commit() + return {"file_path": file_path} + + @app.post("/inference") + async def inference( + video_file_name: str = '', + video_file_path: str = '', + image_file_name: str = '', + image_file_path: str = '', + prompt: str = '', + ): + import requests + requests.get('https://huggingface.co/LanguageBind/Video-LLaVA-7B/resolve/main/config.json').raise_for_status() + video_file_name = urllib.parse.unquote(video_file_name) + video_file_path = urllib.parse.unquote(video_file_path) + if video_file_path is None or video_file_path == '': + if video_file_name is None or video_file_name == '': + raise ValueError("one of video_file_path or video_file_name must be specified") + video_file_path = str(VIDEOS_DIR / video_file_name) + + image_file_name = urllib.parse.unquote(image_file_name) + image_file_path = urllib.parse.unquote(image_file_path) + if image_file_path is None or image_file_path == '': + if image_file_name is not None and image_file_name != '': + image_file_path = str(IMAGES_DIR / image_file_name) + + return model.generate.remote(image_file_path, video_file_path, prompt) + return app + + +@function_dec() +@asgi_app() +def fastapi_app_modal(): + return fastapi_app() + +# local testing: +# comment this out to deploy +# app = fastapi_app() +# conda activate videollava +# uvicorn modal_inference:app diff --git a/stub.py b/stub.py index 72be948..8f925d6 100644 --- a/stub.py +++ b/stub.py @@ -1,4 +1,4 @@ -from modal import Volume, Image, Stub, Mount, Secret, build +from modal import Volume, Image, Stub, Mount, Secret from pathlib import Path REPO_HOME = "/app" VOLUME_DIR = "/volume" @@ -18,11 +18,14 @@ stub = Stub("updated-video-llava", mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv()]) +def remove_old_files(): + import shutil + shutil.rmtree('/volume/models', ignore_errors=True) + image = ( Image.from_registry( "nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04", add_python="3.11" ) - #Image.debian_slim() .apt_install( "git", "curl", @@ -56,36 +59,6 @@ "python -m bitsandbytes", gpu="any" ) - # .pip_install( - # "torchvision>=0.15.2", - # #"tokenizers>=0.12.1,<0.14", - # "sentencepiece==0.1.99", - # "shortuuid", - # "accelerate==0.21.0", - # "peft==0.4.0", - # "pydantic<2,>=1", - # "markdown2[all]", - # "numpy", - # "scikit-learn==1.2.2", - # "gradio==3.37.0", - # "gradio_client==0.7.0", - # "requests", - # "httpx==0.24.0", - # "uvicorn", - # "fastapi", - # #"einops==0.6.1", - # "einops-exts==0.0.4", - # "timm==0.6.13", - # "deepspeed==0.9.5", - # "ninja", - # "wandb", - # "tensorboardX==2.6.2.2", - - # "tenacity", - # "torch==2.1.2", - # "wheel", - # gpu="any", - # ) .run_commands("pip install flash-attn --no-build-isolation", gpu="any") .env({"PYTHONPATH": REPO_HOME, "HF_DATASETS_CACHE": HF_DATASETS_CACHE}) .pip_install( @@ -94,14 +67,19 @@ "git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d", gpu="any", ) - #.run_function(load_pretrained_from_cache, gpu="any") + .pip_install( + "aiofiles", + ) + .run_function(remove_old_files) ) +# TODO bitsandbytes seems to not be working with gpu def function_dec(**extras): return stub.function( image=image, timeout=80000, - checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. + # checkpointing doesn't work because it restricts internet access + #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. _allow_background_volume_commits=True, **extras, ) @@ -110,61 +88,7 @@ def cls_dec(**extras): return stub.cls( image=image, timeout=80000, - checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. + # checkpointing doesn't work because it restricts internet access + #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. **extras, ) - -def load_pretrained_from_cache(load_4bit=True, load_8bit=False): - print("Loading pretrained model") - from videollava.utils import disable_torch_init - from transformers import AutoTokenizer, BitsAndBytesConfig - from videollava.model import LlavaLlamaForCausalLM - import torch - disable_torch_init() - print("imported") - - kwargs = { - "device_map": "auto", - "cache_dir": HF_DATASETS_CACHE, - } - video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' - - vlp_exists = video_llava_path.exists() - if not vlp_exists: - video_llava_path.mkdir(exist_ok=True, parents=True) - - save = False - if not video_llava_path.exists() or len(list(video_llava_path.iterdir())) == 0: - save = True - print("Downloading model") - video_llava_path = 'LanguageBind/Video-LLaVA-7B' - - tokenizer_path = Path(MODELS_DIR) / 'tokenizer' - if not tokenizer_path.exists() or len(list(tokenizer_path.iterdir())) == 0: - print("Downloading tokenizer") - tokenizer_path = 'LanguageBind/Video-LLaVA-7B' - - if load_8bit: - kwargs['load_in_8bit'] = True - elif load_4bit: - kwargs['load_in_4bit'] = True - kwargs['quantization_config'] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type='nf4' - ) - else: - kwargs['torch_dtype'] = torch.float16 - - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False, cache_dir=kwargs["cache_dir"]) - model = LlavaLlamaForCausalLM.from_pretrained(video_llava_path, low_cpu_mem_usage=True, **kwargs) - model.generation_config.do_sample = True - - if save: - # save to on-disk paths - video_llava_path = Path(MODELS_DIR) / 'Video-LLaVA-7B' - tokenizer_path = Path(MODELS_DIR) / 'tokenizer' - tokenizer.save_pretrained(str(tokenizer_path)) - model.save_pretrained(str(video_llava_path)) - return model, tokenizer From b8b1c44242b59f7f184e0d216ca58707ce7f3a76 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Mon, 12 Feb 2024 15:57:31 -0800 Subject: [PATCH 08/21] use s3 mount --- modal_inference.py | 8 ++++---- stub.py | 28 +++++++++++++++++++++------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/modal_inference.py b/modal_inference.py index 937531e..a9fb9c6 100644 --- a/modal_inference.py +++ b/modal_inference.py @@ -2,14 +2,14 @@ import urllib from modal import asgi_app, method, enter, build -from .stub import VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, volume +from .stub import S3_VIDEO_PATH, MODEL_CACHE, cls_dec, function_dec, volume, stub from pathlib import Path # for local testing -#VOLUME_DIR = "volume" +#S3_VIDEO_PATH= "s3_videos" #MODEL_CACHE = "models" #Path(VOLUME_DIR).mkdir(exist_ok=True, parents=True) -VIDEOS_DIR = Path(VOLUME_DIR) / "videos" -IMAGES_DIR = Path(VOLUME_DIR) / "images" +VIDEOS_DIR = Path(S3_VIDEO_PATH) / "videos" +IMAGES_DIR = Path(S3_VIDEO_PATH) / "images" @cls_dec(gpu="any") diff --git a/stub.py b/stub.py index 8f925d6..008ef69 100644 --- a/stub.py +++ b/stub.py @@ -1,20 +1,34 @@ -from modal import Volume, Image, Stub, Mount, Secret +from modal import Volume, Image, Stub, Mount, Secret, S3Mount +import os from pathlib import Path +try: + from dotenv import load_dotenv + if os.environ.get("ENV") in ["dev", "prod"]: + env_file = Path(__file__).parent.parent.parent / ".env" + else: + env_file = Path(__file__).parent.parent.parent / ".env.local" + load_dotenv(env_file) +except ImportError: + pass + + REPO_HOME = "/app" VOLUME_DIR = "/volume" MODELS_DIR = "/root" HF_DATASETS_CACHE = str(Path(VOLUME_DIR) / "hf_datasets_cache") MODEL_CACHE = Path(VOLUME_DIR, "models") -assets_path = Path(__file__).parent / "assets" -local_examples_path = Path(__file__).parent / "videollava" / "serve" / "examples" -EXAMPLES_PATH = "/examples" +S3_VIDEO_PATH = "/s3-videos" mounts = [ Mount.from_local_dir("./ai_video_editor/updated_video_llava", remote_path=REPO_HOME), - Mount.from_local_dir(assets_path, remote_path="/assets"), - Mount.from_local_dir(local_examples_path, remote_path=EXAMPLES_PATH), ] volume = Volume.persisted("video-llava-vol") -volumes = {VOLUME_DIR: volume} +volumes = { + VOLUME_DIR: volume, + S3_VIDEO_PATH: S3Mount( + os.environ["TRIMIT_VIDEO_S3_BUCKET"], + secret=Secret.from_dotenv(), + read_only=True) +} stub = Stub("updated-video-llava", mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv()]) From ed4ef135acc658fee7c4e74d67707cfde09ec84b Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Mon, 12 Feb 2024 16:20:39 -0800 Subject: [PATCH 09/21] change name --- stub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stub.py b/stub.py index 008ef69..dafaf9e 100644 --- a/stub.py +++ b/stub.py @@ -29,7 +29,7 @@ secret=Secret.from_dotenv(), read_only=True) } -stub = Stub("updated-video-llava", mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv()]) +stub = Stub("video-llava", mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv()]) def remove_old_files(): From f79cc95f34bc94d7a398c243d75be5b5dda93451 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Mon, 12 Feb 2024 16:27:55 -0800 Subject: [PATCH 10/21] fix path --- stub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stub.py b/stub.py index dafaf9e..f7d8cfa 100644 --- a/stub.py +++ b/stub.py @@ -19,7 +19,7 @@ MODEL_CACHE = Path(VOLUME_DIR, "models") S3_VIDEO_PATH = "/s3-videos" mounts = [ - Mount.from_local_dir("./ai_video_editor/updated_video_llava", remote_path=REPO_HOME), + Mount.from_local_dir("./ai_video_editor/video_llava", remote_path=REPO_HOME), ] volume = Volume.persisted("video-llava-vol") volumes = { From 625ea13a19058972b32ea916d3f1d116c0c5beb1 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Fri, 16 Feb 2024 11:21:31 -0800 Subject: [PATCH 11/21] change image to stub, remove extra line --- stub.py => image.py | 3 ++- modal_inference.py | 4 +--- 2 files changed, 3 insertions(+), 4 deletions(-) rename stub.py => image.py (95%) diff --git a/stub.py b/image.py similarity index 95% rename from stub.py rename to image.py index f7d8cfa..507a575 100644 --- a/stub.py +++ b/image.py @@ -18,6 +18,7 @@ HF_DATASETS_CACHE = str(Path(VOLUME_DIR) / "hf_datasets_cache") MODEL_CACHE = Path(VOLUME_DIR, "models") S3_VIDEO_PATH = "/s3-videos" +VIDEO_LLAVA_STUB_NAME = "video-llava" mounts = [ Mount.from_local_dir("./ai_video_editor/video_llava", remote_path=REPO_HOME), ] @@ -29,7 +30,7 @@ secret=Secret.from_dotenv(), read_only=True) } -stub = Stub("video-llava", mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv()]) +stub = Stub(VIDEO_LLAVA_STUB_NAME, mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv()]) def remove_old_files(): diff --git a/modal_inference.py b/modal_inference.py index a9fb9c6..199ab02 100644 --- a/modal_inference.py +++ b/modal_inference.py @@ -2,7 +2,7 @@ import urllib from modal import asgi_app, method, enter, build -from .stub import S3_VIDEO_PATH, MODEL_CACHE, cls_dec, function_dec, volume, stub +from .image import S3_VIDEO_PATH, MODEL_CACHE, cls_dec, function_dec, volume, stub from pathlib import Path # for local testing #S3_VIDEO_PATH= "s3_videos" @@ -113,8 +113,6 @@ async def inference( image_file_path: str = '', prompt: str = '', ): - import requests - requests.get('https://huggingface.co/LanguageBind/Video-LLaVA-7B/resolve/main/config.json').raise_for_status() video_file_name = urllib.parse.unquote(video_file_name) video_file_path = urllib.parse.unquote(video_file_path) if video_file_path is None or video_file_path == '': From e4c77cd7157c14c1f3f3f179ae007325c141e8a9 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Mon, 19 Feb 2024 16:21:13 -0800 Subject: [PATCH 12/21] pull dotenv stuff from conf --- image.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/image.py b/image.py index 507a575..66977c7 100644 --- a/image.py +++ b/image.py @@ -1,16 +1,7 @@ from modal import Volume, Image, Stub, Mount, Secret, S3Mount import os from pathlib import Path -try: - from dotenv import load_dotenv - if os.environ.get("ENV") in ["dev", "prod"]: - env_file = Path(__file__).parent.parent.parent / ".env" - else: - env_file = Path(__file__).parent.parent.parent / ".env.local" - load_dotenv(env_file) -except ImportError: - pass - +from ai_video_editor.utils.conf import ENV REPO_HOME = "/app" VOLUME_DIR = "/volume" @@ -18,7 +9,9 @@ HF_DATASETS_CACHE = str(Path(VOLUME_DIR) / "hf_datasets_cache") MODEL_CACHE = Path(VOLUME_DIR, "models") S3_VIDEO_PATH = "/s3-videos" -VIDEO_LLAVA_STUB_NAME = "video-llava" +#VIDEO_LLAVA_STUB_NAME = f"video-llava-{ENV}" +# TODO once deploys are working +VIDEO_LLAVA_STUB_NAME = f"video-llava" mounts = [ Mount.from_local_dir("./ai_video_editor/video_llava", remote_path=REPO_HOME), ] From 75906424c227f071a5d6d393210c2b6eb1cef9c5 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Tue, 20 Feb 2024 16:01:21 -0600 Subject: [PATCH 13/21] optional os env var --- image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/image.py b/image.py index 66977c7..634818d 100644 --- a/image.py +++ b/image.py @@ -19,7 +19,7 @@ volumes = { VOLUME_DIR: volume, S3_VIDEO_PATH: S3Mount( - os.environ["TRIMIT_VIDEO_S3_BUCKET"], + os.environ.get("TRIMIT_VIDEO_S3_BUCKET", ''), secret=Secret.from_dotenv(), read_only=True) } From 2cab2ebb537508ec787b8fb596cf4e17c1a986c8 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Tue, 20 Feb 2024 16:09:36 -0600 Subject: [PATCH 14/21] new deployment works --- image.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/image.py b/image.py index 634818d..3f8b0cd 100644 --- a/image.py +++ b/image.py @@ -1,7 +1,7 @@ from modal import Volume, Image, Stub, Mount, Secret, S3Mount import os from pathlib import Path -from ai_video_editor.utils.conf import ENV +from ai_video_editor.utils.conf import DOTENV_FILE, ENV REPO_HOME = "/app" VOLUME_DIR = "/volume" @@ -9,9 +9,7 @@ HF_DATASETS_CACHE = str(Path(VOLUME_DIR) / "hf_datasets_cache") MODEL_CACHE = Path(VOLUME_DIR, "models") S3_VIDEO_PATH = "/s3-videos" -#VIDEO_LLAVA_STUB_NAME = f"video-llava-{ENV}" -# TODO once deploys are working -VIDEO_LLAVA_STUB_NAME = f"video-llava" +VIDEO_LLAVA_STUB_NAME = f"video-llava-{ENV}" mounts = [ Mount.from_local_dir("./ai_video_editor/video_llava", remote_path=REPO_HOME), ] @@ -20,10 +18,10 @@ VOLUME_DIR: volume, S3_VIDEO_PATH: S3Mount( os.environ.get("TRIMIT_VIDEO_S3_BUCKET", ''), - secret=Secret.from_dotenv(), + secret=Secret.from_dotenv(path=DOTENV_FILE), read_only=True) } -stub = Stub(VIDEO_LLAVA_STUB_NAME, mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv()]) +stub = Stub(VIDEO_LLAVA_STUB_NAME, mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv(path=DOTENV_FILE)]) def remove_old_files(): From 966c48d6e06814e6dea846998f602794219bcbeb Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Wed, 21 Feb 2024 10:59:48 -0600 Subject: [PATCH 15/21] fix dotenv path --- image.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/image.py b/image.py index 3f8b0cd..80720af 100644 --- a/image.py +++ b/image.py @@ -1,7 +1,7 @@ from modal import Volume, Image, Stub, Mount, Secret, S3Mount import os from pathlib import Path -from ai_video_editor.utils.conf import DOTENV_FILE, ENV +from ai_video_editor.utils.conf import DOTENV_PATH, ENV REPO_HOME = "/app" VOLUME_DIR = "/volume" @@ -18,10 +18,10 @@ VOLUME_DIR: volume, S3_VIDEO_PATH: S3Mount( os.environ.get("TRIMIT_VIDEO_S3_BUCKET", ''), - secret=Secret.from_dotenv(path=DOTENV_FILE), + secret=Secret.from_dotenv(path=DOTENV_PATH), read_only=True) } -stub = Stub(VIDEO_LLAVA_STUB_NAME, mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv(path=DOTENV_FILE)]) +stub = Stub(VIDEO_LLAVA_STUB_NAME, mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv(path=DOTENV_PATH)]) def remove_old_files(): From 86a9a1e0de99935bd9ab1d301b689cad805c02f7 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Wed, 6 Mar 2024 15:25:22 -0800 Subject: [PATCH 16/21] switch back to volumes and use new modal structure with MODAL_ENVIRONMENT --- __init__.py | 1 + image.py | 40 +++++++++++++++++++++------------------- modal_inference.py | 39 ++++++++++++++++++++++++++++++++++----- 3 files changed, 56 insertions(+), 24 deletions(-) create mode 100644 __init__.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..3f9b90b --- /dev/null +++ b/__init__.py @@ -0,0 +1 @@ +from .modal_inference import * diff --git a/image.py b/image.py index 80720af..77c1f7e 100644 --- a/image.py +++ b/image.py @@ -1,27 +1,20 @@ -from modal import Volume, Image, Stub, Mount, Secret, S3Mount +from modal import Volume, Image, Mount import os from pathlib import Path -from ai_video_editor.utils.conf import DOTENV_PATH, ENV +from ai_video_editor.stub import stub, REPO_HOME, LOCAL_CERT_PATH, CERT_PATH, EXTRA_ENV -REPO_HOME = "/app" -VOLUME_DIR = "/volume" -MODELS_DIR = "/root" -HF_DATASETS_CACHE = str(Path(VOLUME_DIR) / "hf_datasets_cache") -MODEL_CACHE = Path(VOLUME_DIR, "models") -S3_VIDEO_PATH = "/s3-videos" -VIDEO_LLAVA_STUB_NAME = f"video-llava-{ENV}" -mounts = [ +LOCAL_VOLUME_DIR = "/video_llava_volume" +HF_DATASETS_CACHE = str(Path(LOCAL_VOLUME_DIR) / "hf_datasets_cache") +MODEL_CACHE = Path(LOCAL_VOLUME_DIR, "models") + +LOCAL_VOLUME_NAME = "video-llava-volume" +local_volume = Volume.from_name(LOCAL_VOLUME_NAME, create_if_missing=True) +local_volumes = { + LOCAL_VOLUME_DIR: local_volume, +} +local_mounts = [ Mount.from_local_dir("./ai_video_editor/video_llava", remote_path=REPO_HOME), ] -volume = Volume.persisted("video-llava-vol") -volumes = { - VOLUME_DIR: volume, - S3_VIDEO_PATH: S3Mount( - os.environ.get("TRIMIT_VIDEO_S3_BUCKET", ''), - secret=Secret.from_dotenv(path=DOTENV_PATH), - read_only=True) -} -stub = Stub(VIDEO_LLAVA_STUB_NAME, mounts=mounts, volumes=volumes, secrets=[Secret.from_dotenv(path=DOTENV_PATH)]) def remove_old_files(): @@ -75,8 +68,12 @@ def remove_old_files(): ) .pip_install( "aiofiles", + "aioboto3", ) .run_function(remove_old_files) + .copy_local_file(LOCAL_CERT_PATH, CERT_PATH) + .pip_install("boto3", "aioboto3") + .env(EXTRA_ENV) ) # TODO bitsandbytes seems to not be working with gpu @@ -87,6 +84,8 @@ def function_dec(**extras): # checkpointing doesn't work because it restricts internet access #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. _allow_background_volume_commits=True, + volumes=local_volumes, + mounts=local_mounts, **extras, ) @@ -96,5 +95,8 @@ def cls_dec(**extras): timeout=80000, # checkpointing doesn't work because it restricts internet access #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. + _allow_background_volume_commits=True, + volumes=local_volumes, + mounts=local_mounts, **extras, ) diff --git a/modal_inference.py b/modal_inference.py index 199ab02..fd115ee 100644 --- a/modal_inference.py +++ b/modal_inference.py @@ -1,8 +1,11 @@ import os +import shutil import urllib from modal import asgi_app, method, enter, build -from .image import S3_VIDEO_PATH, MODEL_CACHE, cls_dec, function_dec, volume, stub +from ai_video_editor.utils.fs_utils import async_copy_from_s3 +from .image import LOCAL_VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, local_volume +from ai_video_editor.stub import stub, S3_VIDEO_PATH, VOLUME_DIR, volume as remote_volume from pathlib import Path # for local testing #S3_VIDEO_PATH= "s3_videos" @@ -12,11 +15,12 @@ IMAGES_DIR = Path(S3_VIDEO_PATH) / "images" + @cls_dec(gpu="any") class VideoLlavaModel: - @build() @enter() def load_model(self): + local_volume.reload() import torch from videollava.serve.gradio_utils import Chat self.conv_mode = "llava_v1" @@ -29,8 +33,32 @@ def load_model(self): print("model loaded") # self.handler.model.to(dtype=self.dtype) + def copy_file_from_remote_volume(self, filepath): + in_volume_path = filepath.split('/', 2)[-1] + local_volume_path = Path(LOCAL_VOLUME_DIR) / in_volume_path + local_volume_path.parent.mkdir(parents=True, exist_ok=True) + if not local_volume_path.exists(): + shutil.copy(filepath, str(local_volume_path)) + + async def copy_file_from_s3(self, filepath): + bucket, in_bucket_path = filepath.replace('s3://','').split('/', 1) + await async_copy_from_s3(bucket, in_bucket_path, str(Path(VOLUME_DIR) / in_bucket_path)) + + async def copy_file_to_local(self, filepath): + if not filepath: + return + if filepath.startswith('s3://'): + await self.copy_file_from_s3(filepath) + else: + self.copy_file_from_remote_volume(filepath) + @method() - def generate(self, image1, video, textbox_in): + async def generate(self, image1, video, textbox_in): + remote_volume.reload() + local_volume.reload() + await self.copy_file_to_local(image1) + await self.copy_file_to_local(video) + from videollava.conversation import conv_templates from videollava.constants import DEFAULT_IMAGE_TOKEN if not textbox_in: @@ -97,12 +125,13 @@ def fastapi_app(): async def upload( file: UploadFile = File(...), ): + local_volume.reload() filename_decoded = urllib.parse.unquote(file.filename) - file_path = str(VIDEOS_DIR / filename_decoded) + file_path = str(Path(LOCAL_VOLUME_DIR) / filename_decoded) async with aiofiles.open(file_path, "wb") as buffer: while content := await file.read(1024): # Read chunks of 1024 bytes await buffer.write(content) - volume.commit() + local_volume.commit() return {"file_path": file_path} @app.post("/inference") From ba972fdacd462d894e64cc0d5ab016fb2a0c2d24 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Wed, 6 Mar 2024 16:37:13 -0800 Subject: [PATCH 17/21] bump idle timeout --- image.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/image.py b/image.py index 77c1f7e..bad9277 100644 --- a/image.py +++ b/image.py @@ -84,6 +84,7 @@ def function_dec(**extras): # checkpointing doesn't work because it restricts internet access #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. _allow_background_volume_commits=True, + container_idle_timeout=120, volumes=local_volumes, mounts=local_mounts, **extras, @@ -95,6 +96,7 @@ def cls_dec(**extras): timeout=80000, # checkpointing doesn't work because it restricts internet access #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. + container_idle_timeout=120, _allow_background_volume_commits=True, volumes=local_volumes, mounts=local_mounts, From 496e6a8cc2e80c88b8681bbc6013195d917168d5 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Thu, 7 Mar 2024 16:45:06 -0800 Subject: [PATCH 18/21] add caching --- image.py | 1 + modal_inference.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/image.py b/image.py index bad9277..415c039 100644 --- a/image.py +++ b/image.py @@ -74,6 +74,7 @@ def remove_old_files(): .copy_local_file(LOCAL_CERT_PATH, CERT_PATH) .pip_install("boto3", "aioboto3") .env(EXTRA_ENV) + .pip_install("diskcache") ) # TODO bitsandbytes seems to not be working with gpu diff --git a/modal_inference.py b/modal_inference.py index fd115ee..1ad208e 100644 --- a/modal_inference.py +++ b/modal_inference.py @@ -6,6 +6,7 @@ from ai_video_editor.utils.fs_utils import async_copy_from_s3 from .image import LOCAL_VOLUME_DIR, MODEL_CACHE, cls_dec, function_dec, local_volume from ai_video_editor.stub import stub, S3_VIDEO_PATH, VOLUME_DIR, volume as remote_volume +import diskcache as dc from pathlib import Path # for local testing #S3_VIDEO_PATH= "s3_videos" @@ -19,7 +20,8 @@ @cls_dec(gpu="any") class VideoLlavaModel: @enter() - def load_model(self): + def load_model(self, cache=None): + self.cache = cache or dc.Cache('.cache') local_volume.reload() import torch from videollava.serve.gradio_utils import Chat @@ -30,7 +32,6 @@ def load_model(self): load_4bit = True self.dtype = torch.float16 self.handler = Chat(model_path, conv_mode=self.conv_mode, load_8bit=load_8bit, load_4bit=load_4bit, device=device, cache_dir=str(MODEL_CACHE)) - print("model loaded") # self.handler.model.to(dtype=self.dtype) def copy_file_from_remote_volume(self, filepath): @@ -53,7 +54,12 @@ async def copy_file_to_local(self, filepath): self.copy_file_from_remote_volume(filepath) @method() - async def generate(self, image1, video, textbox_in): + async def generate(self, image1, video, textbox_in, use_existing_output=True): + inputs = (image1, video, textbox_in) + if inputs in self.cache and use_existing_output: + res = self.cache[inputs] + self.cache.close() + return res remote_volume.reload() local_volume.reload() await self.copy_file_to_local(image1) @@ -106,6 +112,8 @@ async def generate(self, image1, video, textbox_in): text_en_out = text_en_out.split('#')[0] textbox_out = text_en_out + self.cache.set(inputs, textbox_out) + self.cache.close() return textbox_out From aa6d1438398de77c7a125ace78fda8bb084e2d3e Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Sat, 9 Mar 2024 12:34:18 -0700 Subject: [PATCH 19/21] fixes for more robust backend --- modal_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modal_inference.py b/modal_inference.py index 1ad208e..448a085 100644 --- a/modal_inference.py +++ b/modal_inference.py @@ -20,8 +20,8 @@ @cls_dec(gpu="any") class VideoLlavaModel: @enter() - def load_model(self, cache=None): - self.cache = cache or dc.Cache('.cache') + def load_model(self): + self.cache = dc.Cache('.cache') local_volume.reload() import torch from videollava.serve.gradio_utils import Chat From cf9c6f4e5f229e40ee8f7992f3e712cf2696a772 Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Mon, 11 Mar 2024 11:24:35 -0700 Subject: [PATCH 20/21] increase idle timeout --- image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/image.py b/image.py index 415c039..dbfff27 100644 --- a/image.py +++ b/image.py @@ -97,7 +97,9 @@ def cls_dec(**extras): timeout=80000, # checkpointing doesn't work because it restricts internet access #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. - container_idle_timeout=120, + container_idle_timeout=1200, + # TODO maybe turn on + #allow_concurrent_inputs=2, _allow_background_volume_commits=True, volumes=local_volumes, mounts=local_mounts, From b93cc312a5283a807d2e0392aaf4954cac35d08e Mon Sep 17 00:00:00 2001 From: Ben Schreck Date: Thu, 14 Mar 2024 08:40:33 -0700 Subject: [PATCH 21/21] allow concurrent inputs, raise on no output generated, add retries --- image.py | 3 ++- modal_inference.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/image.py b/image.py index dbfff27..5ccc077 100644 --- a/image.py +++ b/image.py @@ -99,7 +99,8 @@ def cls_dec(**extras): #checkpointing_enabled=True, # Enable memory checkpointing for faster cold starts. container_idle_timeout=1200, # TODO maybe turn on - #allow_concurrent_inputs=2, + allow_concurrent_inputs=4, + retries=3, _allow_background_volume_commits=True, volumes=local_volumes, mounts=local_mounts, diff --git a/modal_inference.py b/modal_inference.py index 448a085..14891c7 100644 --- a/modal_inference.py +++ b/modal_inference.py @@ -106,12 +106,13 @@ async def generate(self, image1, video, textbox_in, use_existing_output=True): else: print("WARNING: No image or video supplied") - print(text_en_in) text_en_out, _ = self.handler.generate(images_tensor, text_en_in, first_run=True, state=state_) text_en_out = text_en_out.split('#')[0] textbox_out = text_en_out + if not textbox_out: + raise ValueError("no text generated") self.cache.set(inputs, textbox_out) self.cache.close() return textbox_out