## SageMaker training with FSx

If you are training on a multinode cluster its recommended to use FSx for Lustre for storing and retrieving data / checkpoints.This sample shows how to:

- Setup FSx
- Associate data in S3 with FSx
- Tear down the infrastructure

**Please make sure the CIDR block in setup/cfn-nlp.yaml does not conflict with your existing VPC. You can also change FSx storage (currently set at 1.2 TB) depending on your data sets**

In [None]:
# Imports
import time
import boto3

# Inputs
region = "us-west-2"  # update this if your region is different
region_az = "us-west-2d"  # customize this as needed. Your FSx will be set up in a subnet in this AZ
cfn_stack_name = 'fsx-training-trn-stack'  # cloudformation stack name


# Clients
cfn_client = boto3.client("cloudformation", region_name=region)
fsx_client = boto3.client("fsx", region_name=region)


s3_data_bucket = 's3://bucket_name'
s3_data_train_prefix = 'dataset_nemo'  # s3 training data set

s3_data_model_prefix = 'model_dir' # s3 path to save model
s3_data_checkpoint_prefix = 'checkpoint_dir'  # s3 path to save model checkpoints
fsx_file_system_path = 'fsxneuron'  # this is file system path on FSx for the data, can be anything

In [None]:
# Setup infrastructure using CloudFormation
with open("setup/cfn-nlp.yaml", "r") as f:
    template_body = f.read()
    
create_stack_response = cfn_client.create_stack(
    StackName=cfn_stack_name,
    TemplateBody=template_body,
    Parameters=[
        {
            'ParameterKey': 'AZ',
            'ParameterValue': region_az
        }
    ]
)

create_stack_response

In [None]:
# Wait for stack to be created, it takes ~10 minutes to complete.
stack_id = create_stack_response['StackId']

while True:
    response = cfn_client.describe_stacks(
        StackName=stack_id
    )
    status = response['Stacks'][0]['StackStatus']
    if status== "CREATE_IN_PROGRESS":
        print("Create in progress. Waiting..")
        time.sleep(30)
    elif status=="CREATE_COMPLETE":
        print("Stack created!")
        break
    else:
        print("Error creating stack - check the CFN console")
        break

In [None]:
# Get stack outputs
describe_response = cfn_client.describe_stacks(
    StackName=stack_id
)

outputs = describe_response['Stacks'][0]['Outputs']

for output in outputs:
    if output['OutputKey'] == 'sg':
        sec_group = output['OutputValue']
    elif output['OutputKey'] == 'outputfsx':
        fsx_id = output['OutputValue']
    elif output['OutputKey'] == 'privatesubnet':
        private_subnet_id = output['OutputValue']
        
fsx_response = fsx_client.describe_file_systems(
    FileSystemIds=[fsx_id]
)

fsx_mount = fsx_response['FileSystems'][0]['LustreConfiguration']['MountName']

print("FSx ID:", fsx_id)
print("Security Group ID:", sec_group)
print("Private Subnet ID:", private_subnet_id)
print("FSx Mount path:", fsx_mount)

### Store fsx details

In [None]:
%store fsx_id
%store sec_group
%store private_subnet_id
%store fsx_mount
%store fsx_file_system_path


### Create Data Repository Association (DRA) for S3 access.

In [None]:
# Create a data repository association with S3 to load data
# and persist changes back to S3 to save training artifacts

fsx_s3_response = fsx_client.create_data_repository_association(
    FileSystemId=fsx_id,
    FileSystemPath=f"/{fsx_file_system_path}",
    DataRepositoryPath=s3_data_bucket,
    BatchImportMetaDataOnCreate=True,
    S3={
        "AutoImportPolicy": {
            "Events": ['NEW', 'CHANGED', 'DELETED']
        },
         "AutoExportPolicy": {
            "Events": ['NEW', 'CHANGED', 'DELETED']
        }
    }
)

fsx_s3_response

In [None]:
# Wait for association to be complete
while True:
    fsx_s3_assoc = fsx_client.describe_data_repository_associations(
        AssociationIds=[fsx_s3_response['Association']['AssociationId']]
    )
    fsx_status = fsx_s3_assoc['Associations'][0]['Lifecycle']

    if fsx_status== "CREATING":
        print("Create in progress. Waiting..")
        time.sleep(30)
    elif fsx_status=="AVAILABLE":
        print("FSx - S3 association complete!")
        break
    else:
        print("Error creating the association, with status", fsx_status)

### Clean up resources

You can tear down the CloudFormation stack to delete the VPC and associated resources, and the FSx file system to avoid incurring costs.

In [None]:
# Delete the stack

delete_response = cfn_client.delete_stack(
    StackName=stack_id
)

delete_response