# ILLUME+ Model Inference

This notebook demonstrates how to use the ILLUME model for three main tasks:
1. **Image Understanding:** Given an image and a text prompt (e.g., a question), the model generates a textual response.
2. **Image Generation:** Given a text prompt, the model generates an image.
3. **Image Editing:** Given a text prompt with source image, the model generates the edited image.

**Important:** You should download the models and place the model in the same folder as described in Section 1. Otherwise you need to modify the model path in Section 3.

## 1. Setup and Imports

In [None]:
#  If you haven't download the checkpoint, you could uncomment this to download.

# from huggingface_hub import snapshot_download
# import os


# if not os.path.exists('../checkpoints/'):
#     os.makedirs('../checkpoints/')

# save_dir="../checkpoints/illume_plus-qwen2_5-3b"
# snapshot_download(local_dir=save_dir, 
#                   repo_id='ILLUME-MLLM/illume_plus-qwen2_5-3b', 
#                   local_dir_use_symlinks=False, 
#                   resume_download=True))"
# os.makedirs('./logdir/illume_plus_3b/', exist_ok=True)
# os.symlink(os.path.abspath(save_dir), os.path.abspath('./logdir/illume_plus_3b/illume_plus-qwen2_5-3b_stage3'))

# DUALVITOK_CHECKPOINT_DIR="../checkpoints/dualvitok/"
# snapshot_download(local_dir=DUALVITOK_CHECKPOINT_DIR, 
#                   repo_id='ILLUME-MLLM/dualvitok', 
#                   local_dir_use_symlinks=False, 
#                   resume_download=True))"

# DUALVITOK_SDXL_DECODER_DIR="../checkpoints/dualvitok-sdxl-decoder/"
# snapshot_download(local_dir=DUALVITOK_SDXL_DECODER_DIR, 
#                   repo_id='ILLUME-MLLM/dualvitok-sdxl-decoder', 
#                   local_dir_use_symlinks=False, 
#                   resume_download=True)


In [None]:
# link the download path to the checkpoint path designed in the mllm config.
import os

os.makedirs('./logdir/illume_plus_3b/', exist_ok=True)
if not os.path.exists('./logdir/illume_plus_3b/illume_plus-qwen2_5-3b_stage3/'):
    os.symlink(os.path.abspath('../checkpoints/illume_plus-qwen2_5-3b'),
                          os.path.abspath('./logdir/illume_plus_3b/illume_plus-qwen2_5-3b_stage3/'))

In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.abspath('.'), '../vision_tokenizer/')))

import argparse
import traceback
import logging
from functools import partial
import numpy as np

from PIL import Image
import re  # Added for parsing image tokens
from typing import List, Tuple
import matplotlib.pyplot as plt

# --- Add necessary imports from your ILLUME codebase ---
import torch

from transformers import LogitsProcessorList, TextIteratorStreamer

from generation_eval.models.builder import build_eval_model

from illume.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from illume.conversation import conv_templates, default_conversation  # Import Conversation class
from illume.mm_utils import process_images, tokenizer_image_token
from illume.data.data_utils import unpad_and_resize_back

# Logging setup
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

## 2. Helper Functions

In [None]:
def pad_sequence(tokenizer, input_ids_list: List[torch.Tensor], batch_first: bool, padding_value: int) -> torch.Tensor:
    # Assuming input_ids_list contains tensors
    # This is a simplified version. The app.py version handles left padding.
    # For notebook usage with single items, direct padding or checking might be easier.
    if tokenizer.padding_side == 'left':
        # Flip for padding, then flip back. Requires all inputs to be actual tensors.
        input_ids_list = [torch.flip(_input_ids, [0]) for _input_ids in input_ids_list]

    # torch.nn.utils.rnn.pad_sequence expects a list of Tensors
    input_ids_padded = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=batch_first, padding_value=padding_value)

    if tokenizer.padding_side == 'left':
        input_ids_padded = torch.flip(input_ids_padded, [1]) # Flip along the sequence dimension
    return input_ids_padded


def show_image(image, short_side=256):
    if isinstance(image, str):
        image = Image.open(image)
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    w, h = image.size
    if w < h:
        new_w = short_side
        new_h = int(h * short_side / w)
    else:
        new_h = short_side
        new_w = int(w * short_side / h)
    image = image.resize((new_w, new_h))

    plt.imshow(image)
    plt.axis('off')
    plt.show()
    

