## Conclusion: Root Cause and Fix

### Root Cause Analysis

The dtype mismatch error occurs because:

1. **Stage 2 is wrapped with DataParallel** while Stage 1 is not
   - Evidence: Stage 1 shows `0/136` steps, Stage 2 shows `0/68` steps (half!)
   - The traceback includes `torch/nn/parallel/data_parallel.py`

2. **DataParallel + AMP/autocast incompatibility**
   - DataParallel uses worker threads for each GPU replica
   - AMP/autocast context is **thread-local** - not inherited by worker threads
   - In k-bit training, some layers (norms) run in fp32 for stability
   - Without autocast in worker threads, fp32 activations flow to bf16 `lm_head` → crash

3. **Why Stage 2 got DataParallel**
   - The process sees 2 GPUs (`n_gpu=2`), and Trainer wraps DataParallel unless it detects model-parallelism.
   - Model-parallelism is only detected if `hf_device_map` spans >1 GPU (or `is_parallelizable`+`model_parallel` are set).
   - If `device_map` produces `hf_device_map={'': 0}` (single-GPU), Trainer will choose DataParallel by default.

### The Fix (Implemented in `scripts/train_recipe.py`)

```python
# Fix for two-stage training: When loading from a checkpoint with multi-GPU,
# force Trainer to behave as single-GPU to prevent DataParallel wrapping.
# Note: this only prevents Trainer's DataParallel; it does not force model sharding.
if checkpoint_path and torch.cuda.device_count() > 1:
    training_args_recipe._n_gpu = 1
```

**Why this works:**
- Setting `_n_gpu=1` tells Trainer "don't use DataParallel"
- If the model is sharded (multi-GPU `hf_device_map`), it can still use multiple GPUs via `device_map`
- If the model fits on one GPU (`hf_device_map={'': 0}`), Stage 2 will run single-GPU (other GPU mostly idle)
- This fix is conditional - only applies when loading from checkpoint
- Other recipes (r1, r2, r4) that load from HuggingFace are unaffected

### Verified Working (2024-12-15)

```
Stage 1: 136 steps completed (vision encoder training)
Stage 2: 136 steps completed (LLM LoRA training from checkpoint)
```

# Debug: Dtype Issue in Two-Stage Training

This notebook investigates why loading a trained checkpoint with quantization causes dtype mismatch.

**The Error:**
```
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16
```

**Hypothesis:** The Stage 1 checkpoint (trained with gradient checkpointing) has some float32 parameters that don't get properly converted when loaded with BitsAndBytes quantization.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5"  # Use two GPUs for testing (physical IDs 4,5)

import torch
from transformers import BitsAndBytesConfig, Qwen2VLForConditionalGeneration
from collections import Counter

print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"PyTorch sees {torch.cuda.device_count()} GPU(s):")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")



CUDA_VISIBLE_DEVICES: 4,5
PyTorch version: 2.4.1+cu121
CUDA available: True
PyTorch sees 2 GPU(s):
  GPU 0: NVIDIA RTX 6000 Ada Generation
  GPU 1: NVIDIA RTX 6000 Ada Generation


In [2]:
# Paths
HUGGINGFACE_MODEL = "Qwen/Qwen2-VL-7B-Instruct"
STAGE1_CHECKPOINT = "/ssd1/zhuoyuan/vlm_outputs/qwen2vl-nutrition-detection-r3-stage1"

def analyze_model_dtypes(model, prefix=""):
    """Analyze and print dtype distribution of model parameters."""
    dtype_counts = {"vision": Counter(), "llm": Counter(), "other": Counter()}
    
    for name, param in model.named_parameters():
        if ".visual." in name:
            dtype_counts["vision"][str(param.dtype)] += 1
        elif "model.layers" in name or "lm_head" in name:
            dtype_counts["llm"][str(param.dtype)] += 1
        else:
            dtype_counts["other"][str(param.dtype)] += 1
    
    print(f"\n{prefix}Parameter dtype distribution:")
    print(f"  Vision encoder: {dict(dtype_counts['vision'])}")
    print(f"  LLM layers:     {dict(dtype_counts['llm'])}")
    print(f"  Other:          {dict(dtype_counts['other'])}")
    
    return dtype_counts

