# VisionTextDualEncoder and CLIP model training - Pytorch
This notebook shows how to fine-tune a pretrained HuggingFace CLIP PyTorch model with AWS Trainium (trn1 instances) using NeuronSDK.
The original implementation is provided by HuggingFace.

The example has 2 stages:
1. First compile the model using the utility `neuron_parallel_compile` to compile the model to run on the AWS Trainium device.
1. Run the fine-tuning script to train the model based on image classification task. The training job will use 2 workers with data parallel to speed up the training. If you have a larger instance (trn1.32xlarge) you can increase the worker count to 8 or 32.

It has been tested and run on trn1.2xlarge and trn1.32xlarge instances

**Reference:** https://huggingface.co/openai/clip-vit-base-patch32

## 1) Install dependencies

In [None]:
# Set Pip repository  to point to the Neuron repository
%pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
# now restart the kernel

In [None]:
#Install Neuron Compiler and Neuron/XLA packages
%pip install -U "numpy<=1.20.0" "protobuf<4" "transformers==4.27.3" datasets scikit-learn 
# use --force-reinstall if you're facing some issues while loading the modules
# now restart the kernel again

## 2) Set the parameters

In [None]:
# Parameters
model_name = "openai/clip-vit-base-patch32"
text_model_name = "roberta-base"
env_var_options = "MALLOC_ARENA_MAX=64 XLA_USE_BF16=1 NEURON_CC_FLAGS=\"--cache_dir=./compiler_cache\""
num_workers = 2
task_name = "contrastive-image-text"
dataset_name = "ydshieh/coco_dataset_script"
transformers_version = "4.27.3"
model_base_name = "clip"
per_device_train_batch_size = 64
per_device_eval_batch_size = 64
learning_rate = 5e-5

## 3) Download COCO dataset
This example uses COCO dataset (2017) through a custom dataset script, which requires users to manually download the COCO dataset before training.

In [None]:
!mkdir -p data
!wget http://images.cocodataset.org/zips/train2017.zip -P data
!wget http://images.cocodataset.org/zips/val2017.zip -P data
!wget http://images.cocodataset.org/zips/test2017.zip -P data
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P data
!wget http://images.cocodataset.org/annotations/image_info_test2017.zip -P data

## 4) Compile the model with neuron_parallel_compile

In [None]:
%%time
import subprocess
print("Compile model")
COMPILE_CMD = f"""{env_var_options} neuron_parallel_compile torchrun --nproc_per_node={num_workers} \
    run_clip.py \
    --model_name_or_path {model_name} \
    --text_model_name_or_path {text_model_name} \
    --data_dir $PWD/data \
    --dataset_config_name=2017 \
    --dataset_name {dataset_name} \
    --image_column image_path \
    --caption_column caption \
    --do_train \
    --max_steps 10 \
    --num_train_epochs 2 \
    --per_device_train_batch_size {per_device_train_batch_size} \
    --per_device_eval_batch_size {per_device_eval_batch_size} \
    --learning_rate {learning_rate} \
    --warmup_steps 0 \
    --weight_decay 0.1 \
    --save_strategy epoch \
    --save_total_limit 1 \
    --seed 1337 \
    --remove_unused_columns False \
    --overwrite_output_dir \
    --output_dir {model_base_name}-{task_name}"""

print(f'Running command: \n{COMPILE_CMD}')
if subprocess.check_call(COMPILE_CMD,shell=True):
   print("There was an error with the compilation command")
else:
   print("Compilation Success!!!")

## 5) Fine-tune the model

In [None]:
%%time
print("Train model")
RUN_CMD = f"""{env_var_options} torchrun --nproc_per_node={num_workers} \
    run_clip.py \
    --model_name_or_path {model_name} \
    --text_model_name_or_path roberta-base \
    --data_dir $PWD/data \
    --dataset_config_name=2017 \
    --dataset_name {dataset_name} \
    --image_column image_path \
    --caption_column caption \
    --do_train \
    --do_eval \
    --num_train_epochs 2 \
    --per_device_train_batch_size {per_device_train_batch_size} \
    --per_device_eval_batch_size {per_device_eval_batch_size} \
    --learning_rate {learning_rate} \
    --warmup_steps 0 \
    --weight_decay 0.1 \
    --save_strategy epoch \
    --save_total_limit 1 \
    --seed 1337 \
    --remove_unused_columns False \
    --overwrite_output_dir \
    --output_dir {model_base_name}-{task_name}"""

print(f'Running command: \n{RUN_CMD}')
if subprocess.check_call(RUN_CMD,shell=True):
   print("There was an error with the fine-tune command")
else:
   print("Fine-tune Successful!!!")