def convert_np_to_pil_img(samples, batch_data):
    all_pil_images = []
    for sample, info in zip(samples, batch_data):
        img = Image.fromarray(sample.astype(np.uint8))

        if "original_sizes" in info:  # for editing task, unpad and resize back to its original image size
            original_size = info["original_sizes"]
            img = inference_engine.unpad_and_resize_back(img, original_size[0], original_size[1])
        all_pil_images.append(img)
    return all_pil_images


## 3. Model Loading

If your machine has more than 3 GPUs, the mllm, tokenizer and diffusion decoder will be placed in different GPUs.

In [None]:
model_name = 'ILLUME'

mllm_config_path="../configs/example/illume_plus_3b/illume_plus_qwen2_5_3b_stage3.py"
tokenizer_config_path="../configs/example/dualvitok/dualvitok_anyres_max512.py"
vq_tokenizer_ckpt_path="../checkpoints/dualvitok/pytorch_model.bin"
diffusion_decoder_path="../checkpoints/dualvitok-sdxl-decoder/"
torch_dtype = 'fp16'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
local_rank = 0 if 'cuda' in device else -1

eval_model_cfg = dict(
    type=model_name,
    config=mllm_config_path,
    tokenizer_config=tokenizer_config_path,
    diffusion_decoder_path=diffusion_decoder_path,
    tokenizer_checkpoint=vq_tokenizer_ckpt_path,
    torch_dtype=torch_dtype
)
logging.info(f'Building ILLUME model with config: {eval_model_cfg}')
inference_engine = build_eval_model(eval_model_cfg)

# Device assignment
num_gpus = torch.cuda.device_count()
if num_gpus >= 3:
    mllm_device = torch.device('cuda:0'); vq_device = torch.device('cuda:1'); diffusion_device = torch.device('cuda:2')
elif num_gpus == 2:
    mllm_device = torch.device('cuda:0'); vq_device = torch.device('cuda:1'); diffusion_device = torch.device('cuda:1')
elif num_gpus == 1:
    mllm_device = torch.device('cuda:0'); vq_device = torch.device('cuda:0'); diffusion_device = torch.device('cuda:0')
else:
    mllm_device = torch.device('cpu'); vq_device = torch.device('cpu'); diffusion_device = torch.device('cpu')
logging.info(f'MLLM: {mllm_device}, VQ: {vq_device}, Diffusion: {diffusion_device}')

if hasattr(inference_engine, 'mllm_model') and inference_engine.mllm_model: inference_engine.mllm_model.to(mllm_device)
if hasattr(inference_engine, 'vq_model') and inference_engine.vq_model: inference_engine.vq_model.to(vq_device)
if hasattr(inference_engine, 'diffusion_decoder_pipe') and inference_engine.diffusion_decoder_pipe: inference_engine.diffusion_decoder_pipe.to(diffusion_device)

inference_engine.device = device # Overall device
inference_engine.mllm_device = mllm_device
inference_engine.vq_device = vq_device
inference_engine.diffusion_device = diffusion_device

## 4. Image Understanding

In [None]:
# 0. Load Image
prompt = 'depict the image in short'
image_path = '../configs/data_configs/test_data_examples/ImageUnderstandingExample/images/0.png'
input_image = Image.open(image_path).convert('RGB')

inference_config = inference_engine.prepare_inference_config(
    temperature=1.0,
    top_k=50,
    top_p=1.0,
)

batch_data = [
    dict(prompt=prompt, images_data=[input_image])
]

outputs = inference_engine.inference_mllm(
    batch_data, inference_config,
    is_img_gen_task=False,  #  Remember set this for image understanding. 
    do_sample=False  # You could add more params for the model.generate.
)

# outputs is a list. each element is a dict with keys: 'image_embed_inds', 'output_text', 'image_sizes', 'original_sizes'

show_image(input_image)
print(f"Question: {prompt}")
print(f"Answer:  {outputs[0]['output_text']}")

## 5. Image Generation

