# FastVLM Training with Ouro-1.4B (LoopLM)

Train FastVLM using **ByteDance Ouro-1.4B** - a Looped Language Model.

- **LLM Backbone**: Ouro-1.4B (LoopLM architecture)
- **Vision Encoder**: FastViTHD (MobileCLIP)
- **Dataset**: 5CD-AI/Viet-multimodal-open-r1-8k-verified

## Ouro Architecture
- **Hidden Size**: 2048
- **Layers**: 24 (x4 recurrent steps)
- **Effective capacity**: ~4-12B model performance

## 1. Install Dependencies

In [None]:
# Install required packages
# IMPORTANT: Ouro requires transformers < 4.56.0
!pip install -q transformers==4.54.1
!pip install -q torch>=2.1.0 torchvision>=0.16.0
!pip install -q accelerate>=0.26.0 peft>=0.10.0
!pip install -q bitsandbytes>=0.43.0
!pip install -q datasets pillow einops timm>=0.9.0
!pip install -q sentencepiece safetensors
!pip install -q huggingface_hub

In [None]:
# Verify transformers version
import transformers
print(f"Transformers version: {transformers.__version__}")
assert transformers.__version__ < "4.56.0", "Ouro requires transformers < 4.56.0!"

## 2. Configuration

In [None]:
import os
import json
import torch
from pathlib import Path

# ============================================
# CONFIGURATION - OURO-1.4B
# ============================================
CONFIG = {
    # Model - OURO
    "llm_model": "ByteDance/Ouro-1.4B",
    "vision_tower": "mobileclip_l_384",
    "mm_hidden_size": 3072,       # MobileCLIP output (will auto-detect)
    "llm_hidden_size": 2048,      # Ouro hidden size
    
    # Dataset
    "dataset_name": "5CD-AI/Viet-multimodal-open-r1-8k-verified",
    "image_column": "image",
    "question_column": "vi_problem",
    "answer_column": "vi_solution",
    
    # ============================================
    # TRAINING MODE: "steps" or "epochs"
    # ============================================
    "training_mode": "steps",  # <-- CHANGE THIS: "steps" or "epochs"
    
    # Step-based settings (used if training_mode="steps")
    "max_steps": 100,              # Total training steps
    "warmup_steps": 10,            # Warmup steps
    "save_steps": 50,              # Save every N steps
    
    # Epoch-based settings (used if training_mode="epochs")
    "num_train_epochs": 2,         # Number of epochs
    "warmup_ratio": 0.03,          # Warmup ratio (3% of total)
    "save_strategy": "epoch",      # Save every epoch
    
    # Common settings
    "output_dir": "./outputs/fastvlm-ouro-1.4b-vietnamese",
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 8,
    "learning_rate": 1e-5,
    "lr_scheduler_type": "cosine",
    "bf16": True,
    "model_max_length": 2048,
    "save_total_limit": 2,
    
    # LoRA
    "use_lora": True,
    "lora_r": 8,
    "lora_alpha": 16,
    "lora_dropout": 0.05,
    
    # HuggingFace
    "hf_repo": "beyoru/Belle-VLM-Ouro",
    "hf_token": os.environ.get("HF_TOKEN", ""),
}

os.makedirs(CONFIG["output_dir"], exist_ok=True)

# Print config
print("="*50)
print(f"TRAINING MODE: {CONFIG['training_mode'].upper()}")
print("="*50)

print("\nConfiguration (Ouro-1.4B):")
for k, v in CONFIG.items():
    if k == "hf_token":
        continue
    # Highlight relevant settings based on mode
    if CONFIG["training_mode"] == "steps":
        if k in ["num_train_epochs", "warmup_ratio", "save_strategy"]:
            continue  # Skip epoch settings
    else:
        if k in ["max_steps", "warmup_steps", "save_steps"]:
            continue  # Skip step settings
    print(f"  {k}: {v}")

# Estimate
effective_batch = CONFIG["per_device_train_batch_size"] * CONFIG["gradient_accumulation_steps"]
print(f"\nEffective batch size: {effective_batch}")

