# Training a Llama 3.1 8B Model on Long Context Length (PubMed) Dataset with Context Parallelism and FP8 enabled

## Introduction

In this notebook, we'll use the Pubmed dataset to demonstrate how to train a Llama model by enabling long context length distributed training using Amazon SageMaker. We'll compare two approaches: one with context parallelism enabled, and another without it. This comparison will highlight the importance of context parallelism when working with large language models and datasets with long sequences.

We also do a comparitive run on p5.48xlarge instances with Context parallelism enabled, but with both FP8 enabled and disabled. This is to demonstrate the incremental throughput benefits we can get by enabling FP8 based training along side context parallelism

You can either launch this notebook from an Amazon SageMaker notebook instance which handles all credentials automatically,
or by running it locally and setting credentials manually.

The notebook is accompanied by the following files:
- `train.py`: The entry point script that'll be passed to the SageMaker PyTorch estimator later in this notebook when launching the training job.
- `arguments.py`: This file has functions for argument parsing (i.e. hyperparameters).
- `checkpoints.py`: This file has functions for saving and loading checkpoints.
- `data_utils`: This file has functions for handling S3 URLs.
- `data`: This directory has scripts for preparing and loading data.
- `fsdp_utils.py`: This file has util functions for fully sharded data parallelism.
- `learning_rates.py`: This file has functions for learning rate schedule.
- `logging_utils.py`: This file has functions to handle logging.
- `memory_tracker.py`: This file has functions to track memory usage.
- `requirements.txt`: This file installs the dependencies, including HuggingFace transformers.
- `train_lib.py`: This file has functions for running an end-to-end training of the GPT-NeoX or Llama-v2 model with SMP FSDP, settings for hybrid sharding applied, and implemented with code lines to save, load, and fine-tune the model.
- `train_utils.py`: This file has utility functions for training.

## Additional Resources
- To learn more about launching a multi-node distributed PyTorch training job, see [Launching a Distributed Training Job](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#launching-a-distributed-training-job).
- To learn more about using the SageMaker Python SDK with PyTorch, see [Using PyTorch with the SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html).
- To learn more about launching a training job in Amazon SageMaker with your own training image, see [Use Your Own Training Algorithms](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html).

## Prerequisites
You need to create an `S3` bucket to store the input data for training.
This bucket must be located in the same AWS Region that you choose to launch your training job. To learn how to create a `S3` bucket,
see [Create your first S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html) in the Amazon S3 documentation.

## Launching Environment

### Amazon SageMaker Notebook
You can run the notebook with an Amazon SageMaker notebook instance without manually setting your aws credentials.

1. Create a new SageMaker notebook instance and open it.
2. Zip the contents of this folder & upload to the instance with the `Upload` button on the top-right.
3. Open a new terminal with `New -> Terminal`.
4. Within the terminal, enter the correct directory and unzip the file.
    1. `cd SageMaker && unzip <your-zip-name-here>.zip`

### Locally
You can run locally by launching a Jupyter notebook server with `jupyter notebook`.
This requires you to set your aws credentials in the environment manually.
See [Configure the AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html) for more details.

## Amazon SageMaker Initialization
Run the following cell to import SageMaker modules and retrieve information of your current SageMaker work environment,
such as your AWS account ID, the AWS Region, and the ARN of your Amazon SageMaker execution role.
Upgrade SageMaker SDK to the latest version.

**NOTE:** This step might require a kernel restart.

## Setup and Imports

First, we'll set up our environment and import the necessary libraries.

In [None]:
%pip install --upgrade "sagemaker>=2.233"
%pip install sagemaker-experiments
%pip install "datasets==2.14.5"
%pip install transformers

import os
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch
import threading


## AWS Configuration

Next, we'll configure our AWS environment and SageMaker session.

**NOTE**

You will need a Huggingface token with access to Llama-3.1-8B-Instruct for this notebook to work

In [None]:
hf_token = ""

assert hf_token != "", "Please create a Huggingface token and include it above"

In [None]:
role = get_execution_role()
print(f"SageMaker Execution Role: {role}")

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]
print(f"AWS account: `{account}`.")

# Explicitly set region to us-west-2
session = boto3.session.Session(region_name='us-west-2')
region = session.region_name
print(f"AWS region: `{region}`.")

sm_boto_client = boto3.client("sagemaker", region_name='us-west-2')  # Also set region here
sagemaker_session = sagemaker.session.Session(boto_session=session)
default_bucket = sagemaker_session.default_bucket()
print(f"\nDefault bucket for this session: `{default_bucket}`.")

