In [None]:
#!/usr/bin/env python
# coding: utf-8

import boto3
from datasets import Dataset, DatasetDict
import s3fs 
from sagemaker.pytorch import PyTorch
import sagemaker
import os
import transformers
import logging
from sagemaker.huggingface import HuggingFace
from transformers import (
    AutoTokenizer,
)
from huggingface_hub.hf_api import HfFolder
import functools
from transformers.testing_utils import CaptureLogger


# Set Hugging Face token
HfFolder.save_token('***********************')

# Configuration
PRETRAINED_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
# PRETRAINED_MODEL = "EleutherAI/gpt-neox-20b"
model_type = "llama_v3"  # [gpt_neox, llama_v2]
# model_type = "gpt_neox"
max_context_width = 8192  # For Llama v3 model
# max_context_width = 2048 
tokenizer_kwargs = {
    "cache_dir": "/home/ec2-user/SageMaker/tmp",
}

# Set the bucket name and S3 output prefixes
bucket_name = "b-sagemaker"
train_s3_output_prefix = "fsdp-without-LlamaFactory/hf-llama31-default/datasets/train/"
test_s3_output_prefix = "fsdp-without-LlamaFactory/hf-llama31-default/datasets/test/"

print(train_s3_output_prefix)
print(test_s3_output_prefix)

# Create data channels for SageMaker
data_channels = {}

train = sagemaker.inputs.TrainingInput(
    f"s3://{bucket_name}/{train_s3_output_prefix}", distribution="FullyReplicated", s3_data_type="S3Prefix"
)
data_channels["train"] = train

test = sagemaker.inputs.TrainingInput(
    f"s3://{bucket_name}/{test_s3_output_prefix}", distribution="FullyReplicated", s3_data_type="S3Prefix"
)
data_channels["test"] = test

# Print the data channels to verify
print('..................................................................')
print(data_channels)
print('..................................................................')

fp8 = 0  # Enable FP8 mixed precision. 1=True, 0=False.
tensor_parallel_degree = 4 # An integer in [1, world_size] 8
hybrid_shard_degree = (
    0  # An integer in [0, world_size // tensor_parallel_degree] and its default value is 0. 1
)
activation_loading_horizon = (
    2  # Activation loading horizon, a positive integer and its default value is 2.
)
save_steps = 50  # Save step interval.
max_steps = 100  # Maximum training steps.
offload_activations = True  # Activation offloading.

# s3://baykar-sagemaker/fsdp-without-LlamaFactory/hf-llama31-default/datasets/train/

hyperparameters = {
    "activation_checkpointing": 1,
    "auto_wrap_policy": "transformer_auto_wrap_policy",
    "backward_fetch_policy": "backward_pre",
    "beta1": 0.9,
    "beta2": 0.95,
    "bf16": 1,
    "checkpoint_dir": "s3://b-sagemaker/fsdp-without-LlamaFactory/hf-llama31-default/checkpoints/",
    "checkpoint_freq": save_steps,
    "num_kept_checkpoints": 2,
    "clean_cache": 0,
    "delayed_param": 1,
    "enable_memory_profiling": 0,
    "epochs": 5,
    "fast_validation": 0,
    "forward_prefetch": 1,
    "fp8": fp8,
    "limit_all_gathers": 1,
    "logging_freq": 1,
    "lr": 0.0001,
    "lr_decay_iters": 47683,
    "lr_decay_style": "cosine",
    "max_steps": max_steps,
    "min_lr": 1e-05,
    "model_type": model_type,
    "plateau": 0.0,
    "seed": 12345,
    "sharding_strategy": "hybrid_shard",
    "train_batch_size": 16,
    "use_smp_flash_attn": 1,
    "use_smp_implementation": 1,
    "val_batch_size": 4,
    "validation_freq": save_steps,
    "vocab_size": 128256,
    "warmup": 0.0032,
    "weight_decay": 0.2,
    "zipped_data": 0,
    "dataset_type": "hf",  
    "distributed_backend": "nccl",
    "model_dir": "s3://b-sagemaker/fsdp-without-LlamaFactory/hf-llama31-default/saved_model/", 
    "save_final_model": 1,
    
}

metric_definitions = [
    {"Name": "Training Loss", "Regex": ".*Training Loss: ([0-9\\.]+).*"},
    {"Name": "Validation Loss", "Regex": ".*Validation Loss: ([0-9\\.]+).*"},
    {"Name": "Learning Rate", "Regex": ".*Learning Rate: ([0-9\\.]+).*"},
    {"Name": "Gradient Norm", "Regex": ".*Gradient Norm: ([0-9\\.]+).*"},
    {"Name": "Perplexity", "Regex": ".*Perplexity: ([0-9\\.]+).*"},
    {"Name": "Accuracy", "Regex": ".*Accuracy: ([0-9\\.]+).*"},
    {"Name": "GPU Utilization", "Regex": ".*GPU Utilization: ([0-9\\.]+)%.*"},
    {"Name": "Memory Usage", "Regex": ".*Memory Usage: ([0-9\\.]+)GB.*"}
]
hyperparameters["hf_pretrained_model_name_or_dir"] = PRETRAINED_MODEL
# hyperparameters["rope_theta"] = 8.0

print("*********************1111111111111111111*****************************")
print(PRETRAINED_MODEL)
print("********************1111111111111111111111****************************************")


# Select your model size.
model_config = "7b"  # [7b, 65b]

