# Training a YOLOv7 model on SageMaker

## Setup environment

In [None]:
from sagemaker import get_execution_role
from sagemaker.pytorch.estimator import PyTorch

role = get_execution_role()
inputs = {"training": 's3://sagemaker-eu-west-1-366243680492/yolo/input'}
# inputs = {'training': '<your-data-s3-path>'} # training data inputs. this is the dataset location in S3

## Get source code

We clone the [YOLOv7 github repo](https://github.com/wongkinyiu/yolov7) into a `source_dir` folder and then copy content from the `scripts` folder into it.

The `scripts` contains the `data.yaml` for our dataset. It also contains slightly modified versions of `train.py` and `test.py` so they can work smoothly with SageMaker.

In [None]:
# !git clone https://github.com/WongKinYiu/yolov7.git source_dir
# !cp scripts/* source_dir/

## Get pre-trained weights for YOLOv7

In [None]:
# !wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7_training.pt -P source_dir

## Train

In [None]:
hyperparameters={
    "name": "yolov7-custom",
    "workers":"8",
    "device": "0",
    "batch-size": "24",
    "epochs": 300,
    "data": "data.yaml",
    "weights": "yolov7_training.pt",
    "save_dir": "/opt/ml/model"
}


# We define the Estimator object (the one leveraging the PyTorch framework container):
estimator = PyTorch(
    framework_version='1.11.0',
    py_version='py38',
    entry_point='train.py',
    source_dir='source_dir',
    hyperparameters=hyperparameters,
    instance_count=1,
    instance_type='ml.g5.2xlarge',
    role=role,
    disable_profiler=True
)

In [None]:
estimator.fit(inputs)