## Setup

In [None]:
import os
import sys

dir = os.path.abspath('')
while not dir.endswith('ardt'): dir = os.path.dirname(dir)
if not dir in sys.path: sys.path.append(dir)

In [None]:
import requests
import tempfile
import time

from utils.helpers import find_root_dir

from private_keys import LAMBDALABS_API_KEY, LAMBDALABS_SSH_KEY, GITHUB_PRIVATE_KEY


API_BASE = "https://cloud.lambdalabs.com/api/v1"
SSH_KEY_NAMES = ["afonso-mbp"]  # can only be one

s = requests.session()
s.headers["Authorization"] = f"Bearer {LAMBDALABS_API_KEY}"

## Helpers

In [None]:
ARDT_DIR = find_root_dir()
SCRIPT_TO_RUN = "run_eval"
SCRIPT_PATH = "~/action-robust-decision-transformer/cluster-scripts/other-run-scripts/"

In [None]:
def choose_instance_type():
    """
    Choose an instance time based on what is available across the regions
    """
    instance_types = s.get(f"{API_BASE}/instance-types").json()['data']
    options = []
    print("Choose instance type:")
    for k in sorted(instance_types.keys()):
        if len(instance_types[k]['regions_with_capacity_available']) > 0:
            regions = [r['name'] for r in instance_types[k]['regions_with_capacity_available']]
            print(len(options) + 1, ":", k, 
                  instance_types[k]['instance_type']['price_cents_per_hour'], 
                  regions)
            options.append((k, regions))
    idx = int(input(), )
    name, regions = options[idx-1]
    print("Choose region:")
    for n, r in enumerate(regions):
        print(n+1, ":", r)
    idx = int(input())
    return name, regions[idx-1]


def up(instance_type, region):
    """
    Spin up the instance.
    """
    data = {
        "region_name": region,
        "instance_type_name": instance_type,
        "ssh_key_names": SSH_KEY_NAMES,
        "file_system_names": [],
        "quantity": 1,
        "name": SCRIPT_TO_RUN.split(".")[0]
    }
    resp = s.post(f"{API_BASE}/instance-operations/launch", json=data)
    print("Spinning up selected instance:\n", resp.json())
    resp.raise_for_status()
    return resp.json()['data']['instance_ids'][0]


def get_instance_info(instance_id):
    """
    Get instance status.
    """
    resp = s.get(f"{API_BASE}/instances/{instance_id}")
    resp.raise_for_status()
    return resp.json()['data']
    
    
def wait_for_ready(instance_id):
    """
    Wait until instance is active.
    """
    while True:
        info = get_instance_info(instance_id)
        print("Status:", instance_id, info.get('ip'), info['status'])
        if info['status'] == 'active':
            return
        else:
            time.sleep(10)
        

def down(instance_id):
    """
    Shut down instance.
    """
    data = {
        "instance_ids": [instance_id]
    }
    resp = s.post(f"{API_BASE}/instance-operations/terminate", json=data)
    resp.raise_for_status()
    info = resp.json()['data']
    return info


def ssh(ip, command):
    """
    SSH into machine and issue some command.
    """
    with tempfile.NamedTemporaryFile(delete=False) as temp:
        temp.write(LAMBDALABS_SSH_KEY.encode())
        temp.flush()
    os.chmod(temp.name, 0o600)

    os.system(
        f"ssh -o StrictHostKeyChecking=accept-new -i {temp.name} ubuntu@{ip} \"{command}\"")


def rsync(ip, src, dst):
    """
    SSH and send over file via rsync.
    """
    with tempfile.NamedTemporaryFile(delete=False) as temp:
        temp.write(LAMBDALABS_SSH_KEY.encode())
        temp.flush()
    os.chmod(temp.name, 0o600)

    os.system(
        f"rsync -avz -e 'ssh -i {temp.name}' {src} ubuntu@{ip}:{dst}")

## Spin up instance and wait for ready

In [None]:
instance_type, region = choose_instance_type()
instance_id = up(instance_type, region)
wait_for_ready(instance_id)

In [None]:
ip = get_instance_info(instance_id)['ip']

## Set up environment, code, files

In [None]:
ssh(ip, 'uptime')

In [None]:
ssh(ip, 'sudo apt install software-properties-common -y')
ssh(ip, 'sudo add-apt-repository ppa:deadsnakes/ppa -y')
ssh(ip, 'sudo apt update')
ssh(ip, 'sudo apt install python3.10 python3.10-dev python3.10-distutils python3.10-venv -y')

In [None]:
ssh(ip, 'mkdir ardt-env')
ssh(ip, 'python3.10 -m venv ./ardt-env/ardt')
ssh(ip, 'source ./ardt-env/ardt/bin/activate')

In [None]:
ssh(ip, f'git clone https://{GITHUB_PRIVATE_KEY}@github.com/afonsosamarques/action-robust-decision-transformer.git')

In [None]:
ssh(ip, 'source ./ardt-env/ardt/bin/activate && pip3 install --upgrade pip')
ssh(ip, 'source ./ardt-env/ardt/bin/activate && pip3 install -r ./action-robust-decision-transformer/requirements.txt')

In [None]:
ssh(ip, 'sudo apt-get install git-lfs -y')
ssh(ip, 'cd ~/action-robust-decision-transformer/ && git lfs install')

In [None]:
ssh(ip, 'mkdir ~/action-robust-decision-transformer/codebase/ardt/eval-outputs')
ssh(ip, 'mkdir ~/action-robust-decision-transformer/codebase/ardt/eval-outputs-pipeline')
ssh(ip, 'mkdir ~/action-robust-decision-transformer/codebase/ardt/eval-outputs-test')
ssh(ip, 'mkdir ~/action-robust-decision-transformer/codebase/ardt/agents')
ssh(ip, 'mkdir ~/action-robust-decision-transformer/codebase/ardt/agents-pipeline')
ssh(ip, 'mkdir ~/action-robust-decision-transformer/codebase/ardt/agents-test')
ssh(ip, 'mkdir ~/action-robust-decision-transformer/codebase/ardt/wandb')
ssh(ip, 'mkdir ~/action-robust-decision-transformer/codebase/ardt/wandb-json')

In [None]:
src = f'{ARDT_DIR}/access_tokens.py'
rsync(ip, src, '~/')
ssh(ip, 'cp ~/access_tokens.py ~/action-robust-decision-transformer/codebase/ardt')
ssh(ip, 'cp ~/access_tokens.py ~/action-robust-decision-transformer/codebase/evaluation_protocol')

In [None]:
# src = f'{ARDT_DIR}/datasets-to-push'
# rsync(ip, src, '~/action-robust-decision-transformer/codebase/ardt')
# ssh(ip, 'mv ~/action-robust-decision-transformer/codebase/ardt/datasets-to-push ~/action-robust-decision-transformer/codebase/ardt/datasets')

## Run and shut down (eventually)

In [None]:
ssh(ip, 'source ./envs/ardt-env/bin/activate && cd ~/action-robust-decision-transformer/codebase/ && python3 -m evaluation_protocol.evaluate --config_name evaluation_batch_agentadv')

In [None]:
down(instance_id)