# Train Image Classification Model using VIT and Smart Sifting.

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

---

In this notebook we will train a image classification model using Vision Transformer (VIT). VIT is a transformer encoder model pretrained on large collection of images from ImageNEt at a resolution of 224X224 pixels.

## 2. Install Required Dependencies

In [None]:
! pip install datasets transformers --quiet
! pip install -U sagemaker boto3 --quiet

## 3. Prepare Dataset

For this training we will be using [Caltech-101 dataset](https://data.caltech.edu/records/mzrjq-6wc02). Caltech-101 consists of pictures of objects belonging to 101 classes. Each class contains roughly 40 to 800 images, totalling around 9k images. Images are of variable sizes, with typical edge lengths of 200-300 pixels. 

Lets start by downloading and extracting the dataset. 

In [None]:
! aws s3 cp --recursive s3://sagemaker-example-files-prod-us-west-2/datasets/image/caltech-101/ ./caltech

In [None]:
! tar -xf ./caltech/101_ObjectCategories.tar.gz --no-same-owner

We will convert the downloaded data into huggingface datasets arrow format. Note: This is done for convenience not a requirement for sifting library. 

In [None]:
from datasets import load_dataset, DatasetDict

dataset = load_dataset("imagefolder", data_dir="101_ObjectCategories")
ds_train_devtest = dataset["train"].train_test_split(test_size=0.2, seed=42)

In [None]:
ds_splits = DatasetDict(
    {"train": ds_train_devtest["train"], "validation": ds_train_devtest["test"]}
)

After this step we should have a dataset with train and validation splits. Lets print the dataset to confirm.

In [None]:
print(f"Dataset Splits: \n {ds_splits}")

### Upload Dataset to S3 for Training

Lets upload the dataset to S3 , for this we will leverage the dataset API S3 integration to directly save DataSet object to s3. 

In [None]:
import sagemaker
import boto3

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [None]:
training_input_path = f"s3://{sess.default_bucket()}/dataset/caltech101"
print(f"uploading training dataset to: {training_input_path}")  # save train_dataset to s3
ds_splits.save_to_disk(training_input_path)

print(f"uploaded data to: {training_input_path}")

## 4. Run training Job using SageMaker Training.


We define few metrics to be tracked inorder to monitor sifting. This are optional metrics useful to debug and understand sifting performance.

In [None]:
hyperparameters = {}

# change the model name/path here to switch between resnet: "microsoft/resnet-101" and vit: "google/vit-base-patch16-224-in21k"
# hyperparameters["model_name_or_path"] = "microsoft/resnet-101"
hyperparameters["model_name_or_path"] = "google/vit-base-patch16-224-in21k"

hyperparameters["seed"] = 100
hyperparameters["per_device_train_batch_size"] = 64
hyperparameters["per_device_eval_batch_size"] = 64
hyperparameters["learning_rate"] = 5e-5

hyperparameters["max_train_steps"] = 1000  # use 10000
hyperparameters["num_train_epochs"] = 4

In [None]:
from sagemaker.pytorch import PyTorch

We will launch the training job using p3.2xlarge instance and Pytorch deep learning container.

In [None]:
import os

base_job_name = "vit-img-classification-sifting"

estimator = PyTorch(
    base_job_name=base_job_name,
    source_dir="scripts",
    entry_point="train_images.py",
    role=role,
    framework_version="2.0.1",
    py_version="py310",
    instance_count=1,
    instance_type="ml.p3.2xlarge",
    hyperparameters=hyperparameters,
    disable_profiler=True,
    disable_output_compression=True
)

Launch the training job with Data in S3

In [None]:
estimator.fit({"train": training_input_path}, wait=True)

In this notebook, we looked at how to use smart sifting library to train an Image classification model. Smart sifting helps in reducing training time upto 40% without any reduction in Model performance.

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/training|smart_sifting|Image_Classification_VIT|Train_Image_classification.ipynb)
