# Armor U-Net Training on Google Colab

This notebook installs the project dependencies, syncs the repository, configures paths/hyperparameters, and runs `train_armor_detector` using the same codebase as the main repo. Adjust the form fields as needed for your dataset location or run length before executing the training cell.

In [None]:
# Check GPU availability (optional but recommended)
!nvidia-smi

In [None]:
%%capture
# Install the Python dependencies needed by the project
!pip install --quiet pytorch-lightning>=2.2 albumentations>=1.3 wandb>=0.16 pillow matplotlib

In [None]:
#@title (Optional) Mount Google Drive
USE_DRIVE = False  # @param {type:"boolean"}
if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')

In [None]:
#@title Clone or update the repository
import os, subprocess, sys, pathlib
REPO_URL = "https://github.com/YOUR_USERNAME/torch-lightning-with-ray.git"  # @param {type:"string"}
TARGET_DIR = "/content/torch-lightning-with-ray"  # @param {type:"string"}
target_path = pathlib.Path(TARGET_DIR)
if not target_path.exists():
    subprocess.run(["git", "clone", REPO_URL, str(target_path)], check=True)
else:
    subprocess.run(["git", "-C", str(target_path), "pull"], check=True)
os.chdir(target_path)
if str(target_path) not in sys.path:
    sys.path.insert(0, str(target_path))
print(f'Working directory: {target_path}')

In [None]:
#@title Configure data/log locations and key hyperparameters
import os, pathlib
DATA_ROOT = "/content/Dataset_Robomaster-1"  # @param {type:"string"}
CHECKPOINT_DIR = "/content/checkpoints"  # @param {type:"string"}
LOG_DIR = "/content/logs"  # @param {type:"string"}
BATCH_SIZE = 4  # @param {type:"integer"}
MAX_EPOCHS = 5  # @param {type:"integer"}
BASE_CHANNELS = 32  # @param {type:"integer"}
LEARNING_RATE = 1e-4  # @param {type:"number"}
for path in (DATA_ROOT, CHECKPOINT_DIR, LOG_DIR):
    pathlib.Path(path).mkdir(parents=True, exist_ok=True)
os.environ['DATA_ROOT'] = DATA_ROOT
os.environ['CHECKPOINT_DIR'] = CHECKPOINT_DIR
os.environ['LOG_DIR'] = LOG_DIR
print(f'DATA_ROOT: {DATA_ROOT}')
print(f'CHECKPOINT_DIR: {CHECKPOINT_DIR}')
print(f'LOG_DIR: {LOG_DIR}')

In [None]:
#@title Authenticate with Weights & Biases
import wandb
wandb.login()

In [None]:
#@title Run training
from train import train_armor_detector
model, trainer, datamodule = train_armor_detector(
    data_root=DATA_ROOT,
    batch_size=BATCH_SIZE,
    max_epochs=MAX_EPOCHS,
    learning_rate=LEARNING_RATE,
    base_channels=BASE_CHANNELS,
    checkpoint_dir=CHECKPOINT_DIR,
    log_dir=LOG_DIR,
)