Here, we'll load and prepare the PubMed dataset for our experiment.

## Pubmed Data Set Attribution

Pubmed Dataset is "Courtesy of the U.S. National Library of Medicine". This does not indicate the NLM endorses this product or any product AWS builds 

## PubMed Scientific Papers Dataset Overview
The PubMed Scientific Papers dataset is a collection of scientific articles from the biomedical domain, specifically designed for document summarization tasks. Unlike the PubMed abstracts dataset, this contains full scientific papers.

## Key Dataset Characteristics
- **Size**: ~133,215 articles
- **Average Length**: ~6,000 tokens
- **Maximum Length**: Can reach 10,000+ words
- **Content Type**: Full scientific papers with abstract, introduction, methods, results, and discussion sections

## Why It's Ideal for Context Parallelism Demo

### 1. Document Length
```python
# Example token counts
average_tokens_per_paper = 6000
max_tokens = 16384  # Model's context window

# Without CP (single GPU):
tokens_per_gpu = 16384

# With CP (8 GPUs):
tokens_per_gpu = 16384/8  # ~2048 tokens per GPU
```

### 2. Memory Benefits
- **Without CP**: Each GPU handles full 16K sequence
- **With CP (degree=8)**:
  - Sequences split across GPUs
  - Each GPU processes ~2K tokens
  - 8x reduction in memory per GPU

### 3. Natural Structure
- Papers have clear sections
- Logical sequence breaks
- Each GPU can process semantically coherent chunks
- Scientific content benefits from maintaining long-range context

### 4. Training Characteristics
- Specialized vocabulary
- Complex technical content
- Long-range dependencies
- High information density requiring context preservation

### 5. Practical Applications
- Scientific document understanding
- Long-form content processing
- Real-world sequence lengths
- Demonstrates CP benefits on actual research content

## Memory Requirements Example
```python
# Assuming hidden_size = 4096
hidden_size = 4096
seq_len = 16384

# Memory per token (simplified)
bytes_per_token = 8192 bytes # 2 * hidden_size for BF16

# Without CP
memory_per_gpu_no_cp = seq_len * bytes_per_token
print(f"Memory per GPU (No CP): {134.22} GB")

# With CP (degree=8)
memory_per_gpu_cp = (seq_len/8) * bytes_per_token
print(f"Memory per GPU (CP=8): {16.78} GB")
```

This makes the dataset ideal for demonstrating context parallelism's benefits in handling real scientific documents while showing tangible memory savings and processing efficiency.

In [None]:
import datasets
from datasets import load_dataset, DatasetDict

# Load the PubMed dataset
pubmed_dataset = load_dataset(
    "scientific_papers",
    "pubmed",
    cache_dir="/home/ec2-user/SageMaker/datasets",
    download_mode="force_redownload"
)

# Create a smaller subset of the dataset for our experiment
train_test = pubmed_dataset['train'].shuffle(seed=42).select(range(1000)).train_test_split(
    test_size=0.2,
    seed=42
)

dataset = DatasetDict({
    'train': train_test['train'],
    'validation': train_test['test']
})

print(dataset)

## Tokenization

Now we'll tokenize our dataset using the Llama tokenizer.

In [None]:
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"


from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    use_fast=True,
    use_auth_token=hf_token,
    cache_dir = "tmp"
)

tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(examples['article'])

tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    desc="Running tokenizer on dataset",
)

block_size = 16384

def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    desc=f"Grouping texts in chunks of {block_size}",
)

## Preparing Data Channels

We'll now prepare the data channels for our SageMaker training job.

In [None]:
import boto3
s3_client = boto3.client('s3')

if lm_datasets["train"] is not None:
   train_dataset = lm_datasets["train"]
   train_dataset.to_json("./training.json")
   training_dataset_location = f"s3://{default_bucket}/dataset/train/"
   # Extract bucket and key from S3 URI
   bucket = default_bucket
   key = "dataset/train/training.json"
   # Upload file using boto3
   s3_client.upload_file("./training.json", bucket, key)
   # Remove local file
   import os
   os.remove("./training.json")

if lm_datasets["validation"] is not None:
   eval_dataset = lm_datasets["validation"]
   eval_dataset.to_json("./validation.json")
   validation_dataset_location = f"s3://{default_bucket}/dataset/validation/"
   # Extract bucket and key
   bucket = default_bucket  
   key = "dataset/validation/validation.json"
   # Upload file using boto3
   s3_client.upload_file("./validation.json", bucket, key)
   # Remove local file
   os.remove("./validation.json")

%store training_dataset_location
%store validation_dataset_location

s3_train_bucket = training_dataset_location
s3_test_bucket = validation_dataset_location
s3_output_bucket = f"s3://sagemaker-{region}-{account}/smp-fsdp-tp/outputdir/"

if s3_train_bucket != None:
   train = sagemaker.inputs.TrainingInput(
       s3_train_bucket, distribution="FullyReplicated", s3_data_type="S3Prefix"
   )
   data_channels = {"train": train}

if s3_test_bucket != None:
   test = sagemaker.inputs.TrainingInput(
       s3_test_bucket, distribution="FullyReplicated", s3_data_type="S3Prefix"
   )
   data_channels["test"] = test

print(data_channels)

## Launch Job without blocking the notebook

In [None]:
def launch_job():
    smp_estimator.fit(inputs=data_channels)

## Setting Up Training Parameters

Here we define the hyperparameters and configuration for our training jobs.

In [None]:
import copy

# Parallelism settings
context_parallel_degree = 8
tensor_parallel_degree = 2  
hybrid_shard_degree = 8  
save_steps = 10  
max_steps = 15  
offload_activations = True

hyperparameters = {
    # Memory and optimization settings
    "activation_checkpointing": 1,
    "auto_wrap_policy": "transformer_auto_wrap_policy",
    "backward_fetch_policy": "backward_pre",
    "clean_cache": 1,
    "delayed_param": 1,
    "enable_memory_profiling": 1,
    
    # Training parameters
    "beta1": 0.9,
    "beta2": 0.95,
    "bf16": 1,
    "epochs": 100,
    "lr": 0.0001,
    "lr_decay_iters": 47683,
    "lr_decay_style": "cosine",
    "min_lr": 1e-05,
    "warmup": 0.0032,
    "weight_decay": 0.2,
    
    # Training settings
    "train_batch_size": 1,
    "val_batch_size": 1,
    "do_train": True,
    "do_eval": False,
    "validation_freq": 5000,
    "validation_batches": -1,
    "fast_validation": 0,
    "max_steps": max_steps,
    "logging_freq": 1,
    
    # Checkpoint settings
    "checkpoint_dir": "/opt/ml/checkpoints",
    "checkpoint_freq": save_steps,
    "num_kept_checkpoints": 2,
    
    # Model configuration
    "model_type": "llama_v2",
    "vocab_size": 128256,  # Vocab size from Llama 3.1 config file on hugginface
    "num_heads": 32,
    "num_layers": 32,
    "intermediate_size": 14336,
    "hidden_width": 4096,
    "num_key_value_heads": 8,
    "llama_intermediate_size": 14336,
    "hf_pretrained_model_name_or_dir": model_id,
    
    # Performance optimization
    "fast_validation": 0,
    "forward_prefetch": 1,
    "fp8": 0,
    "limit_all_gathers": 1,
    "plateau": 0.0,
    "seed": 12345,
    "sharding_strategy": "hybrid_shard",
    "use_smp_flash_attn": 0,
    "use_smp_implementation": 1,
    "validation_freq": save_steps,
    "zipped_data": 0
}

metric_definitions = [
    {"Name": "base_metric", "Regex": "<><><><><><>"}
]

original_hyperparameters = copy.deepcopy(hyperparameters)

## Training without Context Parallelism (P4D.24xLarge)

Now, we'll attempt to run the training job without context parallelism to demonstrate why it's necessary for this use case.

In [None]:
hyperparameters = copy.deepcopy(original_hyperparameters)

# Instance Settings

instance_type = "ml.p4d.24xlarge"
instance_count = 1
processes_per_host = 8
instance_type_str = instance_type.split(".")[1] + instance_type.split(".")[2][:3]
base_job_name = f'smp-8b-NON-CP-NOFP8-1000-{instance_type_str}-hs{hybrid_shard_degree}-tp{tensor_parallel_degree}-cp{context_parallel_degree}-ao{offload_activations}-bs{hyperparameters["train_batch_size"]:02d}'
print(f"Base job name: `{base_job_name}`.")
checkpoint_bucket = f"s3://sagemaker-{region}-{account}"
checkpoint_s3_uri = f"{checkpoint_bucket}/experiments/smp-fsdp-tp-llama_v2-checkpoints/{base_job_name}/"

