Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add predictor Example tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jan 11, 2019
1 parent 1910932 commit 70cadec
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
7 changes: 7 additions & 0 deletions scala-package/examples/pom.xml
Expand Up @@ -15,6 +15,7 @@

<properties>
<skipTests>true</skipTests>
<skipJavaTests>${skipTests}</skipJavaTests>
</properties>

<build>
Expand Down Expand Up @@ -128,5 +129,11 @@
<artifactId>slf4j-simple</artifactId>
<version>1.7.5</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Expand Up @@ -24,9 +24,9 @@ import org.apache.commons.io.FileUtils

object Util {

def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = {
def downloadUrl(url: String, filePath: String, maxRetry: Int = 3) : Unit = {
val tmpFile = new File(filePath)
var retry = maxRetry.getOrElse(3)
var retry = maxRetry
var success = false
if (!tmpFile.exists()) {
while (retry > 0 && !success) {
Expand Down
@@ -0,0 +1,50 @@
package org.apache.mxnetexamples.javaapi.infer.predictor;

import org.junit.BeforeClass;
import org.junit.Test;
import org.apache.mxnetexamples.Util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;

public class PredictorExampleSuite {

final static Logger logger = LoggerFactory.getLogger(PredictorExampleSuite.class);
private static String modelPathPrefix = "";
private static String inputImagePath = "";

@BeforeClass
public static void downloadFile() {
logger.info("Downloading resnet-18 model");

String tempDirPath = System.getProperty("java.io.tmpdir");
logger.info("tempDirPath: %s".format(tempDirPath));

String baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models";

Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json",
tempDirPath + "/resnet18/resnet-18-symbol.json", 3);
Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params",
tempDirPath + "/resnet18/resnet-18-0000.params", 3);
Util.downloadUrl(baseUrl + "/resnet-18/synset.txt",
tempDirPath + "/resnet18/synset.txt", 3);
Util.downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg", 3);

modelPathPrefix = tempDirPath + File.separator + "resnet18/resnet-18";
inputImagePath = tempDirPath + File.separator +
"inputImages/resnet18/Pug-Cookie.jpg";
}

@Test
public void testPredictor(){
PredictorExample example = new PredictorExample();
String[] args = new String[]{
"--model-path-prefix", modelPathPrefix,
"--input-image", inputImagePath
};
example.main(args);
}

}

0 comments on commit 70cadec

Please sign in to comment.