In [0]:
%scala

import java.io.File
import java.nio.file.{Files, Path, Paths, StandardCopyOption}
import java.nio.file.attribute.PosixFilePermissions
import java.text.SimpleDateFormat
import java.util.{Date, Properties}

import scala.sys.process._

import org.apache.spark.SparkContext

In [0]:
%scala

val publicKey = "cat /home/ubuntu/.ssh/id_rsa.pub".!!

def addAuthorizedPublicKey(key: String): Unit = {
  val fw = new java.io.FileWriter("/home/ubuntu/.ssh/authorized_keys", /* append */ true)
  fw.write("\n" + key)
  fw.close()
}

/**
 * Inject key into executors so that the driver can ssh into them.
 */

val numExecutors = sc.getExecutorMemoryStatus.keys.size
sc.parallelize(0 until numExecutors, numExecutors).foreach { i =>
  addAuthorizedPublicKey(publicKey)
}
addAuthorizedPublicKey(publicKey)

In [0]:
%scala

val workers: List[String] = {
  val executors = sc.getExecutorMemoryStatus.keys.map(_.split(":").head).map { ip =>
    if (ip.startsWith("ip")) ip.stripPrefix("ip-").split('.').head.replace("-", ".")
    else ip
  }.toSet
  println("Executors = " + executors)
  executors.toList
}
println(s"Workers = " + workers)


def writeFile(path: String, contents: String, append: Boolean = false): Unit = {
  val fw = new java.io.FileWriter(path, append)
  fw.write(contents)
  fw.close()
}

def addHostfile(key: String): Unit = {
  writeFile("/home/ubuntu/.ssh/hostfile",key + "\n" , true)
}

workers.foreach { ip =>
  addHostfile(ip)
}

In [0]:
%scala


def createIpyprofile(logStdout: Boolean = true): String = {
  
  
  val outBuffer = new collection.mutable.ArrayBuffer[String]()
  val logger = ProcessLogger(line => outBuffer += line, println(_))
  
  val exitCode = 
    Seq("ipython", "profile", "create", "--parallel", "--profile=mpi", "--profile-dir=~/.ipython/profile_mpi") ! logger
  if (logStdout) {
    outBuffer.foreach(println)
  }
  if (exitCode != 0) {
    println(s"FAILED: on host: ")
    sys.error("Command failed")
  }
  println(s"SUCCESS: on host: ")
  outBuffer.mkString("\n")
}

/**
 * Create Ipython profile on worker nodes
 */

val numExecutors = sc.getExecutorMemoryStatus.keys.size
sc.parallelize(0 until numExecutors, numExecutors).foreach { i =>
  createIpyprofile()
}

In [0]:
%scala


def createSSHfolder(logStdout: Boolean = true): String = {
  
  
  val outBuffer = new collection.mutable.ArrayBuffer[String]()
  val logger = ProcessLogger(line => outBuffer += line, println(_))
  
  val exitCode = 
    Seq("mkdir", "-p", "/root/.ssh") ! logger
  if (logStdout) {
    outBuffer.foreach(println)
  }
  if (exitCode != 0) {
    println(s"FAILED: on host: ")
    sys.error("Command failed")
  }
  println(s"SUCCESS: on host: ")
  outBuffer.mkString("\n")
}

/**
 * Create SSH folder on worker nodes
 */

val numExecutors = sc.getExecutorMemoryStatus.keys.size
sc.parallelize(0 until numExecutors, numExecutors).foreach { i =>
  createSSHfolder()
}

In [0]:
%scala

val publicKey = "cat /root/.ssh/id_rsa.pub".!!

def addAuthorizedPublicKey(key: String): Unit = {
  val fw = new java.io.FileWriter("/root/.ssh/authorized_keys", /* append */ true)
  fw.write("\n" + key)
  fw.close()
}


/**
 * Inject key into executors so that the driver can ssh into them.
 */

val numExecutors = sc.getExecutorMemoryStatus.keys.size
sc.parallelize(0 until numExecutors, numExecutors).foreach { i =>
  addAuthorizedPublicKey(publicKey)
}
addAuthorizedPublicKey(publicKey)

In [0]:
%scala


def startSSHserver(logStdout: Boolean = true): String = {
  
  
  val outBuffer = new collection.mutable.ArrayBuffer[String]()
  val logger = ProcessLogger(line => outBuffer += line, println(_))
  
  val exitCode = 
    Seq("sudo", "service", "ssh", "start") ! logger
  if (logStdout) {
    outBuffer.foreach(println)
  }
  if (exitCode != 0) {
    println(s"FAILED: on host: ")
    sys.error("Command failed")
  }
  println(s"SUCCESS: on host: ")
  outBuffer.mkString("\n")
}

/**
 * Start SSH server on worker nodes
 */

val numExecutors = sc.getExecutorMemoryStatus.keys.size
sc.parallelize(0 until numExecutors, numExecutors).foreach { i =>
  startSSHserver()
}

In [0]:
%scala

/**
 * Ssh into the given `host` and execute `command`.
 */
def ssh(host: String, logStdout: Boolean = true): String = {
  println("ssh'ing onto host - " + host)
  val outBuffer = new collection.mutable.ArrayBuffer[String]()
  val logger = ProcessLogger(line => outBuffer += line, println(_))

  val exitCode = 
    Seq("ssh", "-o", "StrictHostKeyChecking=no", "-p", "22", "-i", "/home/ubuntu/.ssh/id_rsa", s"ubuntu@$host") ! logger
  if (logStdout) {
    outBuffer.foreach(println)
  }
  if (exitCode != 0) {
    println(s"FAILED: on host: $host")
    sys.error("Command failed")
  }
  println(s"SUCCESS: on host: $host")
  outBuffer.mkString("\n")
}

workers.foreach {ip =>
  ssh(ip)
}

In [0]:
%sh
cat /root/.ssh/known_hosts >> /home/ubuntu/.ssh/known_hosts

In [0]:
%sh
chown -R ubuntu:ubuntu /home/ubuntu/.ssh/

In [0]:
%sh
#!/usr/bin/env bash
sudo -i -u ubuntu bash << EOF
whoami
/databricks/python/bin/python -V
. /databricks/conda/etc/profile.d/conda.sh
conda activate /databricks/python
mpiexec -n 3 -f /home/ubuntu/.ssh/hostfile -prepend-rank hostname
EOF