def print_sample_params(model, component="visual", n=5):
    """Print dtype of first n parameters from a component."""
    count = 0
    print(f"\nSample {component} parameters:")
    for name, param in model.named_parameters():
        if component in name:
            print(f"  {name}: {param.dtype}")
            count += 1
            if count >= n:
                break

## Test 1: Load HuggingFace Original (No Quantization)

Baseline: What dtype does the original model have?

In [3]:
print("="*60)
print("TEST 1: HuggingFace Original (bf16, no quantization)")
print("="*60)

model_hf_bf16 = Qwen2VLForConditionalGeneration.from_pretrained(
    HUGGINGFACE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

analyze_model_dtypes(model_hf_bf16, "[HF Original bf16] ")
print_sample_params(model_hf_bf16, "visual")

# Clean up
del model_hf_bf16
torch.cuda.empty_cache()

TEST 1: HuggingFace Original (bf16, no quantization)


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]


[HF Original bf16] Parameter dtype distribution:
  Vision encoder: {'torch.bfloat16': 391}
  LLM layers:     {'torch.bfloat16': 337}
  Other:          {'torch.bfloat16': 2}

Sample visual parameters:
  model.visual.patch_embed.proj.weight: torch.bfloat16
  model.visual.blocks.0.norm1.weight: torch.bfloat16
  model.visual.blocks.0.norm1.bias: torch.bfloat16
  model.visual.blocks.0.norm2.weight: torch.bfloat16
  model.visual.blocks.0.norm2.bias: torch.bfloat16


## Test 2: Load HuggingFace Original (With 4-bit Quantization)

This is what r1-llm-only does. Should work fine.

In [4]:
print("="*60)
print("TEST 2: HuggingFace Original (4-bit quantization)")
print("="*60)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model_hf_4bit = Qwen2VLForConditionalGeneration.from_pretrained(
    HUGGINGFACE_MODEL,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

analyze_model_dtypes(model_hf_4bit, "[HF Original 4-bit] ")
print_sample_params(model_hf_4bit, "visual")

# Clean up
del model_hf_4bit
torch.cuda.empty_cache()

TEST 2: HuggingFace Original (4-bit quantization)


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]


[HF Original 4-bit] Parameter dtype distribution:
  Vision encoder: {'torch.bfloat16': 261, 'torch.uint8': 130}
  LLM layers:     {'torch.uint8': 196, 'torch.bfloat16': 141}
  Other:          {'torch.bfloat16': 2}

Sample visual parameters:
  model.visual.patch_embed.proj.weight: torch.bfloat16
  model.visual.blocks.0.norm1.weight: torch.bfloat16
  model.visual.blocks.0.norm1.bias: torch.bfloat16
  model.visual.blocks.0.norm2.weight: torch.bfloat16
  model.visual.blocks.0.norm2.bias: torch.bfloat16


## Test 3: Load Stage 1 Checkpoint (No Quantization)

What dtype does the trained checkpoint have when loaded without quantization?

In [5]:
print("="*60)
print("TEST 3: Stage 1 Checkpoint (bf16, no quantization)")
print("="*60)

