In [None]:
from tinyllava.eval.run_tiny_llava import *

## Change the model path here
## You need to change the type of the model
## 0.55B, 0.89B: llama
## 2.4B: gemma
## 3.1B: phi
model_path = "<MODEL_PATH>"
conv_mode = "llama" # or llama, gemma, phi

## You need to change the "max_new_tokens" if the model can't deal with long tokens.
## possible values: 1024, 1152, 1536, 2048, 3072
args = type('Args', (), {
    "model_path": model_path,
    "model_base": None,
    "conv_mode": conv_mode,
    "sep": ",",
    "temperature": 0,
    "top_p": None,
    "num_beams": 1,
    "max_new_tokens": 2048
})()

# Model
disable_torch_init()

if args.model_path is not None:
    model, tokenizer, image_processor, context_len = load_pretrained_model(args.model_path)
else:
    assert args.model is not None, 'model_path or model must be provided'
    model = args.model
    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048
    tokenizer = model.tokenizer
    image_processor = model.vision_tower._image_processor


text_processor = TextPreprocess(tokenizer, args.conv_mode)
data_args = model.config
image_processor = ImagePreprocess(image_processor, data_args)
model.cuda()

In [None]:
import os

def ensure_dir(path):
    """
    create path by first checking its existence,
    :param paths: path
    :return:
    """
    if not os.path.exists(path):
        os.makedirs(path)


import signal

class TimeoutException(Exception):
    pass

def handler(signum, frame):
    raise TimeoutException()

# Set timeout (unit: s)
timeout = 300

def timeout_decorator(func):
    def wrapper(*args, **kwargs):
        signal.signal(signal.SIGALRM, handler)
        signal.alarm(timeout)
        try:
            result = func(*args, **kwargs)
        except TimeoutException:
            print("Function timed out!")
            raise TimeoutException
            result = None
        finally:
            signal.alarm(0)
        return result
    return wrapper

In [None]:
@timeout_decorator
def process_image(qs, path):
    qs = DEFAULT_IMAGE_TOKEN + "\n" + qs


    msg = Message()
    msg.add_message(qs)

    result = text_processor(msg.messages, mode='eval')
    input_ids = result['input_ids']
    prompt = result['prompt']
    input_ids = input_ids.unsqueeze(0).cuda()
            

    image_files = [path]
    images = load_images(image_files)[0]
    images_tensor = image_processor(images)
    images_tensor = images_tensor.unsqueeze(0).half().cuda()

    stop_str = text_processor.template.separator.apply()[1]
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor,
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            pad_token_id=tokenizer.pad_token_id,
            max_new_tokens=args.max_new_tokens,
            use_cache=True,
            stopping_criteria=[stopping_criteria],
        )

    outputs = tokenizer.batch_decode(
        output_ids, skip_special_tokens=True
    )[0]
    outputs = outputs.strip()
    if outputs.endswith(stop_str):
        outputs = outputs[: -len(stop_str)]
    outputs = outputs.strip()
    return outputs

import os
import glob
import traceback
errors = []

ts = ["Default", "Transparent", "Orthographic"]
# Change the src_b to the path of "Input Pictures and Reference Codes" in 
src_b = "<PATH of Input Pictures and Reference Codes>"
out_b = "<ANY OUTPUT PATH YOU LIKE>"
for index in range(len(ts)):
    tt = ts[index]
    src = src_b + tt
    out = out_b + tt
    ensure_dir(out)
    out_paths = sorted(glob.glob(os.path.join(src, "*.{}".format("jpg"))))
    if tt == "Orthographic" :
        qs = "This image is 4 views of a 3D model from certain angles. Please try to use Python-style APIs to render this model."
    else:    
        qs = "This image is a view of a 3D model from a certain angle. Please try to use Python-style APIs to render this model."

    for i in range(len(out_paths)):
        path = out_paths[i]
        print(f"{tt}: {i + 1}/{len(out_paths)}", end='\r')
        name = path.split("/")[-1].split(".")[0]
        save_path = os.path.join(out, f'{name}.py')
        if os.path.isfile(save_path): continue
        try:
            outputs = process_image(qs, path)
            with open(save_path, 'w', encoding='utf-8') as file:
                file.write(outputs)
            file.close()
        except:
            errors.append(f"{tt}: {name}")
            print(f"gen error: {name}")
            traceback.print_exc()
    print()

print(errors)