# Parallelism settings
hybrid_shard_degree = 8  

hyperparameters.update({
    "use_smp_implementation": 0,  # Disable SMP/CP
    "max_context_width": 16384,   # Full sequence length
    "activation_checkpointing": 1, # Disable activation checkpointing
    "clean_cache": 0,
    "bf16": 1,                    # Use BF16
    "offload_activations": False,
    "use_smp_flash_attn": 0,       # Disable flash attention
    "train_batch_size": 1
})

smp_estimator = PyTorch(
    entry_point="train.py",
    hyperparameters=hyperparameters,
    source_dir=os.path.join(os.getcwd(), "./shared-scripts"),
    role=role,
    checkpoint_s3_uri=checkpoint_s3_uri,
    instance_type=instance_type,
    volume_size=400,
    instance_count=instance_count,
    sagemaker_session=sagemaker_session,
    distribution={
        "torch_distributed": {
            "enabled": True,
        },
        "smdistributed": {
            "modelparallel": {
                "enabled": True,  # Enable model parallelism but with minimal parameters
                "parameters": {
                    "delayed_parameter_initialization": True,
                    "hybrid_shard_degree": hybrid_shard_degree
                }
            }
        }
    },
    py_version="py311",
    framework_version="2.4.1",
    output_path=s3_output_bucket,
    max_run=86400,
    debugger_hook_config=False,
    base_job_name=base_job_name,
    metric_definitions=metric_definitions,
    keep_alive_period_in_seconds=1800,
    environment={
        "HF_TOKEN": hf_token,
        "NCCL_DEBUG": "INFO",
        "NCCL_MIN_NRINGS": "1",
        "NCCL_IB_TIMEOUT": "22"
    }, 
    wait=False
)

### Launch Non Context Parallelism Job

In [None]:
thread = threading.Thread(target=launch_job)
thread.daemon = True  # Allow the thread to be terminated when notebook closes
thread.start()
print(f"Job launched: {base_job_name}")

## Training with Context Parallelism (P4D.24XLarge)

In this section, we set up and run the training job with context parallelism enabled. This configuration should successfully handle the large model and long sequences.

In [None]:
hyperparameters = copy.deepcopy(original_hyperparameters)

# Instance Settings
instance_type = "ml.p4d.24xlarge"
instance_count = 1
processes_per_host = 8
instance_type_str = instance_type.split(".")[1] + instance_type.split(".")[2][:3]

instance_type_str = instance_type.split(".")[1] + instance_type.split(".")[2][:3]
base_job_name = f'smp-8b-CP-NOFP8-1000-{instance_type_str}-hs{hybrid_shard_degree}-tp{tensor_parallel_degree}-cp{context_parallel_degree}-ao{offload_activations}-bs{hyperparameters["train_batch_size"]:02d}'
print(f"Base job name: `{base_job_name}`.")

# Parallelism settings
context_parallel_degree = 8
hybrid_shard_degree = 8  

hyperparameters.update({
    "use_smp_implementation": 1,  # Enable SMP/CP
    "max_context_width": 16384,   # Full sequence length
    "activation_checkpointing": 1, # Disable activation checkpointing
    "clean_cache": 0,
    "bf16": 1,                    # Use BF16 
    "offload_activations": False,
    "use_smp_flash_attn": 0,       # Disable flash attention
    "train_batch_size": 1
})

checkpoint_bucket = f"s3://sagemaker-{region}-{account}"
checkpoint_s3_uri = f"{checkpoint_bucket}/experiments/smp-fsdp-tp-llama_v2-checkpoints/{base_job_name}/"

smp_estimator = PyTorch(
    entry_point="train.py",
    hyperparameters=hyperparameters,
    source_dir=os.path.join(os.getcwd(), "./shared-scripts"),
    role=role,
    checkpoint_s3_uri=checkpoint_s3_uri,
    instance_type=instance_type,
    volume_size=400,
    instance_count=instance_count,
    sagemaker_session=sagemaker_session,
    distribution={
        "torch_distributed": {
            "enabled": True,
        },
        "smdistributed": {
            "modelparallel": {
                "enabled": True,
                "parameters": {
                    "context_parallel_degree": context_parallel_degree,
                    "hybrid_shard_degree": hybrid_shard_degree,
                    "delayed_parameter_initialization": True,
                }
            }
        }
    },
    py_version="py311",
    framework_version="2.4.1",
    output_path=s3_output_bucket,
    max_run=86400,
    debugger_hook_config=False,
    base_job_name=base_job_name,
    metric_definitions=metric_definitions,
    keep_alive_period_in_seconds=1800,
    environment={
        "HF_TOKEN": hf_token,
        "NCCL_DEBUG": "INFO",
        "NCCL_MIN_NRINGS": "1",
        "NCCL_IB_TIMEOUT": "22"
    }, 
    wait = False
)

