In [None]:
import os
import torch
import random
import numpy as np
from PIL import Image

from accelerate import (
    init_empty_weights,
    infer_auto_device_map,
    load_checkpoint_and_dispatch,
)

from data.data_utils import add_special_tokens
from data.transforms import ImageTransform
from inferencer import InterleaveInferencer
from modeling.autoencoder import load_ae
from modeling.bagel import (
    BagelConfig, Bagel,
    Qwen2Config, Qwen2ForCausalLM,
    SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer


In [None]:
# CHANGE THIS to your local path
MODEL_PATH = "/content/models/BAGEL-7B-MoT"  # Colab example
# MODEL_PATH = "models/BAGEL-7B-MoT"         # Local example

CHECKPOINT = os.path.join(MODEL_PATH, "ema.safetensors")

assert os.path.exists(MODEL_PATH), "Model path not found"
assert os.path.exists(CHECKPOINT), "Checkpoint not found"

torch.set_grad_enabled(False)


In [None]:
#loading model configuration no weights yet

llm_config = Qwen2Config.from_json_file(
    os.path.join(MODEL_PATH, "llm_config.json")
)
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"

vit_config = SiglipVisionConfig.from_json_file(
    os.path.join(MODEL_PATH, "vit_config.json")
)
vit_config.rope = False
vit_config.num_hidden_layers -= 1

vae_model, vae_config = load_ae(
    local_path=os.path.join(MODEL_PATH, "ae.safetensors")
)


In [None]:
config = BagelConfig(
    visual_gen=True,
    visual_und=False,
    llm_config=llm_config,
    vit_config=vit_config,
    vae_config=vae_config,
    vit_max_num_patch_per_side=70,
    connector_act="gelu_pytorch_tanh",
    latent_patch_size=2,
    max_latent_size=64,
)

with init_empty_weights():
    llm = Qwen2ForCausalLM(llm_config)
    vit = SiglipVisionModel(vit_config)
    model = Bagel(llm, vit, config)

print("✅ BAGEL architecture created (empty weights)")


In [None]:
device_map = infer_auto_device_map(
    model,
    max_memory={0: "24GiB"},  # adjust if needed
    no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
)

model = load_checkpoint_and_dispatch(
    model,
    checkpoint=CHECKPOINT,
    device_map=device_map,
    dtype=torch.bfloat16,
).eval()

print("✅ Weights loaded")


In [None]:
tokenizer = Qwen2Tokenizer.from_pretrained(MODEL_PATH)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)

vae_transform = ImageTransform(1024, 512, 16)
vit_transform = ImageTransform(980, 224, 14)

print("✅ Tokenizer & transforms ready")


In [None]:
def set_seed(seed):
    if seed > 0:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def generate_image(prompt, seed=42):
    set_seed(seed)

    result = inferencer(
        text=prompt,
        think=False,                      # disable CoT
        image_shapes=(1024, 1024),        # reduce to (768,768) if low VRAM
        cfg_text_scale=4.0,
        cfg_interval=[0.4, 1.0],
        timestep_shift=3.0,
        num_timesteps=50,
        cfg_renorm_min=0.0,
        cfg_renorm_type="global",
    )

    return result["image"]


In [None]:
prompt = "A futuristic cyberpunk city at night, neon lights, flying cars, cinematic lighting"

image = generate_image(prompt, seed=123)
image
