# Train Text Classification Model using BERT and Smart Sifting.


---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

---

In this notebook we will train a text classification model using BERT (Transformers) and Smart Sifting library. BERT is a transformers encoder model pretrained on a large corpus of English data in a self-supervised fashion. This model is primarily aimed at being fine-tuned on tasks that use the whole sentence (potentially masked) to make decisions, such as sequence classification, token classification, and question answering. We will be building an BERT based Sentiment Analysis model using SST datasaet.

## 1.Introduction to Smart Sifting 

Smart Sifting is a framework to speed up training of PyTorch models. The framework implements a set of algorithms that filter out inconsequential training examples during training, reducing the computational cost and accelerating the training process. It is configuration-driven and extensible, allowing users to add custom logic to transform their training examples into a filterable format. Smart sifting provides a generic utility for any DNN model, and can reduce the training cost by up to 35% in infrastructure cost.

![image](sifting_flow.png)



Smart sifting’s task is to sift through your training data during the training process and only feed the more informative samples to the model. During typical training with PyTorch, data is iteratively sent in batches to the training loop and to accelerator devices (e.g. GPUs or Trainium chips) by the PyTorch data loader. Smart sifting is implemented at this data loading stage and is thus independent of any upstream data preprocessing in your training pipeline. Smart sifting uses your live model and a user specified loss function to do an evaluative forward pass of each data sample as it is loaded. Samples which are high loss will materially impact model training and thus are included in training data; meanwhile data samples which are relatively low loss are already well represented by the model and so are set aside and excluded from training. A key input to smart sifting is the proportion of data to exclude: for example, by setting the proportion to 25%, samples in approximately the bottom quartile of loss of each batch will be excluded from training. Once enough high-loss samples have been identified to complete a batch, the data is sent through the full training loop and the model learns and trains normally. Customers don’t need to make any downstream changes to their training loop when smart sifting is enabled.|


## 2. Prepare Dataset