if CONFIG["training_mode"] == "steps":
    print(f"Total steps: {CONFIG['max_steps']}")
    print(f"Warmup steps: {CONFIG['warmup_steps']}")
    print(f"Save every: {CONFIG['save_steps']} steps")
else:
    print(f"Epochs: {CONFIG['num_train_epochs']}")
    print(f"Warmup: {CONFIG['warmup_ratio']*100:.0f}% of total")
    print(f"Save: every epoch")

In [None]:
# Login to HuggingFace
from huggingface_hub import login

if CONFIG["hf_token"]:
    login(token=CONFIG["hf_token"])
    print("Logged in to HuggingFace!")
else:
    print("HF_TOKEN not set.")

## 3. Load Ouro Model

In [None]:
# Check GPU
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig

print(f"Loading Ouro model: {CONFIG['llm_model']}")

# Quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    CONFIG["llm_model"],
    trust_remote_code=True,
    model_max_length=CONFIG["model_max_length"],
    padding_side="right",
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer vocab size: {tokenizer.vocab_size}")

# Load Ouro model
model = AutoModelForCausalLM.from_pretrained(
    CONFIG["llm_model"],
    trust_remote_code=True,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

print(f"Model loaded!")
print(f"Model type: {model.config.model_type}")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Layers: {model.config.num_hidden_layers}")
print(f"Total UT steps: {getattr(model.config, 'total_ut_steps', 'N/A')}")

In [None]:
# Skip - mm_projector will be created AFTER LoRA
# (If we create it here, it will be lost after get_peft_model wraps the model)
print("mm_projector will be created after LoRA setup...")

## 4. Setup LoRA

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Prepare for k-bit training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

# LoRA config for Ouro
lora_config = LoraConfig(
    r=CONFIG["lora_r"],
    lora_alpha=CONFIG["lora_alpha"],
    lora_dropout=CONFIG["lora_dropout"],
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# ============================================
# NOW create mm_projector (AFTER LoRA wrap)
# ============================================
import torch.nn as nn

# Build mm_projector: 3072 (MobileCLIP) -> 2048 (Ouro)
mm_projector = nn.Sequential(
    nn.Linear(CONFIG["mm_hidden_size"], CONFIG["llm_hidden_size"]),
    nn.GELU(),
    nn.Linear(CONFIG["llm_hidden_size"], CONFIG["llm_hidden_size"]),
).to(model.device, dtype=torch.bfloat16)

# Attach to PeftModel
model.mm_projector = mm_projector

# Make mm_projector trainable
for param in model.mm_projector.parameters():
    param.requires_grad = True

print(f"\nAdded mm_projector: {CONFIG['mm_hidden_size']} -> {CONFIG['llm_hidden_size']}")
proj_params = sum(p.numel() for p in mm_projector.parameters())
print(f"mm_projector parameters: {proj_params / 1e6:.2f}M (trainable)")

## 5. Load Vision Tower

In [None]:
import timm
from transformers import CLIPImageProcessor

# Load MobileCLIP vision tower
print("Loading MobileCLIP vision tower...")
vision_tower = timm.create_model(
    "fastvit_mci2.apple_mclip",
    pretrained=True,
    num_classes=0,
)
vision_tower.eval()
vision_tower = vision_tower.to(model.device, dtype=torch.bfloat16)

# Image processor
image_processor = CLIPImageProcessor(
    size={"shortest_edge": 384},
    crop_size={"height": 384, "width": 384},
    do_center_crop=True,
    do_normalize=True,
    image_mean=[0.48145466, 0.4578275, 0.40821073],
    image_std=[0.26862954, 0.26130258, 0.27577711],
)

# Test output dimension
dummy_img = torch.randn(1, 3, 384, 384).to(model.device, dtype=torch.bfloat16)
with torch.no_grad():
    features = vision_tower.forward_features(dummy_img)
    if features.dim() == 4:
        B, C, H, W = features.shape
        features = features.flatten(2).transpose(1, 2)
        
print(f"Vision tower output: {features.shape}")
actual_hidden = features.shape[-1]
print(f"Actual hidden size: {actual_hidden}")

# Verify mm_projector matches
expected_hidden = CONFIG["mm_hidden_size"]
if actual_hidden != expected_hidden:
    print(f"WARNING: Config mm_hidden_size={expected_hidden} but vision tower outputs {actual_hidden}")
    print(f"Updating CONFIG and recreating mm_projector...")
    CONFIG["mm_hidden_size"] = actual_hidden
    
    # Recreate mm_projector with correct size
    model.mm_projector = nn.Sequential(
        nn.Linear(actual_hidden, CONFIG["llm_hidden_size"]),
        nn.GELU(),
        nn.Linear(CONFIG["llm_hidden_size"], CONFIG["llm_hidden_size"]),
    ).to(model.device, dtype=torch.bfloat16)
    
    for param in model.mm_projector.parameters():
        param.requires_grad = True
    print(f"Recreated mm_projector: {actual_hidden} -> {CONFIG['llm_hidden_size']}")

# Attach to model
model.vision_tower = vision_tower
model.image_processor = image_processor

print("Vision tower ready!")

## 6. Prepare Dataset

In [None]:
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm

# Load dataset
print(f"Loading dataset: {CONFIG['dataset_name']}")
dataset = load_dataset(CONFIG["dataset_name"], split="train")

print(f"Total samples: {len(dataset)}")
print(f"Columns: {dataset.column_names}")

In [None]:
# Create LLaVA format data
import json
import os

DATA_DIR = "./data"
IMAGE_FOLDER = os.path.join(DATA_DIR, "images")
os.makedirs(IMAGE_FOLDER, exist_ok=True)

llava_data = []

for idx, sample in enumerate(tqdm(dataset, desc="Converting")):
    # Save image
    image_filename = f"{idx:06d}.jpg"
    image_path = os.path.join(IMAGE_FOLDER, image_filename)
    
    img = sample['image']
    if isinstance(img, Image.Image):
        if img.mode != 'RGB':
            img = img.convert('RGB')
        img.save(image_path, 'JPEG', quality=95)
    
    # Conversation
    question = sample['vi_problem'].strip()
    answer = sample['vi_solution'].strip()
    if len(answer) > 4096:
        answer = answer[:4096] + "..."
    
    llava_data.append({
        "id": str(idx),
        "image": image_filename,
        "conversations": [
            {"from": "human", "value": f"<image>\n{question}"},
            {"from": "gpt", "value": answer}
        ]
    })

# Save
json_path = os.path.join(DATA_DIR, "train_data.json")
with open(json_path, 'w', encoding='utf-8') as f:
    json.dump(llava_data, f, ensure_ascii=False, indent=2)

print(f"Dataset converted: {len(llava_data)} samples")

## 7. Training

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

IMAGE_TOKEN_INDEX = -200

class VLMDataset(Dataset):
    """Simple VLM dataset for training."""
    
    def __init__(self, data, image_folder, tokenizer, image_processor, max_length=2048):
        self.data = data
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        
        self.transform = transforms.Compose([
            transforms.Resize((384, 384)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711]
            ),
        ])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Load image
        image_path = os.path.join(self.image_folder, item['image'])
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image)
        
        # Build conversation
        conv = item['conversations']
        question = conv[0]['value'].replace('<image>', '').strip()
        answer = conv[1]['value']
        
        # Format prompt (Ouro uses standard format)
        prompt = f"User: <image>\n{question}\nAssistant: {answer}"
        
        # Tokenize
        tokens = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        input_ids = tokens['input_ids'].squeeze(0)
        attention_mask = tokens['attention_mask'].squeeze(0)
        
        # Labels (mask prompt, only train on answer)
        labels = input_ids.clone()
        # Find where answer starts
        answer_start = prompt.find('Assistant:') + len('Assistant:')
        answer_tokens = self.tokenizer(prompt[:answer_start], return_tensors='pt')['input_ids'].shape[1]
        labels[:answer_tokens] = -100  # Ignore prompt tokens
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'images': image_tensor,
        }