model_stage1_bf16 = Qwen2VLForConditionalGeneration.from_pretrained(
    STAGE1_CHECKPOINT,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

analyze_model_dtypes(model_stage1_bf16, "[Stage1 Checkpoint bf16] ")
print_sample_params(model_stage1_bf16, "visual")

# Clean up
del model_stage1_bf16
torch.cuda.empty_cache()

TEST 3: Stage 1 Checkpoint (bf16, no quantization)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


[Stage1 Checkpoint bf16] Parameter dtype distribution:
  Vision encoder: {'torch.bfloat16': 391}
  LLM layers:     {'torch.bfloat16': 337}
  Other:          {'torch.bfloat16': 2}

Sample visual parameters:
  model.visual.patch_embed.proj.weight: torch.bfloat16
  model.visual.blocks.0.norm1.weight: torch.bfloat16
  model.visual.blocks.0.norm1.bias: torch.bfloat16
  model.visual.blocks.0.norm2.weight: torch.bfloat16
  model.visual.blocks.0.norm2.bias: torch.bfloat16


## Test 4: Load Stage 1 Checkpoint (With 4-bit Quantization)

**THIS IS THE PROBLEMATIC CASE!** 

This is what r3-Stage2 tries to do and fails.

In [6]:
print("="*60)
print("TEST 4: Stage 1 Checkpoint (4-bit quantization)")
print("This is what causes the dtype mismatch!")
print("="*60)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model_stage1_4bit = Qwen2VLForConditionalGeneration.from_pretrained(
    STAGE1_CHECKPOINT,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

analyze_model_dtypes(model_stage1_4bit, "[Stage1 Checkpoint 4-bit] ")
print_sample_params(model_stage1_4bit, "visual")

# Check lm_head dtype
print("\nlm_head dtype:")
for name, param in model_stage1_4bit.named_parameters():
    if "lm_head" in name:
        print(f"  {name}: {param.dtype}")
        break

TEST 4: Stage 1 Checkpoint (4-bit quantization)
This is what causes the dtype mismatch!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


[Stage1 Checkpoint 4-bit] Parameter dtype distribution:
  Vision encoder: {'torch.bfloat16': 261, 'torch.uint8': 130}
  LLM layers:     {'torch.uint8': 196, 'torch.bfloat16': 141}
  Other:          {'torch.bfloat16': 2}

Sample visual parameters:
  model.visual.patch_embed.proj.weight: torch.bfloat16
  model.visual.blocks.0.norm1.weight: torch.bfloat16
  model.visual.blocks.0.norm1.bias: torch.bfloat16
  model.visual.blocks.0.norm2.weight: torch.bfloat16
  model.visual.blocks.0.norm2.bias: torch.bfloat16

lm_head dtype:
  lm_head.weight: torch.bfloat16


## Test 5: Apply Our Fix

Cast vision encoder parameters to bf16 after loading.

In [7]:
print("="*60)
print("TEST 5: Apply Fix - Cast vision encoder to bf16")
print("="*60)

# Apply fix: cast float32 vision params to bf16
cast_count = 0
for name, param in model_stage1_4bit.named_parameters():
    if ".visual." in name and param.dtype == torch.float32:
        param.data = param.data.to(torch.bfloat16)
        cast_count += 1

print(f"Cast {cast_count} vision encoder parameters to bf16")

analyze_model_dtypes(model_stage1_4bit, "[After Fix] ")
print_sample_params(model_stage1_4bit, "visual")

TEST 5: Apply Fix - Cast vision encoder to bf16
Cast 0 vision encoder parameters to bf16

[After Fix] Parameter dtype distribution:
  Vision encoder: {'torch.bfloat16': 261, 'torch.uint8': 130}
  LLM layers:     {'torch.uint8': 196, 'torch.bfloat16': 141}
  Other:          {'torch.bfloat16': 2}

Sample visual parameters:
  model.visual.patch_embed.proj.weight: torch.bfloat16
  model.visual.blocks.0.norm1.weight: torch.bfloat16
  model.visual.blocks.0.norm1.bias: torch.bfloat16
  model.visual.blocks.0.norm2.weight: torch.bfloat16
  model.visual.blocks.0.norm2.bias: torch.bfloat16


Therefore, from the output, we can see this fix doesn't change anything, so we reverted the fix instead. Casting float32 vision params to bf16 is not the right solution.

## Test 6: Verify Forward Pass Works

Try a simple forward pass to confirm the fix resolves the dtype mismatch.

In [8]:
print("="*60)
print("TEST 6: Forward Pass Test")
print("="*60)

from transformers import Qwen2VLProcessor
from PIL import Image
import requests

# Load processor
processor = Qwen2VLProcessor.from_pretrained(
    HUGGINGFACE_MODEL,
    trust_remote_code=True,
)

# Create a simple test input
# Use a small test image
test_image = Image.new('RGB', (224, 224), color='red')

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": test_image},
            {"type": "text", "text": "What is this?"}
        ]
    }
]

