Skip to content

Commit

Permalink
Rollback of google@eb6c118
Browse files Browse the repository at this point in the history
*** Original commit ***

Handle int instead of byte in SSIM.

The value of pixels are converted to integers at the point of use,
move this logic to the initialisation step.

This is a prerequisite step for testing SSIM calculation, which
will lead on to some SSIM improvements being verifiable.

Tested manually and SSIM values match for the same video
before and after this change.

***

PiperOrigin-RevId: 473259446
  • Loading branch information
Samrobbo authored and marcbaechinger committed Oct 19, 2022
1 parent eb6c118 commit 8f9c9d0
Showing 1 changed file with 13 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ public static double calculate(
new VideoDecodingWrapper(context, referenceVideoPath, DEFAULT_COMPARISON_INTERVAL);
VideoDecodingWrapper distortedDecodingWrapper =
new VideoDecodingWrapper(context, distortedVideoPath, DEFAULT_COMPARISON_INTERVAL);
@Nullable int[] referenceLumaBuffer = null;
@Nullable int[] distortedLumaBuffer = null;
@Nullable byte[] referenceLumaBuffer = null;
@Nullable byte[] distortedLumaBuffer = null;
double accumulatedSsim = 0.0;
int comparedImagesCount = 0;
try {
Expand All @@ -101,10 +101,10 @@ public static double calculate(
assertThat(distortedImage.getHeight()).isEqualTo(height);

if (referenceLumaBuffer == null || referenceLumaBuffer.length != width * height) {
referenceLumaBuffer = new int[width * height];
referenceLumaBuffer = new byte[width * height];
}
if (distortedLumaBuffer == null || distortedLumaBuffer.length != width * height) {
distortedLumaBuffer = new int[width * height];
distortedLumaBuffer = new byte[width * height];
}
try {
accumulatedSsim +=
Expand Down Expand Up @@ -134,7 +134,7 @@ public static double calculate(
* @param lumaChannelBuffer The buffer where the extracted luma values are stored.
* @return The {@code lumaChannelBuffer} for convenience.
*/
private static int[] extractLumaChannelBuffer(Image image, int[] lumaChannelBuffer) {
private static byte[] extractLumaChannelBuffer(Image image, byte[] lumaChannelBuffer) {
// This method is invoked on the main thread.
// `image` should contain YUV channels.
Image.Plane[] imagePlanes = image.getPlanes();
Expand All @@ -147,8 +147,7 @@ private static int[] extractLumaChannelBuffer(Image image, int[] lumaChannelBuff
ByteBuffer lumaByteBuffer = lumaPlane.getBuffer();
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
lumaChannelBuffer[y * width + x] =
lumaByteBuffer.get(y * rowStride + x * pixelStride) & 0xFF;
lumaChannelBuffer[y * width + x] = lumaByteBuffer.get(y * rowStride + x * pixelStride);
}
}
return lumaChannelBuffer;
Expand Down Expand Up @@ -364,7 +363,7 @@ private static final class MssimCalculator {
* @return The MSSIM score between the input images.
*/
public static double calculate(
int[] referenceBuffer, int[] distortedBuffer, int width, int height) {
byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) {
double totalSsim = 0;
int windowsCount = 0;

Expand Down Expand Up @@ -451,20 +450,20 @@ private static double getWindowSsim(

/** Returns the mean of the pixels in the window. */
private static double getMean(
int[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) {
byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) {
double total = 0;
for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) {
total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)];
total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)] & 0xFF;
}
}
return total / (windowWidth * windowHeight);
}

/** Calculates the variances and covariance of the pixels in the window for both buffers. */
private static double[] getVariancesAndCovariance(
int[] referenceBuffer,
int[] distortedBuffer,
byte[] referenceBuffer,
byte[] distortedBuffer,
double referenceMean,
double distortedMean,
int bufferIndexOffset,
Expand All @@ -477,8 +476,8 @@ private static double[] getVariancesAndCovariance(
for (int y = 0; y < windowHeight; y++) {
for (int x = 0; x < windowWidth; x++) {
int index = get1dIndex(x, y, stride, bufferIndexOffset);
double referencePixelDeviation = referenceBuffer[index] - referenceMean;
double distortedPixelDeviation = distortedBuffer[index] - distortedMean;
double referencePixelDeviation = (referenceBuffer[index] & 0xFF) - referenceMean;
double distortedPixelDeviation = (distortedBuffer[index] & 0xFF) - distortedMean;
referenceVariance += referencePixelDeviation * referencePixelDeviation;
distortedVariance += distortedPixelDeviation * distortedPixelDeviation;
referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation;
Expand Down

0 comments on commit 8f9c9d0

Please sign in to comment.