In [None]:
import os
import numpy as np
import pandas as pd
import sagemaker
from sagemaker.pytorch import PyTorch
 
sagemaker_session = sagemaker.Session()
bucket = "<your bucket name>" 
prefix = "sagemaker/pytorch-bert-financetext"
role = sagemaker.get_execution_role()

output_path = f"s3://{bucket}/{prefix}"

In [None]:
inputs_train = sagemaker_session.upload_data("./data/train.csv", bucket=bucket, key_prefix=prefix)
inputs_test = sagemaker_session.upload_data("./data/test.csv", bucket=bucket, key_prefix=prefix)


## Distributed training

In [None]:
estimator = PyTorch(
    entry_point="train-dis.py",
    source_dir="code",
    role=role,
    framework_version="1.6",
    py_version="py3",
    instance_count=2,  
    instance_type= "ml.g4dn.12xlarge", # "ml.p3.2xlarge",  ml.g4dn.12xlarge
    output_path=output_path,
    hyperparameters={
        "epochs": 10,
        "lr" : 5e-5,
        "num_labels": 3,
        "train_file": "train.csv",
        "test_file" : "test.csv",
        "MAX_LEN" : 315,
        "batch_size" : 64,
        "test_batch_size" : 10,
        "backend": "nccl"
    },
    
)
estimator.fit({"training": inputs_train, "testing": inputs_test}, logs = "None")


In [None]:
model_data = estimator.model_data
print(model_data)

## Deployment

In [None]:
from sagemaker.pytorch.model import PyTorchModel 

pytorch_model = PyTorchModel(model_data=model_data,
                             role=role,
                             framework_version="1.3.1",
                             source_dir="code",
                             py_version="py3",
                             entry_point="inference.py")

predictor = pytorch_model.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")

In [None]:
predictor.serializer = sagemaker.serializers.JSONSerializer()
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

In [None]:
result = predictor.predict("The market went up 15% today.  This is better than average")
print("predicted class: ", np.argmax(result))

In [None]:
predictor.delete_endpoint()