# Vision Transformer - Pytorch
This notebook shows how to fine-tune a pretrained HuggingFace vision transformer PyTorch model with AWS Trainium (trn1 instances) using NeuronSDK.
The original implementation is provided by HuggingFace.

The example has 2 stages:
1. First compile the model using the utility `neuron_parallel_compile` to compile the model to run on the AWS Trainium device.
1. Run the fine-tuning script to train the model based on image classification task. The training job will use 2 workers with data parallel to speed up the training. If you have a larger instance (trn1.32xlarge) you can increase the worker count to 8 or 32.

It has been tested and run on a trn1.2xlarge

**Reference:** https://huggingface.co/google/vit-base-patch16-224

## 1) Install dependencies

In [None]:
# Set Pip repository  to point to the Neuron repository
%pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
# now restart the kernel

In [None]:
#Install Neuron Compiler and Neuron/XLA packages
%pip install -U "numpy<=1.20.0" "protobuf<4" "transformers==4.27.3" datasets scikit-learn 
# use --force-reinstall if you're facing some issues while loading the modules
# now restart the kernel again

In [None]:
# Clone transformers from Gighub
!git clone https://github.com/huggingface/transformers --branch v4.27.3

## 2) Set the parameters

In [None]:
# Parameters
model_name = "google/vit-base-patch16-224-in21k"
extra_pip_packages = ""
extra_yum_packages = ""
env_var_options = "XLA_USE_BF16=1 NEURON_CC_FLAGS=\"--cache_dir=./compiler_cache --model-type=transformer\""
num_workers = 2
task_name = "image-classification"
dataset_name = "beans"
transformers_version = "4.27.3"
model_base_name = "vit"

## 3) Compile the model with neuron_parallel_compile

Compile with 2 workers may take about 45min

In [None]:
%%time
import subprocess
print("Compile model")
COMPILE_CMD = f"""{env_var_options} neuron_parallel_compile torchrun --nproc_per_node={num_workers} \
    transformers/examples/pytorch/image-classification/run_image_classification.py \
    --model_name_or_path {model_name} \
    --dataset_name {dataset_name} \
    --do_train \
    --num_train_epochs 2 \
    --per_device_train_batch_size 8 \
    --learning_rate 2e-5 \
    --save_strategy epoch \
    --save_total_limit 1 \
    --seed 1337 \
    --remove_unused_columns False \
    --overwrite_output_dir \
    --output_dir {model_base_name}-{task_name}"""

print(f'Running command: \n{COMPILE_CMD}')
if subprocess.check_call(COMPILE_CMD,shell=True):
   print("There was an error with the compilation command")
else:
   print("Compilation Success!!!")

## 4) Fine-tune the model

It takes 5min for 5 epoch with 2 workers

In [None]:
%%time
print("Train model")
RUN_CMD = f"""{env_var_options} torchrun --nproc_per_node={num_workers} \
    transformers/examples/pytorch/image-classification/run_image_classification.py \
    --model_name_or_path {model_name} \
    --dataset_name {dataset_name} \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --do_train \
    --do_eval \
    --remove_unused_columns False \
    --learning_rate 2e-5 \
    --num_train_epochs 5 \
    --logging_strategy steps \
    --logging_steps 10 \
    --evaluation_strategy epoch \
    --save_strategy epoch \
    --load_best_model_at_end True \
    --save_total_limit 3 \
    --seed 1337 \
    --overwrite_output_dir \
    --output_dir {model_base_name}-{task_name}"""

print(f'Running command: \n{RUN_CMD}')
if subprocess.check_call(RUN_CMD,shell=True):
   print("There was an error with the fine-tune command")
else:
   print("Fine-tune Successful!!!")