In [1]:
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from copy import deepcopy
from typing import (
    Any,
    AsyncIterable,
    Callable,
    Dict,
    Generator,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Union,
)
import requests
from io import BytesIO

from PIL import Image
import torch
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights

from data.transforms import ImageTransform
from data.data_utils import pil_img2rgb, add_special_tokens
from modeling.bagel import (
    BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer
from modeling.bagel.qwen2_navit import NaiveCache
from modeling.autoencoder import load_ae
from safetensors.torch import load_file

## Model Initialization

In [None]:
# load qwen2.5
from transformers import AutoModelForCausalLM, AutoTokenizer

qwen_path = "/home/jake0360/BAGEL/Qwen2.5-7B-Instruct"
qwen = AutoModelForCausalLM.from_pretrained(
    qwen_path,
    torch_dtype="auto",
    device_map="auto"
).to("cuda").eval()

print("Loaded qwen2.5")

In [5]:
model_path = "/home/jake0360/projects/def-sreddy/checkpoints/BAGEL-7B-MoT"  # Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT

# LLM config preparing
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"

# ViT config preparing
vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
vit_config.rope = False
vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1

# VAE loading
vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))

# Bagel config preparing
config = BagelConfig(
    visual_gen=True,
    visual_und=True,
    llm_config=llm_config, 
    vit_config=vit_config,
    vae_config=vae_config,
    vit_max_num_patch_per_side=70,
    connector_act='gelu_pytorch_tanh',
    latent_patch_size=2,
    max_latent_size=64,
)

with init_empty_weights():
    language_model = Qwen2ForCausalLM(llm_config)
    vit_model      = SiglipVisionModel(vit_config)
    model          = Bagel(language_model, vit_model, config)
    model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)

# Tokenizer Preparing
tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)

# Image Transform Preparing
vae_transform = ImageTransform(1024, 512, 16)
vit_transform = ImageTransform(980, 224, 14)

## Compare Weights

In [None]:
import torch
import torch.nn.functional as F
import json, os
from safetensors import safe_open
from tqdm import tqdm
import re
import io
from PIL import Image, ImageDraw, ImageFont

import contextlib, sys

bagel_path = "/home/jake0360/projects/def-sreddy/checkpoints/BAGEL-7B-MoT/ema.safetensors"
qwen_dir   = "/home/jake0360/BAGEL/Qwen2.5-7B-Instruct"
index_file = os.path.join(qwen_dir, "model.safetensors.index.json")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load weights
with open(index_file, "r") as f:
    index = json.load(f)["weight_map"]

qwen_shards = {k: os.path.join(qwen_dir, v) for k, v in index.items()}
print(f"Indexed {len(qwen_shards)} Qwen2.5 tensors across {len(set(qwen_shards.values()))} shards.")

# compare weights
def compare_tensors(t1, t2):
    t1 = t1.to(device, dtype=torch.float32)
    t2 = t2.to(device, dtype=torch.float32)
    # delta mean
    mean_diff = (t1 - t2).abs().mean().item()
    # cos similarity
    cosine = F.cosine_similarity(t1.flatten(), t2.flatten(), dim=0).item()
    del t1, t2
    torch.cuda.empty_cache()
    return mean_diff, cosine

layer_diffs = {}

with safe_open(bagel_path, framework="pt") as f_bagel:
    bagel_keys = [k for k in f_bagel.keys() if k.startswith("language_model.")]
    for k in tqdm(bagel_keys, desc="Comparing BAGEL vs Qwen2.5"):
        layer_name = k.replace("language_model.", "")
        if layer_name not in qwen_shards:
            continue
        qwen_file = qwen_shards[layer_name]
        try:
            bagel_w = f_bagel.get_tensor(k)
            with safe_open(qwen_file, framework="pt") as f_q:
                ref_w = f_q.get_tensor(layer_name)
        except Exception:
            continue

        if bagel_w.shape != ref_w.shape:
            continue

        md, cs = compare_tensors(bagel_w, ref_w)
        group = ".".join(layer_name.split(".")[:3])
        layer_diffs.setdefault(group, []).append((md, cs))


output_path = "layer_diff_summary.txt"

with open(output_path, "w", encoding="utf-8") as f:
    f.write("=== Layer Δ Summary ===\n")

    for mod, vals in sorted(layer_diffs.items(), key=lambda x: layer_sort_key(x[0])):
        md = sum(v[0] for v in vals) / len(vals)
        cs = sum(v[1] for v in vals) / len(vals)
        bar = "█" * int((1 - cs) * 40) + "░" * int(cs * 40)
        f.write(f"{mod:25s} | Δmean={md:.3e} | cos={cs:.4f} | {bar}\n")

    f.write("\nDone — compared BAGEL vs Qwen2.5-7B-Instruct.\n")

print(f"Output: {output_path}")

In [None]:
from safetensors import safe_open

bagel_path = "/home/jake0360/projects/def-sreddy/checkpoints/BAGEL-7B-MoT/ema.safetensors"

with safe_open(bagel_path, framework="pt") as f:
    all_keys = list(f.keys())

for k in all_keys:
    if "moe" not in k.lower():
        print(k)