# Fine-tune Llama 3.2 11B Vision Instruct with HuggingFace + PyTorch on AMD MI300X/MI325X

This tutorial demonstrates how to fine-tune **Llama 3.2 11B Vision Instruct** â€” a multimodal vision-language model â€” on the ChartQA dataset using **only HuggingFace Transformers, PEFT, and TRL**.

Unlike the text-only Unsloth version, this notebook trains with **actual chart images** from ChartQA, allowing the model to learn visual chart understanding end-to-end.

**Key Features:**
- Llama 3.2 11B Vision: true multimodal model with vision encoder + cross-attention
- Trains on real chart images (not just text descriptions)
- LoRA fine-tuning targeting language model + cross-attention layers
- ChartQA dataset with synthetic chain-of-thought reasoning
- Optimized for AMD ROCm (MI300X)

**References:**
- [Llama 3.2 Vision Model Card](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
- [TRL Multimodal SFT Guide](https://huggingface.co/docs/trl/main/en/training_vlm_sft)

## Prerequisites

### Hardware
- **AMD Instinct MI300X/MI325X**: This tutorial was designed for AMD Instinctâ„¢ MI300X GPUs with ROCm support.


## Table of Contents

1. [Environment Setup](#environment-setup)
2. [Data Preparation](#data-preparation)
3. [Model Loading](#model-loading)
4. [LoRA Configuration](#lora-config)
5. [Training](#training)
6. [Saving the Model](#saving-model)
7. [Load the Adapters](#lora-inference)

## 1. Environment Setup <a id="environment-setup"></a>

In [None]:
import os
import sys
import types
import torch

if "torchvision._meta_registrations" not in sys.modules:
    sys.modules["torchvision._meta_registrations"] = types.ModuleType(
        "torchvision._meta_registrations"
    )
    print("Applied torchvision _meta_registrations workaround for ROCm")

# Verify GPU setup
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA/ROCm available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Data Preparation <a id="data-preparation"></a>

We use the ChartQA dataset with synthetic chain-of-thought reasoning. Since Llama 3.2 Vision is a true multimodal model, we include the **actual chart images** in the training data.

In [None]:
from utils_hf import create_chart_qa_with_reasoning_dataset

create_chart_qa_with_reasoning_dataset(
    "reasoning.parquet",
    "cot_chartqa",
    override=True,
)

In [None]:
from datasets import load_dataset

print("Loading datasets...")

# Load original ChartQA (has images!)
original_chartqa = load_dataset("HuggingFaceM4/ChartQA")

# Load our CoT dataset
reasoning_dataset = load_dataset("cot_chartqa")

print(f"\nDataset Statistics:")
print(f"   Training samples: {len(reasoning_dataset['train'])}")
print(f"   Original ChartQA columns: {original_chartqa['train'].column_names}")

In [None]:
# Compare original vs reasoning versions
sample_idx = 27901

print(f"Sample Comparison (Index {sample_idx}):")
print(f"\nQuery: {original_chartqa['train'][sample_idx]['query']}")
print(f"\nOriginal Answer: {original_chartqa['train'][sample_idx]['label']}")
print(f"\nReasoning Version:\n{reasoning_dataset['train'][sample_idx]['label']}")

In [None]:
from IPython.display import display

# Display the sample chart image
sample = original_chartqa['train'][sample_idx]
if 'image' in sample:
    print("Sample Chart:")
    display(sample['image'])
    sample['image'].save('example_chart.png')
    print("Saved as example_chart.png")

### Format Dataset for Vision-Language Training

For Llama 3.2 Vision, we format data as **multimodal conversations** with `{"type": "image"}` and `{"type": "text"}` content items. The actual PIL images are included alongside the text.

In [None]:
from datasets import Dataset
from utils_hf import format_chartqa_for_vision_training

MAX_SAMPLES = 1000  # Adjust based on your needs

train_data = format_chartqa_for_vision_training(
    original_chartqa["train"],
    reasoning_dataset["train"],
    max_samples=MAX_SAMPLES,
)

train_dataset = Dataset.from_list(train_data)

print(f"\nDataset columns: {train_dataset.column_names}")
print(f"Total training samples: {len(train_dataset)}")

# Preview a sample
sample = train_dataset[0]
print(f"\nSample messages structure:")
for msg in sample["messages"]:
    print(f"  {msg['role']}: {[c['type'] for c in msg['content']]}")

Note: We formatted only 1000 examples from the whole data as an example.

## 3. Model Loading <a id="model-loading"></a>

Load Llama 3.2 11B Vision Instruct using `MllamaForConditionalGeneration` and `AutoProcessor`. The processor handles both text tokenization and image preprocessing.

In [None]:
from transformers import MllamaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig

model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"

max_seq_length = 2048
dtype = torch.bfloat16

print(f"Loading {model_name}...")
print("This is a ~22GB model in bf16, please be patient...")

# Load processor (handles both text tokenization and image preprocessing)
processor = AutoProcessor.from_pretrained(
    model_name,
    trust_remote_code=True,
)

# Ensure padding is on the right for training
processor.tokenizer.padding_side = "right"

# Ensure pad token is set
if processor.tokenizer.pad_token is None:
    processor.tokenizer.pad_token = processor.tokenizer.eos_token
    processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id

# Load model
model = MllamaForConditionalGeneration.from_pretrained(
    model_name,
    dtype=dtype,
    device_map="auto",
    trust_remote_code=True,
)

print(f"\nModel loaded successfully!")
print(f"   Model type: {type(model).__name__}")
print(f"   Dtype: {model.dtype}")

# Show model structure overview
total_params = sum(p.numel() for p in model.parameters())
print(f"   Total parameters: {total_params / 1e9:.2f}B")

## 4. LoRA Configuration <a id="lora-config"></a>

Add LoRA adapters using PEFT. For the vision model, we target the **language model** layers including cross-attention (which bridges vision and language). The vision encoder is kept frozen.

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

# Configure LoRA for the vision-language model
# We target the language model layers (including cross-attention to vision)
# The vision encoder is kept frozen
lora_config = LoraConfig(
    r=16,                    # LoRA rank
    lora_alpha=16,           # Scaling factor
    lora_dropout=0.05,       # Small dropout for regularization
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        # Language model self-attention
        "q_proj", "k_proj", "v_proj", "o_proj",
        # Language model MLP
        "gate_proj", "up_proj", "down_proj",
    ],
)

# Apply LoRA to the model
print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)

# Print trainable parameters
model.print_trainable_parameters()

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTrainable: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")
print(f"Total:     {total_params:,}")

## 5. Training <a id="training"></a>

For multimodal training with TRL's SFTTrainer, we need a custom `collate_fn` that:
1. Applies the chat template to format the conversation text
2. Processes images through the vision processor
3. Creates proper labels (masking padding and image tokens)

In [None]:
from PIL import Image
def collate_fn(examples):
    """
    Custom collator for multimodal vision-language training.
    
    For each example:
    1. Apply the Llama 3.2 Vision chat template to the messages
    2. Process images + text through the processor
    3. Create labels with proper masking
    """
    # Apply chat template to get formatted text
    texts = []
    for example in examples:
        text = processor.apply_chat_template(
            example["messages"],
            tokenize=False,
            add_generation_prompt=False,
        )
        texts.append(text.strip())
    
    # Collect images - each example has a list of images
    images = []
    for example in examples:
        example_images = example.get("images", [])
        if example_images:
            # Ensure all images are RGB PIL Images
            imgs = [img.convert("RGB") if isinstance(img, Image.Image) else img 
                    for img in example_images]
            images.append(imgs)
        else:
            images.append(None)
    
    # Process through the Mllama processor (tokenizes text + processes images)
    # The processor expects images as a list of lists (one list per batch item)
    batch = processor(
        images=images,
        text=texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_seq_length,
    )
    
    # Create labels: clone input_ids and mask tokens we don't want to compute loss on
    labels = batch["input_ids"].clone()
    
    # Mask padding tokens
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    # Mask the image token (<|image|>) so we don't compute loss on it
    image_token = processor.tokenizer.convert_tokens_to_ids("<|image|>")
    if image_token is not None and image_token != processor.tokenizer.unk_token_id:
        labels[labels == image_token] = -100
    
    batch["labels"] = labels
    return batch

# Quick test with one sample
print("Testing collate_fn with a single sample...")
test_batch = collate_fn([train_data[0]])
print(f"   input_ids shape: {test_batch['input_ids'].shape}")
print(f"   labels shape: {test_batch['labels'].shape}")
if 'pixel_values' in test_batch:
    print(f"   pixel_values shape: {test_batch['pixel_values'].shape}")
print(f"   Keys: {list(test_batch.keys())}")
print("collate_fn works!")

In [None]:
from trl import SFTConfig, SFTTrainer

print("Setting up trainer...")

# Training configuration optimized for AMD MI300X/MI325X
training_args = SFTConfig(
    # Batch settings
    # Vision models use more memory per sample than text-only models
    per_device_train_batch_size=1,   # Smaller batch due to image memory
    gradient_accumulation_steps=8,    # Effective batch = 1 * 8 = 8
    
    # Training duration
    num_train_epochs=1,
    max_steps=10,
    
    # Learning rate (lower than text-only since vision model is larger)
    learning_rate=5e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    
    # Optimization
    optim="adamw_torch",
    weight_decay=0.01,
    max_grad_norm=1.0,
    
    # Precision
    bf16=True,
    tf32=False,
    
    # Logging
    logging_steps=1,
    logging_dir="./logs",
    
    # Checkpointing
    output_dir="./checkpoints/chartqa_llama_vision",
    save_strategy="epoch",
    save_total_limit=2,
    
    # Memory optimization
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataloader_pin_memory=False,   
    dataloader_num_workers=0,       
    
    # IMPORTANT for multimodal training:
    remove_unused_columns=False,    # Keep image columns
    dataset_kwargs={"skip_prepare_dataset": True},  # We handle preprocessing in collate_fn
    
    # Misc
    seed=42,
    report_to="none",
    # report_to="wandb",
)

# Create trainer with custom collate_fn for multimodal data
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,       # Custom collator handles images + text
    processing_class=processor,      # Pass processor instead of tokenizer
)

print("Trainer configured!")
print(f"   Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"   Learning rate: {training_args.learning_rate}")
print(f"   Epochs: {training_args.num_train_epochs}")
print(f"   Max seq length: {max_seq_length}")

In [None]:
# Start training
print("Starting training...")
print("=" * 50)

trainer_stats = trainer.train()

print("=" * 50)
print("Training complete!")

The above cell (max_steps=10) will show the tiny fine-tuning which will finish very soon.

<img src="./assets/Screenshot 2026-02-12 020007.png" alt="20step-11b" width="700" height="450">

Note that full fine-tuning (5 epochs) will take about 25 hours to finish. During this time and if the WandBLogger was enabled we can follow the progress online.

<img src="./assets/Screenshot 2026-02-12 010520.png" alt="epoch10-11b" width="700" height="450">

## 6. Saving the Model <a id="saving-model"></a>

Save the fine-tuned model. For vision models, we save:
1. **LoRA adapters** (lightweight, recommended)
2. **Merged model** (full weights, optional)

In [None]:
# Save LoRA adapters only (lightweight, ~50-100MB)
lora_path = "./models/chartqa_llama_vision_lora"
print(f"Saving LoRA adapters to {lora_path}...")

model.save_pretrained(lora_path)
processor.save_pretrained(lora_path)

print(f"LoRA adapters saved!")

## 7. Loading LoRA Adapters & Inference Comparison <a id="lora-inference"></a>

Now let's demonstrate the full workflow of **loading saved LoRA adapters** from disk, merging them with the base model, and comparing inference outputs against the unmodified base model.

This is the key step for deployment: you train once, save the lightweight LoRA adapters (~50MB vs ~22GB full model), and later load + merge them for inference.

In [None]:
# First, free memory from the training model if it's still loaded
import gc

try:
    del model, trainer
except NameError:
    pass

gc.collect()
torch.cuda.empty_cache()
print(f"GPU memory freed. Available: {torch.cuda.mem_get_info()[0] / 1e9:.1f} GB")

### 7.1 Load the Base Model (No Fine-Tuning)

We load the original Llama 3.2 11B Vision Instruct model **without** any LoRA adapters to establish a baseline for comparison.

In [None]:
# Load the BASE model (no LoRA) for comparison
from transformers import MllamaForConditionalGeneration, AutoProcessor

base_model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"

print(f"Loading BASE model: {base_model_name} ...")
base_model = MllamaForConditionalGeneration.from_pretrained(
    base_model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)
base_model.eval()

base_processor = AutoProcessor.from_pretrained(
    base_model_name,
    trust_remote_code=True,
)

print(f"Base model loaded successfully!")
print(f"   Parameters: {sum(p.numel() for p in base_model.parameters()):,}")

### 7.2 Run Base Model Inference

Let's run the base model on a set of chart questions to see how it performs **before** fine-tuning.

In [None]:
import re
from PIL import Image
from datasets import load_dataset
from IPython.display import display, HTML

def run_inference(model, processor, question, image, max_new_tokens=512):
    """Run inference on a single question + image pair."""
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": f"Look at this chart and answer the following question.\n\nQuestion: {question}"}
            ]
        }
    ]
    
    input_text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    inputs = processor(
        images=[image],
        text=input_text,
        return_tensors="pt",
    ).to(model.device)
    
    input_len = inputs["input_ids"].shape[1]
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            use_cache=True,
        )
    
    response = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
    return response.strip()

# Prepare test samples from the ChartQA dataset
original_chartqa = load_dataset("HuggingFaceM4/ChartQA")

test_indices = [0, 100, 500, 27901]
test_samples = []

for idx in test_indices:
    sample = original_chartqa["train"][idx]
    if sample.get("image") is not None:
        test_samples.append({
            "idx": idx,
            "question": sample["query"],
            "ground_truth": sample["label"],
            "image": sample["image"].convert("RGB"),
        })

print(f"Prepared {len(test_samples)} test samples")
for s in test_samples:
    print(f"  Sample #{s['idx']}: {s['question'][:80]}...")

In [None]:
# Run inference with the BASE model
print("Running BASE model inference...")
print("=" * 60)

base_results = []
for s in test_samples:
    answer = run_inference(base_model, base_processor, s["question"], s["image"])
    base_results.append(answer)
    print(f"\nSample #{s['idx']}")
    print(f"  Q: {s['question']}")
    print(f"  Ground Truth: {s['ground_truth']}")
    print(f"  Base Model:   {answer}")

print("\n" + "=" * 60)
print(f"Base model inference complete ({len(base_results)} samples)")

### 7.3 Load LoRA Adapters & Merge with Base Model

Now we load the saved LoRA adapters from disk and merge them into the base model. This is the standard deployment workflow:

1. **Load base model** (same as before)
2. **Load LoRA adapters** from the saved checkpoint
3. **Merge** adapters into the base model weights
4. **Run inference** with the enhanced model

In [None]:
import gc, os
gc.collect()
torch.cuda.empty_cache()

# â”€â”€ Configuration â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
HF_REPO_ID = "viani-sus/llama-3.2-11b-vision-chartqa-lora"
BASE_MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct"
# â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€

from transformers import MllamaForConditionalGeneration, AutoProcessor
from peft import PeftModel
from huggingface_hub import snapshot_download

# 1. Apply the LoRA adapters from the Hub on top of the base model
print(f"Applying LoRA adapters from Hub: {HF_REPO_ID} ...")
finetuned_model = PeftModel.from_pretrained(
    base_model,
    HF_REPO_ID,
    torch_dtype=torch.bfloat16,
)
finetuned_model.eval()

# 2. Load the processor
finetuned_processor = AutoProcessor.from_pretrained(HF_REPO_ID, trust_remote_code=True)

# Show adapter info
trainable = sum(p.numel() for p in finetuned_model.parameters() if p.requires_grad)
total = sum(p.numel() for p in finetuned_model.parameters())
print(f"Fine-tuned model loaded!")

# Show the adapter file sizes from the cached download
lora_path = snapshot_download(HF_REPO_ID, local_files_only=True)
adapter_files = [f for f in os.listdir(lora_path) if f.endswith(('.safetensors', '.bin', '.json'))]
total_size = sum(os.path.getsize(os.path.join(lora_path, f)) for f in adapter_files)
print(f"Adapter files: {len(adapter_files)} files, {total_size / 1e6:.1f} MB total")

### 7.4 Run Fine-Tuned Model Inference

Now let's run the same questions through the fine-tuned model and compare the outputs.

In [None]:
# Run inference with the FINE-TUNED model (LoRA merged)
print("Running FINE-TUNED model inference...")
print("=" * 60)

finetuned_results = []       # Full CoT responses
finetuned_extracted = []     # Extracted final answers

for s in test_samples:
    full_response = run_inference(finetuned_model, finetuned_processor, s["question"], s["image"])
    finetuned_results.append(full_response)
    
    print(f"\nSample #{s['idx']}")
    print(f"  Q: {s['question']}")
    print(f"  Ground Truth:      {s['ground_truth']}")
    print(f"  Full CoT Response: {full_response}")

print("\n" + "=" * 60)
print(f"Fine-tuned model inference complete ({len(finetuned_results)} samples)")

### 7.5 Side-by-Side Comparison

Let's visualize the results side by side to clearly see the improvement from fine-tuning. The fine-tuned model should provide **chain-of-thought reasoning** before giving the final answer, while the base model typically gives short, often incorrect answers.

In [None]:
# Side-by-side comparison with images
from IPython.display import display, HTML

print("=" * 80)
print("COMPARISON: Base Model vs Fine-Tuned Model (LoRA)")
print("=" * 80)

for i, s in enumerate(test_samples):
    print(f"\n{'â”€' * 80}")
    print(f"SAMPLE #{s['idx']}")
    print(f"{'â”€' * 80}")
    
    # Display the chart image
    display(s["image"].resize((400, 300)))
    
    print(f"\nQuestion:     {s['question']}")
    print(f"Ground Truth: {s['ground_truth']}")
    print(f"\n--- Base Model (before fine-tuning) ---")
    print(f"{base_results[i]}")
    print(f"\n--- Fine-Tuned Model (with LoRA) ---")
    print(f"Full CoT: {finetuned_results[i]}")

gt_answers = [s["ground_truth"] for s in test_samples]

In [None]:
# Cleanup
del finetuned_model
gc.collect()
torch.cuda.empty_cache()

print("Inference comparison complete!")
print("\nKey Takeaways:")
print("  1. The base model gives generic/short answers to chart questions, even if correct, hard to extract")
print("  2. The fine-tuned model provides chain-of-thought reasoning before the answer")
print("  3. LoRA adapters are lightweight (~50MB) vs full model weights (~22GB)")
print("  4. Loading and merging LoRA adapters takes seconds, not minutes")

In [None]:
print("ðŸŽ‰ Tutorial complete! Thank You!")