In [None]:
# 0. Load Image
content = 'a cat with a hat.'
target_resolution = (512,512)
resolution_tag = inference_engine.get_resolution_tag_from_resolution(target_resolution)

prompt = inference_engine.default_generation_template.format(resolution_tag=resolution_tag, content=content)
unconditional_prompt = inference_engine.default_generation_unconditional_template.format(resolution_tag=resolution_tag)

print(f"prompt: {prompt}")
print(f"unconditional prompt: {unconditional_prompt}")

inference_config = inference_engine.prepare_inference_config(
    temperature=1.0,
    top_k=128,
    top_p=1.0,

    llm_cfg_scale = 2.0,
    image_semantic_temperature= 1.0,
    image_semantic_top_k = 2024,
    image_semantic_top_p = 1.0,
    
    resolution = target_resolution,
    unconditional_prompt = unconditional_prompt,
)

batch_data = [dict(prompt=prompt)]
outputs = inference_engine.inference_mllm(batch_data, inference_config, is_img_gen_task=True)

In [None]:
inference_config = inference_engine.prepare_inference_config(
    temperature=1.0,
    top_k=128,
    top_p=1.0,

    llm_cfg_scale = 2.0,
    image_semantic_temperature= 1.0,
    image_semantic_top_k = 2024,
    image_semantic_top_p = 1.0,
    
    diffusion_cfg_scale=1.5,
    diffusion_num_inference_steps=50,

    resolution = target_resolution,
    unconditional_prompt = unconditional_prompt,
)

In [None]:
#  using vq tokenizer to decode image.
out_images = inference_engine.inference_tokenizer_decoder(outputs, inference_config, use_diffusion_decoder=False)
print(f"Image prompt: {content}")
generated_image = convert_np_to_pil_img(out_images, outputs)[0]
print(f'Generated Image Size: {generated_image.size}')
generated_image

In [None]:
#  Using sdxl diffusion decoder to decode image.
out_images = inference_engine.inference_tokenizer_decoder(outputs, inference_config, use_diffusion_decoder=True)
generated_image = convert_np_to_pil_img(out_images, outputs)[0]
print(f'Generated Image Size: {generated_image.size}')
generated_image

## 6. Image Editing

In [None]:
instruction = 'Change the color of the boots to a deep forest green.'
image_path = '../configs/data_configs/test_data_examples/EditingSingleTurnExample/images/0.jpg'

prompt = inference_engine.default_editing_template.format(resolution_tag='', content=instruction)
unconditional_prompt = inference_engine.default_editing_unconditional_template.format(resolution_tag='')
print(f"prompt: {prompt}")
print(f"unconditional prompt: {unconditional_prompt}")

input_image = Image.open(image_path).convert('RGB')
original_image_size = input_image.size
show_image(input_image)
inference_config = inference_engine.prepare_inference_config(
    temperature=1.0,
    top_k=128,
    top_p=1.0,
    
    llm_cfg_scale = 1.5,
    diffusion_cfg_scale=1.5,
    diffusion_num_inference_steps=50,
    
    image_semantic_temperature= 0.7,
    image_semantic_top_k = 512,
    image_semantic_top_p = 0.8,
    unconditional_prompt=unconditional_prompt,
    #    resolution=(512,512)  # the resolution will be obtrain from the source image within the code.
)

batch_data = [
    dict(prompt=prompt, images_data=[input_image])
]

outputs = inference_engine.inference_mllm(batch_data, inference_config, is_img_gen_task=True)

In [None]:
#  using vq tokenizer to decode image.
out_images = inference_engine.inference_tokenizer_decoder(outputs, inference_config, use_diffusion_decoder=False)
padded_image = convert_np_to_pil_img(out_images, outputs)[0]
generated_image = inference_engine.unpad_and_resize_back(padded_image, outputs[0]['original_sizes'][0], outputs[0]['original_sizes'][1])
generated_image

In [None]:
#  Using sdxl diffusion decoder to decode image.
out_images = inference_engine.inference_tokenizer_decoder(outputs, inference_config, use_diffusion_decoder=True)
padded_image= convert_np_to_pil_img(out_images, outputs)[0]
generated_image = inference_engine.unpad_and_resize_back(padded_image, outputs[0]['original_sizes'][0], outputs[0]['original_sizes'][1])
generated_image