# Apply chat template
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Process inputs
inputs = processor(
    text=[text],
    images=[test_image],
    return_tensors="pt",
    padding=True,
)

# Move to GPU
inputs = {k: v.to(model_stage1_4bit.device) if hasattr(v, 'to') else v for k, v in inputs.items()}

# Try forward pass
try:
    with torch.no_grad():
        outputs = model_stage1_4bit(**inputs)
    print("Forward pass SUCCESSFUL!")
    print(f"  Output logits shape: {outputs.logits.shape}")
    print(f"  Output logits dtype: {outputs.logits.dtype}")
except RuntimeError as e:
    print(f"Forward pass FAILED!")
    print(f"  Error: {e}")

TEST 6: Forward Pass Test


The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


Forward pass SUCCESSFUL!
  Output logits shape: torch.Size([1, 89, 152064])
  Output logits dtype: torch.bfloat16


## Summary

Compare the dtype distributions across all tests to understand the issue.

In [9]:
print("="*60)
print("ACTUAL FINDINGS (Updated after running)")
print("="*60)
print("""
ACTUAL RESULTS:

1. HF Original (bf16):        Vision = bf16 (391), LLM = bf16 (337)
2. HF Original (4-bit):       Vision = bf16 (261) + uint8 (130), LLM = uint8 (196) + bf16 (141)
3. Stage1 Checkpoint (bf16):  Vision = bf16 (391), LLM = bf16 (337)
4. Stage1 Checkpoint (4-bit): Vision = bf16 (261) + uint8 (130), LLM = uint8 (196) + bf16 (141)
5. After Fix:                 Cast 0 parameters (no float32 found!)
6. Forward Pass:              SUCCESSFUL!

CONCLUSION: 
- The vision encoder parameters are NOT float32!
- My original hypothesis was WRONG.
- The forward pass works fine in inference mode on single GPU.

THE REAL ISSUE might be:
- Multi-GPU + DataParallel interaction (original error shows data_parallel.py)
- PEFT/LoRA wrapping the model during training
- Training mode vs Inference mode differences

The error only happens during TRAINING with:
- 2 GPUs (device_map="balanced")
- SFTTrainer + PEFT/LoRA
- Training mode (gradients enabled)

Next step: Test with PEFT + training mode to reproduce the issue.
""")

# Clean up
del model_stage1_4bit
torch.cuda.empty_cache()
print("\nGPU memory cleaned up.")

ACTUAL FINDINGS (Updated after running)

ACTUAL RESULTS:

1. HF Original (bf16):        Vision = bf16 (391), LLM = bf16 (337)
2. HF Original (4-bit):       Vision = bf16 (261) + uint8 (130), LLM = uint8 (196) + bf16 (141)
3. Stage1 Checkpoint (bf16):  Vision = bf16 (391), LLM = bf16 (337)
4. Stage1 Checkpoint (4-bit): Vision = bf16 (261) + uint8 (130), LLM = uint8 (196) + bf16 (141)
5. After Fix:                 Cast 0 parameters (no float32 found!)
6. Forward Pass:              SUCCESSFUL!

CONCLUSION: 
- The vision encoder parameters are NOT float32!
- My original hypothesis was WRONG.
- The forward pass works fine in inference mode on single GPU.

THE REAL ISSUE might be:
- Multi-GPU + DataParallel interaction (original error shows data_parallel.py)
- PEFT/LoRA wrapping the model during training
- Training mode vs Inference mode differences

