In [18]:
import kfp
import yaml
import os
import logging
import fairing

from kubernetes import client as k8s_client
from kubernetes import config as k8s_config

GCP_PROJECT = fairing.cloud.gcp.guess_project_name()
namespace = fairing_utils.get_current_k8s_namespace()

logging.info(f"Running in project {GCP_PROJECT}")
logging.info(f"Running in namespace {namespace}")

# Notebook to start mpi-operator job

## First thing first,  set up the credentials if needed. 

If you are runing on your kubeflow cluster this should not be an issues and you can contiue to the next step. 

In [20]:
#!export GOOGLE_APPLICATION_CREDENTIALS

## Define mpi job. 

This example is based aroud the tensorflow-benchmark found [here](https://github.com/kubeflow/mpi-operator/blob/master/examples/v1alpha1/tensorflow-benchmarks.yaml). To make this example we have defined the complete job below.

In [21]:
mpi_job = f"""
apiVersion: kubeflow.org/v1alpha2
kind: MPIJob
metadata:
  name: tensorflow-benchmarks
spec:
  slotsPerWorker: 1
  cleanPodPolicy: Running
  mpiReplicaSpecs:
    Launcher:
      replicas: 1
      template:
         spec:
           containers:
           - image: mpioperator/tensorflow-benchmarks:latest
             name: tensorflow-benchmarks
             command:
             - mpirun
             - --allow-run-as-root
             - -np
             - "2"
             - -bind-to
             - none
             - -map-by
             - slot
             - -x
             - NCCL_DEBUG=INFO
             - -x
             - LD_LIBRARY_PATH
             - -x
             - PATH
             - -mca
             - pml
             - ob1
             - -mca
             - btl
             - ^openib
             - python
             - scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py
             - --model=resnet101
             - --batch_size=64
             - --variable_update=horovod
    Worker:
      replicas: 2
      template:
        spec:
          containers:
          - image: mpioperator/tensorflow-benchmarks:latest
            name: tensorflow-benchmarks
            resources:
              limits:
                nvidia.com/gpu: 1
"""

In [22]:
import retrying
from kubernetes import client
from kubernetes import watch as k8s_watch
from table_logger import TableLogger

#from kubeflow.tfjob.utils import utils

tbl = TableLogger(
  columns='NAME,STATE,TIME',
  colwidth={'NAME': 30, 'STATE':20, 'TIME':30},
  border=False)

@retrying.retry(wait_fixed=1000, stop_max_attempt_number=20)
def watch_k8s(name=None, namespace=None, timeout_seconds=600):
  """Watch the created or patched InferenceService in the specified namespace"""

  #if namespace is None:
    #namespace = utils.get_default_target_namespace()

  stream = k8s_watch.Watch().stream(
    client.CustomObjectsApi().list_namespaced_custom_object,
    MPI_JOB_GROUP,
    MPI_JOB_VERSION,
    namespace,
    MPI_JOB_PLURAL,
    timeout_seconds=timeout_seconds)

  for event in stream:
    mpijob = event['object']
    mpijob_name = mpijob['metadata']['name']
    if name and name != mpijob_name:
      continue
    else:
      status = ''
      update_time = ''
      last_condition = mpijob.get('status', {}).get('conditions', [])[-1]
      status = last_condition.get('type', '')
      update_time = last_condition.get('lastTransitionTime', '')

      tbl(mpijob_name, status, update_time)

      if name == mpijob_name:
        if status == 'Succeeded' or status == 'Failed':
          break

### Setup a python mpijob client

This sectin contains some helper code that is a modified version of the [tfjob client](https://github.com/kubeflow/tf-operator/blob/master/sdk/python/docs/TFJobClient.md). 

In [23]:
from kubernetes import client, config
import time

def is_running_in_k8s():
  return os.path.isdir('/var/run/secrets/kubernetes.io/')

MPI_JOB_GROUP = "kubeflow.org"
MPI_JOB_PLURAL = "mpijobs"
MPI_JOB_NAME_LABEL = "job-mpi"
MPI_JOB_VERSION = "v1alpha2"
APISERVER_TIMEOUT = 120

class MPIJobClient(object):

  def __init__(self, config_file=None, context=None, # pylint: disable=too-many-arguments
               client_configuration=None, persist_config=True):
    """
    TFJob client constructor
    :param config_file: kubeconfig file, defaults to ~/.kube/config
    :param context: kubernetes context
    :param client_configuration: kubernetes configuration object
    :param persist_config:
    """
    if config_file or not is_running_in_k8s():
      config.load_kube_config()
        #config_file=config_file,
        #context=context,
        #client_configuration=client_configuration,
        #persist_config=persist_config)
    else:
      config.load_incluster_config()

    self.custom_api = client.CustomObjectsApi()
    self.core_api = client.CoreV1Api()

  def create(self, mpijob, namespace=None):
    """
    Create the TFJob
    :param tfjob: tfjob object
    :param namespace: defaults to current or default namespace
    :return: created tfjob
    """

    if namespace is None:
      namespace = utils.set_tfjob_namespace(tfjob)

    try:
      outputs = self.custom_api.create_namespaced_custom_object(
        MPI_JOB_GROUP,
        MPI_JOB_VERSION,
        namespace,
        MPI_JOB_PLURAL,
        mpijob)
    except client.rest.ApiException as e:
      raise RuntimeError(
        "Exception when calling CustomObjectsApi->create_namespaced_custom_object:\
         %s\n" % e)

    return outputs


  def is_job_succeeded(self, name, namespace=None):
    """Returns true if the TFJob succeeded; false otherwise.
    :param name: The TFJob name.
    :param namespace: defaults to current or default namespace.
    :return: True or False
    """
    mpijob_status = self.get_job_status(name, namespace=namespace)
    return mpijob_status.lower() == "succeeded"


  def get_job_status(self, name, namespace=None):
    """Returns TFJob status, such as Running, Failed or Succeeded.
    :param name: The TFJob name.
    :param namespace: defaults to current or default namespace.
    :return: Object TFJob status
    """
    if namespace is None:
      namespace = utils.get_default_target_namespace()

    mpijob = self.get(name, namespace=namespace)
    last_condition = mpijob.get("status", {}).get("conditions", [])[-1]
    return last_condition.get("type", "")

  def wait_for_condition(self, name,
                         expected_condition,
                         namespace=None,
                         timeout_seconds=1200,
                         polling_interval=30,
                         status_callback=None):
    """Waits until any of the specified conditions occur.
    :param name: Name of the job.
    :param expected_condition: A list of conditions. Function waits until any of the
           supplied conditions is reached.
    :param namespace: defaults to current or default namespace.
    :param timeout_seconds: How long to wait for the job.
    :param polling_interval: How often to poll for the status of the job.
    :param status_callback: (Optional): Callable. If supplied this callable is
           invoked after we poll the job. Callable takes a single argument which
           is the job.
    :return: Object TFJob status
    """

    #if namespace is None:
    #  namespace = utils.get_default_target_namespace()

    for _ in range(round(timeout_seconds/polling_interval)):

      mpi_job = None
      mpi_job = self.get(name, namespace=namespace)

      if mpi_job:
        if status_callback:
          status_callback(mpi_job)

        # If we poll the CRD quick enough status won't have been set yet.
        conditions = mpi_job.get("status", {}).get("conditions", [])
        # Conditions might have a value of None in status.
        conditions = conditions or []
        for c in conditions:
          if c.get("type", "") in expected_condition:
            return mpi_job

      time.sleep(polling_interval)

    raise RuntimeError(
      "Timeout waiting for TFJob {0} in namespace {1} to enter one of the "
      "conditions {2}.".format(name, namespace, expected_condition), tfjob)

  def get(self, name=None, namespace=None, watch=False, timeout_seconds=600): #pylint: disable=inconsistent-return-statements
    """
    Get the tfjob
    :param name: existing tfjob name, if not defined, the get all tfjobs in the namespace.
    :param namespace: defaults to current or default namespace
    :param watch: Watch the TFJob if `True`.
    :param timeout_seconds: How long to watch the job..
    :return: tfjob
    """
    if namespace is None:
      namespace = utils.get_default_target_namespace()

    if name:
      if watch:
        tfjob_watch(
          name=name,
          namespace=namespace,
          timeout_seconds=timeout_seconds)
      else:
        thread = self.custom_api.get_namespaced_custom_object(
          MPI_JOB_GROUP,
          MPI_JOB_VERSION,
          namespace,
          MPI_JOB_PLURAL,
          name,
          async_req=True)

        mpijob = None
        try:
          mpijob = thread.get(APISERVER_TIMEOUT)
        except multiprocessing.TimeoutError:
          raise RuntimeError("Timeout trying to get TFJob.")
        except client.rest.ApiException as e:
          raise RuntimeError(
            "Exception when calling CustomObjectsApi->get_namespaced_custom_object:\
            %s\n" % e)
        except Exception as e:
          raise RuntimeError(
            "There was a problem to get TFJob {0} in namespace {1}. Exception: \
            {2} ".format(name, namespace, e))
        return mpijob
    else:
      if watch:
        watch_k8s(
            namespace=namespace,
            timeout_seconds=timeout_seconds)
      else:
        thread = self.custom_api.list_namespaced_custom_object(
          MPI_JOB_GROUP,
          MPI_JOB_VERSION,
          namespace,
          MPI_JOB_PLURAL,
          async_req=True)

        tfjobs = None
        try:
          tfjobs = thread.get(constants.APISERVER_TIMEOUT)
        except multiprocessing.TimeoutError:
          raise RuntimeError("Timeout trying to get TFJob.")
        except client.rest.ApiException as e:
          raise RuntimeError(
            "Exception when calling CustomObjectsApi->list_namespaced_custom_object:\
            %s\n" % e)
        except Exception as e:
          raise RuntimeError(
            "There was a problem to list TFJobs in namespace {0}. \
            Exception: {1} ".format(namespace, e))
        return tfjobs

## MPI client

### Define the mpi client

In [24]:
#mpi_job_client = MPIJobClient(config_file="PATH/.kube/config") # if you run locally
mpi_job_client = MPIJobClient()

### Run the job

In [25]:
mpi_job_body = yaml.safe_load(mpi_job)
mpi_job = mpi_job_client.create(mpi_job_body, namespace=namespace)  

logging.info(f"Created job {namespace}")

### Wait for the job to finish

The job takes roughtly 20 minutes to finish. 

In [None]:
train_name="tensorflow-benchmarks"
mpi_job = mpi_job_client.wait_for_condition(train_name, expected_condition=["Succeeded", "Failed"], namespace=namespace, timeout_seconds=1200)

if mpi_job_client.is_job_succeeded(train_name, namespace):
    print("The job succeded")
    logging.info(f"TFJob {namespace}.{train_name} succeeded")
else:
    raise ValueError(f"TFJob {namespace}.{train_name} failed")

### Check the logs from the job

In [None]:
command = f"$(kubectl -n {namespace} get pods -l mpi_job_name=tensorflow-benchmarks-complete,mpi_role_type=launcher -o name)"

In [24]:
!kubectl logs -f {command}

+ POD_NAME=tensorflow-benchmarks-complete-worker-0
+ shift
+ /opt/kube/kubectl exec tensorflow-benchmarks-complete-worker-0 -- /bin/sh -c     PATH=/usr/local/bin:$PATH ; export PATH ; LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH ; export LD_LIBRARY_PATH ; DYLD_LIBRARY_PATH=/usr/local/lib:$DYLD_LIBRARY_PATH ; export DYLD_LIBRARY_PATH ;   /usr/local/bin/orted -mca ess "env" -mca ess_base_jobid "4289331200" -mca ess_base_vpid 1 -mca ess_base_num_procs "3" -mca orte_node_regex "tensorflow-benchmarks-complete-launcher-hfx[2:88],tensorflow-benchmarks-complete-worker-[1:0-1]@0(3)" -mca orte_hnp_uri "4289331200.0;tcp://10.44.1.28:38135" -mca pml "ob1" -mca btl "^openib" -mca plm "rsh" --tree-spawn -mca orte_parent_uri "4289331200.0;tcp://10.44.1.28:38135" -mca plm_rsh_agent "/etc/mpi/kubexec.sh" -mca orte_default_hostfile "/etc/mpi/hostfile" -mca hwloc_base_binding_policy "none" -mca rmaps_base_mapping_policy "slot" -mca pmix "^s1,s2,cray,isolated"
+ POD_NAME=tensorflow-benchmarks-complete-

## That was all! 