Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.google.protobuf.ByteString;
import com.google.protobuf.TextFormat;
import java.nio.ByteBuffer;
import java.util.Random;
import org.apache.hadoop.hdds.protocol.proto.SCMRatisProtocol;
import org.apache.hadoop.hdds.protocol.proto.testing.Proto2SCMRatisProtocolForTesting;
import org.junit.jupiter.api.Test;
Expand All @@ -29,46 +32,143 @@
* Tests proto2 to proto3 compatibility for SCMRatisProtocol.
*/
public class TestSCMRatisProtocolCompatibility {
static final Random RANDOM = new Random();
static final Class<?>[] TYPES = {String.class, Integer.class, byte[].class};

static <T> ByteString randomValue(Class<T> clazz) {
if (clazz == String.class) {
final int length = RANDOM.nextInt(3);
final StringBuilder builder = new StringBuilder(length);
for (int i = 0; i < length; i++) {
builder.append(RANDOM.nextInt(10));
}
final String string = builder.toString();
assertEquals(length, string.length());
return ByteString.copyFromUtf8(string);
} else if (clazz == Integer.class) {
final ByteBuffer buffer = ByteBuffer.allocate(4);
buffer.putInt(RANDOM.nextInt());
return ByteString.copyFrom(buffer.array());
} else if (clazz == byte[].class) {
final byte[] bytes = new byte[RANDOM.nextInt(3)];
RANDOM.nextBytes(bytes);
return ByteString.copyFrom(bytes);
}
throw new IllegalArgumentException("Unrecognized class " + clazz);
}

static Proto2SCMRatisProtocolForTesting.MethodArgument randomProto2MethodArgument() {
final Class<?> type = TYPES[RANDOM.nextInt(TYPES.length)];
return Proto2SCMRatisProtocolForTesting.MethodArgument.newBuilder()
.setType(type.getName())
.setValue(randomValue(type))
.build();
}

static SCMRatisProtocol.MethodArgument randomProto3MethodArgument() {
final Class<?> type = TYPES[RANDOM.nextInt(TYPES.length)];
return SCMRatisProtocol.MethodArgument.newBuilder()
.setType(type.getName())
.setValue(randomValue(type))
.build();
}

static Proto2SCMRatisProtocolForTesting.SCMRatisResponseProto randomProto2SCMRatisResponseProto() {
final Class<?> type = TYPES[RANDOM.nextInt(TYPES.length)];
return Proto2SCMRatisProtocolForTesting.SCMRatisResponseProto.newBuilder()
.setType(type.getName())
.setValue(randomValue(type))
.build();
}

static SCMRatisProtocol.SCMRatisResponseProto randomProto3SCMRatisResponseProto() {
final Class<?> type = TYPES[RANDOM.nextInt(TYPES.length)];
return SCMRatisProtocol.SCMRatisResponseProto.newBuilder()
.setType(type.getName())
.setValue(randomValue(type))
.build();
}

static Proto2SCMRatisProtocolForTesting.SCMRatisRequestProto proto2Request(
String name, Proto2SCMRatisProtocolForTesting.RequestType type, int numArgs) {
// Build request using proto2 (test-only schema)
final Proto2SCMRatisProtocolForTesting.Method.Builder b = Proto2SCMRatisProtocolForTesting.Method.newBuilder()
.setName(name);
for (int i = 0; i < numArgs; i++) {
b.addArgs(randomProto2MethodArgument());
}

return Proto2SCMRatisProtocolForTesting.SCMRatisRequestProto.newBuilder()
.setType(type)
.setMethod(b)
.build();
}

static SCMRatisProtocol.SCMRatisRequestProto proto3Request(
String name, SCMRatisProtocol.RequestType type, int numArgs) {
final SCMRatisProtocol.Method.Builder b = SCMRatisProtocol.Method.newBuilder()
.setName(name);
for (int i = 0; i < numArgs; i++) {
b.addArgs(randomProto3MethodArgument());
}

return SCMRatisProtocol.SCMRatisRequestProto.newBuilder()
.setType(type)
.setMethod(b)
.build();
}

/**
* Verifies that messages encoded with the legacy proto2 schema
* can be parsed by the current proto3 schema.
*/
@Test
public void testProto2RequestCanBeParsedByProto3() throws Exception {
// Build request using proto2 (test-only schema)
Proto2SCMRatisProtocolForTesting.MethodArgument arg =
Proto2SCMRatisProtocolForTesting.MethodArgument.newBuilder()
.setType("java.lang.String")
.setValue(ByteString.copyFromUtf8("v"))
.build();

Proto2SCMRatisProtocolForTesting.Method method =
Proto2SCMRatisProtocolForTesting.Method.newBuilder()
.setName("testOp")
.addArgs(arg)
.build();

Proto2SCMRatisProtocolForTesting.SCMRatisRequestProto proto2 =
Proto2SCMRatisProtocolForTesting.SCMRatisRequestProto.newBuilder()
.setType(Proto2SCMRatisProtocolForTesting.RequestType.PIPELINE)
.setMethod(method)
.build();
for (Proto2SCMRatisProtocolForTesting.RequestType type : Proto2SCMRatisProtocolForTesting.RequestType.values()) {
for (int numArgs = 0; numArgs < 3; numArgs++) {
final String name = type + "_" + numArgs;
runTestProto2RequestCanBeParsedByProto3(name, type, numArgs);
}
}
}

byte[] bytes = proto2.toByteArray();
static void runTestProto2RequestCanBeParsedByProto3(
String name, Proto2SCMRatisProtocolForTesting.RequestType type, int numArgs) throws Exception {
// Build request using proto2 (test-only schema)
final Proto2SCMRatisProtocolForTesting.SCMRatisRequestProto proto2 = proto2Request(name, type, numArgs);
assertEquals(type, proto2.getType());
assertEquals(name, proto2.getMethod().getName());
assertEquals(numArgs, proto2.getMethod().getArgsCount());

// Parse using proto3
SCMRatisProtocol.SCMRatisRequestProto proto3 =
SCMRatisProtocol.SCMRatisRequestProto.parseFrom(bytes);
final SCMRatisProtocol.SCMRatisRequestProto proto3 = SCMRatisProtocol.SCMRatisRequestProto.parseFrom(
proto2.toByteArray());

// Presence + value checks (proto3 optional fields)
assertTrue(proto3.hasType(), "proto3 should see type presence from proto2 bytes");
assertTrue(proto3.hasMethod(), "proto3 should see method presence from proto2 bytes");
assertEquals(SCMRatisProtocol.RequestType.PIPELINE, proto3.getType());
assertEquals("testOp", proto3.getMethod().getName());
assertEquals(1, proto3.getMethod().getArgsCount());
assertEquals("java.lang.String", proto3.getMethod().getArgs(0).getType());
assertEquals(ByteString.copyFromUtf8("v"), proto3.getMethod().getArgs(0).getValue());
assertEquals(SCMRatisProtocol.RequestType.valueOf(type.name()), proto3.getType());
assertEquals(name, proto3.getMethod().getName());
assertEquals(numArgs, proto3.getMethod().getArgsCount());
for (int i = 0; i < numArgs; i++) {
assertMethodArgument(proto2.getMethod().getArgs(i), proto3.getMethod().getArgs(i));
}

assertEquals(proto2.toString(), proto3.toString());
assertEquals(TextFormat.shortDebugString(proto2), TextFormat.shortDebugString(proto3));
assertEquals(proto2, Proto2SCMRatisProtocolForTesting.SCMRatisRequestProto.parseFrom(proto3.toByteArray()));
}

static void assertMethodArgument(Proto2SCMRatisProtocolForTesting.MethodArgument proto2,
SCMRatisProtocol.MethodArgument proto3) {
assertEquals(proto2.getValue(), proto3.getValue());
assertEquals(proto2.getType(), proto3.getType());
}

static void assertSCMRatisResponseProto(Proto2SCMRatisProtocolForTesting.SCMRatisResponseProto proto2,
SCMRatisProtocol.SCMRatisResponseProto proto3) {
assertEquals(proto2.getValue(), proto3.getValue());
assertEquals(proto2.getType(), proto3.getType());
}

/**
Expand All @@ -77,22 +177,101 @@ public void testProto2RequestCanBeParsedByProto3() throws Exception {
*/
@Test
public void testProto2ResponseCanBeParsedByProto3() throws Exception {
// Build response using proto2
Proto2SCMRatisProtocolForTesting.SCMRatisResponseProto proto2 =
Proto2SCMRatisProtocolForTesting.SCMRatisResponseProto.newBuilder()
.setType("java.lang.String")
.setValue(ByteString.copyFromUtf8("ok"))
.build();
for (int i = 0; i < 10; i++) {
runTestProto2ResponseCanBeParsedByProto3();
}
}

byte[] bytes = proto2.toByteArray();
static void runTestProto2ResponseCanBeParsedByProto3() throws Exception {
// Build response using proto2
final Proto2SCMRatisProtocolForTesting.SCMRatisResponseProto proto2 = randomProto2SCMRatisResponseProto();

// Parse using proto3 (production schema)
SCMRatisProtocol.SCMRatisResponseProto proto3 =
SCMRatisProtocol.SCMRatisResponseProto.parseFrom(bytes);
final SCMRatisProtocol.SCMRatisResponseProto proto3 = SCMRatisProtocol.SCMRatisResponseProto.parseFrom(
proto2.toByteArray());

assertTrue(proto3.hasType(), "proto3 should see type presence from proto2 bytes");
assertTrue(proto3.hasValue(), "proto3 should see value presence from proto2 bytes");
assertEquals("java.lang.String", proto3.getType());
assertEquals(ByteString.copyFromUtf8("ok"), proto3.getValue());
assertSCMRatisResponseProto(proto2, proto3);

assertEquals(proto2.toString(), proto3.toString());
assertEquals(TextFormat.shortDebugString(proto2), TextFormat.shortDebugString(proto3));
assertEquals(proto2, Proto2SCMRatisProtocolForTesting.SCMRatisResponseProto.parseFrom(proto3.toByteArray()));
}

@Test
public void testRequestType() {
for (Proto2SCMRatisProtocolForTesting.RequestType proto2 : Proto2SCMRatisProtocolForTesting.RequestType.values()) {
final SCMRatisProtocol.RequestType proto3 = SCMRatisProtocol.RequestType.valueOf(proto2.name());
assertEquals(proto3.getNumber(), proto2.getNumber());
assertEquals(proto3.toString(), proto2.toString());
}
}

/**
* Verifies that messages encoded with the current proto3 schema
* can be parsed by the legacy proto2 schema.
*/
@Test
public void testProto3RequestCanBeParsedByProto2() throws Exception {
for (SCMRatisProtocol.RequestType type : SCMRatisProtocol.RequestType.values()) {
// Skip default (0) and UNRECOGNIZED values: proto3 may omit default enums
// on the wire, and UNRECOGNIZED cannot be serialized; proto2 requires `type`.
if (type == SCMRatisProtocol.RequestType.UNRECOGNIZED
|| type == SCMRatisProtocol.RequestType.REQUEST_TYPE_UNSPECIFIED) {
continue;
}
for (int numArgs = 0; numArgs < 3; numArgs++) {
final String name = type + "_" + numArgs;
runTestProto3RequestCanBeParsedByProto2(name, type, numArgs);
}
}
}

static void runTestProto3RequestCanBeParsedByProto2(
String name, SCMRatisProtocol.RequestType type, int numArgs) throws Exception {

final SCMRatisProtocol.SCMRatisRequestProto proto3 = proto3Request(name, type, numArgs);
assertEquals(type, proto3.getType());
assertEquals(name, proto3.getMethod().getName());
assertEquals(numArgs, proto3.getMethod().getArgsCount());
// Parse using proto2 (legacy test schema)
final Proto2SCMRatisProtocolForTesting.SCMRatisRequestProto proto2 =
Proto2SCMRatisProtocolForTesting.SCMRatisRequestProto.parseFrom(proto3.toByteArray());

assertEquals(Proto2SCMRatisProtocolForTesting.RequestType.valueOf(type.name()), proto2.getType());
assertTrue(proto2.hasMethod());
assertEquals(name, proto2.getMethod().getName());
assertEquals(numArgs, proto2.getMethod().getArgsCount());

for (int i = 0; i < numArgs; i++) {
assertEquals(proto3.getMethod().getArgs(i).getType(), proto2.getMethod().getArgs(i).getType());
assertEquals(proto3.getMethod().getArgs(i).getValue(), proto2.getMethod().getArgs(i).getValue());
}

assertEquals(proto2.toString(), proto3.toString());
assertEquals(TextFormat.shortDebugString(proto2), TextFormat.shortDebugString(proto3));
assertEquals(proto3, SCMRatisProtocol.SCMRatisRequestProto.parseFrom(proto2.toByteArray()));
}

@Test
public void testProto3ResponseCanBeParsedByProto2() throws Exception {
for (int i = 0; i < 10; i++) {
runTestProto3ResponseCanBeParsedByProto2();
}
}

static void runTestProto3ResponseCanBeParsedByProto2() throws Exception {
final SCMRatisProtocol.SCMRatisResponseProto proto3 = randomProto3SCMRatisResponseProto();

final Proto2SCMRatisProtocolForTesting.SCMRatisResponseProto proto2 =
Proto2SCMRatisProtocolForTesting.SCMRatisResponseProto.parseFrom(proto3.toByteArray());

assertEquals(proto3.getType(), proto2.getType());
assertEquals(proto3.getValue(), proto2.getValue());

assertEquals(proto2.toString(), proto3.toString());
assertEquals(TextFormat.shortDebugString(proto2), TextFormat.shortDebugString(proto3));
assertEquals(proto3, SCMRatisProtocol.SCMRatisResponseProto.parseFrom(proto2.toByteArray()));
}
}