### Launch Context Parallelism Enabled Job

In [None]:
thread = threading.Thread(target=launch_job)
thread.daemon = True  # Allow the thread to be terminated when notebook closes
thread.start()
print(f"Job launched: {base_job_name}")

## Showcasing the value of FP8 enabled training using Context Parallelism

FP8, a datatype supported by NVIDIA's H100 and H200 GPUs, has revolutionized deep learning workloads with its remarkable efficiency. This innovative format occupies a mere 8 bits of memory, half that of its bf16 or fp16 counterparts, significantly reducing computational costs for operations like matrix multiplication.

The next few sections are going to show you how we can increase the speed of our training using FP8 enabled Sagemaker Training Jobs. We will compare two jobs for their epoch speed, both with context parallelism enabled but one with and another without FP8 enabled. Both of these will be run on a P5.48xlarge instance

## Training with Context Parallelism WITH FP8 for increased training throughput (P5.48xlarge)

Modern training of large language models (LLMs) combines two key optimizations:
- **Context Parallelism (CP)**: Distributes long sequences across GPUs
- **FP8 Training**: Uses 8-bit precision for computations and activations


In [None]:
import copy
hyperparameters = copy.deepcopy(original_hyperparameters)


hyperparameters.update({
    "use_smp_implementation": 1,  # Enable SMP/CP
    "train_batch_size": 4,  # Train Batch Size 4
    "max_context_width": 16384,   # Full sequence length
    "activation_checkpointing": 1,  # Disable activation checkpointing
    "clean_cache": 0,
    "offload_activations": False,
    "use_smp_flash_attn": 0,       # Disable flash attention
    "fp8": 1,  # Enable FP8 flag
    "distributed_backend": "nccl"  # Add this line to explicitly use NCCL

})

# Instance Settings
instance_type = "ml.p5.48xlarge"
instance_count = 1
processes_per_host = 8
instance_type_str = instance_type.split(".")[1] + instance_type.split(".")[2][:3]
base_job_name = f'smp-8b-CP-WITH-FP8-TBS-4-1000-{instance_type_str}-hs{hybrid_shard_degree}-tp{tensor_parallel_degree}-cp{context_parallel_degree}-ao{offload_activations}-bs{hyperparameters["train_batch_size"]:02d}'
print(f"Base job name: `{base_job_name}`.")

# Parallelism settings
context_parallel_degree = 8
hybrid_shard_degree = 8  

checkpoint_bucket = f"s3://sagemaker-{region}-{account}"
checkpoint_s3_uri = f"{checkpoint_bucket}/experiments/smp-fsdp-tp-llama_v2-checkpoints/{base_job_name}/"

smp_estimator = PyTorch(
    entry_point="train.py",
    hyperparameters=hyperparameters,
    source_dir=os.path.join(os.getcwd(), "./shared-scripts"),
    role=role,
    checkpoint_s3_uri=checkpoint_s3_uri,
    instance_type=instance_type,
    volume_size=400,
    instance_count=instance_count,
    sagemaker_session=sagemaker_session,
    distribution={
        "torch_distributed": {
            "enabled": True,
        },
        "smdistributed": {
            "modelparallel": {
                "enabled": True,
                "parameters": {
                    "context_parallel_degree": context_parallel_degree,
                    "hybrid_shard_degree": hybrid_shard_degree,
                    "delayed_parameter_initialization": True,
                }
            }
        }
    },
    py_version="py311",
    framework_version="2.4.1",
    output_path=s3_output_bucket,
    max_run=86400,
    debugger_hook_config=False,
    base_job_name=base_job_name,
    metric_definitions=metric_definitions,
    keep_alive_period_in_seconds=1800,
    environment={
        "HF_TOKEN": hf_token,
        "NCCL_DEBUG": "INFO",
        "NCCL_MIN_NRINGS": "1",
        "NCCL_IB_TIMEOUT": "22"
    }, 
    wait = False
)


### Launch Context Parallelism with FP8 Enabled Job (ON P5)

