# Train Image Classification Model using VIT and Smart Sifting.

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

---

In this notebook we will train a image classification model using Vision Transformer (VIT) and Smart Sifting library. VIT is a transformer encoder model pretrained on large collection of images from ImageNEt at a resolution of 224X224 pixels.

### 1.Introduction to Smart Sifting 

Smart Sifting is a framework to speed up training of PyTorch models. The framework implements a set of algorithms that filter out inconsequential training examples during training, reducing the computational cost and accelerating the training process. It is configuration-driven and extensible, allowing users to add custom logic to transform their training examples into a filterable format. Smart sifting provides a generic utility for any DNN model, and can reduce the training cost by up to 35% in infrastructure cost.

![image](sifting_flow.png)



Smart sifting’s task is to sift through your training data during the training process and only feed the more informative samples to the model. During typical training with PyTorch, data is iteratively sent in batches to the training loop and to accelerator devices (e.g. GPUs or Trainium chips) by the PyTorch data loader. Smart sifting is implemented at this data loading stage and is thus independent of any upstream data preprocessing in your training pipeline. Smart sifting uses your live model and a user specified loss function to do an evaluative forward pass of each data sample as it is loaded. Samples which are high loss will materially impact model training and thus are included in training data; meanwhile data samples which are relatively low loss are already well represented by the model and so are set aside and excluded from training. A key input to smart sifting is the proportion of data to exclude: for example, by setting the proportion to 25%, samples in approximately the bottom quartile of loss of each batch will be excluded from training. Once enough high-loss samples have been identified to complete a batch, the data is sent through the full training loop and the model learns and trains normally. Customers don’t need to make any downstream changes to their training loop when smart sifting is enabled.|


## 2. Install Required Dependencies

In [None]:
    ! pip install datasets transformers --quiet
    ! pip install -U sagemaker boto3 --quiet

## 3. Prepare Dataset

