# Train a multiclass CNN model using tensorflow 

Step 1: Import the required packages. 

In [None]:
import sagemaker

sess = sagemaker.Session()
role = sagemaker.get_execution_role()

Step 2: Provide the path of S3 bucket where the data is stored. Provide the actual bucket name.

In [None]:
#Give the actual bucket name below
bucket = "replace this text with actual bucket name"
prefix = "dataset"   # whatever folder structure you use

train_path = f"s3://{bucket}/{prefix}/train"
val_path   = f"s3://{bucket}/{prefix}/validation"


print(f"Training path: {train_path}")
print(f"Validation path: {val_path}")


Step 3: Display the script

In [None]:
!pygmentize train-cnn.py

Step 4: Train locally first for 1 epoch. If everything works fine then train externally on an instance for more number of epochs.

In [None]:
from sagemaker.tensorflow import TensorFlow

tf_estimator = TensorFlow(entry_point="train-cnn.py",
                          role=role,
                          instance_count=1,
                          instance_type="local",
                          framework_version="2.11",
                          py_version="py39",
                          hyperparameters={ "epochs": 1}
)



Step 5: Train for 1 epoch

In [None]:
tf_estimator.fit({'training': train_path, 'validation': val_path})

Step 6: When things works fine, we can train for more number of epochs. Using ml.m5.large instance type instead of GPU to control the cost. For bigger datasets use instance with GPU.

In [None]:
tf_estimator = TensorFlow(entry_point="train-cnn.py",
                          role=role,
                          instance_count=1,
                          instance_type="ml.m5.large",
                          framework_version="2.11",
                          py_version="py39",
                          hyperparameters={
                            "epochs": 10,
                            "batch_size": 32,
                            "learning_rate": 0.001
                          }
)

Step 7: Train the model

Note: Mostly we'll get accuracy around 75% as we're training with small dataset and only for 10 epochs. Increasing dataset and training for more number of epochs shall improve the accuracy. 

In [None]:
tf_estimator.fit({'training': train_path, 'validation': val_path})

Step 8: Deploy the model

In [None]:
#Endpoint name, this will be used later during inference and cleanup 
tf_endpoint_name = "cnn-endpoint"

tf_predictor = tf_estimator.deploy(initial_instance_count=1,
                         instance_type='ml.c5.large',        
                         endpoint_name=tf_endpoint_name)    

Step 9: Make Prediction

In [None]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import json
import boto3

# Initialize SageMaker runtime client
runtime = boto3.client('sagemaker-runtime', region_name='us-east-1')

def predict_image(img_path, endpoint_name):
    # Load image and resize
    img = Image.open(img_path).convert('RGB')
    img = img.resize((128, 128))  # same as training

    # Convert to numpy array and normalize
    x = np.array(img) / 255.0
    x = np.expand_dims(x, axis=0).tolist()  # convert to list for JSON

    # Call SageMaker endpoint
    response = runtime.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Body=json.dumps({"instances": x})
    )
    result = json.loads(response['Body'].read())
    prediction = result['predictions']
    predicted_class = int(np.argmax(prediction, axis=1)[0])

    print(f"{img_path} → Predicted Class: {predicted_class}")

    # Show image
    plt.imshow(img)
    plt.title(f"Predicted Class: {predicted_class}")
    plt.axis('off')
    plt.show()

# Example usage, update the test image names as per actual
predict_image('dog.jpeg', 'cnn-endpoint')
predict_image('elephant.jpeg', 'cnn-endpoint')
predict_image('horse.jpeg', 'cnn-endpoint')


Step 10: Cleanup deployment

In [None]:
sess.delete_endpoint(endpoint_name=tf_endpoint_name)