The error only happens during TRAINING with:
- 2 GPUs (device_map="balanced")
- SFTTrainer + PEFT/LoRA
- Training mode (gradients enabled)

Nex

## Part 2: DataParallel Investigation

The tests above showed that parameter dtypes are correct, but the error only occurs during training with:
- 2 GPUs with `device_map="balanced"`
- SFTTrainer + PEFT/LoRA
- Training mode (gradients enabled)

**Key observation from error logs:**
- Stage 1: Progress bar shows `0/136` steps (expected)
- Stage 2: Progress bar shows `0/68` steps (half!)
- Stage 2 traceback includes `torch/nn/parallel/data_parallel.py`

**Hypothesis:** Stage 2 is being wrapped with DataParallel while Stage 1 is not.

HuggingFace Trainer prevents DataParallel when it detects model parallelism via:
- `model.hf_device_map` containing multiple devices
- `model.is_parallelizable=True` and `model.model_parallel=True`

**Important detail:** `hf_device_map` being present is not enough — Trainer only treats it as *model-parallel* if the map spans >1 GPU (or if `is_parallelizable` + `model_parallel` are set).

Also, `device_map="auto"` / `"balanced"` may keep a 4-bit model on a single GPU if it fits. So seeing `hf_device_map={'': 0}` is normal and means *no sharding happened*.

If you need to **force** multi-GPU sharding for debugging, you must add explicit constraints (e.g. `max_memory`) so GPU 0 cannot hold the full model.

**Reminder:** inside this notebook, GPUs are renumbered to `0,1` (these correspond to physical GPUs `4,5`).

Let's investigate what attributes differ between HuggingFace loading vs checkpoint loading.

### Clarifying: `device_map` vs DataParallel

There are **two independent ways** a single Python process can "use 2 GPUs":

1. **Model sharding (`device_map` / `hf_device_map`)**
   - Splits *weights* across GPUs (model-parallel).
   - Trainer will avoid DataParallel only if `hf_device_map` spans >1 GPU (or `is_parallelizable`+`model_parallel` are set).

2. **DataParallel (Trainer wrapping)**
   - Replicates the *entire model* onto each visible GPU and splits the batch.
   - **Does not require sharding** → it can happen even when `hf_device_map={'': 0}`.

**What the Stage 2 fix changes:** forcing `_n_gpu=1` disables Trainer's DataParallel wrapping. If the model is not sharded (e.g. `hf_device_map={'': 0}`), training becomes single-GPU and the other GPU will look idle. If the model *is* sharded, you can still use multiple GPUs via model-parallel sharding even with `_n_gpu=1`.

In [10]:
# Helper functions for DataParallel analysis
def analyze_trainer_attributes(model, source_name):
    """Check attributes that determine if Trainer uses DataParallel.
    
    HuggingFace Trainer (trainer.py) checks these attributes to decide
    whether to wrap the model with DataParallel:
    
    1. is_parallelizable + model_parallel: If both True, Trainer sets n_gpu=1
    2. hf_device_map: If present with multiple GPUs, Trainer sets n_gpu=1
    3. is_loaded_in_8bit: If True, Trainer skips DataParallel wrapping
    
    DataParallel is applied when: n_gpu > 1 AND not is_loaded_in_8bit
    """
    print(f"\n[{source_name}]")
    print(f"  hf_device_map: {getattr(model, 'hf_device_map', None)}")
    print(f"  is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}")
    print(f"  is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}")
    print(f"  is_parallelizable: {getattr(model, 'is_parallelizable', False)}")
    print(f"  model_parallel: {getattr(model, 'model_parallel', False)}")

