In [None]:
# SPDX-License-Identifier: Apache-2.0
%load_ext autoreload
%autoreload 2

In [None]:
import os
from copy import deepcopy
from typing import (
    Any,
    AsyncIterable,
    Callable,
    Dict,
    Generator,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Union,
)
import requests
from io import BytesIO

from PIL import Image
import torch
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights

from data.transforms import ImageTransform
from data.data_utils import pil_img2rgb, add_special_tokens
from modeling.bagel import (
    BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer
from modeling.bagel.qwen2_navit import NaiveCache
from modeling.autoencoder import load_ae
from safetensors.torch import load_file

## Model Initialization

In [None]:
model_path = "your_model_path"  
# LLM config preparing
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"

# ViT config preparing
vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
vit_config.rope = False
vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1

# VAE loading
vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))

# Bagel config preparing
config = BagelConfig(
    visual_gen=True,
    visual_und=True,
    llm_config=llm_config, 
    vit_config=vit_config,
    vae_config=vae_config,
    vit_max_num_patch_per_side=70,
    connector_act='gelu_pytorch_tanh',
    latent_patch_size=2,
    max_latent_size=64,
)

num_experts = 32
num_shared_experts = 0
top_k = 16

share_ratio = num_shared_experts / (num_shared_experts + num_experts)
intermediate_size = 18944
share_size = int(share_ratio * intermediate_size)   

setattr(llm_config, "num_experts", num_experts)
setattr(llm_config, "num_shared_experts", num_shared_experts)
setattr(llm_config, "top_k", top_k)

with init_empty_weights():
    language_model = Qwen2ForCausalLM(llm_config)
    vit_model      = SiglipVisionModel(vit_config)
    model          = Bagel(language_model, vit_model, config)
    model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)

# Tokenizer Preparing
tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)

# Image Transform Preparing
vae_transform = ImageTransform(1024, 512, 16)
vit_transform = ImageTransform(980, 224, 14)

## Model Loading and Multi GPU Infernece Preparing

In [None]:
# max_mem_per_gpu = "40GiB"  # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU.
max_mem_per_gpu = "50GiB"  # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU.

device_map = infer_auto_device_map(
    model,
    max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},
    no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
)
print(device_map)

same_device_modules = [
    'language_model.model.embed_tokens',
    'time_embedder',
    'latent_pos_embed',
    'vae2llm',
    'llm2vae',
    'connector',
    'vit_pos_embed'
]

if torch.cuda.device_count() == 1:
    first_device = device_map.get(same_device_modules[0], "cuda:0")
    for k in same_device_modules:
        if k in device_map:
            device_map[k] = first_device
        else:
            device_map[k] = "cuda:0"
else:
    first_device = device_map.get(same_device_modules[0])
    for k in same_device_modules:
        if k in device_map:
            device_map[k] = first_device

model_path = "your_model_path" 
model = load_checkpoint_and_dispatch(
    model,
    checkpoint=os.path.join(model_path, "ema.safetensors"),
    device_map=device_map,
    offload_buffers=True,
    dtype=torch.bfloat16,
    force_hooks=True,
    offload_folder="/tmp/offload"
)

model = model.eval()
print('Model loaded')


## Convert dense to sparse

In [None]:
import tqdm as tqdm
import pickle

def read_list_from_file(file_path):
    with open(file_path, 'rb') as file:
        return pickle.load(file)

num_hidden_layers = llm_config.num_hidden_layers
intermediate_size = llm_config.intermediate_size

# scores_und = torch.tensor(read_list_from_file(und)) # fill in your local score path
# scores_gen = torch.tensor(read_list_from_file(gen)) # fill in your local score path

# random scores for test
scores_und = torch.randn(num_hidden_layers, intermediate_size)
scores_gen = torch.randn(num_hidden_layers, intermediate_size)

mode = "gen" # und, gen, und_gen 
idx_und = None 
idx_gen = None

for i, layer in enumerate(model.language_model.model.layers):

    if i in [0, num_hidden_layers - 1]: 
        continue

    if "und" in mode:
        layer.convert_dense_to_sparse_moe_dual(mode="und", importance_scores = scores_und[i], shared_ratio = share_ratio, )

    elif "gen" in mode:
        layer.convert_dense_to_sparse_moe_dual(mode="gen", importance_scores = scores_gen[i], shared_ratio = share_ratio, )

## Inferencer Preparing 

In [None]:
from inferencer import InterleaveInferencer

inferencer = InterleaveInferencer(
    model=model, 
    vae_model=vae_model, 
    tokenizer=tokenizer, 
    vae_transform=vae_transform, 
    vit_transform=vit_transform, 
    new_token_ids=new_token_ids
)

In [None]:
import random
import numpy as np

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Forward Example

In [None]:
inference_hyper=dict(
    max_think_token_n=1000,
    do_sample=False,
    cfg_text_scale=4.0,
    cfg_img_scale=1.0,
    cfg_interval=[0.4, 1.0],
    timestep_shift=3.0,
    num_timesteps=50,
    cfg_renorm_min=0.0,
    cfg_renorm_type="global",
)


prompt = 'A realistic wooden bench sits on a flat surface. The bench is crafted from dark oak, with a smooth, polished texture that reflects light subtly. Its seat is long and rectangular, supported by four sturdy legs, each positioned evenly at the corners. The backrest is slightly curved, consisting of vertical slats, evenly spaced and securely attached. The wood grain is visible throughout, with natural variations in tone and pattern. The edges of the bench are neatly rounded, giving it a refined appearance. The overall structure is solid and stable, emphasizing durability in its photographic realism.'
print('-' * 10)
output_dict = inferencer(text=prompt, think=True, **inference_hyper)
img = output_dict['image']
display(img)