# Create dataset
train_dataset = VLMDataset(
    llava_data, 
    IMAGE_FOLDER, 
    tokenizer, 
    image_processor,
    max_length=CONFIG["model_max_length"]
)

print(f"Training dataset: {len(train_dataset)} samples")

In [None]:
from transformers import Trainer, TrainingArguments

# ============================================
# TRAINING ARGUMENTS - Auto-detect mode from CONFIG
# ============================================
if CONFIG["training_mode"] == "steps":
    # STEP-BASED training
    print("Mode: STEP-BASED training")
    training_args = TrainingArguments(
        output_dir=CONFIG["output_dir"],
        max_steps=CONFIG["max_steps"],
        per_device_train_batch_size=CONFIG["per_device_train_batch_size"],
        gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
        learning_rate=CONFIG["learning_rate"],
        warmup_steps=CONFIG["warmup_steps"],
        lr_scheduler_type=CONFIG["lr_scheduler_type"],
        bf16=CONFIG["bf16"],
        logging_steps=10,
        save_steps=CONFIG["save_steps"],
        save_total_limit=CONFIG["save_total_limit"],
        gradient_checkpointing=True,
        dataloader_num_workers=4,
        report_to="none",
        remove_unused_columns=False,
    )
    print(f"  max_steps: {CONFIG['max_steps']}")
    print(f"  warmup_steps: {CONFIG['warmup_steps']}")
    print(f"  save_steps: {CONFIG['save_steps']}")
