Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add image change
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jan 9, 2019
1 parent d11f887 commit cadd753
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
Expand Up @@ -31,18 +31,17 @@ object Image {
* to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format
*/
def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean, out: NDArray): NDArray = {
org.apache.mxnet.Image.imDecode(buf, flag, toRGB, Some(out))
def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None)
}

/**
* Same imageDecode with InputStream
* @param inputStream the inputStream of the image
* @return NDArray in HWC format
*/
def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true,
out: NDArray): NDArray = {
org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, Some(out))
def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true): NDArray = {
org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None)
}

/**
Expand All @@ -54,8 +53,8 @@ object Image {
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format
*/
def imRead(filename: String, flag: Int, toRGB: Boolean = true, out: NDArray): NDArray = {
org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), Some(out))
def imRead(filename: String, flag: Int, toRGB: Boolean = true): NDArray = {
org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None)
}

/**
Expand All @@ -66,9 +65,9 @@ object Image {
* @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @return org.apache.mxnet.NDArray
*/
def imResize(src: NDArray, w: Int, h: Int,
interp: Integer, out: NDArray): NDArray = {
org.apache.mxnet.Image.imResize(src, w, h, Some(interp), Some(out))
def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = {
val interpVal = if (interp == null) None else Some(interp.intValue())
org.apache.mxnet.Image.imResize(src, w, h, interpVal, None)
}

/**
Expand Down
Expand Up @@ -6,11 +6,13 @@
import java.io.File;
import java.net.URL;

import static org.junit.Assert.assertArrayEquals;

public class ImageTest {

private String imLocation;
private static String imLocation;

private void downloadUrl(String url, String filePath, int maxRetry) throws Exception{
private static void downloadUrl(String url, String filePath, int maxRetry) throws Exception{
File tmpFile = new File(filePath);
Boolean success = false;
if (!tmpFile.exists()) {
Expand All @@ -29,7 +31,7 @@ private void downloadUrl(String url, String filePath, int maxRetry) throws Excep
}

@BeforeClass
public void downloadFile() throws Exception {
public static void downloadFile() throws Exception {
String tempDirPath = System.getProperty("java.io.tmpdir");
imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg";
try {
Expand All @@ -42,8 +44,10 @@ public void downloadFile() throws Exception {

@Test
public void testImageProcess() {
NDArray nd = Image.imRead(imLocation, 1, true, null);
NDArray nd2 = Image.imResize(nd, 224, 224, null, null);
NDArray nd = Image.imRead(imLocation, 1, true);
assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3});
NDArray nd2 = Image.imResize(nd, 224, 224, null);
assertArrayEquals(nd.shape().toArray(), new int[]{224, 224, 3});
NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224);
Image.toImage(cropped);
}
Expand Down
Expand Up @@ -98,8 +98,8 @@ public static void main(String[] args) {
inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0);
// Prepare data
NDArray img = Image.imRead(inst.inputImagePath, 1, true, null);
img = Image.imResize(img, 224, 224, null, null);
NDArray img = Image.imRead(inst.inputImagePath, 1, true);
img = Image.imResize(img, 224, 224, null);
// predict
float[][] result = predictor.predict(new float[][]{img.toArray()});
try {
Expand Down

0 comments on commit cadd753

Please sign in to comment.