For this training we will be using [SST2](https://huggingface.co/datasets/sst2). SST2 consists of positive/negative sentiment texts with roughly about 11k sentences extracted from movie reviews.

Lets start by downloading and extracting the dataset. 

In [None]:
! aws s3 cp --recursive s3://sagemaker-example-files-prod-us-west-2/datasets/text/SST2/ ./text

We will convert the dataset into a tsv file which will be uploaded to S3. 

In [None]:
train_file = open("text/sst2.train", "r")
input_lines = train_file.readlines()
train_data = []
for line in input_lines:
    label = line[0]
    text = line[2:].strip()
    data = {}
    data["label"] = label
    data["sentence"] = text
    train_data.append(data)

In [None]:
import pandas as pd

trainDF = pd.DataFrame(train_data)

trainDF.to_csv("train.tsv", sep="\t", index=False)

Upload the TSV file to S3

In [None]:
import sagemaker

In [None]:
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists

sagemaker_session_bucket = sess.default_bucket()
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

role = sagemaker.get_execution_role()

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [None]:
# upload to s3

train_data_url = sess.upload_data(
    path="train.tsv",
    key_prefix="data",
)

In [None]:
print(f"Training data uploaded at {train_data_url}")

## 3. Run training Job using SageMaker Training.


Adding Sifting library to the Image classification code involves following the below steps

1. **Define Loss Function** - For Image classification we use CrossEntropy loss defined as below
    ````

    class BertLoss(Loss):
    """
    This is an implementation of the Loss interface for the BERT model
    required for Smart Sifting. Use Cross-Entropy loss with 2 classes
    """
    def __init__(self):
        self.celoss = torch.nn.CrossEntropyLoss(reduction='none')

    def loss(
            self,
            model: torch.nn.Module,
            transformed_batch: SiftingBatch,
            original_batch: Any = None,
    ) -> torch.Tensor:
        # get original batch onto model device. Note that we are assuming the model is on the right device here   already
        # Pytorch lightning takes care of this under the hood with the model thats passed in.
        # TODO: ensure batch and model are on the same device in SiftDataloader so that the customer
        #  doesn't have to implement this
        device = next(model.parameters()).device
        batch = [t.to(device) for t in original_batch]

        # compute loss
        outputs = model(batch)
        return self.celoss(outputs.logits, batch[2])

    ````
2. **Define Transformation Function** to convert input batch to sifting format.

````
class BertListBatchTransform(SiftingBatchTransform):
    """
    This is an implementation of the data transforms for the BERT model
    required for Smart Sifting. Transform to and from ListBatch
    """
    def transform(self, batch: Any):
        inputs = []
        for i in range(len(batch[0])):
            inputs.append((batch[0][i], batch[1][i]))

        labels = batch[2].tolist()  # assume the last one is the list of labels
        return ListBatch(inputs, labels)

    def reverse_transform(self, list_batch: ListBatch):
        inputs = list_batch.inputs
        input_ids = [iid for (iid, _) in inputs]
        masks = [mask for (_, mask) in inputs]
        a_batch = [torch.stack(input_ids), torch.stack(masks), torch.tensor(list_batch.labels, dtype=torch.int64)]
        return a_batch
````


3. **Define sifting config** - Define configuration for sifting. 

    Beta_value depicts the proportion of samples to keep , higher the value more samples are sifted. 
    loss_history_length - Depicts the window of samples to include when evaluating relative loss.

````
   sift_config = RelativeProbabilisticSiftConfig(
            beta_value=3,
            loss_history_length=500,
            loss_based_sift_config=LossConfig(
                 sift_config=SiftingBaseConfig(sift_delay=10)
            )
        )
````
4. **Wrap the Pytorch Data Loader with Sifting Data loader**.
   As a last step we wrap the Pytroch Dataloader with siftingDataLoader passing config, transformation and loss functions.

```
   SiftingDataloader(
            sift_config = sift_config,
            orig_dataloader=DataLoader(self.train, self.batch_size, shuffle=True),
            loss_impl=BertLoss(),
            model=self.model,
            batch_transforms=BertListBatchTransform()
        )   
```


We define few metrics to be tracked inorder to monitor sifting. This are optional metrics useful to debug and understand sifting performance.

In [None]:
from sagemaker.pytorch import PyTorch


SAGEMAKER_METRIC_DEFINITIONS = [
    # Sifting Metrics
    {"Name": "num_samples_batch:count", "Regex": "num_orig_samples_seen: ([0-9.]+)"},
    {"Name": "num_kept_batch:count", "Regex": "num_orig_samples_kept: ([0-9.]+)"},
    {"Name": "batch_size:count", "Regex": "Batch size for next sifted batch: ([0-9.]+)"},
    {
        "Name": "sifted_batch_creation:latency",
        "Regex": "sifted_batch_creation_latency: ([0-9.e+\-]+)",
    },
    {"Name": "sifting_batch_transform:latency", "Regex": "func:transform latency: ([0-9.e+\-]+)"},
    {
        "Name": "calc_rel_vals:latency",
        "Regex": "func:calculate_relevance_values latency: ([0-9.e+\-]+)",
    },
    {"Name": "should_sift:latency", "Regex": "func:should_sift latency: ([0-9.e+\-]+)"},
    {"Name": "sift:latency", "Regex": "func:sift latency: ([0-9.e+\-]+)"},
    {"Name": "accumulate:latency", "Regex": "func:accumulate latency: ([0-9.e+\-]+)"},
    {
        "Name": "orig_batch_transform:latency",
        "Regex": "func:reverse_transform latency: ([0-9.e+\-]+)",
    },
    # The following two metrics (rel_vals:number, kept_rel_vals:number) have multiple data points per log line, but
    # this metric capturing method can only capture one data point per log line.
    # TODO: capture all data points once we have a custom reporting API for sifting.
    {"Name": "rel_vals:number", "Regex": "relevance_values: \[([0-9.e+\-]+)"},
    {"Name": "kept_rel_vals:number", "Regex": "relevance_values_kept: \[([0-9.e+\-]+)"},
    # Training Metrics
    {"Name": "train_step:count", "Regex": "train step count: ([0-9.]+)"},
    {"Name": "train_fp:latency", "Regex": "train forward pass latency: ([0-9.e+\-]+)"},
    {"Name": "train_bp:latency", "Regex": "train backprop latency: ([0-9.e+\-]+)"},
    {"Name": "train_optim:latency", "Regex": "train optimizer step latency: ([0-9.e+\-]+)"},
    {"Name": "train_total_step:latency", "Regex": "train total step latency: ([0-9.e+\-]+)"},
]

We will launch the training job using G5.2xlarge instance. Sifting library is part of the SageMaker Pytorch Deep Learning containers starting version 2.0.1. 

In [None]:
estimator = PyTorch(
    source_dir="./scripts",
    entry_point="train_bert.py",
    role=role,
    framework_version="2.0.1",
    py_version="py310",
    instance_count=1,
    instance_type="ml.g5.2xlarge",
    disable_profiler=True,
    metric_definitions=SAGEMAKER_METRIC_DEFINITIONS,
    hyperparameters={
        "epochs": 2,
        "num_nodes": 1,
        "log_level": 20,  # 10 is debug, 20 is info
        "log_batches": True,
    },
)

Launch the training job with Data in S3

In [None]:
from datetime import datetime

estimator.fit(
    inputs={"data": train_data_url},
    job_name=f"bert-cola-sifting-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}",
)

In this notebook, we looked at how to use smart sifting library to train an Text classification (Sentiment analysis) model. Smart sifting helps in reducing training time upto 40% without any reduction in Model performance.

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/training|smart_sifting|Text_Classification_BERT|Train_text_classification.ipynb)
