diff --git a/library/transformer/src/androidTest/java/com/google/android/exoplayer2/transformer/SsimHelper.java b/library/transformer/src/androidTest/java/com/google/android/exoplayer2/transformer/SsimHelper.java index e4e51b8f7d5..0eb3bce3ef1 100644 --- a/library/transformer/src/androidTest/java/com/google/android/exoplayer2/transformer/SsimHelper.java +++ b/library/transformer/src/androidTest/java/com/google/android/exoplayer2/transformer/SsimHelper.java @@ -80,8 +80,8 @@ public static double calculate( new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL); VideoDecodingWrapper distortedDecodingWrapper = new VideoDecodingWrapper(context, distortedVideoPath, DEFAULT_COMPARISON_INTERVAL); - @Nullable int[] referenceLumaBuffer = null; - @Nullable int[] distortedLumaBuffer = null; + @Nullable byte[] referenceLumaBuffer = null; + @Nullable byte[] distortedLumaBuffer = null; double accumulatedSsim = 0.0; int comparedImagesCount = 0; try { @@ -101,10 +101,10 @@ public static double calculate( assertThat(distortedImage.getHeight()).isEqualTo(height); if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) { - referenceLumaBuffer = new int[width * height]; + referenceLumaBuffer = new byte[width * height]; } if (distortedLumaBuffer == null || distortedLumaBuffer.length != width * height) { - distortedLumaBuffer = new int[width * height]; + distortedLumaBuffer = new byte[width * height]; } try { accumulatedSsim += @@ -134,7 +134,7 @@ public static double calculate( * @param lumaChannelBuffer The buffer where the extracted luma values are stored. * @return The {@code lumaChannelBuffer} for convenience. */ - private static int[] extractLumaChannelBuffer(Image image, int[] lumaChannelBuffer) { + private static byte[] extractLumaChannelBuffer(Image image, byte[] lumaChannelBuffer) { // This method is invoked on the main thread. // `image` should contain YUV channels. Image.Plane[] imagePlanes = image.getPlanes(); @@ -147,8 +147,7 @@ private static int[] extractLumaChannelBuffer(Image image, int[] lumaChannelBuff ByteBuffer lumaByteBuffer = lumaPlane.getBuffer(); for (int y = 0; y < height; y++) { for (int x = 0; x < width; x++) { - lumaChannelBuffer[y * width + x] = - lumaByteBuffer.get(y * rowStride + x * pixelStride) & 0xFF; + lumaChannelBuffer[y * width + x] = lumaByteBuffer.get(y * rowStride + x * pixelStride); } } return lumaChannelBuffer; @@ -364,7 +363,7 @@ private static final class MssimCalculator { * @return The MSSIM score between the input images. */ public static double calculate( - int[] referenceBuffer, int[] distortedBuffer, int width, int height) { + byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) { double totalSsim = 0; int windowsCount = 0; @@ -451,11 +450,11 @@ private static double getWindowSsim( /** Returns the mean of the pixels in the window. */ private static double getMean( - int[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) { + byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) { double total = 0; for (int y = 0; y < windowHeight; y++) { for (int x = 0; x < windowWidth; x++) { - total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)]; + total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)] & 0xFF; } } return total / (windowWidth * windowHeight); @@ -463,8 +462,8 @@ private static double getMean( /** Calculates the variances and covariance of the pixels in the window for both buffers. */ private static double[] getVariancesAndCovariance( - int[] referenceBuffer, - int[] distortedBuffer, + byte[] referenceBuffer, + byte[] distortedBuffer, double referenceMean, double distortedMean, int bufferIndexOffset, @@ -477,8 +476,8 @@ private static double[] getVariancesAndCovariance( for (int y = 0; y < windowHeight; y++) { for (int x = 0; x < windowWidth; x++) { int index = get1dIndex(x, y, stride, bufferIndexOffset); - double referencePixelDeviation = referenceBuffer[index] - referenceMean; - double distortedPixelDeviation = distortedBuffer[index] - distortedMean; + double referencePixelDeviation = (referenceBuffer[index] & 0xFF) - referenceMean; + double distortedPixelDeviation = (distortedBuffer[index] & 0xFF) - distortedMean; referenceVariance += referencePixelDeviation * referencePixelDeviation; distortedVariance += distortedPixelDeviation * distortedPixelDeviation; referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation;