# UNET training - Pytorch 2.1
This notebook shows how to fine-tune a pretrained UNET PyTorch model with AWS Trainium (trn1 instances) using NeuronSDK.\
The model implementation is provided by milesial/Pytorch-UNet. 



The example has 2 stages:
1. First compile the model using the utility `neuron_parallel_compile` to compile the model to run on the AWS Trainium device.
1. Run the fine-tuning script to train the model based on image segmentaion task. The training job will use 32 workers with data parallel to speed up the training.

It has been tested and run on trn1.32xlarge instance using 256 x 256 input image for binary segmentation with batch size 4.

**Reference:** 
milesial, U-Net: Semantic segmentation with PyTorch, GitHub repository
https://github.com/milesial/Pytorch-UNet

## 1) Install dependencies

In [None]:
#Install Neuron Compiler and Neuron/XLA packages
%pip install -U "timm" "tensorboard" torchvision==0.16.*
%pip install -U "Pillow" "glob2" "scikit-learn" 
# use --force-reinstall if you're facing some issues while loading the modules
# now restart the kernel again

## 2) Download Carvana dataset
This example uses Carvana dataset which requires users to manually download the dataset before training.\
 https://www.kaggle.com/competitions/carvana-image-masking-challenge/data 

1. Download train.zip and train_masks.zip 
2. Unzip
3. Create a carvana directory
4. Directory structure\
carvana/train/\
carvana/train_masks/

dataset_path = \<Path to Carvana directory\>

## 3) Set the parameters

In [None]:
num_workers = 32
dataloader_num_workers = 2
image_dim = 256
num_epochs = 20

In [None]:
learning_rate = 2e-4
batch_size = 4
env_var_options = "NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3  " + \
    "NEURON_CC_FLAGS=\'--cache_dir=./compiler_cache --model-type=cnn-training\'"
dataset_path = "./carvana/"

## 4) Compile the model with neuron_parallel_compile

In [None]:
%%time
import subprocess
print("Compile model")
COMPILE_CMD = f"""{env_var_options} neuron_parallel_compile torchrun --nproc_per_node={num_workers} \
   train.py \
    --num_workers {dataloader_num_workers} \
    --image_dim {image_dim} \
    --num_epochs 2 \
    --batch_size {batch_size} \
    --drop_last \
    --data_dir {dataset_path} \
    --lr {learning_rate}"""

print(f'Running command: \n{COMPILE_CMD}')
if subprocess.check_call(COMPILE_CMD,shell=True):
   print("There was an error with the compilation command")
else:
   print("Compilation Success!!!")

## 5) Compile and Fine-tune the model

In [None]:
%%time
import subprocess
print("Compile model")
COMPILE_CMD = f"""{env_var_options} torchrun --nproc_per_node={num_workers} \
    train.py \
    --num_workers {dataloader_num_workers} \
    --image_dim {image_dim} \
    --num_epochs {num_epochs} \
    --batch_size {batch_size} \
    --do_eval \
    --drop_last \
    --data_dir {dataset_path} \
    --lr {learning_rate}"""

print(f'Running command: \n{COMPILE_CMD}')
if subprocess.check_call(COMPILE_CMD,shell=True):
   print("There was an error with the fine-tune command")
else:
   print("Fine-tune Successful!!!")