# Local docker run for JAX VIT training

This notebook shows local docker run for JAX VIT training.
This notebook uses a workbench with TensorFlow 2.11 and 8 v100 GPUs.
You also need to upload the 'train_vit_gpu.Dockerfile' to the home directory for workbench.

In [None]:
# Build training docker

project = "cloud-nas-260507"
image_tag = "jax-vit-train-gpu-lavrai-test:latest"
train_docker_uri = "gcr.io/{}/{}".format(project, image_tag)

!docker build -f train_vit_gpu.Dockerfile . -t {image_tag}

!docker tag {image_tag} {train_docker_uri}

!docker push {train_docker_uri}

In [None]:
# Docker arguments.
workdir = "tmp"
docker_args_list = [
    "--config",
    "vit_jax/configs/augreg.py:R_Ti_16",
    "--config.dataset",
    "tf_flowers",
    "--config.pp.train",
    "train[:90%]",
    "--config.pp.test",
    "train[90%:]",
    "--config.batch_eval",
    "120",
    "--config.base_lr",
    "0.01",
    "--config.shuffle_buffer",
    "1000",
    "--config.total_steps",
    "100",
    "--config.warmup_steps",
    "10",
    "--config.accum_steps",
    "0",  # Not needed with R+Ti/16 model.
    "--config.pp.crop",
    "224",
    "--workdir",
    f"{workdir}",
]
print(docker_args_list)

In [None]:
# Utility functions.

import io
import subprocess
import sys


def run_command_with_stdout(cmd, job_log_file=None, error_message=""):
    """Runs the command and stream the command outputs."""
    if job_log_file is None:
        job_log_file = sys.stdout
    buf = io.StringIO()
    ret_code = None

    with subprocess.Popen(
        cmd,
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=False,
    ) as p:
        out = io.TextIOWrapper(p.stdout, newline="")

        for line in out:
            buf.write(line)
            job_log_file.write(line)
            job_log_file.flush()

        # flush to force the contents to display.
        job_log_file.flush()

        while p.poll() is None:
            # Process hasn't exited yet, let's wait some
            time.sleep(0.5)

        ret_code = p.returncode
        p.stdout.close()

    if ret_code:
        raise RuntimeError(
            "Error: {} with return code {}".format(error_message, ret_code)
        )
    return buf.getvalue(), ret_code

In [None]:
# Run local training.
cmd = ["nvidia-docker", "run"] + ["-t", train_docker_uri] + docker_args_list
run_command_with_stdout(cmd, error_message="Failed to run docker locally")