In [1]:
!mkdir -p lora-model/code
!mkdir -p lora-model/adapter_model

!aws s3 cp s3://llava-ue1/llava-v15-7b-task-full-2024-02-23-01-55-53/checkpoints/ lora-model/ --recursive
!aws s3 cp s3://llava-ue1/llava-v15-7b-task-lora-2024-02-17-07-33-50/checkpoints/ lora-model/adapter_model --recursive


download: s3://llava-ue1/llava-v15-7b-task-full-2024-02-23-01-55-53/checkpoints/generation_config.json to lora-model/generation_config.json
download: s3://llava-ue1/llava-v15-7b-task-full-2024-02-23-01-55-53/checkpoints/config.json to lora-model/config.json
download: s3://llava-ue1/llava-v15-7b-task-full-2024-02-23-01-55-53/checkpoints/special_tokens_map.json to lora-model/special_tokens_map.json
download: s3://llava-ue1/llava-v15-7b-task-full-2024-02-23-01-55-53/checkpoints/tokenizer_config.json to lora-model/tokenizer_config.json
download: s3://llava-ue1/llava-v15-7b-task-full-2024-02-23-01-55-53/checkpoints/model.safetensors.index.json to lora-model/model.safetensors.index.json
download: s3://llava-ue1/llava-v15-7b-task-full-2024-02-23-01-55-53/checkpoints/tokenizer.model to lora-model/tokenizer.model
download: s3://llava-ue1/llava-v15-7b-task-full-2024-02-23-01-55-53/checkpoints/trainer_state.json to lora-model/trainer_state.json
download: s3://llava-ue1/llava-v15-7b-task-full-2024

In [1]:
!rm -rf llava-lora.tar.gz

In [3]:
%%writefile lora-model/code/inference.py
import os
import requests
from PIL import Image
from io import BytesIO
import torch
from transformers import AutoTokenizer

from llava.model import LlavaLlamaForCausalLM
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria

from llava.conversation import conv_templates, SeparatorStyle
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
)


def model_fn(model_dir):
    kwargs = {"device_map": "auto"}
    kwargs["torch_dtype"] = torch.float16
    adapter_path = os.path.join(model_dir, "adapter_model")
    
    print("loading lora config")
    from llava.model.language_model.llava_llama import LlavaConfig
    lora_cfg_pretrained = LlavaConfig.from_pretrained(adapter_path)
    print("loading model")
    model = LlavaLlamaForCausalLM.from_pretrained(
        model_dir, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs
    )
    token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
    if model.lm_head.weight.shape[0] != token_num:
        model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
        model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))

    if os.path.exists(os.path.join(adapter_path, 'non_lora_trainables.bin')):
        non_lora_trainables = torch.load(os.path.join(adapter_path, 'non_lora_trainables.bin'), map_location='cpu')

    non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
    if any(k.startswith('model.model.') for k in non_lora_trainables):
        non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
    model.load_state_dict(non_lora_trainables, strict=False)

    from peft import PeftModel
    print('Loading LoRA weights...')
    model = PeftModel.from_pretrained(model, adapter_path)
    print('Merging LoRA weights...')
    model = model.merge_and_unload()
    print('Model is loaded...')

    print("loading tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)

    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model()
    vision_tower.to(device="cuda", dtype=torch.float16)
    image_processor = vision_tower.image_processor
    return model, tokenizer, image_processor


def predict_fn(data, model_and_tokenizer):
    # unpack model and tokenizer
    model, tokenizer, image_processor = model_and_tokenizer

    # get prompt & parameters
    image_file = data.pop("image", data)
    raw_prompt = data.pop("question", data)

    max_new_tokens = data.pop("max_new_tokens", 1024)
    temperature = data.pop("temperature", 0.2)
    conv_mode = data.pop("conv_mode", "llava_v1")

    if conv_mode == "raw":
        # use raw_prompt as prompt
        prompt = raw_prompt
        stop_str = "###"
    else:
        conv = conv_templates[conv_mode].copy()
        roles = conv.roles
        inp = f"{roles[0]}: {raw_prompt}"
        inp = (
            DEFAULT_IM_START_TOKEN
            + DEFAULT_IMAGE_TOKEN
            + DEFAULT_IM_END_TOKEN
            + "\n"
            + inp
        )
        conv.append_message(conv.roles[0], inp)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2

    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")

    disable_torch_init()
    image_tensor = (
        image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
        .half()
        .cuda()
    )

    keywords = [stop_str]
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            stopping_criteria=[stopping_criteria],
        )
    outputs = tokenizer.decode(
        output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
    ).strip()
    return outputs

Overwriting lora-model/code/inference.py


In [3]:
%%writefile lora-model/code/requirements.txt
llava @ git+https://github.com/haotian-liu/LLaVA@v1.1.1

Writing lora-model/code/requirements.txt


In [11]:
!tar -cvf llava-lora.tar.gz --use-compress-program=pigz -C lora-model .


./
./special_tokens_map.json
./generation_config.json
./code/
./code/requirements.txt
./code/inference.py
./config.json
./training_args.bin
./tokenizer_config.json
./adapter_model/
./adapter_model/adapter_model.safetensors


./adapter_model/non_lora_trainables.bin
./adapter_model/config.json
./adapter_model/README.md
./adapter_model/adapter_config.json
./adapter_model/trainer_state.json
./tokenizer.model
./model.safetensors.index.json
./model-00003-of-00003.safetensors
./model-00002-of-00003.safetensors
./trainer_state.json
./model-00001-of-00003.safetensors


In [12]:
!aws s3 cp llava-lora.tar.gz s3://llava-ue1/

upload: ./llava-lora.tar.gz to s3://llava-ue1/llava-lora.tar.gz     


In [None]:
import sagemaker
from sagemaker.huggingface.model import HuggingFaceModel

role = sagemaker.get_execution_role()
sess = sagemaker.Session()

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
   model_data="s3://llava-ue1/llava-lora.tar.gz",      # path to your model and script
   role=role,                    # iam role with permissions to create an Endpoint
   transformers_version="4.28.1",  # transformers version used
   pytorch_version="2.0.0",       # pytorch version used
   py_version='py310',            # python version used
   model_server_workers=1
)

# deploy the endpoint endpoint
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.xlarge",
    container_startup_health_check_timeout=600, # increase timeout for large models
    model_data_download_timeout=600, # increase timeout for large models
)

In [None]:
data = {
    "image" : 'https://raw.githubusercontent.com/haotian-liu/LLaVA/main/images/llava_logo.png', 
    "question" : "Describe the image and color details.",
    # "max_new_tokens" : 1024,
    # "temperature" : 0.2,
    # "stop_str" : "###",
}

# request
output = predictor.predict(data)
print(output)