# Train and Deploy a Transformer Model for Adverse Event Classification in Amazon SageMaker

# Introduction

This notebook shows how to fine tune a transformer model in Amazon SageMaker for adverse event (AE) classification. We use the Hugging Face [Transformers](https://huggingface.co/transformers/) as example code and library to train and deploy the model in Amazon SageMaker.

The AE dataset used in this demo is the Hugging Face's Adverse Drug Reaction Data: [ade_corpus_v2](https://huggingface.co/datasets/ade_corpus_v2). Users can replace the dataset with their own data.

# Enviornment set up

In [None]:
# install packages
! pip install --upgrade datasets

In [None]:
import os
import numpy as np
import pandas as pd
import sagemaker
import argparse
from datasets import load_dataset
from sklearn.model_selection import train_test_split

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()

# 1. Raw dataset
Here we use the Hugging Face's Adverse Drug Reaction to create a raw dateset for model training. Users can skip this step if they have their own raw dataset for AE classification. The raw data should have two columns: one is the text column, and the other is the class column to indicate whether a text mentions AE or not. 

In [None]:
# Download the dataset
dataset = load_dataset("ade_corpus_v2", "Ade_corpus_v2_classification")
df_context, df_label = dataset['train'].__getitem__('text'), dataset['train'].__getitem__('label')
df_raw = pd.DataFrame(
    {'text': df_context,
     'class': df_label
    })

# convert label id to class description for the raw dataset
df_raw['class'] = df_raw['class'].apply(lambda x: 'Adverse_Event' if x == 1 else 'Not_AE')

# Save the raw dataset to a local folder ./data/
if not os.path.exists('data'):
    os.makedirs('data')

df_raw.to_csv('./data/raw_data.csv', index=False)

# 2. Process raw data and load it to S3 for model training

In [None]:
def load_data(data_path):
    """
    Load the raw data and convert the class names into integer IDs.
    """
    
    df = pd.read_csv(data_path)
    label2id = {'Adverse_Event': 1, 'Not_AE': 0}
    df['label'] = df['class'].map(lambda x: label2id[x])
    return df

# load the raw data and do basic data processing
df = load_data('./data/raw_data.csv')

In [None]:
# Create train and validation datasets
train, valid = train_test_split(df, test_size=0.20, shuffle = True, random_state = 1,  stratify=df[['class']])

# save prepared train and valid datasets into local for S3 uploading
data_dir = './data/model_input'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

train.to_csv(os.path.join(data_dir, 'train.csv'), index=False)
valid.to_csv(os.path.join(data_dir, 'valid.csv'), index=False)

### Upload train/valid data to S3 for SageMaker model training 

In [None]:
task_name = 'AE_bert/data'
s3_prefix = 'HF_models/' + task_name

# upload data to S3
inputs_data = sagemaker_session.upload_data(path=data_dir, bucket=bucket, key_prefix=s3_prefix)
print('input spec (in this case, just an S3 path): {}'.format(inputs_data))

# 3. SageMaker model training

In [None]:
from sagemaker.pytorch import PyTorch

In [None]:
# hyperparameters, which are passed into the training job
hyperparameters={'epochs': 4,
                 'train_batch_size': 64,
                 'max_seq_length': 128,
                 'learning_rate': 5e-5,
                 'model_name':'distilbert-base-uncased',
                 'text_column':'text', # the column name for input text
                 'label_column': 'label' # the column name for label IDs
                 }

In [None]:
# Amazon SageMaker PyTorch framework
train_instance_type = 'ml.p3.2xlarge'

bert_estimator = PyTorch(entry_point='hf_train_deploy.py',
                    source_dir = 'src',
                    role=role,
                    framework_version='1.4.0',
                    py_version='py3',
                    instance_count=1,
                    instance_type= train_instance_type, # use 'local' for code testing within the notebook instance
                    hyperparameters = hyperparameters
                   )

In [None]:
bert_estimator.fit({'training': inputs_data})

In [None]:
# the model artifact in S3 after training
bert_estimator.model_data

# 4. SageMaker Endpoint Deploy

In [None]:
import sagemaker
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import JSONSerializer

In [None]:
role = sagemaker.get_execution_role()

In [None]:
model_data = bert_estimator.model_data
src_dir = 'src'

pytorch_model = PyTorchModel(model_data=model_data,
                             role=role,
                             framework_version="1.4.0",
                             source_dir=src_dir,
                             py_version="py3",
                             entry_point="hf_train_deploy.py")

In [None]:
predictor = pytorch_model.deploy(initial_instance_count=1, 
                                 instance_type="ml.m5.large", 
                                 endpoint_name='HF-BERT-AE-model',
                                 serializer=JSONSerializer(),
                                 deserializer=JSONDeserializer())

# 5. Inference (optional): invoke SageMaker Endpoint
This example shows how to invoke an endpoint for model prediction. You can use AWS Lambda to invoke the endpoint for real-time model predictions.

In [None]:
import boto3
import json
import time
import numpy as np
import pandas as pd

In [None]:
endpoint_name = 'HF-BERT-AE-model'
runtime= boto3.client('runtime.sagemaker')

In [None]:
query = 'This entity is probably related to a combination of high doses of corticosteroids, vecuronium administration and metabolic abnormalities associated with respiratory failure.'


response = runtime.invoke_endpoint(EndpointName=endpoint_name,
                                   ContentType='application/json',
                                   Body=json.dumps(query))
prob = eval(response['Body'].read())
print(f"probability for Not_AE: {round(prob[0],3)}, for AE: {round(prob[1],3)}")

In [None]:
# set a classification threshold
threshold = 0.6

prd_prob = prob[1]
pred_label = "Adverse_Event" if prd_prob >= threshold else "Not_AE"

In [None]:
pred_label, prd_prob

# 6. Cleanup (optional)

If you don't need to keep the deployed endpoint live, please remember to delete the Amazon SageMaker endpoint to avoid charges:

In [None]:
# predictor.delete_endpoint()