else:
    # EPOCH-BASED training
    print("Mode: EPOCH-BASED training")
    training_args = TrainingArguments(
        output_dir=CONFIG["output_dir"],
        num_train_epochs=CONFIG["num_train_epochs"],
        per_device_train_batch_size=CONFIG["per_device_train_batch_size"],
        gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
        learning_rate=CONFIG["learning_rate"],
        warmup_ratio=CONFIG["warmup_ratio"],
        lr_scheduler_type=CONFIG["lr_scheduler_type"],
        bf16=CONFIG["bf16"],
        logging_steps=10,
        save_strategy=CONFIG["save_strategy"],
        save_total_limit=CONFIG["save_total_limit"],
        gradient_checkpointing=True,
        dataloader_num_workers=4,
        report_to="none",
        remove_unused_columns=False,
    )
    print(f"  num_train_epochs: {CONFIG['num_train_epochs']}")
    print(f"  warmup_ratio: {CONFIG['warmup_ratio']}")
    print(f"  save_strategy: {CONFIG['save_strategy']}")

# Custom collate function
def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch]),
        'images': torch.stack([x['images'] for x in batch]),
    }

print("\nTraining arguments ready!")

In [None]:
# Custom Trainer for multimodal
class VLMTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        images = inputs.pop('images').to(model.device, dtype=torch.bfloat16)
        input_ids = inputs['input_ids']
        labels = inputs['labels']
        attention_mask = inputs['attention_mask']
        
        # Encode images
        with torch.no_grad():
            image_features = model.vision_tower.forward_features(images)
            if image_features.dim() == 4:
                B, C, H, W = image_features.shape
                image_features = image_features.flatten(2).transpose(1, 2)
        
        # Project
        image_features = model.mm_projector(image_features)
        
        # Get text embeddings
        text_embeds = model.get_input_embeddings()(input_ids)
        
        # Concatenate image + text
        inputs_embeds = torch.cat([image_features, text_embeds], dim=1)
        
        # Adjust attention mask
        image_mask = torch.ones(
            images.size(0), image_features.size(1),
            device=attention_mask.device, dtype=attention_mask.dtype
        )
        attention_mask = torch.cat([image_mask, attention_mask], dim=1)
        
        # Adjust labels
        image_labels = torch.full(
            (images.size(0), image_features.size(1)),
            -100,
            device=labels.device, dtype=labels.dtype
        )
        labels = torch.cat([image_labels, labels], dim=1)
        
        # Forward
        outputs = model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels,
        )
        
        return (outputs.loss, outputs) if return_outputs else outputs.loss

# Create trainer
trainer = VLMTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
)

print("Trainer ready!")

In [None]:
# Start training
if CONFIG["training_mode"] == "steps":
    print(f"Starting training for {CONFIG['max_steps']} steps...")
else:
    print(f"Starting training for {CONFIG['num_train_epochs']} epochs...")
    
trainer.train()
print("Training completed!")

