## SageMaker training with EFS

This sample shows how to:

- Setup VPC
- Setup EFS
- Setup Security Group 

**Please make sure the CIDR block in setup/cfn-nlp.yaml does not conflict with your existing VPC. **

In [1]:
import sagemaker, boto3, time, json

sagemaker_session = sagemaker.Session()

print(f"sagemaker session region: {sagemaker_session.boto_region_name}")

  import scipy.sparse


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
sagemaker session region: us-west-2


In [2]:
# CF Setup
region = sagemaker_session.boto_region_name  # update this if your region is different

# Clients
cfn_client = boto3.client("cloudformation", region_name=region)
fsx_client = boto3.client("fsx", region_name=region)

# Inputs
region_az = "us-west-2a"  # customize this as needed. Your EFS will be set up in a subnet in this AZ

In [7]:
# Execute the CF stack

timestamp = time.strftime("%Y-%m-%d-%H-%M-%S")
cfn_stack_name = f'vpc-setup-for-torchtune-{timestamp}'  # cloudformation stack name

# Setup infrastructure using CloudFormation
with open("scripts/cfn-nlp.yaml", "r") as f:
    template_body = f.read()
    
create_stack_response = cfn_client.create_stack(
    StackName=cfn_stack_name,
    TemplateBody=template_body,
    Parameters=[
        {
            'ParameterKey': 'AZ',
            'ParameterValue': region_az
        }
    ]
)

create_stack_response

{'StackId': 'arn:aws:cloudformation:us-west-2:015476483300:stack/vpc-setup-for-torchtune-2024-10-15-15-03-12/98f23710-8b06-11ef-8d16-0667c3ef04d1',
 'ResponseMetadata': {'RequestId': '0eb9b55a-d28b-4389-81fb-eb4694a48e15',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '0eb9b55a-d28b-4389-81fb-eb4694a48e15',
   'date': 'Tue, 15 Oct 2024 15:03:12 GMT',
   'content-type': 'text/xml',
   'content-length': '413',
   'connection': 'keep-alive'},
  'RetryAttempts': 0}}

In [8]:
# Stack name that we will create

stack_name = create_stack_response['StackId'].split('/')[-2]
stack_name

'vpc-setup-for-torchtune-2024-10-15-15-03-12'

In [9]:
# Wait for CF stack to complete

def wait_for_stack_completion(stack_name, region):
    cf_client = boto3.client('cloudformation', region_name=region)
    
    print(f"Waiting for stack {stack_name} to complete...")
    while True:
        response = cf_client.describe_stacks(StackName=stack_name)
        status = response['Stacks'][0]['StackStatus']
        
        if status.endswith('_COMPLETE'):
            print(f"Stack {stack_name} completed with status: {status}")
            break
        elif status.endswith('_FAILED'):
            print(f"Stack {stack_name} failed with status: {status}")
            break
        else:
            print(f"Current status: {status}. Waiting...")
            time.sleep(30)

# Replace with your actual stack name
wait_for_stack_completion(stack_name, region)

Waiting for stack vpc-setup-for-torchtune-2024-10-15-15-03-12 to complete...
Current status: CREATE_IN_PROGRESS. Waiting...
Current status: CREATE_IN_PROGRESS. Waiting...
Current status: CREATE_IN_PROGRESS. Waiting...
Current status: CREATE_IN_PROGRESS. Waiting...
Stack vpc-setup-for-torchtune-2024-10-15-15-03-12 completed with status: CREATE_COMPLETE


In [17]:
# Get EFS-id, private-subnet-id and EFS-id for next step of fine-tuning

def get_stack_outputs(stack_name, region='us-east-1'):
    """
    Retrieves all outputs from a CloudFormation stack.
    
    :param stack_name: Name of the CloudFormation stack
    :param region: AWS region where the stack is deployed (default is 'us-east-1')
    :return: Dictionary of stack outputs
    """
    cfn_client = boto3.client('cloudformation', region_name=region)
    
    try:
        response = cfn_client.describe_stacks(StackName=stack_name)
        stack_outputs = response['Stacks'][0]['Outputs']
        
       # print(stack_outputs)
        # Convert the list of outputs to a dictionary for easier access
        outputs_dict = {output['OutputKey']: output['OutputValue'] for output in stack_outputs}
       

        return outputs_dict
    
    except Exception as e:
        print(f"Error retrieving stack outputs: {str(e)}")
        return None

stack_name='vpc-setup-for-torchtune-2024-10-15-15-03-12'
outputs = get_stack_outputs(stack_name, region)
print(outputs['sg'])
print(outputs['privatesubnet'])
print(outputs['outputEFS'])


print("done")
if outputs:
    print("Stack Outputs:")
    #for key, value in outputs.items():
       # print(f"{key}: {value}")
else:
    print("Failed to retrieve stack outputs.")

sg-045de559bb87ac814
subnet-006e7d64b8f6ce47b
fs-064f90a1bc9c12c76
done
Stack Outputs:
