In [None]:
from sagemaker.pytorch import PyTorch
import sagemaker
import boto3
import json

transformers#==4.37.2
accelerate#==0.25.0
datasets
peft#==0.6.2
bitsandbytes#==0.41.0
jiwer
wandb
triton#==2.0.0
einops#>=0.6.1

In [None]:
sagemaker_session = sagemaker.Session()
bucket = 'slip-ml'
role = 'arn:aws:iam::438465160412:role/Sagemaker'
project_name = 'llama-phoneme'

In [None]:
secret_name = "huggingface"
region_name = "us-east-1"
session = boto3.session.Session()
secretsmanager = session.client(service_name='secretsmanager', region_name=region_name)
get_secret_value_response = secretsmanager.get_secret_value(SecretId=secret_name)
secret = get_secret_value_response['SecretString']
api_key = json.loads(secret)["API_KEY"]

In [None]:
training_instances_gpus = {
    "ml.g5.2xlarge": 1,
    "ml.g5.12xlarge": 4,
    "ml.p4d.24xlarge": 8,
    "ml.p5.48xlarge": 8
}

In [None]:
instance_type = "ml.p4d.24xlarge"

In [None]:
sagemaker.image_uris.get_training_image_uri(framework='pytorch',
                            region=sagemaker_session.boto_region_name, 
                            instance_type=instance_type)

In [None]:
image_uri = sagemaker.image_uris.retrieve(framework='pytorch',
                            region=sagemaker_session.boto_region_name,
                            instance_type=instance_type,
                            image_scope='training'
                             )
print(image_uri)

In [None]:
estimator = PyTorch(
    entry_point="finetune_llama.py",
    source_dir="source",
    role=role,
    base_job_name=project_name,
    instance_count=1,  
    instance_type=instance_type,
    framework_version='2.2.0',
    py_version="py310",
    #distribution={"torch_distributed": {"enabled": True}},
    distribution={
            'smdistributed': {
                'dataparallel': {
                    'enabled': True,
                    'parameters': {
                        'sharded_data_parallel_degree': str(training_instances_gpus[instance_type]),
                    }
                }
            }
        },
    hyperparameters={
        "batch-size": 2,
        "epochs": 7,
        "lr": 3e-4,
        "seed": 1,
        "project-name": f"{project_name}",
        "bucket": f"{bucket}",
    },
    environment={
        "HF_TOKEN": "" + api_key,
    },
    sagemaker_session=sagemaker_session,
    volume_size=100,
    output_path=f's3://{bucket}/models/{project_name}',
    code_location=f's3://{bucket}/model-building/{project_name}'
)

In [None]:
estimator.fit({'training': f's3://{bucket}/data/transcriptions/train/',
               'test': f's3://{bucket}/data/transcriptions/test/'})

The error you're encountering in AWS SageMaker, `RuntimeError: [2]: params[0] in this process with sizes [0] appears not to match sizes of the same param in process 0`, indicates a mismatch in the model parameter sizes across different processes during distributed training with PyTorch's Fully Sharded Data Parallel (FSDP). This typically happens when the model parameters are not consistently initialized or synchronized across all processes (GPUs) in a distributed setup. Below, I'll analyze the issue based on your provided code (`finetune_llama.py`) and suggest fixes to resolve this error.

### Root Cause Analysis
The error occurs in the `torch.nn.parallel.DistributedDataParallel` (DDP) or FSDP setup, specifically during parameter verification across processes. The key points from the stack trace and code are:

1. **Parameter Size Mismatch**: The error suggests that the first parameter (`params[0]`) in process 2 (or process 6 in another instance) has a size of `[0]`, while process 0 has a different size. This indicates that the model parameters are not identical across all ranks, which is a requirement for distributed training.

2. **FSDP and LoRA**: Your code uses FSDP with LoRA (Low-Rank Adaptation) via the `peft` library. LoRA modifies the model by adding adapter layers, which can sometimes cause inconsistencies if not applied uniformly across all processes. The error likely stems from the model initialization or LoRA application phase.

3. **SageMaker and SMDDP Backend**: You're using the `smddp` backend (`smdistributed.dataparallel.torch.torch_smddp`) for distributed training in SageMaker. This backend integrates with PyTorch's distributed training but may have specific requirements for model synchronization.

4. **Parameter Synchronization**: The code attempts to synchronize parameters using `dist.broadcast` for `requires_grad` parameters after applying LoRA and moving the model to the GPU. However, the synchronization might not be covering all parameters or might be failing due to incorrect handling of the model state.