In [None]:
# Save model
print("Saving model...")
trainer.save_model(CONFIG["output_dir"])

# Save mm_projector separately (from PeftModel)
mm_projector_path = os.path.join(CONFIG["output_dir"], "mm_projector.bin")

# Access mm_projector from the model (works for both PeftModel and regular model)
if hasattr(model, 'mm_projector'):
    mm_proj = model.mm_projector
elif hasattr(model, 'base_model') and hasattr(model.base_model, 'mm_projector'):
    mm_proj = model.base_model.mm_projector
else:
    raise ValueError("Cannot find mm_projector in model!")

torch.save(mm_proj.state_dict(), mm_projector_path)
print(f"Saved mm_projector to {mm_projector_path}")
print(f"  Size: {os.path.getsize(mm_projector_path) / 1024 / 1024:.2f} MB")

## 8. Merge and Save

In [None]:
from peft import PeftModel

OUTPUT_DIR = CONFIG["output_dir"]
MERGED_DIR = os.path.join(OUTPUT_DIR, "merged")
os.makedirs(MERGED_DIR, exist_ok=True)

print("Loading base model for merging...")
base_model = AutoModelForCausalLM.from_pretrained(
    CONFIG["llm_model"],
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="cpu",
)

print("Loading LoRA adapter...")
merged_model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)

print("Merging weights...")
merged_model = merged_model.merge_and_unload()

# Load mm_projector weights
mm_projector_path = os.path.join(OUTPUT_DIR, "mm_projector.bin")
print(f"\nLoading mm_projector from {mm_projector_path}")
mm_projector_weights = torch.load(mm_projector_path, map_location="cpu")

print("mm_projector weights:")
for k, v in mm_projector_weights.items():
    print(f"  {k}: {v.shape}")

# Get state dict and add mm_projector
state_dict = merged_model.state_dict()

# Add mm_projector with correct key format
for k, v in mm_projector_weights.items():
    # Key format: "model.mm_projector.0.weight" etc
    new_key = f"model.mm_projector.{k}"
    state_dict[new_key] = v.to(torch.float16)
    print(f"Added: {new_key}")

# Save
merged_model.save_pretrained(MERGED_DIR, state_dict=state_dict, safe_serialization=True)
tokenizer.save_pretrained(MERGED_DIR)

print(f"\nModel saved to: {MERGED_DIR}")

In [None]:
# Create config for LLaVA-Ouro
import json

config_data = merged_model.config.to_dict()
config_data["model_type"] = "llava_ouro"
config_data["architectures"] = ["LlavaOuroForCausalLM"]
config_data["mm_vision_tower"] = CONFIG["vision_tower"]
config_data["mm_hidden_size"] = CONFIG["mm_hidden_size"]
config_data["mm_projector_type"] = "mlp2x_gelu"
config_data["auto_map"] = {
    "AutoConfig": "configuration_llava_ouro.LlavaOuroConfig",
    "AutoModelForCausalLM": "modeling_llava_ouro.LlavaOuroForCausalLM"
}

config_path = os.path.join(MERGED_DIR, "config.json")
with open(config_path, 'w') as f:
    json.dump(config_data, f, indent=2)

print("Config saved!")

In [None]:
# Verify
from safetensors import safe_open

safetensor_path = os.path.join(MERGED_DIR, "model.safetensors")
print(f"Model size: {os.path.getsize(safetensor_path) / 1024 / 1024:.2f} MB")

with safe_open(safetensor_path, framework="pt") as f:
    mm_keys = [k for k in f.keys() if 'mm_projector' in k]
    if mm_keys:
        print("\nmm_projector found:")
        for k in mm_keys:
            print(f"  {k}: {f.get_tensor(k).shape}")

## 9. Upload to HuggingFace

In [None]:
# Create model card
if CONFIG["training_mode"] == "steps":
    training_info = f"Steps: {CONFIG['max_steps']}"
else:
    training_info = f"Epochs: {CONFIG['num_train_epochs']}"

