Skip to content
This repository has been archived by the owner on Jul 10, 2024. It is now read-only.

Commit

Permalink
Add testList, testUpdate
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnTing committed Aug 19, 2020
1 parent 4b264cb commit 5060482
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
"description": "This is a template to run tf-mnist\n",
"parameters": [
{
"name": "input.train_data",
"name": "training.learning_rate",
"value": 150,
"required": true,
"description": "train data is expected in SVM format, and can be stored in HDFS/S3 \n"
"description": " mnist learning_rate "
},
{
"name": "training.batch_size",
Expand All @@ -17,11 +18,10 @@
],
"experimentSpec": {
"meta": {
"cmd": "python /var/tf_mnist/mnist_with_summaries.py --log_dir=/train/log --learning_rate=0.01 --batch_size={{training.batch_size}}",
"name": "tf-mnist-json",
"cmd": "python /var/tf_mnist/mnist_with_summaries.py --log_dir=/train/log --learning_rate={{training.learning_rate}} --batch_size={{training.batch_size}}",
"name": "tf-mnist-template-test",
"envVars": {
"input_path": "{{input.train_data}}",
"ENV2": "ENV2"
"ENV1": "ENV1"
},
"framework": "TensorFlow",
"namespace": "default"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
"description": "This is a template to run tf-mnist\n",
"parameters": [
{
"name": "input.train_data",
"name": "training.learning_rate",
"value": 150,
"required": true,
"description": "train data is expected in SVM format, and can be stored in HDFS/S3 \n"
"description": " mnist learning_rate "
},
{
"name": "training.batch_size",
Expand All @@ -17,11 +18,10 @@
],
"experimentSpec": {
"meta": {
"cmd": "python /var/tf_mnist/mnist_with_summaries.py --log_dir=/train/log --learning_rate=0.01 --batch_size={{training.batch_size}}",
"name": "tf-mnist-json",
"cmd": "python /var/tf_mnist/mnist_with_summaries.py --log_dir=/train/log --learning_rate={{training.learning_rate}} --batch_size={{training.batch_size}}",
"name": "tf-mnist-template-test2",
"envVars": {
"input_path": "{{input.train_data}}",
"ENV2": "ENV2"
"ENV1": "ENV1"
},
"framework": "TensorFlow",
"namespace": "default"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.submarine.rest;

import java.io.IOException;
import java.util.List;

import javax.ws.rs.core.Response;

Expand All @@ -36,34 +37,34 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;

@SuppressWarnings("rawtypes")
public class ExperimentTemplateManagerRestApiIT extends AbstractSubmarineServerTest {

protected static String TPL_PATH =
"/api/" + RestConstants.V1 + "/" + RestConstants.EXPERIMENT_TEMPLATES;
protected static String TPL_NAME = "tf-mnist-test2";


protected static String TPL_FILE = "experimenttemplate/test_template_2.json";
protected Gson gson = new GsonBuilder().create();
@BeforeClass
public static void startUp() throws Exception {
Assert.assertTrue(checkIfServerIsRunning());
}

@Test
public void testCreateExperimentTemplate() throws Exception {
String body = loadContent("experimenttemplate/test_template_2.json");
String body = loadContent(TPL_FILE);
run(body, "application/json");
deleteExperimentTemplate();
}

@Test
public void testGetExperimentTemplate() throws Exception {

String body = loadContent("experimenttemplate/test_template_2.json");
String body = loadContent(TPL_FILE);
run(body, "application/json");

Gson gson = new GsonBuilder().create();
GetMethod getMethod = httpGet(TPL_PATH + "/" + TPL_NAME);
Assert.assertEquals(Response.Status.OK.getStatusCode(),
getMethod.getStatusCode());
Expand All @@ -81,13 +82,38 @@ public void testGetExperimentTemplate() throws Exception {


@Test
public void testUpdateExperimentTemplate() throws IOException {
public void testUpdateExperimentTemplate() throws Exception {
String body = loadContent(TPL_FILE);
run(body, "application/json");

ExperimentTemplate tpl =
gson.fromJson(gson.toJson(body), ExperimentTemplate.class);
tpl.getExperimentTemplateSpec().setDescription("new description");
String newBody = gson.toJson(tpl);

httpPatch(TPL_PATH + "/" + TPL_NAME, newBody, "application/json");

GetMethod getMethod = httpGet(TPL_PATH + "/" + TPL_NAME);
Assert.assertEquals(Response.Status.OK.getStatusCode(),
getMethod.getStatusCode());

String json = getMethod.getResponseBodyAsString();
JsonResponse jsonResponse = gson.fromJson(json, JsonResponse.class);
Assert.assertEquals(Response.Status.OK.getStatusCode(),
jsonResponse.getCode());

ExperimentTemplate getExperimentTemplate =
gson.fromJson(gson.toJson(jsonResponse.getResult()), ExperimentTemplate.class);

Assert.assertEquals("new description",
getExperimentTemplate.getExperimentTemplateSpec().getDescription());

deleteExperimentTemplate();
}

@Test
public void testDeleteExperimentTemplate() throws Exception {
String body = loadContent("experimenttemplate/test_template_2.json");
String body = loadContent(TPL_FILE);
run(body, "application/json");
deleteExperimentTemplate();

Expand All @@ -98,12 +124,30 @@ public void testDeleteExperimentTemplate() throws Exception {
}

@Test
public void testListExperimentTemplates() throws IOException {
public void testListExperimentTemplates() throws Exception {
String body = loadContent(TPL_FILE);
run(body, "application/json");

GetMethod getMethod = httpGet(TPL_PATH + "/");
Assert.assertEquals(Response.Status.OK.getStatusCode(),
getMethod.getStatusCode());

String json = getMethod.getResponseBodyAsString();
JsonResponse jsonResponse = gson.fromJson(json, JsonResponse.class);
Assert.assertEquals(Response.Status.OK.getStatusCode(),
jsonResponse.getCode());

List<ExperimentTemplate> getExperimentTemplates =
gson.fromJson(gson.toJson(jsonResponse.getResult()), new TypeToken<List<ExperimentTemplate>>() {
}.getType());

Assert.assertEquals(TPL_NAME, getExperimentTemplates.get(0).getExperimentTemplateSpec().getName());

deleteExperimentTemplate();
}

protected void deleteExperimentTemplate() throws IOException {
Gson gson = new GsonBuilder().create();

DeleteMethod deleteMethod = httpDelete(TPL_PATH + "/" + TPL_NAME);
Assert.assertEquals(Response.Status.OK.getStatusCode(),
deleteMethod.getStatusCode());
Expand All @@ -119,7 +163,6 @@ protected void deleteExperimentTemplate() throws IOException {
}

protected void run(String body, String contentType) throws Exception {
Gson gson = new GsonBuilder().create();

// create
LOG.info("Create ExperimentTemplate using ExperimentTemplate REST API");
Expand Down

0 comments on commit 5060482

Please sign in to comment.