In [None]:
import sagemaker
from sagemaker import get_execution_role
import json
import boto3

sess = sagemaker.Session(default_bucket='packets-bucket-web')
role = get_execution_role()

bucket = 'packets-bucket-web'
prefix = 'text-classification/'

In [None]:
index_to_label = {}
with open("training_data/classes.txt") as f:
    for i, label in enumerate(f.readlines()):
        index_to_label[str(i)] = label.strip()
print(index_to_label)


In [None]:
from random import shuffle
import multiprocessing
from multiprocessing import Pool
import csv
import nltk

nltk.download("punkt")

In [None]:
def transform_instance(row):
    cur_row = []
    label = "__label__" + index_to_label[row[0]]  # Prefix the index-ed label with __label__
    cur_row.append(label)
    cur_row.extend(nltk.word_tokenize(row[1].lower()))
    cur_row.extend(nltk.word_tokenize(row[2].lower()))
    return cur_row

In [None]:
def preprocess(input_file, output_file, keep=1):
    all_rows = []
    with open(input_file, "r") as csvinfile:
        csv_reader = csv.reader(csvinfile, delimiter=",")
        for row in csv_reader:
            all_rows.append(row)
    shuffle(all_rows)
    all_rows = all_rows[: int(keep * len(all_rows))]
    pool = Pool(processes=multiprocessing.cpu_count())
    transformed_rows = pool.map(transform_instance, all_rows)
    pool.close()
    pool.join()

    with open(output_file, "w") as csvoutfile:
        csv_writer = csv.writer(csvoutfile, delimiter=" ", lineterminator="\n")
        csv_writer.writerows(transformed_rows)


In [None]:
%%time

# Preparing the training dataset

# Since preprocessing the whole dataset might take a couple of mintutes,
# we keep 20% of the training dataset for this demo.
# Set keep to 1 if you want to use the complete dataset
preprocess("training_data/training.csv", "packets.train", keep=0.2)

# Preparing the validation dataset
preprocess("training_data/testing.csv", "packets.validation")

In [None]:
%%time

train_channel = prefix + "/train"
validation_channel = prefix + "/validation"

sess.upload_data(path="packets.train", bucket=bucket, key_prefix=train_channel)
sess.upload_data(path="packets.validation", bucket=bucket, key_prefix=validation_channel)

s3_train_data = "s3://{}/{}".format(bucket, train_channel)
s3_validation_data = "s3://{}/{}".format(bucket, validation_channel)

In [None]:
s3_output_location = "s3://{}/{}/output".format(bucket, prefix)

In [None]:
region_name = boto3.Session().region_name

In [None]:
container = sagemaker.amazon.amazon_estimator.get_image_uri(region_name, "blazingtext", "latest")
print("Using SageMaker BlazingText container: {} ({})".format(container, region_name))

In [None]:
bt_model = sagemaker.estimator.Estimator(
    container,
    role,
    instance_count=1,
    instance_type="ml.c4.4xlarge",
    volume_size=30,
    max_run=360000,
    input_mode="File",
    output_path=s3_output_location,
    hyperparameters={
        "mode": "supervised",
        "epochs": 1,
        "min_count": 2,
        "learning_rate": 0.05,
        "vector_dim": 10,
        "early_stopping": True,
        "patience": 4,
        "min_epochs": 5,
        "word_ngrams": 2,
    },
)

In [None]:
train_data = sagemaker.inputs.TrainingInput(
    s3_train_data,
    distribution="FullyReplicated",
    content_type="text/plain",
    s3_data_type="S3Prefix",
)
validation_data = sagemaker.inputs.TrainingInput(
    s3_validation_data,
    distribution="FullyReplicated",
    content_type="text/plain",
    s3_data_type="S3Prefix",
)
data_channels = {"train": train_data, "validation": validation_data}