model_card = f"""---
license: apache-2.0
language:
- vi
- en
tags:
- vision-language-model
- vlm
- ouro
- looplm
- fastvlm
- vietnamese
base_model: {CONFIG['llm_model']}
datasets:
- {CONFIG['dataset_name']}
---

# Belle-VLM-Ouro: Vietnamese Vision Language Model

Built on **ByteDance Ouro-1.4B** (Looped Language Model).

## Architecture
- **LLM**: Ouro-1.4B (LoopLM, 4 recurrent steps)
- **Vision**: FastViTHD (MobileCLIP)
- **Projector**: MLP 3072 -> 2048

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
    "{CONFIG['hf_repo']}",
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto"
)
```

## Training
- Dataset: {CONFIG['dataset_name']}
- {training_info}
- LoRA: r={CONFIG['lora_r']}, alpha={CONFIG['lora_alpha']}
"""

with open(os.path.join(MERGED_DIR, "README.md"), "w") as f:
    f.write(model_card)

print("Model card created!")

In [None]:
# ============================================
# CREATE AND UPLOAD CUSTOM CODE FILES
# ============================================
# These files are REQUIRED for trust_remote_code=True

# 1. configuration_llava_ouro.py
config_code = '''# Configuration for LLaVA Ouro model
from transformers import PretrainedConfig

class LlavaOuroConfig(PretrainedConfig):
    """Configuration class for LLaVA Ouro model (LoopLM architecture)."""
    model_type = "llava_ouro"
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
        self,
        vocab_size=49152, hidden_size=2048, intermediate_size=5632,
        num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16,
        hidden_act="silu", max_position_embeddings=65536, initializer_range=0.02,
        rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False,
        rope_theta=1000000.0, rope_scaling=None, use_sliding_window=False,
        sliding_window=4096, max_window_layers=28, layer_types=None, attention_dropout=0.0,
        total_ut_steps=4, early_exit_threshold=1.0,
        mm_vision_tower=None, mm_hidden_size=None, mm_projector_type="mlp2x_gelu",
        mm_vision_select_layer=-2, mm_vision_select_feature="patch",
        mm_patch_merge_type="flat", mm_use_im_start_end=False,
        mm_use_im_patch_token=False, image_aspect_ratio="pad", **kwargs
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.use_sliding_window = use_sliding_window
        self.sliding_window = sliding_window if use_sliding_window else None
        self.max_window_layers = max_window_layers
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        self.attention_dropout = attention_dropout
        self.total_ut_steps = total_ut_steps
        self.early_exit_threshold = early_exit_threshold
        self.layer_types = layer_types or ["full_attention"] * num_hidden_layers
        self.mm_vision_tower = mm_vision_tower
        self.mm_hidden_size = mm_hidden_size
        self.mm_projector_type = mm_projector_type
        self.mm_vision_select_layer = mm_vision_select_layer
        self.mm_vision_select_feature = mm_vision_select_feature
        self.mm_patch_merge_type = mm_patch_merge_type
        self.mm_use_im_start_end = mm_use_im_start_end
        self.mm_use_im_patch_token = mm_use_im_patch_token
        self.image_aspect_ratio = image_aspect_ratio
        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
'''

with open("/tmp/configuration_llava_ouro.py", "w") as f:
    f.write(config_code)

print("Created configuration_llava_ouro.py")