def would_trainer_use_dataparallel(model, n_gpu=2):
    """Simulate Trainer's logic from trainer.py lines 503-624.
    
    Returns (would_use_dp, is_model_parallel, effective_n_gpu)
    """
    is_model_parallel = False

    # Check 1: Direct model parallelism flags
    if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False):
        is_model_parallel = True

    # Check 2: Device map with multiple GPUs
    hf_device_map = getattr(model, "hf_device_map", None)
    if hf_device_map is not None:
        devices = [d for d in set(hf_device_map.values()) if d not in ["cpu", "disk"]]
        if len(devices) > 1:
            is_model_parallel = True

    # Trainer sets n_gpu=1 if model_parallel detected
    effective_n_gpu = 1 if is_model_parallel else n_gpu

    # DataParallel condition: n_gpu > 1 AND not 8-bit
    would_use_dp = effective_n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False)

    return would_use_dp, is_model_parallel, effective_n_gpu

print("Helper functions defined.")

Helper functions defined.


### Test 7: Compare Model Attributes (HuggingFace vs Checkpoint)

Let's check the model attributes that Trainer uses to decide whether to apply DataParallel.

In [11]:
print("="*60)
print("TEST 7: Compare Model Attributes (HuggingFace vs Checkpoint)")
print("="*60)

# Load HuggingFace model with device_map (simulating what works)
print(f"\nPyTorch sees {torch.cuda.device_count()} GPU(s) in this process.")
print("\nLoading HuggingFace model with device_map='balanced'...")
model_hf = Qwen2VLForConditionalGeneration.from_pretrained(
    HUGGINGFACE_MODEL,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="balanced",
    trust_remote_code=True,
)

analyze_trainer_attributes(model_hf, "HuggingFace (4-bit, device_map='balanced')")
would_dp, is_mp, eff_ngpu = would_trainer_use_dataparallel(model_hf, n_gpu=torch.cuda.device_count())
print(f"\n  Would Trainer use DataParallel? {would_dp}")
print(f"  Is model parallel detected? {is_mp}")
print(f"  Effective n_gpu: {eff_ngpu}")

# Clean up
del model_hf
torch.cuda.empty_cache()

# Load Checkpoint model with device_map (simulating what fails)
print("\n" + "-"*60)
print("\nLoading Checkpoint model with device_map='balanced'...")
model_ckpt = Qwen2VLForConditionalGeneration.from_pretrained(
    STAGE1_CHECKPOINT,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="balanced",
    trust_remote_code=True,
)

analyze_trainer_attributes(model_ckpt, "Checkpoint (4-bit, device_map='balanced')")
would_dp, is_mp, eff_ngpu = would_trainer_use_dataparallel(model_ckpt, n_gpu=torch.cuda.device_count())
print(f"\n  Would Trainer use DataParallel? {would_dp}")
print(f"  Is model parallel detected? {is_mp}")
print(f"  Effective n_gpu: {eff_ngpu}")

TEST 7: Compare Model Attributes (HuggingFace vs Checkpoint)

PyTorch sees 2 GPU(s) in this process.

Loading HuggingFace model with device_map='balanced'...


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]


[HuggingFace (4-bit, device_map='balanced')]
  hf_device_map: {'': 0}
  is_loaded_in_4bit: True
  is_loaded_in_8bit: False
  is_parallelizable: False
  model_parallel: False

  Would Trainer use DataParallel? True
  Is model parallel detected? False
  Effective n_gpu: 2

------------------------------------------------------------

Loading Checkpoint model with device_map='balanced'...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


[Checkpoint (4-bit, device_map='balanced')]
  hf_device_map: {'': 0}
  is_loaded_in_4bit: True
  is_loaded_in_8bit: False
  is_parallelizable: False
  model_parallel: False

  Would Trainer use DataParallel? True
  Is model parallel detected? False
  Effective n_gpu: 2


### Test 8: Check Device Map Details

Let's examine what devices are in the device_map to understand if Trainer should detect model parallelism.

In [12]:
print("="*60)
print("TEST 8: Device Map Details")
print("="*60)

# Check device map structure
hf_device_map = getattr(model_ckpt, "hf_device_map", None)