if model_type == "gpt_neox":
    if model_config == "7b":
        model_params = {
            "max_context_width": 1024,
            "hidden_width": 4096,
            "num_layers": 32,
            "num_heads": 32,
        }
    elif model_config == "20b":
        model_params = {
            "max_context_width": 2048,
            "hidden_width": 6144,
            "num_layers": 44,
            "num_heads": 64,
        }
    elif model_config == "65b":
        model_params = {
            "max_context_width": 1024,
            "hidden_width": 8192,
            "num_layers": 80,
            "num_heads": 64,
        }
    else:
        raise RuntimeError("Unknown model config")
elif model_type == "llama_v3":
    if model_config == "7b":
        model_params = {
            "max_context_width": 8192,
            "hidden_width": 4096,
            "num_layers": 32,
            "num_heads": 32,
            "llama_intermediate_size": 14336,
        }
        
    # elif model_config == "8b":
    #     model_params = {
    #         "max_context_width": 8192,
    #         "hidden_width": 4096,
    #         "num_layers": 40, 
    #         "num_heads": 32,   
    #         "llama_intermediate_size": 14336,
    #     }        
        
    elif model_config == "65b":
        model_params = {
            "max_context_width": 4096,
            "hidden_width": 8192,
            "num_layers": 80,
            "num_heads": 64,
            "llama_intermediate_size": 22016,
        }
    else:
        raise RuntimeError("Unknown model config")

for k, v in model_params.items():
    hyperparameters[k] = v

print(hyperparameters)


instance_type = "ml.p4de.24xlarge"

# You need >= 1 p4d for 7b model.
# You need >= 8 p4d for 65b model.
instance_count = 1

# set to the number of GPUs on that instance
processes_per_host = 8

# Assuming you will run this code on your local machine
role = sagemaker.get_execution_role()
print(f"SageMaker Execution Role: {role}")

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]
print(f"AWS account: {account}")

session = boto3.session.Session()
region = session.region_name
print(f"AWS region: {region}")

sm_boto_client = boto3.client("sagemaker")
sagemaker_session = sagemaker.session.Session(boto_session=session)

# get default bucket
# default_bucket = sagemaker_session.default_bucket()
# print("Default bucket for this session: ", default_bucket)

machine_str = instance_type.split(".")[1] + instance_type.split(".")[2][:3]
# base_job_name = f'smp-{model_config}-{machine_str}-hs{hybrid_shard_degree}-ao{offload_activations}-bs{hyperparameters["train_batch_size"]}'
base_job_name = "baykar-fsdp-without-llamafactory-model-training"

print(base_job_name)

# checkpoint_s3_uri = "s3://tubitak-exp-v1/Exp1-TubitakV0-qa1m-tr-instruction-formatted/checkpoints/"
s3_output_bucket = "s3://baykar-sagemaker/fsdp-without-LlamaFactory/hf-llama31-default/saved_model/"

# # Pytorch Estimator

kwargs = {}
tags={'Key': 'CostId', 'Value': 'Baykar_NBS'}

smp_options = {
    "enabled":True,
    "parameters": {                        # Required
        "pipeline_parallel_degree": 2,     # Required
        "microbatches": 4,
        "placement_strategy": "spread",
        "pipeline": "interleaved",
        "optimize": "speed",
        "ddp": True,
    }
}

mpi_options = {
    "enabled" : True,                      # Required
    "processes_per_host" : 8,              # Required
    # "custom_mpi_options" : "--mca btl_vader_single_copy_mechanism none"
}

smp_estimator = PyTorch(
    entry_point="train.py",
    hyperparameters=hyperparameters,
    source_dir=os.path.join(os.getcwd(), "./shared-scripts"),
    role=role,
    # checkpoint_s3_uri=checkpoint_s3_uri,
    checkpoint_local_path="/opt/ml/checkpoints",
    instance_type=instance_type,
    volume_size=400,  
    instance_count=instance_count,
    sagemaker_session=sagemaker_session,
    image_uri='658645717510.dkr.ecr.us-east-1.amazonaws.com/smdistributed-modelparallel:2.4.1-gpu-py311-cu121',
    tags=tags,
#     distribution={
        
#         "smdistributed": {"modelparallel": smp_options},
#         "mpi": mpi_options
#     },
    distribution={
        "torch_distributed": {"enabled": True}, 
        "smdistributed": {
            "modelparallel": {
                "enabled": True,
                "parameters": {
                    "tensor_parallel_degree": tensor_parallel_degree,
                    "hybrid_shard_degree": hybrid_shard_degree,
                    "sm_activation_offloading": offload_activations,
                    "activation_loading_horizon": activation_loading_horizon
                    # "smp_rope_enabled": True
                },
            }
        },
    },
  
    
    
    # py_version='py310',
    # framework_version="2.4.0",
    # image_uri=$IMAGE,  # Either provide `framework_version` or `image_uri`
    output_path=s3_output_bucket,
    max_run=86400,
    debugger_hook_config=False,
    base_job_name=base_job_name,
    keep_alive_period_in_seconds=3600,
    metric_definitions=metric_definitions,
    **kwargs,
)

# # Make sure data channels are ready

print(data_channels)

import sys
from contextlib import redirect_stdout, redirect_stderr

# Specify the file to save the logs
log_file_path = "exp-v2.txt"


smp_estimator.fit(inputs=data_channels)