In [None]:
# 2. modeling_llava_ouro.py (simplified for inference)
modeling_code = '''# LLaVA Ouro Model for HuggingFace
from typing import List, Optional, Tuple, Union
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_llava_ouro import LlavaOuroConfig

IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200

def build_vision_projector(config):
    projector_type = getattr(config, "mm_projector_type", "mlp2x_gelu")
    mm_hidden_size = getattr(config, "mm_hidden_size", 3072)
    hidden_size = config.hidden_size
    if projector_type == "mlp2x_gelu":
        return nn.Sequential(nn.Linear(mm_hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, hidden_size))
    return nn.Linear(mm_hidden_size, hidden_size)

class MobileCLIPVisionTower(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.is_loaded = False
        self.image_processor = None
        self.vision_tower = None
        self.hidden_size = getattr(config, "mm_hidden_size", 3072)
        try:
            self.image_size = int(getattr(config, "mm_vision_tower", "mobileclip_l_384").split("_")[-1])
        except:
            self.image_size = 384

    def load_model(self, **kwargs):
        if self.is_loaded:
            return
        try:
            import timm
            from transformers import CLIPImageProcessor
            self.vision_tower = timm.create_model("fastvit_mci2.apple_mclip", pretrained=True, num_classes=0)
            self.vision_tower.eval()
            self.image_processor = CLIPImageProcessor(size={"shortest_edge": self.image_size},
                crop_size={"height": self.image_size, "width": self.image_size}, do_normalize=True,
                image_mean=[0.48145466, 0.4578275, 0.40821073], image_std=[0.26862954, 0.26130258, 0.27577711])
            self.is_loaded = True
            print("MobileCLIP loaded!")
        except Exception as e:
            print(f"Could not load MobileCLIP: {e}")
            self.is_loaded = True

    def forward(self, images):
        if not self.is_loaded:
            self.load_model()
        with torch.no_grad():
            features = self.vision_tower.forward_features(images) if hasattr(self.vision_tower, "forward_features") else self.vision_tower(images)
        if features.dim() == 4:
            B, C, H, W = features.shape
            features = features.flatten(2).transpose(1, 2)
        elif features.dim() == 2:
            features = features.unsqueeze(1)
        return features

    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        if self.vision_tower is not None:
            self.vision_tower = self.vision_tower.to(*args, **kwargs)
        return self

class LlavaOuroForCausalLM(PreTrainedModel):
    config_class = LlavaOuroConfig
    _no_split_modules = ["OuroDecoderLayer"]

    def __init__(self, config: LlavaOuroConfig):
        super().__init__(config)
        self.vocab_size = config.vocab_size
        # Base model will be loaded from pretrained
        self.model = None
        self.lm_head = None
        # Vision modules
        if hasattr(config, "mm_vision_tower") and config.mm_vision_tower:
            self.vision_tower = MobileCLIPVisionTower(config)
            self.mm_projector = build_vision_projector(config)

    def get_vision_tower(self):
        return getattr(self, "vision_tower", None)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        config = kwargs.pop("config", None)
        if config is None:
            config = LlavaOuroConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        
        # Load base Ouro model
        base_model = AutoModelForCausalLM.from_pretrained(
            "ByteDance/Ouro-1.4B", trust_remote_code=True,
            torch_dtype=kwargs.get("torch_dtype", torch.float16),
            device_map=kwargs.get("device_map", "auto"),
        )
        
        # Create LLaVA wrapper
        model = cls(config)
        model.model = base_model.model
        model.lm_head = base_model.lm_head
        
        # Load mm_projector weights from safetensors
        try:
            from safetensors import safe_open
            import os
            if os.path.isdir(pretrained_model_name_or_path):
                sf_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
            else:
                from huggingface_hub import hf_hub_download
                sf_path = hf_hub_download(pretrained_model_name_or_path, "model.safetensors")
            
            with safe_open(sf_path, framework="pt") as f:
                for key in f.keys():
                    if "mm_projector" in key:
                        clean_key = key.replace("model.mm_projector.", "")
                        model.mm_projector.state_dict()[clean_key].copy_(f.get_tensor(key))
            print("Loaded mm_projector weights")
        except Exception as e:
            print(f"Could not load mm_projector: {e}")
        
        return model

    def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None,
                inputs_embeds=None, labels=None, images=None, image_sizes=None, **kwargs):
        # Simple forward - for full implementation see modeling_llava_ouro.py
        if inputs_embeds is None and input_ids is not None:
            inputs_embeds = self.model.embed_tokens(input_ids)
        
        outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
                            position_ids=position_ids, past_key_values=past_key_values, **kwargs)
        
        hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs.last_hidden_state
        logits = self.lm_head(hidden_states)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits[..., :-1, :].contiguous().view(-1, self.vocab_size),
                           labels[..., 1:].contiguous().view(-1))
        
        return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=getattr(outputs, "past_key_values", None))

    @torch.no_grad()
    def generate(self, inputs=None, images=None, **kwargs):
        if images is not None and self.get_vision_tower() is not None:
            vt = self.get_vision_tower()
            if not vt.is_loaded:
                vt.load_model()
                vt = vt.to(device=images.device, dtype=images.dtype)
            img_features = vt(images)
            img_features = self.mm_projector(img_features)
            inputs_embeds = self.model.embed_tokens(inputs)
            inputs_embeds = torch.cat([img_features, inputs_embeds], dim=1)
            kwargs["inputs_embeds"] = inputs_embeds
            kwargs["attention_mask"] = torch.ones(inputs_embeds.shape[:2], device=inputs_embeds.device)
            inputs = None
        return super().generate(inputs, **kwargs)

AutoConfig.register("llava_ouro", LlavaOuroConfig)
AutoModelForCausalLM.register(LlavaOuroConfig, LlavaOuroForCausalLM)
'''