if hf_device_map:
    print(f"\nDevice map has {len(hf_device_map)} entries")
    
    # Count devices
    device_counts = Counter(hf_device_map.values())
    print(f"Device distribution: {dict(device_counts)}")
    
    # Show first/last few entries
    items = list(hf_device_map.items())
    print(f"\nFirst 5 entries:")
    for k, v in items[:5]:
        print(f"  {k}: {v}")
    print(f"\nLast 5 entries:")
    for k, v in items[-5:]:
        print(f"  {k}: {v}")
    
    # Check GPU devices (excluding cpu/disk)
    gpu_devices = [d for d in set(hf_device_map.values()) if d not in ["cpu", "disk"]]
    print(f"\nGPU devices in device_map: {gpu_devices}")
    print(f"Number of GPU devices: {len(gpu_devices)}")
    print(f"\nShould Trainer detect model parallelism? {len(gpu_devices) > 1}")
else:
    print("No hf_device_map attribute found!")

# Clean up
del model_ckpt
torch.cuda.empty_cache()
print("\nGPU memory cleaned up.")

TEST 8: Device Map Details

Device map has 1 entries
Device distribution: {0: 1}

First 5 entries:
  : 0

Last 5 entries:
  : 0

GPU devices in device_map: [0]
Number of GPU devices: 1

Should Trainer detect model parallelism? False

GPU memory cleaned up.


### Interpreting the Test 8 output

- `hf_device_map={'': 0}` means `device_map` did **not** shard the model; everything lives on GPU 0.
- With 2 visible GPUs, Trainer would normally wrap **DataParallel** (replicate the model onto both GPUs) unless model-parallel is detected.
- The Stage 2 fix disables that wrapping, so you should expect **one GPU doing compute** unless you force sharding or switch to DDP.

### Test 9: Check PEFT Wrapping Effect

SFTTrainer internally wraps the model with PEFT. Let's check if PEFT preserves the attributes that Trainer uses for DataParallel detection.

**Hypothesis:** PEFT wrapping might lose `hf_device_map` or `is_parallelizable` attributes, causing Trainer to incorrectly apply DataParallel.

In [13]:
print("="*60)
print("TEST 9: PEFT Wrapping Effect on Model Attributes")
print("="*60)

from peft import LoraConfig, get_peft_model

# Reload checkpoint model
print("\nLoading checkpoint model...")
model_ckpt = Qwen2VLForConditionalGeneration.from_pretrained(
    STAGE1_CHECKPOINT,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="balanced",
    trust_remote_code=True,
)

# Check attributes BEFORE PEFT wrapping
print("\n" + "-"*40)
print("BEFORE PEFT Wrapping:")
print("-"*40)
analyze_trainer_attributes(model_ckpt, "Base Model")
would_dp, is_mp, eff_ngpu = would_trainer_use_dataparallel(model_ckpt, n_gpu=torch.cuda.device_count())
print(f"  Would use DataParallel: {would_dp}")

# Define LoRA config (same as r1-llm-only)
lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

# Apply PEFT
print("\nApplying PEFT (LoRA)...")
peft_model = get_peft_model(model_ckpt, lora_config)

# Check attributes AFTER PEFT wrapping
print("\n" + "-"*40)
print("AFTER PEFT Wrapping:")
print("-"*40)
analyze_trainer_attributes(peft_model, "PEFT Model")
would_dp, is_mp, eff_ngpu = would_trainer_use_dataparallel(peft_model, n_gpu=torch.cuda.device_count())
print(f"  Would use DataParallel: {would_dp}")

# Check if hf_device_map is preserved
print("\n" + "-"*40)
print("Summary:")
print("-"*40)
def _count_gpu_devices(device_map):
    if not device_map:
        return 0
    return len([d for d in set(device_map.values()) if d not in ["cpu", "disk"]])

base_device_map = getattr(model_ckpt, "hf_device_map", None)
peft_device_map = getattr(peft_model, "hf_device_map", None)
base_gpu_count = _count_gpu_devices(base_device_map)
peft_gpu_count = _count_gpu_devices(peft_device_map)