5. **Potential Issues**:
   - **LoRA Application Inconsistency**: The LoRA adapters might not be applied consistently across all ranks, leading to different parameter counts.
   - **FSDP Wrapping**: The model is wrapped with FSDP after LoRA is applied, but the parameter synchronization might not account for FSDP's sharding.
   - **Device Placement**: Moving the model to a specific GPU (`cuda:{device_id}`) before FSDP wrapping and synchronization could cause issues if the model state is not properly aligned.
   - **Dataset or Model Initialization**: If the model or dataset sharding is not deterministic across ranks, it could indirectly affect the training setup.

### Suggested Fixes
Here are step-by-step fixes to address the parameter mismatch error:

#### 1. **Ensure Consistent Model Initialization Across Ranks**
The model must be initialized identically on all ranks before applying FSDP or LoRA. The current code loads the model and applies LoRA before synchronization, which is correct, but we need to ensure that no rank-specific operations interfere.

**Fix**: Move the model to the CPU initially and only move it to the GPU after FSDP wrapping to avoid device-specific issues. Also, ensure the model is fully synchronized before FSDP wrapping.

**Modified Code** (in the `train` function, replace the model loading and LoRA application section):

```python
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    token=os.environ.get("HF_TOKEN"),
    use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token

# Load model configuration
config = AutoConfig.from_pretrained(model_name)
config.use_cache = False

# Determine compute dtype
compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
logger.info(f"Using {compute_dtype} as compute dtype")

# Load model on CPU to ensure consistency
logger.info(f"Loading base model: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype=compute_dtype,
    token=os.environ.get("HF_TOKEN"),
    device_map=None  # Keep on CPU initially
)

# Apply LoRA configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    inference_mode=False
)
logger.info("Applying LoRA adapters to model")
model = get_peft_model(model, lora_config)

# Synchronize model parameters across all ranks
if dist.is_initialized():
    logger.info(f"Synchronizing model parameters across processes (rank {rank})")
    for param in model.parameters():
        dist.broadcast(param.data, src=0)  # Broadcast all parameters, not just requires_grad
    dist.barrier()
    logger.info(f"Model parameters synchronized on rank {rank}")

# Move model to GPU after synchronization
device_id = local_rank
model = model.to(f"cuda:{device_id}")
logger.info(f"Moved model to device cuda:{device_id}")

# Apply FSDP wrapping
fsdp_kwargs = {
    "auto_wrap_policy": auto_wrap_policy,
    "mixed_precision": mixed_precision_policy,
    "sharding_strategy": sharding_strategy,
    "limit_all_gathers": True,
    "cpu_offload": cpu_offload,
    "device_id": torch.cuda.current_device(),
    "use_orig_params": True
}
if backward_prefetch is not None:
    fsdp_kwargs["backward_prefetch"] = backward_prefetch
if args.forward_prefetch:
    fsdp_kwargs["forward_prefetch"] = True

model = FSDP(model, **fsdp_kwargs)
logger.info(f"Created FSDP model with configuration: {fsdp_kwargs}")
```

**Changes Made**:
- Load the model on CPU first to avoid GPU-specific initialization issues.
- Synchronize all parameters (not just `requires_grad` ones) using `dist.broadcast` to ensure consistency.
- Move the model to GPU only after synchronization.
- Apply FSDP wrapping after synchronization and device placement.

#### 2. **Disable `use_orig_params` in FSDP**
The `use_orig_params=True` setting in FSDP can sometimes cause issues with parameter synchronization, especially with PEFT models like LoRA. Setting `use_orig_params=False` forces FSDP to create new parameter tensors, which can help ensure consistency.

**Fix**: Update the FSDP configuration to set `use_orig_params=False`.

**Modified Code** (in the FSDP setup section):

```python
fsdp_kwargs = {
    "auto_wrap_policy": auto_wrap_policy,
    "mixed_precision": mixed_precision_policy,
    "sharding_strategy": sharding_strategy,
    "limit_all_gathers": True,
    "cpu_offload": cpu_offload,
    "device_id": torch.cuda.current_device(),
    "use_orig_params": False  # Changed to False
}
```

#### 3. **Verify LoRA Configuration**
Ensure that the LoRA configuration is applied identically across all ranks. The current code applies LoRA before synchronization, which is fine, but we need to verify that the LoRA parameters are correctly initialized.

**Fix**: Add logging to check the number of parameters after applying LoRA to ensure consistency.

**Modified Code** (after applying LoRA):