with open("/tmp/modeling_llava_ouro.py", "w") as f:
    f.write(modeling_code)

print("Created modeling_llava_ouro.py")

In [None]:
# ============================================
# UPLOAD TO HUGGINGFACE (with custom code)
# ============================================
from huggingface_hub import HfApi, create_repo

api = HfApi(token=CONFIG["hf_token"])

# Create repo
create_repo(CONFIG["hf_repo"], exist_ok=True, token=CONFIG["hf_token"])

# 1. Upload model files
print(f"Uploading model to {CONFIG['hf_repo']}...")
api.upload_folder(
    folder_path=MERGED_DIR,
    repo_id=CONFIG["hf_repo"],
    commit_message="Upload Belle-VLM-Ouro model",
)
print("Model uploaded!")

# 2. Upload custom code files (REQUIRED for trust_remote_code)
print("\nUploading custom code files...")

api.upload_file(
    path_or_fileobj="/tmp/configuration_llava_ouro.py",
    path_in_repo="configuration_llava_ouro.py",
    repo_id=CONFIG["hf_repo"],
)
print("Uploaded configuration_llava_ouro.py")

api.upload_file(
    path_or_fileobj="/tmp/modeling_llava_ouro.py",
    path_in_repo="modeling_llava_ouro.py",
    repo_id=CONFIG["hf_repo"],
)
print("Uploaded modeling_llava_ouro.py")

print(f"\n{'='*50}")
print(f"Model uploaded successfully!")
print(f"https://huggingface.co/{CONFIG['hf_repo']}")
print(f"{'='*50}")

## 10. Test Inference

In [None]:
# ============================================
# TEST INFERENCE
# ============================================
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
import timm
from transformers import CLIPImageProcessor

print("Loading model for inference...")

# Load from merged directory or HuggingFace
MODEL_PATH = MERGED_DIR  # or CONFIG["hf_repo"]

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto",
)
model.eval()

# Load vision tower
vision_tower = timm.create_model("fastvit_mci2.apple_mclip", pretrained=True, num_classes=0)
vision_tower.eval()
vision_tower = vision_tower.to(model.device, dtype=torch.float16)

# Image processor
image_processor = CLIPImageProcessor(
    size={"shortest_edge": 384},
    crop_size={"height": 384, "width": 384},
    do_normalize=True,
    image_mean=[0.48145466, 0.4578275, 0.40821073],
    image_std=[0.26862954, 0.26130258, 0.27577711],
)

print("Model loaded!")

In [None]:
# Inference function
def inference(image_path, question, max_new_tokens=512):
    """Run inference on a single image."""
    # Load and process image
    image = Image.open(image_path).convert("RGB")
    image_tensor = image_processor(image, return_tensors="pt")["pixel_values"]
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    
    # Get image features
    with torch.no_grad():
        image_features = vision_tower.forward_features(image_tensor)
        if image_features.dim() == 4:
            B, C, H, W = image_features.shape
            image_features = image_features.flatten(2).transpose(1, 2)
        
        # Project to LLM space
        image_features = model.mm_projector(image_features)
    
    # Tokenize question
    prompt = f"User: {question}\nAssistant:"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Get text embeddings
    text_embeds = model.model.embed_tokens(inputs["input_ids"])
    
    # Concatenate image + text
    inputs_embeds = torch.cat([image_features, text_embeds], dim=1)
    attention_mask = torch.ones(inputs_embeds.shape[:2], device=model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

print("Inference function ready!")