Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SUBMARINE-597. Support for SSH based git sync mode #391

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -19,8 +19,8 @@

package org.apache.submarine.server.submitter.k8s.experiment.codelocalizer;

import java.net.MalformedURLException;
import java.net.URL;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -86,16 +86,16 @@ public static CodeLocalizer getGitCodeLocalizer(String url)
throws InvalidSpecException {

try {
URL urlParser = new URL(url);
String protocol = urlParser.getProtocol();
if (protocol.equals(GitCodeLocalizerModes.HTTP.getMode())) {
URI uriParser = new URI(url);
String scheme = uriParser.getScheme();
if (scheme.equals(GitCodeLocalizerModes.HTTP.getMode())) {
return new HTTPGitCodeLocalizer(url);
} else if (protocol.equals(GitCodeLocalizerModes.SSH.getMode())) {
} else if (scheme.equals(GitCodeLocalizerModes.SSH.getMode())) {
return new SSHGitCodeLocalizer(url);
} else {
return new DummyCodeLocalizer(url);
}
} catch (MalformedURLException e) {
} catch (URISyntaxException e) {
throw new InvalidSpecException(
"Invalid Code Spec: URL is malformed. " + url);
}
Expand Down
Expand Up @@ -19,17 +19,52 @@

package org.apache.submarine.server.submitter.k8s.experiment.codelocalizer;

import java.util.List;

import io.kubernetes.client.models.V1Container;
import io.kubernetes.client.models.V1EnvVar;
import io.kubernetes.client.models.V1PodSpec;
import io.kubernetes.client.models.V1SecurityContext;
import io.kubernetes.client.models.V1VolumeMount;

public class SSHGitCodeLocalizer extends GitCodeLocalizer {

public static final String GIT_SECRET_NAME = "git-creds";
public static final int GIT_SECRET_MODE = 0400;
public static final String GIT_SECRET_MOUNT_NAME = "git-secret";
public static final String GIT_SECRET_PATH = "/etc/git-secret";
public static final long GIT_SYNC_USER = 65533L;
public static final String GIT_SYNC_SSH_NAME = "GIT_SYNC_SSH";
public static final String GIT_SYNC_SSH_VALUE = "true";

public SSHGitCodeLocalizer(String url) {
super(url);
}

@Override
public void localize(V1PodSpec podSpec) {
// Code SSH based logic here
super.localize(podSpec);
for (V1Container container : podSpec.getInitContainers()) {
if (container.getName().equals(CODE_LOCALIZER_INIT_CONTAINER_NAME)) {
List<V1EnvVar> gitSyncEnvVars = container.getEnv();
V1EnvVar sshEnv = new V1EnvVar();
sshEnv.setName(GIT_SYNC_SSH_NAME);
sshEnv.setValue(GIT_SYNC_SSH_VALUE);
gitSyncEnvVars.add(sshEnv);

List<V1VolumeMount> mounts = container.getVolumeMounts();
V1VolumeMount mount = new V1VolumeMount();
mount.setName(GIT_SECRET_MOUNT_NAME);
mount.setMountPath(GIT_SECRET_PATH);
mount.setReadOnly(true);
mounts.add(mount);

V1SecurityContext containerSecurityContext =
new V1SecurityContext();
containerSecurityContext
.setRunAsUser(SSHGitCodeLocalizer.GIT_SYNC_USER);
container.setSecurityContext(containerSecurityContext);
}
}
}

}
Expand Up @@ -23,9 +23,12 @@
import io.kubernetes.client.models.V1Container;
import io.kubernetes.client.models.V1EnvVar;
import io.kubernetes.client.models.V1ObjectMeta;
import io.kubernetes.client.models.V1PodSecurityContext;
import io.kubernetes.client.models.V1PodSpec;
import io.kubernetes.client.models.V1PodTemplateSpec;
import io.kubernetes.client.models.V1ResourceRequirements;
import io.kubernetes.client.models.V1SecretVolumeSource;
import io.kubernetes.client.models.V1Volume;
import io.kubernetes.client.models.V1VolumeMount;

import org.apache.submarine.commons.utils.SubmarineConfVars;
Expand All @@ -39,6 +42,7 @@
import org.apache.submarine.server.environment.EnvironmentManager;
import org.apache.submarine.server.submitter.k8s.experiment.codelocalizer.AbstractCodeLocalizer;
import org.apache.submarine.server.submitter.k8s.experiment.codelocalizer.CodeLocalizer;
import org.apache.submarine.server.submitter.k8s.experiment.codelocalizer.SSHGitCodeLocalizer;
import org.apache.submarine.server.submitter.k8s.model.MLJob;
import org.apache.submarine.server.submitter.k8s.model.MLJobReplicaSpec;
import org.apache.submarine.server.submitter.k8s.model.MLJobReplicaType;
Expand Down Expand Up @@ -182,18 +186,41 @@ private static V1PodTemplateSpec parseTemplateSpec(

if (podSpec.getInitContainers() != null
&& podSpec.getInitContainers().size() > 0) {
String volumeName = podSpec.getInitContainers().get(0).getVolumeMounts()
.get(0).getName();
String path = podSpec.getInitContainers().get(0).getVolumeMounts()
.get(0).getMountPath();

V1VolumeMount mount = new V1VolumeMount();
mount.setName(volumeName);
mount.setMountPath(path);

List<V1VolumeMount> initContainerVolumeMounts =
podSpec.getInitContainers().get(0).getVolumeMounts();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why get the first item?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 2 init containers. 1. for git sync process and 2. for environment setup. 0th init container has been created for git code sync process. As part of git sync, it is required to get some properties of corresponding init container and set the same in pod spec level as well. Hence, this step.

List<V1VolumeMount> volumeMounts = new ArrayList<V1VolumeMount>();
volumeMounts.add(mount);
container.setVolumeMounts(volumeMounts);

// Populate container volume mounts using Init container info
for (V1VolumeMount initContainerVolumeMount : initContainerVolumeMounts) {
String volumeName = initContainerVolumeMount.getName();
String path = initContainerVolumeMount.getMountPath();
if (volumeName
.equals(AbstractCodeLocalizer.CODE_LOCALIZER_MOUNT_NAME)) {
V1VolumeMount mount = new V1VolumeMount();
mount.setName(volumeName);
mount.setMountPath(path);
volumeMounts.add(mount);
container.setVolumeMounts(volumeMounts);
} else if (volumeName
.equals(SSHGitCodeLocalizer.GIT_SECRET_MOUNT_NAME)) {
V1Volume volume = new V1Volume();
volume.setName(volumeName);

List<V1Volume> existingVolumes = podSpec.getVolumes();
V1SecretVolumeSource secret = new V1SecretVolumeSource();
secret.secretName(SSHGitCodeLocalizer.GIT_SECRET_NAME);
secret.setDefaultMode(SSHGitCodeLocalizer.GIT_SECRET_MODE);
volume.setSecret(secret);
existingVolumes.add(volume);

// Pod level security context
V1PodSecurityContext podSecurityContext =
new V1PodSecurityContext();
podSecurityContext.setFsGroup(SSHGitCodeLocalizer.GIT_SYNC_USER);
podSpec.setSecurityContext(podSecurityContext);
}
}

V1EnvVar codeEnvVar = new V1EnvVar();
codeEnvVar.setName(AbstractCodeLocalizer.CODE_LOCALIZER_PATH_ENV_VAR);
Expand Down
Expand Up @@ -46,6 +46,7 @@
import org.apache.submarine.server.submitter.k8s.parser.ExperimentSpecParser;
import org.apache.submarine.server.submitter.k8s.experiment.codelocalizer.AbstractCodeLocalizer;
import org.apache.submarine.server.submitter.k8s.experiment.codelocalizer.GitCodeLocalizer;
import org.apache.submarine.server.submitter.k8s.experiment.codelocalizer.SSHGitCodeLocalizer;
import org.junit.Assert;
import org.junit.Test;
import io.kubernetes.client.models.V1Container;
Expand Down Expand Up @@ -310,4 +311,55 @@ public void testValidPyTorchJobSpecWithHTTPGitCodeLocalizer()
Assert.assertEquals(AbstractCodeLocalizer.CODE_LOCALIZER_MOUNT_NAME,
V1Volume.getName());
}

@Test
public void testValidPyTorchJobSpecWithSSHGitCodeLocalizer()
throws IOException, URISyntaxException, InvalidSpecException {
ExperimentSpec jobSpec =
(ExperimentSpec) buildFromJsonFile(ExperimentSpec.class,
pytorchJobWithSSHGitCodeLocalizerFile);
PyTorchJob pyTorchJob = (PyTorchJob) ExperimentSpecParser.parseJob(jobSpec);

MLJobReplicaSpec mlJobReplicaSpec = pyTorchJob.getSpec().getReplicaSpecs()
.get(PyTorchJobReplicaType.Master);
Assert.assertEquals(1,
mlJobReplicaSpec.getTemplate().getSpec().getInitContainers().size());
V1Container initContainer =
mlJobReplicaSpec.getTemplate().getSpec().getInitContainers().get(0);
Assert.assertEquals(
AbstractCodeLocalizer.CODE_LOCALIZER_INIT_CONTAINER_NAME,
initContainer.getName());
Assert.assertEquals(GitCodeLocalizer.GIT_SYNC_IMAGE,
initContainer.getImage());
Assert.assertEquals(AbstractCodeLocalizer.CODE_LOCALIZER_MOUNT_NAME,
initContainer.getVolumeMounts().get(0).getName());
Assert.assertEquals(AbstractCodeLocalizer.CODE_LOCALIZER_PATH,
initContainer.getVolumeMounts().get(0).getMountPath());
for (V1EnvVar env : initContainer.getEnv()) {
if (env.getName().equals(SSHGitCodeLocalizer.GIT_SYNC_SSH_NAME)) {
Assert.assertEquals(SSHGitCodeLocalizer.GIT_SYNC_SSH_VALUE,
env.getValue());
}
}

V1Container container =
mlJobReplicaSpec.getTemplate().getSpec().getInitContainers().get(0);
Assert.assertEquals(AbstractCodeLocalizer.CODE_LOCALIZER_MOUNT_NAME,
container.getVolumeMounts().get(0).getName());
Assert.assertEquals(AbstractCodeLocalizer.CODE_LOCALIZER_PATH,
container.getVolumeMounts().get(0).getMountPath());
for (V1EnvVar env : container.getEnv()) {
if (env.getName()
.equals(AbstractCodeLocalizer.CODE_LOCALIZER_PATH_ENV_VAR)) {
Assert.assertEquals(AbstractCodeLocalizer.CODE_LOCALIZER_PATH,
env.getValue());
}
}

V1Volume V1Volume =
mlJobReplicaSpec.getTemplate().getSpec().getVolumes().get(0);
Assert.assertEquals(new V1EmptyDirVolumeSource(), V1Volume.getEmptyDir());
Assert.assertEquals(AbstractCodeLocalizer.CODE_LOCALIZER_MOUNT_NAME,
V1Volume.getName());
}
}
Expand Up @@ -42,6 +42,8 @@ public abstract class SpecBuilder {
protected final String notebookReqFile = "/notebook_req.json";
protected final String pytorchJobWithHTTPGitCodeLocalizerFile =
"/pytorch_job_req_http_git_code_localizer.json";
protected final String pytorchJobWithSSHGitCodeLocalizerFile =
"/pytorch_job_req_ssh_git_code_localizer.json";

protected Object buildFromJsonFile(Object obj, String filePath) throws IOException,
URISyntaxException {
Expand Down
@@ -0,0 +1,30 @@
{
"meta": {
"name": "pytorch-dist-mnist",
"namespace": "submarine",
"framework": "PyTorch",
"cmd": "python /var/mnist.py --backend gloo",
"envVars": {
"ENV_1": "ENV1"
}
},
"environment": {
"image": "apache/submarine:pytorch-dist-mnist-1.0"
},
"spec": {
"Master": {
"name": "master",
"replicas": 1,
"resources": "cpu=2,memory=2048M"
},
"Worker": {
"name": "worker",
"replicas": 2,
"resources": "cpu=1,memory=1024M"
}
},
"code": {
"syncMode": "git",
"url" : "ssh://git@github.com/apache/submarine.git"
}
}
Expand Up @@ -180,6 +180,13 @@ public void testTensorFlowUsingCodeWithJsonSpec() throws Exception {
run(body, patchBody, "application/json");
}

@Test
public void testTensorFlowUsingSSHCodeWithJsonSpec() throws Exception {
String body = loadContent("tensorflow/tf-mnist-with-ssh-git-code-localizer-req.json");
String patchBody = loadContent("tensorflow/tf-mnist-with-ssh-git-code-localizer-req.json");
run(body, patchBody, "application/json");
}

private void run(String body, String patchBody, String contentType) throws Exception {
// create
LOG.info("Create training job by Job REST API");
Expand Down
@@ -0,0 +1,28 @@
{
"meta": {
"name": "tf-mnist-json",
"namespace": "default",
"framework": "TensorFlow",
"cmd": "python /var/tf_mnist/mnist_with_summaries.py --log_dir=/train/log --learning_rate=0.01 --batch_size=150",
"envVars": {
"ENV_1": "ENV1"
}
},
"environment": {
"image": "gcr.io/kubeflow-ci/tf-mnist-with-summaries:1.0"
},
"spec": {
"Ps": {
"replicas": 1,
"resources": "cpu=1,memory=512M"
},
"Worker": {
"replicas": 1,
"resources": "cpu=1,memory=512M"
}
},
"code": {
"syncMode": "git",
"url" : "ssh://git@github.com/apache/submarine.git"
}
}