# <center>Import libraries</center>

In [1]:
import os
import logging
from tqdm import tqdm
from utils.ssh_utils import SSH
from utils.sftp_utils import SFTP
from utils.line_utils import send_line_message
from utils.json_utils import combine_configs, read_json
from utils.pipeline_utils import remote_tensorboard_logdir

################################ Logger #######################################
ROOT_DIR = os.getcwd()
logging.basicConfig(
    filename=os.path.join(ROOT_DIR, "logs", "create_delete_logs.log"),
    level=logging.INFO,
    format='%(asctime)s + %(levelname)s + %(message)s'
)
logger = logging.getLogger('create_delete_logger')

########################### DO NOT CHANGE #####################################
PORT = 22
USERNAME = "your_vm_username"
PASSWORD = "your_vm_password"
VM_CONFIG_PATH = os.path.join(ROOT_DIR, "config", "vm_config.json")
VM_PUBLIC_IPS_PATH = os.path.join(ROOT_DIR, "config", "vm_public_ips.json")
vm_details_data = combine_configs(VM_CONFIG_PATH, VM_PUBLIC_IPS_PATH)

###############################################################################

# <center>Terraform</center>

In [None]:
!terraform -chdir=terraform/ init

In [None]:
!terraform -chdir=terraform/ plan

In [2]:
vm_keys = list(read_json(VM_CONFIG_PATH).get("vm_config", {}).keys())
ignore_var = [logger.info(f"Created VM: {k}") for k in vm_keys]

!terraform -chdir=terraform/ apply -auto-approve

In [None]:
!terraform -chdir=terraform/ output -json > config/vm_public_ips.json

# <center>Python</center>

**Files uploading and Model Training**

In [None]:
# TODO Using threads
# Insert you pretrained model's url here!!!
pretrained_url = "http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v2_320x320_coco17_tpu-8.tar.gz"

print(100 * "#")
for vm_key, vm_data in vm_details_data.items():
    print(47 * "#" + f" {vm_key} " + 48 * "#")
    public_ip = vm_data.get("public_ip", None)
    sftp_obj = SFTP(public_ip, PORT, USERNAME, PASSWORD)
    vm_name = vm_data.get("name", "")
    model = vm_data.get("model", "")
    workspace = vm_data.get("workspace", "")
    remote_wsp_dir = os.path.join(
        f"/home/{USERNAME}/your_repo",
        model,
        "workspace",
        workspace
    )

    # Remote
    remote_json_dir = os.path.join(remote_wsp_dir, "json")
    remote_picture_dir = os.path.join(remote_wsp_dir, "picture")
    remote_labelmap_path = os.path.join(remote_wsp_dir, "label_map.pbtxt")
    remote_pipeline_path = os.path.join(remote_wsp_dir, "pipeline.yml")

    # Local
    local_wsp_dir = os.path.join("workspace", model, workspace)
    local_json_dir = os.path.join(local_wsp_dir, "json")
    local_picture_dir = os.path.join(local_wsp_dir, "picture")
    local_labelmap_path = os.path.join(local_wsp_dir, "label_map.pbtxt")
    local_pipeline_path = os.path.join(local_wsp_dir, "pipeline.yml")
    local_picture_subdirs = [
        i for i in os.listdir(local_picture_dir) 
        if os.path.isdir(os.path.join(local_picture_dir, i))
    ]

    # Dataset Upload
    ssh_obj = SSH(public_ip, PORT, USERNAME, PASSWORD)
    ssh_obj.exec(f"mkdir -p {remote_wsp_dir}")
    print("Uploading labelmap and pipeline files ...")
    sftp_obj.putfile(
        local_labelmap_path, 
        os.path.join(remote_wsp_dir, "label_map.pbtxt")
    )

    sftp_obj.putfile(
        local_pipeline_path, 
        os.path.join(remote_wsp_dir, "pipeline.yml")
    )

    ssh_obj.exec(f"mkdir -p {remote_json_dir}")
    print("Uploading jsonfiles ...")
    for jsonfilename in tqdm(os.listdir(local_json_dir)):
        local_jsonfilepath = os.path.join(local_json_dir, jsonfilename)
        remote_jsonfilepath = os.path.join(remote_json_dir, jsonfilename)
        sftp_obj.putfile(local_jsonfilepath, remote_jsonfilepath)
    
    print("Uploading image files ...")
    for img_subdir in local_picture_subdirs:
        remote_picture_subdir = os.path.join(remote_picture_dir, img_subdir)
        local_picture_subdir = os.path.join(local_picture_dir, img_subdir)
        ssh_obj.exec(f"mkdir -p {remote_picture_subdir}")
        for imgfilename in tqdm(os.listdir(local_picture_subdir)):
            local_imgfilepath = os.path.join(local_picture_subdir, imgfilename)
            remote_imgfilepath = os.path.join(
                remote_picture_subdir, imgfilename
            )
            sftp_obj.putfile(local_imgfilepath, remote_imgfilepath)

    # Download example model
    print("Downloading pretrained model ...")
    if model == "detection":
        ssh_obj.exec(f"""
            cd your_repo/detection/pretrained-models && \
            wget {pretrained_url} && \
            tar -xzvf {pretrained_url.split("/")[-1]} && \
            rm {pretrained_url.split("/")[-1]}*"""
        )

    # Train model
    print("Training model ...")
    ssh_obj.exec(
        f"""docker exec \
            --detach \
            trainer sh -c \
            'cd {model} && \
            python3 train_{model}.py -w {workspace}'"""
    )

    # Expose tensorboard
    print("Exposing TensorBoard ...")
    tboard_logdir = remote_tensorboard_logdir(workspace, model)
    ssh_obj.exec(
        f"""docker run \
        --name watcher \
        --rm \
        --detach \
        -p 9999:9999 \
        -v /home/your_vm_name/your_repo:/your_working_space_in_container \
        your_container_image:your_container_tag \
        tensorboard --logdir={tboard_logdir} --host=0.0.0.0 --port=9999
        """
    )

    # Line alert
    send_line_message(f"{str(model).capitalize()} training job on VM is now running:\n - VM key: {vm_key}\n - VM name: {vm_name}\n - Workspace: {workspace}\n - Jupyter server: http://{public_ip}:8888/\n - TensorBoard: http://{public_ip}:9999/")
    
    ssh_obj.close()
    sftp_obj.close()

    print(100 * "#")

**Destroy all VMs (Uncomment the below cell)**

In [None]:
# !terraform -chdir=terraform/ destroy -auto-approve