print(f"Base model hf_device_map: {'Present' if base_device_map else 'MISSING'} (gpu_devices={base_gpu_count})")
print(f"PEFT model hf_device_map: {'Present' if peft_device_map else 'MISSING'} (gpu_devices={peft_gpu_count})")

if peft_device_map and peft_gpu_count > 1:
    print("\nPEFT preserves a multi-GPU hf_device_map -> Trainer should detect model-parallel and avoid DataParallel")
elif peft_device_map:
    print("\nPEFT preserves hf_device_map, but it is single-GPU -> Trainer will NOT treat this as model-parallel")
else:
    print("\n⚠️ PEFT loses hf_device_map -> This could cause DataParallel issues!")

# Clean up
del model_ckpt, peft_model
torch.cuda.empty_cache()
print("\nGPU memory cleaned up.")

TEST 9: PEFT Wrapping Effect on Model Attributes



Loading checkpoint model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


----------------------------------------
BEFORE PEFT Wrapping:
----------------------------------------

[Base Model]
  hf_device_map: {'': 0}
  is_loaded_in_4bit: True
  is_loaded_in_8bit: False
  is_parallelizable: False
  model_parallel: False
  Would use DataParallel: True

Applying PEFT (LoRA)...

----------------------------------------
AFTER PEFT Wrapping:
----------------------------------------

[PEFT Model]
  hf_device_map: {'': 0}
  is_loaded_in_4bit: True
  is_loaded_in_8bit: False
  is_parallelizable: False
  model_parallel: False
  Would use DataParallel: True

----------------------------------------
Summary:
----------------------------------------
Base model hf_device_map: Present (gpu_devices=1)
PEFT model hf_device_map: Present (gpu_devices=1)

PEFT preserves hf_device_map, but it is single-GPU -> Trainer will NOT treat this as model-parallel

GPU memory cleaned up.


## Conclusion: Root Cause and Fix

### Root Cause Analysis

The dtype mismatch error occurs because:

1. **Stage 2 is wrapped with DataParallel** while Stage 1 is not
   - Evidence: Stage 1 shows `0/136` steps, Stage 2 shows `0/68` steps (half!)
   - The traceback includes `torch/nn/parallel/data_parallel.py`

2. **DataParallel + AMP/autocast incompatibility**
   - DataParallel uses worker threads for each GPU replica
   - AMP/autocast context is **thread-local** - not inherited by worker threads
   - In k-bit training, some layers (norms) run in fp32 for stability
   - Without autocast in worker threads, fp32 activations flow to bf16 `lm_head` → crash

3. **Why Stage 2 got DataParallel**
   - The process sees 2 GPUs (`n_gpu=2`), and Trainer wraps DataParallel unless it detects model-parallelism.
   - Model-parallelism is only detected if `hf_device_map` spans >1 GPU (or `is_parallelizable`+`model_parallel` are set).
   - If `device_map` produces `hf_device_map={'': 0}` (single-GPU), Trainer will choose DataParallel by default.

### The Fix (Implemented in `scripts/train_recipe.py`)

```python
# Fix for two-stage training: When loading from a checkpoint with multi-GPU,
# force Trainer to behave as single-GPU to prevent DataParallel wrapping.
# Note: this only prevents Trainer's DataParallel; it does not force model sharding.
if checkpoint_path and torch.cuda.device_count() > 1:
    training_args_recipe._n_gpu = 1
```

**Why this works:**
- Setting `_n_gpu=1` tells Trainer "don't use DataParallel"
- If the model is sharded (multi-GPU `hf_device_map`), it can still use multiple GPUs via `device_map`
- If the model fits on one GPU (`hf_device_map={'': 0}`), Stage 2 will run single-GPU (other GPU mostly idle)
- This fix is conditional - only applies when loading from checkpoint
- Other recipes (r1, r2, r4) that load from HuggingFace are unaffected

### Verified Working (2025-12-15)

```
Stage 1: 136 steps completed (vision encoder training)
Stage 2: 136 steps completed (LLM LoRA training from checkpoint)
```