For this training we will be using [Caltech-101 dataset](https://data.caltech.edu/records/mzrjq-6wc02). Caltech-101 consists of pictures of objects belonging to 101 classes. Each class contains roughly 40 to 800 images, totalling around 9k images. Images are of variable sizes, with typical edge lengths of 200-300 pixels. 

Lets start by downloading and extracting the dataset. 

In [None]:
! aws s3 cp --recursive s3://sagemaker-example-files-prod-us-west-2/datasets/image/caltech-101/ ./caltech

In [None]:
! tar -xf ./caltech/101_ObjectCategories.tar.gz

We will convert the downloaded data into huggingface datasets arrow format. Note: This is done for convenience not a requirement for sifting library. 

In [None]:
from datasets import load_dataset, DatasetDict

dataset = load_dataset("imagefolder", data_dir="101_ObjectCategories")
ds_train_devtest = dataset["train"].train_test_split(test_size=0.2, seed=42)

In [None]:
ds_splits = DatasetDict(
    {"train": ds_train_devtest["train"], "validation": ds_train_devtest["test"]}
)

After this step we should have a dataset with train and validation splits. Lets print the dataset to confirm.

In [None]:
print(f"Dataset Splits: \n {ds_splits}")

### Upload Dataset to S3 for Training

Lets upload the dataset to S3 , for this we will leverage the dataset API S3 integration to directly save DataSet object to s3. 

In [None]:
import sagemaker
import boto3

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [None]:
training_input_path = f"s3://{sess.default_bucket()}/dataset/caltech101"
print(f"uploading training dataset to: {training_input_path}")  # save train_dataset to s3
ds_splits.save_to_disk(training_input_path)

print(f"uploaded data to: {training_input_path}")

## 4. Run training Job using SageMaker Training.


Adding Sifting library to the Image classification code involves following the below steps

1. **Define Loss Function** - For Image classification we use CrossEntropy loss defined as below
    ````
    class ImageLoss(Loss):
    """
    This is an implementation of the Loss interface for the model 
    required for Smart Sifting. 
    """
    def __init__(self):
        self.celoss = torch.nn.CrossEntropyLoss(reduction='none')

    def loss(
            self,
            model: torch.nn.Module,
            transformed_batch: SiftingBatch,
            original_batch: Any = None,
    ) -> torch.Tensor:
        device = next(model.parameters()).device
        batch = {k: v.to(device) for k, v, in original_batch.items()}

        # compute loss
        outputs = model(**batch)
        return self.celoss(outputs.logits, batch["labels"])

    ````
2. **Define Transformation Function** to convert input batch to sifting format.

````
class ImageListBatchTransform(SiftingBatchTransform):
    """
    This is an implementation of the data transforms for the model 
    required for Smart Sifting. Transform to and from ListBatch
    """
    def transform(self, batch: Any):
        inputs = []
        labels = []

        for i in range(len(batch["pixel_values"])):
            inputs.append(batch["pixel_values"][i])

        for i in range(len(batch["labels"])):
            labels.append(batch["labels"][i])

        return ListBatch(inputs, labels)
    
    def reverse_transform(self, list_batch: ListBatch):
        a_batch = {}
        a_batch["pixel_values"] = self.stack_tensors(list_batch.inputs)
        a_batch["labels"] = self.stack_tensors(list_batch.labels)
  
        return a_batch

    def stack_tensors(self,list_of_tensors):
        if list_of_tensors:
            t = torch.stack(list_of_tensors)
        else:
            t = torch.tensor([])
        return t
````


3. **Define sifting config** - Define configuration for sifting. 

    Beta_value depicts the proportion of samples to keep , higher the value more samples are sifted. 
    loss_history_length - Depicts the window of samples to include when evaluating relative loss.

````
   sift_config = RelativeProbabilisticSiftConfig(
            beta_value=3,
            loss_history_length=500,
            loss_based_sift_config=LossConfig(
                 sift_config=SiftingBaseConfig(sift_delay=10)
            )
        )
````
4. **Wrap the Pytorch Data Loader with Sifting Data loader.
   As a last step we wrap the Pytroch Dataloader with siftingDataLoader passing config, transformation and loss functions.

```
  train_dataloader = SiftingDataloader(
                sift_config=sift_config,
                orig_dataloader=train_dataloader,
                batch_transforms=ImageListBatchTransform(),
                loss_impl=ImageLoss(),
                model=model
        )  
```


We define few metrics to be tracked inorder to monitor sifting. This are optional metrics useful to debug and understand sifting performance.

In [None]:
SAGEMAKER_METRIC_DEFINITIONS = [
    # Sifting Metrics
    {"Name": "num_samples_batch:count", "Regex": "num_orig_samples_seen: ([0-9.]+)"},
    {"Name": "num_kept_batch:count", "Regex": "num_orig_samples_kept: ([0-9.]+)"},
    {"Name": "batch_size:count", "Regex": "Batch size for next sifted batch: ([0-9.]+)"},
    {
        "Name": "sifted_batch_creation:latency",
        "Regex": "sifted_batch_creation_latency: ([0-9.e+\-]+)",
    },
    {"Name": "sifting_batch_transform:latency", "Regex": "func:transform latency: ([0-9.e+\-]+)"},
    {
        "Name": "calc_rel_vals:latency",
        "Regex": "func:calculate_relevance_values latency: ([0-9.e+\-]+)",
    },
    {"Name": "should_sift:latency", "Regex": "func:should_sift latency: ([0-9.e+\-]+)"},
    {"Name": "sift:latency", "Regex": "func:sift latency: ([0-9.e+\-]+)"},
    {"Name": "accumulate:latency", "Regex": "func:accumulate latency: ([0-9.e+\-]+)"},
    {
        "Name": "orig_batch_transform:latency",
        "Regex": "func:reverse_transform latency: ([0-9.e+\-]+)",
    },
    # The following two metrics (rel_vals:number, kept_rel_vals:number) have multiple data points per log line, but
    # this metric capturing method can only capture one data point per log line.
    # TODO: capture all data points once we have a custom reporting API for sifting.
    {"Name": "rel_vals:number", "Regex": "relevance_values: \[([0-9.e+\-]+)"},
    {"Name": "kept_rel_vals:number", "Regex": "relevance_values_kept: \[([0-9.e+\-]+)"},
    # Training Metrics
    {"Name": "train_step:count", "Regex": "train step count: ([0-9.]+)"},
    {"Name": "train_fp:latency", "Regex": "train forward pass latency: ([0-9.e+\-]+)"},
    {"Name": "train_bp:latency", "Regex": "train backprop latency: ([0-9.e+\-]+)"},
    {"Name": "train_optim:latency", "Regex": "train optimizer step latency: ([0-9.e+\-]+)"},
    {"Name": "train_total_step:latency", "Regex": "train total step latency: ([0-9.e+\-]+)"},
]

In [None]:
hyperparameters = {}

# change the model name/path here to switch between resnet: "microsoft/resnet-101" and vit: "google/vit-base-patch16-224-in21k"
# hyperparameters["model_name_or_path"] = "microsoft/resnet-101"
hyperparameters["model_name_or_path"] = "google/vit-base-patch16-224-in21k"

hyperparameters["seed"] = 100
hyperparameters["per_device_train_batch_size"] = 64
hyperparameters["per_device_eval_batch_size"] = 64
hyperparameters["learning_rate"] = 5e-5

hyperparameters["max_train_steps"] = 1000  # use 10000
hyperparameters["num_train_epochs"] = 4

hyperparameters["use_sifting"] = 1  # param for enabling sifting 0 or 1

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

We will launch the training job using G5.2xlarge instance. Smart Sifting library is part of the SageMaker Pytorch Deep Learning containers starting version 2.0.1. 

In [None]:
import os

base_job_name = "vit-img-classification-sifting"

estimator = PyTorch(
    base_job_name=base_job_name,
    source_dir="scripts",
    entry_point="train_images.py",
    role=role,
    framework_version="2.0.1",
    py_version="py310",
    instance_count=1,
    instance_type="ml.g5.2xlarge",
    hyperparameters=hyperparameters,
    disable_profiler=True,
    metric_definitions=SAGEMAKER_METRIC_DEFINITIONS,
)

Launch the training job with Data in S3

In [None]:
estimator.fit({"train": training_input_path}, wait=True)

In this notebook, we looked at how to use smart sifting library to train an Image classification model. Smart sifting helps in reducing training time upto 40% without any reduction in Model performance.

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)