```python
logger.info("Applying LoRA adapters to model")
model = get_peft_model(model, lora_config)
if rank == 0:
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Total parameters: {total_params}, Trainable parameters: {trainable_params}")
```

This will help you confirm that the model has the same number of parameters on all ranks before synchronization.

#### 4. **Check SMDDP Backend Compatibility**
The `smddp` backend (SageMaker Data Parallel) might have specific requirements for model synchronization. The current code uses `dist.broadcast` for synchronization, which is correct, but we should ensure that the SMDDP backend is properly initialized.

**Fix**: Verify that the `smddp` backend is correctly set up and that all processes are properly initialized.

**Modified Code** (at the start of the script):

```python
import smdistributed.dataparallel.torch.torch_smddp
backend = "smddp"
dist.init_process_group(backend=backend)
logger.info(f"Initialized distributed training with backend: {backend}, rank: {dist.get_rank()}, world_size: {dist.get_world_size()}")

# Verify SMDDP initialization
if not dist.is_initialized():
    raise RuntimeError("Distributed process group not initialized correctly")
```

#### 5. **Ensure Dataset Consistency**
The dataset sharding in the code looks correct, but ensure that the dataset is loaded and processed identically across all ranks to avoid indirect effects on training.

**Fix**: Add logging to verify dataset sizes after sharding.

**Modified Code** (in the dataset loading section):

```python
logger.info(f"Rank {dist.get_rank()} has {len(train_dataset)} training samples and {len(eval_dataset)} eval samples")
if dist.is_initialized():
    dist.barrier()  # Ensure all ranks have loaded the dataset
    train_sizes = [0] * dist.get_world_size()
    eval_sizes = [0] * dist.get_world_size()
    train_sizes[dist.get_rank()] = len(train_dataset)
    eval_sizes[dist.get_rank()] = len(eval_dataset)
    dist.all_gather_object(train_sizes, train_sizes[dist.get_rank()])
    dist.all_gather_object(eval_sizes, eval_sizes[dist.get_rank()])
    if rank == 0:
        logger.info(f"Training dataset sizes across ranks: {train_sizes}")
        logger.info(f"Evaluation dataset sizes across ranks: {eval_sizes}")
        if len(set(train_sizes)) > 1 or len(set(eval_sizes)) > 1:
            logger.warning("Inconsistent dataset sizes across ranks detected!")
```

This ensures that all ranks have consistent dataset sizes, which can indirectly affect training stability.

#### 6. **Clear CUDA Cache Before Training**
The error might be caused by residual GPU memory states from previous runs. Clearing the CUDA cache before training can help.

**Fix**: Add a call to `clear_device_cache` before model initialization.

**Modified Code** (at the start of the `train` function):

```python
def train(args, device):
    clear_device_cache()  # Clear CUDA cache at the start
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    logger.info(f"Distributed training initialized. Rank: {rank}, World size: {world_size}, Backend: {backend}")
```

#### 7. **Disable Parameter Flattening (Optional)**
The `flatten_parameters` argument is set to `True`, which can sometimes cause issues with FSDP and LoRA. Try disabling it to see if it resolves the issue.

**Fix**: Set `flatten_parameters=False` in the argument parser.

**Modified Code** (in the argument parser):

```python
parser.add_argument("--flatten_parameters", action="store_true", default=False, 
                    help="Enable parameter flattening for FSDP (improves performance)")
```

Then, comment out or remove the `flatten_parameters` logic in the auto-wrap policy:

```python
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer}
)
```

#### 8. **Update SageMaker Environment Variables**
Ensure that the SageMaker environment variables (`SM_HOSTS`, `SM_CURRENT_HOST`, `SM_MODEL_DIR`, etc.) are correctly set in your SageMaker training job configuration. Misconfigured environment variables can lead to issues with distributed training.

**Fix**: Verify your SageMaker training script configuration. For example, ensure that the `SM_NUM_GPUS` and `SM_DEFAULT_BUCKET` environment variables are set correctly in your SageMaker estimator:

```python
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point="finetune_llama.py",
    role="SageMakerRole",
    instance_count=2,  # Number of instances
    instance_type="ml.p4d.24xlarge",  # Instance with multiple GPUs
    framework_version="2.0.1",
    py_version="py310",
    distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
    hyperparameters={
        "batch-size": 25,
        "epochs": 7,
        "lr": 3e-4,
        "seed": 1,
        "project-name": "vallr-phoneme-llama",
        "sharding_strategy": "FULL_SHARD",
        "cpu_offload": True,
        "forward_prefetch": True,
        "backward_prefetch": "BACKWARD_PRE",
        "activation_checkpointing": True,
        "min_params_to_wrap": 10000000
    },
    environment={
        "SM_DEFAULT_BUCKET": "your-s3-bucket",
        "HF_TOKEN": "your-hf-token"
    }
)
estimator.fit({"training": "s3://your-s3-bucket/data/"})
```

