In [None]:
import os
import torch
import random
import numpy as np
import pandas as pd

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

from data.data_utils import add_special_tokens, pil_img2rgb
from data.transforms import ImageTransform
from inferencer import InterleaveInferencer
from modeling.autoencoder import load_ae
from modeling.bagel import (
    BagelConfig, Bagel,
    Qwen2Config, Qwen2ForCausalLM,
    SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer


In [None]:
df = pd.read_parquet("dataset/bloom_vist_story2image.parquet")
print(df.head())

df = df.sample(n=100, random_state=42).reset_index(drop=True)
print("Using rows:", len(df))


In [None]:
class StoryImageDataset(Dataset):
    def __init__(self, df, image_root):
        self.df = df
        self.image_root = image_root

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        prompt = row["text_prompt"]
        image_path = os.path.join(self.image_root, row["image_path"])

        image = Image.open(image_path).convert("RGB")
        image = pil_img2rgb(image)

        return {
            "text": prompt,
            "image": image
        }


In [None]:
MODEL_PATH = "models/BAGEL-7B-MoT"
CHECKPOINT = os.path.join(MODEL_PATH, "ema.safetensors")
OUTPUT_DIR = "checkpoints/bagel_finetune"

os.makedirs(OUTPUT_DIR, exist_ok=True)


In [None]:
llm_config = Qwen2Config.from_json_file(os.path.join(MODEL_PATH, "llm_config.json"))
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"

vit_config = SiglipVisionConfig.from_json_file(os.path.join(MODEL_PATH, "vit_config.json"))
vit_config.rope = False
vit_config.num_hidden_layers -= 1

vae_model, vae_config = load_ae(os.path.join(MODEL_PATH, "ae.safetensors"))


In [None]:
config = BagelConfig(
    visual_gen=True,
    visual_und=False,
    llm_config=llm_config,
    vit_config=vit_config,
    vae_config=vae_config,
)

with init_empty_weights():
    llm = Qwen2ForCausalLM(llm_config)
    vit = SiglipVisionModel(vit_config)
    model = Bagel(llm, vit, config)

model = load_checkpoint_and_dispatch(
    model,
    checkpoint=CHECKPOINT,
    device_map={"": "cuda"},
    dtype=torch.bfloat16,
).train()


In [None]:
for name, param in model.named_parameters():
    param.requires_grad = False
    if "language_model" in name and "attn" in name:
        param.requires_grad = True

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Trainable params:", trainable)


In [None]:
tokenizer = Qwen2Tokenizer.from_pretrained(MODEL_PATH)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)

vae_transform = ImageTransform(1024, 512, 16)
vit_transform = ImageTransform(980, 224, 14)

dataset = StoryImageDataset(df, image_root="dataset/images")
loader = DataLoader(dataset, batch_size=1, shuffle=True)


In [None]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-5
)


In [None]:
EPOCHS = 3
SAVE_EVERY = 20

step = 0
for epoch in range(EPOCHS):
    for batch in loader:
        step += 1
        optimizer.zero_grad()

        result = model(
            text=batch["text"],
            image=batch["image"],
            return_loss=True
        )

        loss = result["loss"]
        loss.backward()
        optimizer.step()

        if step % 5 == 0:
            print(f"Epoch {epoch} | Step {step} | Loss {loss.item():.4f}")

        if step % SAVE_EVERY == 0:
            ckpt_path = os.path.join(OUTPUT_DIR, f"step_{step}.pt")
            torch.save(
                {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "step": step
                },
                ckpt_path
            )
            print("ðŸ’¾ Saved checkpoint:", ckpt_path)


In [None]:
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "final_model.pt"))
print("âœ… Training finished")