In [None]:
thread = threading.Thread(target=launch_job)
thread.daemon = True  # Allow the thread to be terminated when notebook closes
thread.start()
print(f"Job launched: {base_job_name}")

## Training with Context Parallelism WITHOUT FP8 (P5.48xlarge)

In [None]:
hyperparameters = copy.deepcopy(original_hyperparameters)


hyperparameters.update({
    "use_smp_implementation": 1,  # Enable SMP/CP
    "train_batch_size": 4,  # Train Batch Size 4
    "max_context_width": 16384,   # Full sequence length
    "activation_checkpointing": 1, # Disable activation checkpointing
    "clean_cache": 0,
    "bf16": 1,                    # Use BF16 
    "offload_activations": False,
    "use_smp_flash_attn": 0,       # Disable flash attention
    "distributed_backend": "nccl"  # Add this line to explicitly use NCCL

})

# Parallelism settings
context_parallel_degree = 8
hybrid_shard_degree = 8  

# Instance Settings
instance_type = "ml.p5.48xlarge"
instance_count = 1
processes_per_host = 8

instance_type_str = instance_type.split(".")[1] + instance_type.split(".")[2][:3]
base_job_name = f'smp-8b-CP-NOFP8-TBS-4-1000-{instance_type_str}-hs{hybrid_shard_degree}-tp{tensor_parallel_degree}-cp{context_parallel_degree}-ao{offload_activations}-bs{hyperparameters["train_batch_size"]:02d}'
print(f"Base job name: `{base_job_name}`.")

checkpoint_bucket = f"s3://sagemaker-{region}-{account}"
checkpoint_s3_uri = f"{checkpoint_bucket}/experiments/smp-fsdp-tp-llama_v2-checkpoints/{base_job_name}/"

smp_estimator = PyTorch(
    entry_point="train.py",
    hyperparameters=hyperparameters,
    source_dir=os.path.join(os.getcwd(), "./shared-scripts"),
    role=role,
    checkpoint_s3_uri=checkpoint_s3_uri,
    instance_type=instance_type,
    volume_size=400,
    instance_count=instance_count,
    sagemaker_session=sagemaker_session,
    distribution={
        "torch_distributed": {
            "enabled": True,
        },
        "smdistributed": {
            "modelparallel": {
                "enabled": True,
                "parameters": {
                    "context_parallel_degree": context_parallel_degree,
                    "hybrid_shard_degree": hybrid_shard_degree,
                    "delayed_parameter_initialization": True,
                }
            }
        }
    },
    py_version="py311",
    framework_version="2.4.1",
    output_path=s3_output_bucket,
    max_run=86400,
    debugger_hook_config=False,
    base_job_name=base_job_name,
    metric_definitions=metric_definitions,
    keep_alive_period_in_seconds=1800,
    environment={
        "HF_TOKEN": hf_token,
        "NCCL_DEBUG": "INFO",
        "NCCL_MIN_NRINGS": "1",
        "NCCL_IB_TIMEOUT": "22"
    }, 
    wait = False
)

### Launch Context Parallelism without FP8 Enabled Job (ON P5)

In [None]:
thread = threading.Thread(target=launch_job)
thread.daemon = True  # Allow the thread to be terminated when notebook closes
thread.start()
print(f"Job launched: {base_job_name}")

## Conclusion

In this notebook, we demonstrated the process of setting up and running training jobs for the PubMed dataset using the Llama model, both with and without context parallelism. 

Key observations:

1. With context parallelism enabled, the training job runs successfully. This is because context parallelism allows for efficient distribution of the model across multiple GPUs, reducing the memory requirements per GPU.

![With Context Parallelism](./with_cp.png)

2. Without context parallelism, the training job fails. This is likely due to memory constraints, as the large Llama model cannot fit into a single GPU's memory when processing long sequences from the PubMed dataset.

![Without Context Parallelism](./without_cp.png)


This experiment highlights the importance of techniques like context parallelism when working with large language models and datasets with long sequences. It allows us to train models that would otherwise be too large to fit in memory, enabling work with more complex models and larger datasets.

Context Parallelism with FP8 enabled vs not enabled (On P5)

1. **With FP8 Enabled** :

![With FP8](./With_FP8.png)

2. **Without FP8 Enabled**:

![Without FP8](./WO_FP8_P5.png)

If you look at the TFLOPS / GPU metric, We can clearly see that with FP8 the throughput is higher
