# Training a YOLO implementation in pytorch

## Step 1: Preparation

### Upgrades

In [2]:
%pip install --upgrade sagemaker

### Import requirements

In [10]:
import boto3
import sagemaker
from sagemaker.processing import ScriptProcessor, ProcessingInput, ProcessingOutput
from sagemaker.pytorch import PyTorch
from sagemaker.debugger import DebuggerHookConfig, CollectionConfig

In [None]:
VOC_TAR = "VOCtrainval_11-May-2012.tar"
# To do: Search for the bucket name by prefix
BUCKET_NAME = 'sagemaker-20250212110251'

### Initialize session and role

In [6]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

### S3 locations

In [None]:
raw_data_s3_uri = f's3://{BUCKET_NAME}/datasets/{VOC_TAR}'
# Folder where processed data is stored
processed_prefix = 'datasets/processed/VOC'
processed_data_s3_uri = f's3://{BUCKET_NAME}/{processed_prefix}'
s3_debugger_output_path = f's3://{BUCKET_NAME}/debugger-output'

## Step 2: Preprocessing Job

Create a ScriptProcessor to run a data preprocessing script.
In this example, "preprocess.py" should be a script you create that:
- Downloads and extracts the tar file.
- Splits the dataset into training/validation/test sets.
- Converts the dataset into a format (for example, preprocessed images and annotation files) that your training script expects.
Upload this script into a local folder (e.g., "preprocessing") which you'll point to as your source_dir.

In [18]:
image_uri = sagemaker.image_uris.retrieve(
    framework='pytorch',
    region=sagemaker_session.boto_region_name,
    version='2.1.0',
    py_version='py310',
    image_scope='training',
    instance_type='ml.m5.xlarge'
)

script_processor = ScriptProcessor(
    command=['python3'],
    image_uri=image_uri,
    role=role,
    instance_count=1,
    instance_type='ml.m5.xlarge',
    sagemaker_session=sagemaker_session
)

### Run the processing job

In [None]:
s3 = boto3.client('s3')

# List objects under the processed data prefix
response = s3.list_objects_v2(Bucket=BUCKET_NAME, Prefix=processed_prefix)
if 'Contents' in response and len(response['Contents']) > 0:
    print("Processed data already exists. Skipping processing job.")
else:
    print("Processed data not found. Running processing job.")
    script_processor.run(
        code='preprocessing/preprocess.py',
        inputs=[
            ProcessingInput(
                source=raw_data_s3_uri,
                destination='/opt/ml/processing/input'
            )
        ],
        outputs=[
            ProcessingOutput(
                output_name='processed_data',
                source='/opt/ml/processing/output',
                destination=processed_data_s3_uri
            )
        ],
        arguments=[
            '--image_size', '448'
        ]
    )

    print("Data preprocessing complete. Processed data available at:",
          processed_data_s3_uri)

## Step 2: Launch a Training Job with Debugger

### Define a Debugger hook configuration. This tells SageMaker which collections to capture.

In [None]:
debugger_hook_config = DebuggerHookConfig(
    s3_output_path=s3_debugger_output_path,
    # Saves debugger tensors every 100 steps
    hook_parameters={"save_interval": "100"},
    collection_configs=[
        CollectionConfig("losses"),      # Collects training/validation losses
        CollectionConfig("gradients"),   # Collects gradient information
        CollectionConfig("weights")      # Collects weight values
    ]
)

### Create a PyTorch estimator. Ensure your training code (e.g., train.py and supporting modules) is available in a source directory.

In [None]:
estimator = PyTorch(
    entry_point='training/train.py',
    role=role,
    instance_count=1,
    # instance_type='ml.p3.2xlarge',
    instance_type='ml.g4dn.xlarge',
    image_uri=image_uri,
    framework_version='2.1.0',
    py_version='py310',
    debugger_hook_config=debugger_hook_config,
    hyperparameters={
        'epochs': 10,
        'batch-size': 32,
        'learning-rate': 0.001,
        'data-dir': '/opt/ml/input/data/processed',
        'load-weights': '',
    },
    debugger_rule_configs=[],
    dependencies=['requirements.txt'],
    command=["accelerate", "launch", "train.py"],
)

### Define channel for preprocessed data.

The processed data is output from the processing job and will be used as input for training.

In [None]:
data_channels = {
    'processed': processed_data_s3_uri
}

### Launch the training job

In [None]:
estimator.fit(inputs=data_channels)