In [None]:
import boto3
import sagemaker

session = boto3.session.Session()
aws_region = session.region_name
s3_bucket  =  # s3 bucket name

try:
    s3_client = boto3.client('s3')
    response = s3_client.get_bucket_location(Bucket=s3_bucket)
    print(f"Bucket region: {response['LocationConstraint']}")
except:
    print(f"Access Error: Check if '{s3_bucket}' S3 bucket is in '{aws_region}' region")

In [None]:
s3_prefix = "models/blazing-text/classification/dbpedia"
s3_output_location = f"s3://{s3_bucket}/{s3_prefix}"
print(f"Model output location:{s3_output_location}")

In [None]:
container = sagemaker.image_uris.retrieve("blazingtext", aws_region, "1")
print(f"Using SageMaker BlazingText container: {container} ({aws_region})")

In [None]:
sess = sagemaker.Session()
role = sagemaker.get_execution_role()
print(role)

bt_model = sagemaker.estimator.Estimator(
    container,
    role,
    instance_count=1,
    instance_type="ml.c5.4xlarge",
    volume_size=100,
    max_run=360000,
    input_mode="File",
    output_path=s3_output_location,
    sagemaker_session=sess,
)

In [None]:
bt_model.set_hyperparameters(
    mode="supervised",
    word_ngrams=2,
    early_stopping=True,
    patience=4,
    min_epochs=15,
    epochs=30,
    min_count=5,
    learning_rate=0.05,
    vector_dim=150,
)

In [None]:
from sagemaker.inputs import TrainingInput
s3_train = f"s3://{s3_bucket}/blazing-text/classification/dbpedia/train"

train_input = TrainingInput(s3_data=s3_train, 
                            distribution="FullyReplicated", 
                            s3_data_type="S3Prefix", 
                            input_mode="File")

s3_validation = f"s3://{s3_bucket}/blazing-text/classification/dbpedia/validation"

validation_input = TrainingInput(s3_data=s3_train, 
                            distribution="FullyReplicated", 
                            s3_data_type="S3Prefix", 
                            input_mode="File")

data_channels = {"train": train_input, "validation": validation_input}

In [None]:
bt_model.fit(inputs=data_channels, logs="All", wait=True)

In [None]:
from sagemaker.serializers import JSONSerializer

text_classifier = bt_model.deploy(
    initial_instance_count=1, instance_type="ml.m5.xlarge", serializer=JSONSerializer()
)


In [None]:
import json

sentences = [
    "Convair was an american aircraft manufacturing company which later expanded into rockets and spacecraft.",
    "Berwick secondary college is situated in the outer melbourne metropolitan suburb of berwick .",
]

# using the same nltk tokenizer that we used during data preparation for training
tokenized_sentences = [" ".join(x.replace('"', '').replace(".", " . ").replace(",", " , ").replace("-", " - ").replace("'", "").replace("(", " ( ").replace(")", " ) ")) for x in sentences]

payload = {"instances": tokenized_sentences, "configuration": {"k": 2}}

response = text_classifier.predict(payload)

predictions = json.loads(response)
print(json.dumps(predictions, indent=2))

In [None]:
text_classifier.delete_endpoint(delete_endpoint_config=True)
text_classifier.delete_model()