### Final Updated Code Snippet
Here’s the consolidated updated `train` function incorporating the key fixes:

```python
def train(args, device):
    clear_device_cache()  # Clear CUDA cache
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    logger.info(f"Distributed training initialized. Rank: {rank}, World size: {world_size}, Backend: {backend}")

    if not dist.is_initialized():
        raise RuntimeError("Distributed process group not initialized correctly")

    torch.manual_seed(args.seed + rank)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed + rank)
        torch.cuda.set_device(local_rank)

    # Load tokenizer
    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        token=os.environ.get("HF_TOKEN"),
        use_fast=True
    )
    tokenizer.pad_token = tokenizer.eos_token

    # Load model configuration
    config = AutoConfig.from_pretrained(model_name)
    config.use_cache = False

    # Determine compute dtype
    compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    logger.info(f"Using {compute_dtype} as compute dtype")

    # Load model on CPU
    logger.info(f"Loading base model: {model_name}")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        config=config,
        torch_dtype=compute_dtype,
        token=os.environ.get("HF_TOKEN"),
        device_map=None
    )

    # Apply LoRA
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
        inference_mode=False
    )
    logger.info("Applying LoRA adapters to model")
    model = get_peft_model(model, lora_config)
    if rank == 0:
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logger.info(f"Total parameters: {total_params}, Trainable parameters: {trainable_params}")

    # Synchronize model parameters
    if dist.is_initialized():
        logger.info(f"Synchronizing model parameters across processes (rank {rank})")
        for param in model.parameters():
            dist.broadcast(param.data, src=0)
        dist.barrier()
        logger.info(f"Model parameters synchronized on rank {rank}")

    # Move model to GPU
    device_id = local_rank
    model = model.to(f"cuda:{device_id}")
    logger.info(f"Moved model to device cuda:{device_id}")

    # FSDP configuration
    bf16_ready = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    mixed_precision_policy = MixedPrecision(
        param_dtype=torch.bfloat16 if bf16_ready else torch.float16,
        reduce_dtype=torch.bfloat16 if bf16_ready else torch.float16,
        buffer_dtype=torch.bfloat16 if bf16_ready else torch.float16
    )
    cpu_offload = CPUOffload(offload_params=args.cpu_offload)
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={LlamaDecoderLayer}
    )
    sharding_strategy = ShardingStrategy.FULL_SHARD if args.sharding_strategy == "FULL_SHARD" else ShardingStrategy.SHARD_GRAD_OP
    backward_prefetch = {
        "BACKWARD_PRE": BackwardPrefetch.BACKWARD_PRE,
        "BACKWARD_POST": BackwardPrefetch.BACKWARD_POST,
        "NONE": None
    }.get(args.backward_prefetch, None)

    fsdp_kwargs = {
        "auto_wrap_policy": auto_wrap_policy,
        "mixed_precision": mixed_precision_policy,
        "sharding_strategy": sharding_strategy,
        "limit_all_gathers": True,
        "cpu_offload": cpu_offload,
        "device_id": torch.cuda.current_device(),
        "use_orig_params": False
    }
    if backward_prefetch is not None:
        fsdp_kwargs["backward_prefetch"] = backward_prefetch
    if args.forward_prefetch:
        fsdp_kwargs["forward_prefetch"] = True

    model = FSDP(model, **fsdp_kwargs)
    logger.info(f"Created FSDP model with configuration: {fsdp_kwargs}")

    # Apply activation checkpointing
    if args.activation_checkpointing:
        non_reentrant_wrapper = functools.partial(
            checkpoint_wrapper,
            offload_to_cpu=True,
            checkpoint_impl=CheckpointImpl.NO_REENTRANT
        )
        check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
        apply_activation_checkpointing(
            model,
            checkpoint_wrapper_fn=non_reentrant_wrapper,
            check_fn=check_fn
        )
        logger.info(f"Applied activation checkpointing to {model.__class__.__name__}")

    # Log memory usage
    if rank == 0:
        gpu_memory_allocated = torch.cuda.memory_allocated() / (1024 ** 3)
        gpu_memory_reserved = torch.cuda.memory_reserved() / (1024 ** 3)
        logger.info(f"GPU memory allocated: {gpu_memory_allocated:.2f} GB")
        logger.info(f"GPU memory reserved: {gpu_memory_reserved:.2f} GB")

    # Load and shard dataset
    dataset = PhonemeDataset(args.train_data_dir, output_file="phoneme_sentence_pairs.json")
    if len(dataset) == 0:
        raise ValueError("No valid data found in dataset")
    hf_dataset = HFDataset.from_list(dataset.data)
    tokenized_dataset = hf_dataset.map(
        tokenize_and_add_labels,
        batched=True,
        remove_columns=["phonemes", "text", "input_text"]
    )
    train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
    train_dataset = train_test_split["train"]
    eval_dataset = train_test_split["test"]

    if dist.is_initialized() and dist.get_world_size() > 1:
        train_dataset = train_dataset.shuffle(seed=args.seed)
        eval_dataset = eval_dataset.shuffle(seed=args.seed)
        train_dataset = train_dataset.shard(
            num_shards=dist.get_world_size(),
            index=dist.get_rank()
        )
        eval_dataset = eval_dataset.shard(
            num_shards=dist.get_world_size(),
            index=dist.get_rank()
        )
        logger.info(f"Rank {dist.get_rank()} has {len(train_dataset)} training samples and {len(eval_dataset)} eval samples")
        dist.barrier()
        train_sizes = [0] * dist.get_world_size()
        eval_sizes = [0] * dist.get_world_size()
        train_sizes[dist.get_rank()] = len(train_dataset)
        eval_sizes[dist.get_rank()] = len(eval_dataset)
        dist.all_gather_object(train_sizes, train_sizes[dist.get_rank()])
        dist.all_gather_object(eval_sizes, eval_sizes[dist.get_rank()])
        if rank == 0:
            logger.info(f"Training dataset sizes across ranks: {train_sizes}")
            logger.info(f"Evaluation dataset sizes across ranks: {eval_sizes}")
            if len(set(train_sizes)) > 1 or len(set(eval_sizes)) > 1:
                logger.warning("Inconsistent dataset sizes across ranks detected!")

    # Rest of the training code remains unchanged...
```

### Additional Debugging Steps
If the above fixes don’t resolve the issue, try these debugging steps:

1. **Log Parameter Sizes**: Add logging to print the size of each parameter before and after synchronization to identify which parameter is causing the mismatch.

   ```python
   if rank == 0:
       for name, param in model.named_parameters():
           logger.info(f"Parameter {name}: size {list(param.size())}, dtype {param.dtype}")
   dist.barrier()
   ```

2. **Run on a Single GPU**: Temporarily disable distributed training by setting `instance_count=1` in your SageMaker estimator to rule out distributed training issues.

3. **Check SageMaker Logs**: Review the full SageMaker training logs for any additional errors or warnings related to model loading, LoRA application, or FSDP initialization.

4. **Update Dependencies**: Ensure that your PyTorch, Transformers, and PEFT versions are compatible. For example, use:
   - PyTorch 2.0.1 or later
   - Transformers 4.36.0 or later
   - PEFT 0.7.0 or later

   Update your SageMaker estimator to use a compatible framework version:

   ```python
   estimator = PyTorch(
       framework_version="2.0.1",
       py_version="py310",
       ...
   )
   ```

### Testing the Fixes
1. Update your `finetune_llama.py` script with the changes above.
2. Redeploy the SageMaker training job with the updated script and verify the environment variables.
3. Monitor the logs for parameter counts, dataset sizes, and memory usage to ensure consistency across ranks.
4. If the error persists, check the logged parameter sizes to pinpoint the mismatched parameter.

### Expected Outcome
These changes should resolve the parameter mismatch error by ensuring consistent model initialization, proper synchronization, and correct FSDP configuration. The model should train successfully across all ranks, and the logs will confirm consistent parameter counts and dataset sizes.

If you encounter further issues or need help with specific log outputs, please share the relevant logs, and I can provide more targeted assistance.

https://github.com/aws/deep-learning-containers/blob/c54da3b9246fe487d2f898f7f0042f25c84ddedf/available_images.md
https://github.com/aws/deep-learning-containers/tree/c54da3b9246fe487d2f898f7f0042f25c84ddedf
https://github.com/aws/deep-learning-containers/blob/c54da3b9246fe487d2f898f7f0042f25c84ddedf/pytorch/training/docker/2.5/py3/cu124/Dockerfile.gpu